refactor: modularize codebase — deduplicate, extract, clean up
- Unify duplicate uTLS transports into shared internal/transport package - Extract shared version constant into internal/version - Move LoadDefaultCredentials from config to auth (remove config→auth import) - Deduplicate handler.go: extract telemetry/error helpers (324→268 lines) - Break up main.go::run() into initCredential/initEmbedded - Eliminate logging.Config duplication (use config.LoggingConfig directly) - Extract logWriter to embedded/log.go, SSE fixtures to consts in sniff.go - Use uTLS client for usage polling (consistent TLS fingerprint) - Handle sjson.SetBytes errors in sanitize.go instead of silently swallowing - Document reverse-engineered magic values in billing.go - Unexport Credential.CooldownUntil (internal state) - Replace hardcoded auth bypass paths with map in server.go
This commit is contained in:
@@ -0,0 +1,56 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// claudeCredentialsJSON matches the structure of ~/.claude/.credentials.json.
|
||||||
|
type claudeCredentialsJSON struct {
|
||||||
|
ClaudeAiOauth struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
ExpiresAt int64 `json:"expiresAt"`
|
||||||
|
SubscriptionType string `json:"subscriptionType"`
|
||||||
|
} `json:"claudeAiOauth"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadDefaultCredentials reads credentials from ~/.claude/.credentials.json.
|
||||||
|
// Returns nil, nil if the file does not exist.
|
||||||
|
func LoadDefaultCredentials() ([]*Credential, error) {
|
||||||
|
path, err := DefaultCredentialPath()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var cf claudeCredentialsJSON
|
||||||
|
if err := json.Unmarshal(data, &cf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
oauth := cf.ClaudeAiOauth
|
||||||
|
if oauth.AccessToken == "" {
|
||||||
|
return nil, fmt.Errorf("no access token in %s", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
cred := &Credential{
|
||||||
|
ID: "claude-native",
|
||||||
|
Email: oauth.SubscriptionType,
|
||||||
|
AccessToken: oauth.AccessToken,
|
||||||
|
RefreshToken: oauth.RefreshToken,
|
||||||
|
ExpiresAt: time.UnixMilli(oauth.ExpiresAt),
|
||||||
|
FilePath: path,
|
||||||
|
}
|
||||||
|
|
||||||
|
return []*Credential{cred}, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefaultCredentialPath(t *testing.T) {
|
||||||
|
path, err := DefaultCredentialPath()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DefaultCredentialPath error: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(path, filepath.Join(".claude", ".credentials.json")) {
|
||||||
|
t.Errorf("path = %q, want suffix .claude/.credentials.json", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadDefaultCredentials_MissingFile(t *testing.T) {
|
||||||
|
// When credential file doesn't exist, returns nil, nil
|
||||||
|
path, err := DefaultCredentialPath()
|
||||||
|
if err != nil {
|
||||||
|
t.Skip("cannot determine home dir")
|
||||||
|
}
|
||||||
|
if _, statErr := os.Stat(path); os.IsNotExist(statErr) {
|
||||||
|
creds, err := LoadDefaultCredentials()
|
||||||
|
if creds != nil {
|
||||||
|
t.Errorf("expected nil creds for missing file, got %v", creds)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expected nil error for missing file, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeCredentialsJSON_ParsesCorrectly(t *testing.T) {
|
||||||
|
jsonData := `{"claudeAiOauth":{"accessToken":"test-token","refreshToken":"test-refresh","expiresAt":1234567890,"subscriptionType":"pro"}}`
|
||||||
|
|
||||||
|
var cf claudeCredentialsJSON
|
||||||
|
if err := json.Unmarshal([]byte(jsonData), &cf); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cf.ClaudeAiOauth.AccessToken != "test-token" {
|
||||||
|
t.Errorf("AccessToken = %q, want test-token", cf.ClaudeAiOauth.AccessToken)
|
||||||
|
}
|
||||||
|
if cf.ClaudeAiOauth.RefreshToken != "test-refresh" {
|
||||||
|
t.Errorf("RefreshToken = %q, want test-refresh", cf.ClaudeAiOauth.RefreshToken)
|
||||||
|
}
|
||||||
|
if cf.ClaudeAiOauth.ExpiresAt != 1234567890 {
|
||||||
|
t.Errorf("ExpiresAt = %d, want 1234567890", cf.ClaudeAiOauth.ExpiresAt)
|
||||||
|
}
|
||||||
|
if cf.ClaudeAiOauth.SubscriptionType != "pro" {
|
||||||
|
t.Errorf("SubscriptionType = %q, want pro", cf.ClaudeAiOauth.SubscriptionType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeCredentialsJSON_EmptyAccessToken(t *testing.T) {
|
||||||
|
jsonData := `{"claudeAiOauth":{"accessToken":"","refreshToken":"r","expiresAt":1}}`
|
||||||
|
|
||||||
|
var cf claudeCredentialsJSON
|
||||||
|
if err := json.Unmarshal([]byte(jsonData), &cf); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if cf.ClaudeAiOauth.AccessToken != "" {
|
||||||
|
t.Errorf("expected empty access token")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,16 +6,14 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
tls "github.com/refraction-networking/utls"
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/net/http2"
|
|
||||||
|
"github.com/fujin/anthropic-proxy/internal/transport"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -28,7 +26,7 @@ const (
|
|||||||
refreshBackoff = 5 * time.Minute
|
refreshBackoff = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
var utlsClient = newUTLSClient()
|
var utlsClient = transport.NewHTTPClient(15 * time.Second)
|
||||||
|
|
||||||
type tokenRequest struct {
|
type tokenRequest struct {
|
||||||
ClientID string `json:"client_id"`
|
ClientID string `json:"client_id"`
|
||||||
@@ -147,67 +145,6 @@ func persistCredential(cred *Credential) error {
|
|||||||
return os.WriteFile(filePath, out, 0600)
|
return os.WriteFile(filePath, out, 0600)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUTLSClient() *http.Client {
|
|
||||||
return &http.Client{
|
|
||||||
Timeout: 15 * time.Second,
|
|
||||||
Transport: &utlsRefreshTransport{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type utlsRefreshTransport struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
conn *http2.ClientConn
|
|
||||||
host string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *utlsRefreshTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
host := req.URL.Hostname()
|
|
||||||
port := req.URL.Port()
|
|
||||||
if port == "" {
|
|
||||||
port = "443"
|
|
||||||
}
|
|
||||||
|
|
||||||
t.mu.Lock()
|
|
||||||
if t.conn != nil && t.host == host && t.conn.CanTakeNewRequest() {
|
|
||||||
conn := t.conn
|
|
||||||
t.mu.Unlock()
|
|
||||||
resp, err := conn.RoundTrip(req)
|
|
||||||
if err == nil {
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
t.mu.Lock()
|
|
||||||
t.conn = nil
|
|
||||||
t.mu.Unlock()
|
|
||||||
} else {
|
|
||||||
t.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := net.JoinHostPort(host, port)
|
|
||||||
rawConn, err := net.DialTimeout("tcp", addr, 10*time.Second)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConn := tls.UClient(rawConn, &tls.Config{ServerName: host}, tls.HelloChrome_Auto)
|
|
||||||
if err := tlsConn.Handshake(); err != nil {
|
|
||||||
rawConn.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
h2Conn, err := (&http2.Transport{}).NewClientConn(tlsConn)
|
|
||||||
if err != nil {
|
|
||||||
tlsConn.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
t.mu.Lock()
|
|
||||||
t.conn = h2Conn
|
|
||||||
t.host = host
|
|
||||||
t.mu.Unlock()
|
|
||||||
|
|
||||||
return h2Conn.RoundTrip(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func StartBackgroundRefresh(ctx context.Context, pool *Pool) {
|
func StartBackgroundRefresh(ctx context.Context, pool *Pool) {
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ func TestPool_Pick_RoundRobin(t *testing.T) {
|
|||||||
func TestPool_Pick_SkipsCooldown(t *testing.T) {
|
func TestPool_Pick_SkipsCooldown(t *testing.T) {
|
||||||
creds := []*Credential{
|
creds := []*Credential{
|
||||||
{ID: "a"},
|
{ID: "a"},
|
||||||
{ID: "b", CooldownUntil: time.Now().Add(1 * time.Hour)},
|
{ID: "b", cooldownUntil: time.Now().Add(1 * time.Hour)},
|
||||||
{ID: "c"},
|
{ID: "c"},
|
||||||
}
|
}
|
||||||
p := NewPool(creds)
|
p := NewPool(creds)
|
||||||
@@ -116,8 +116,8 @@ func TestPool_Pick_SkipsCooldown(t *testing.T) {
|
|||||||
func TestPool_Pick_AllOnCooldown(t *testing.T) {
|
func TestPool_Pick_AllOnCooldown(t *testing.T) {
|
||||||
future := time.Now().Add(1 * time.Hour)
|
future := time.Now().Add(1 * time.Hour)
|
||||||
creds := []*Credential{
|
creds := []*Credential{
|
||||||
{ID: "a", CooldownUntil: future},
|
{ID: "a", cooldownUntil: future},
|
||||||
{ID: "b", CooldownUntil: future},
|
{ID: "b", cooldownUntil: future},
|
||||||
}
|
}
|
||||||
p := NewPool(creds)
|
p := NewPool(creds)
|
||||||
|
|
||||||
@@ -203,13 +203,13 @@ func TestPool_MarkFailure(t *testing.T) {
|
|||||||
}
|
}
|
||||||
// Verify approximate duration
|
// Verify approximate duration
|
||||||
cred.mu.Lock()
|
cred.mu.Lock()
|
||||||
cooldownEnd := cred.CooldownUntil
|
cooldownEnd := cred.cooldownUntil
|
||||||
cred.mu.Unlock()
|
cred.mu.Unlock()
|
||||||
|
|
||||||
lower := before.Add(tt.expectedDur)
|
lower := before.Add(tt.expectedDur)
|
||||||
upper := time.Now().Add(tt.expectedDur)
|
upper := time.Now().Add(tt.expectedDur)
|
||||||
if cooldownEnd.Before(lower) || cooldownEnd.After(upper) {
|
if cooldownEnd.Before(lower) || cooldownEnd.After(upper) {
|
||||||
t.Errorf("CooldownUntil %v not in expected range [%v, %v]", cooldownEnd, lower, upper)
|
t.Errorf("cooldownUntil %v not in expected range [%v, %v]", cooldownEnd, lower, upper)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if cred.IsOnCooldown() {
|
if cred.IsOnCooldown() {
|
||||||
@@ -223,7 +223,7 @@ func TestPool_MarkFailure(t *testing.T) {
|
|||||||
func TestPool_MarkSuccess(t *testing.T) {
|
func TestPool_MarkSuccess(t *testing.T) {
|
||||||
cred := &Credential{
|
cred := &Credential{
|
||||||
ID: "test",
|
ID: "test",
|
||||||
CooldownUntil: time.Now().Add(1 * time.Hour),
|
cooldownUntil: time.Now().Add(1 * time.Hour),
|
||||||
}
|
}
|
||||||
p := NewPool([]*Credential{cred})
|
p := NewPool([]*Credential{cred})
|
||||||
|
|
||||||
@@ -282,7 +282,7 @@ func TestPool_RoundRobinCursorAdvancement(t *testing.T) {
|
|||||||
func TestPool_RoundRobinWithCooldownSkip(t *testing.T) {
|
func TestPool_RoundRobinWithCooldownSkip(t *testing.T) {
|
||||||
creds := []*Credential{
|
creds := []*Credential{
|
||||||
{ID: "0"},
|
{ID: "0"},
|
||||||
{ID: "1", CooldownUntil: time.Now().Add(1 * time.Hour)},
|
{ID: "1", cooldownUntil: time.Now().Add(1 * time.Hour)},
|
||||||
{ID: "2"},
|
{ID: "2"},
|
||||||
}
|
}
|
||||||
p := NewPool(creds)
|
p := NewPool(creds)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ type Credential struct {
|
|||||||
RefreshToken string
|
RefreshToken string
|
||||||
ExpiresAt time.Time
|
ExpiresAt time.Time
|
||||||
FilePath string
|
FilePath string
|
||||||
CooldownUntil time.Time
|
cooldownUntil time.Time
|
||||||
nextRefreshAfter time.Time
|
nextRefreshAfter time.Time
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
@@ -22,21 +22,21 @@ type Credential struct {
|
|||||||
func (c *Credential) IsOnCooldown() bool {
|
func (c *Credential) IsOnCooldown() bool {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
return time.Now().Before(c.CooldownUntil)
|
return time.Now().Before(c.cooldownUntil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetCooldown puts the credential on cooldown for the given duration.
|
// SetCooldown puts the credential on cooldown for the given duration.
|
||||||
func (c *Credential) SetCooldown(duration time.Duration) {
|
func (c *Credential) SetCooldown(duration time.Duration) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
c.CooldownUntil = time.Now().Add(duration)
|
c.cooldownUntil = time.Now().Add(duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClearCooldown removes any active cooldown on the credential.
|
// ClearCooldown removes any active cooldown on the credential.
|
||||||
func (c *Credential) ClearCooldown() {
|
func (c *Credential) ClearCooldown() {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
c.CooldownUntil = time.Time{}
|
c.cooldownUntil = time.Time{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Token returns the current access token.
|
// Token returns the current access token.
|
||||||
|
|||||||
+11
-11
@@ -31,7 +31,7 @@ func TestCredential_IsOnCooldown(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
c := &Credential{CooldownUntil: tt.cooldownUntil}
|
c := &Credential{cooldownUntil: tt.cooldownUntil}
|
||||||
got := c.IsOnCooldown()
|
got := c.IsOnCooldown()
|
||||||
if got != tt.want {
|
if got != tt.want {
|
||||||
t.Errorf("IsOnCooldown() = %v, want %v", got, tt.want)
|
t.Errorf("IsOnCooldown() = %v, want %v", got, tt.want)
|
||||||
@@ -57,12 +57,12 @@ func TestCredential_SetCooldown(t *testing.T) {
|
|||||||
c.SetCooldown(tt.duration)
|
c.SetCooldown(tt.duration)
|
||||||
after := time.Now()
|
after := time.Now()
|
||||||
|
|
||||||
// CooldownUntil should be between before+duration and after+duration
|
// cooldownUntil should be between before+duration and after+duration
|
||||||
if c.CooldownUntil.Before(before.Add(tt.duration)) {
|
if c.cooldownUntil.Before(before.Add(tt.duration)) {
|
||||||
t.Errorf("CooldownUntil %v is before expected lower bound %v", c.CooldownUntil, before.Add(tt.duration))
|
t.Errorf("cooldownUntil %v is before expected lower bound %v", c.cooldownUntil, before.Add(tt.duration))
|
||||||
}
|
}
|
||||||
if c.CooldownUntil.After(after.Add(tt.duration)) {
|
if c.cooldownUntil.After(after.Add(tt.duration)) {
|
||||||
t.Errorf("CooldownUntil %v is after expected upper bound %v", c.CooldownUntil, after.Add(tt.duration))
|
t.Errorf("cooldownUntil %v is after expected upper bound %v", c.cooldownUntil, after.Add(tt.duration))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should now be on cooldown
|
// Should now be on cooldown
|
||||||
@@ -75,7 +75,7 @@ func TestCredential_SetCooldown(t *testing.T) {
|
|||||||
|
|
||||||
func TestCredential_ClearCooldown(t *testing.T) {
|
func TestCredential_ClearCooldown(t *testing.T) {
|
||||||
t.Run("clears active cooldown", func(t *testing.T) {
|
t.Run("clears active cooldown", func(t *testing.T) {
|
||||||
c := &Credential{CooldownUntil: time.Now().Add(1 * time.Hour)}
|
c := &Credential{cooldownUntil: time.Now().Add(1 * time.Hour)}
|
||||||
if !c.IsOnCooldown() {
|
if !c.IsOnCooldown() {
|
||||||
t.Fatal("precondition: expected credential to be on cooldown")
|
t.Fatal("precondition: expected credential to be on cooldown")
|
||||||
}
|
}
|
||||||
@@ -85,8 +85,8 @@ func TestCredential_ClearCooldown(t *testing.T) {
|
|||||||
if c.IsOnCooldown() {
|
if c.IsOnCooldown() {
|
||||||
t.Error("expected credential to not be on cooldown after ClearCooldown")
|
t.Error("expected credential to not be on cooldown after ClearCooldown")
|
||||||
}
|
}
|
||||||
if !c.CooldownUntil.IsZero() {
|
if !c.cooldownUntil.IsZero() {
|
||||||
t.Errorf("expected CooldownUntil to be zero time, got %v", c.CooldownUntil)
|
t.Errorf("expected cooldownUntil to be zero time, got %v", c.cooldownUntil)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -97,8 +97,8 @@ func TestCredential_ClearCooldown(t *testing.T) {
|
|||||||
if c.IsOnCooldown() {
|
if c.IsOnCooldown() {
|
||||||
t.Error("expected credential to not be on cooldown")
|
t.Error("expected credential to not be on cooldown")
|
||||||
}
|
}
|
||||||
if !c.CooldownUntil.IsZero() {
|
if !c.cooldownUntil.IsZero() {
|
||||||
t.Errorf("expected CooldownUntil to be zero time, got %v", c.CooldownUntil)
|
t.Errorf("expected cooldownUntil to be zero time, got %v", c.cooldownUntil)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fujin/anthropic-proxy/internal/auth"
|
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -67,15 +64,6 @@ type LoggingConfig struct {
|
|||||||
Compress bool `yaml:"compress"`
|
Compress bool `yaml:"compress"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type claudeCredentialsJSON struct {
|
|
||||||
ClaudeAiOauth struct {
|
|
||||||
AccessToken string `json:"accessToken"`
|
|
||||||
RefreshToken string `json:"refreshToken"`
|
|
||||||
ExpiresAt int64 `json:"expiresAt"`
|
|
||||||
SubscriptionType string `json:"subscriptionType"`
|
|
||||||
} `json:"claudeAiOauth"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func Load(path string) (*Config, error) {
|
func Load(path string) (*Config, error) {
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -128,47 +116,3 @@ func Load(path string) (*Config, error) {
|
|||||||
|
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DefaultCredentialPath() string {
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return home + "/.claude/.credentials.json"
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadDefaultCredentials() ([]*auth.Credential, error) {
|
|
||||||
path := DefaultCredentialPath()
|
|
||||||
if path == "" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var cf claudeCredentialsJSON
|
|
||||||
if err := json.Unmarshal(data, &cf); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
oauth := cf.ClaudeAiOauth
|
|
||||||
if oauth.AccessToken == "" {
|
|
||||||
return nil, fmt.Errorf("no access token in %s", path)
|
|
||||||
}
|
|
||||||
|
|
||||||
cred := &auth.Credential{
|
|
||||||
ID: "claude-native",
|
|
||||||
Email: oauth.SubscriptionType,
|
|
||||||
AccessToken: oauth.AccessToken,
|
|
||||||
RefreshToken: oauth.RefreshToken,
|
|
||||||
ExpiresAt: time.UnixMilli(oauth.ExpiresAt),
|
|
||||||
FilePath: path,
|
|
||||||
}
|
|
||||||
|
|
||||||
return []*auth.Credential{cred}, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -269,81 +268,3 @@ func TestExportConfig_Enabled(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultCredentialPath(t *testing.T) {
|
|
||||||
path := DefaultCredentialPath()
|
|
||||||
if path == "" {
|
|
||||||
t.Skip("could not determine home directory")
|
|
||||||
}
|
|
||||||
if !strings.HasSuffix(path, "/.claude/.credentials.json") {
|
|
||||||
t.Errorf("DefaultCredentialPath() = %q, want suffix /.claude/.credentials.json", path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadDefaultCredentials_ValidFile(t *testing.T) {
|
|
||||||
// We can't easily override DefaultCredentialPath, so test the JSON parsing
|
|
||||||
// logic by creating a file at a temp location and calling the internal parsing
|
|
||||||
// directly. Instead, we test LoadDefaultCredentials indirectly by verifying
|
|
||||||
// it returns nil,nil when the default path doesn't exist (common in CI).
|
|
||||||
// For a full test, we create the credential file at the expected path.
|
|
||||||
|
|
||||||
// Test with the actual function — if the default credential file doesn't
|
|
||||||
// exist, it should return nil, nil.
|
|
||||||
creds, err := LoadDefaultCredentials()
|
|
||||||
path := DefaultCredentialPath()
|
|
||||||
if path == "" {
|
|
||||||
if creds != nil || err != nil {
|
|
||||||
t.Errorf("expected nil,nil when home dir unavailable, got %v, %v", creds, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, statErr := os.Stat(path); os.IsNotExist(statErr) {
|
|
||||||
// File doesn't exist — should return nil, nil
|
|
||||||
if creds != nil {
|
|
||||||
t.Errorf("expected nil creds for missing file, got %v", creds)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("expected nil error for missing file, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadDefaultCredentials_ParsesJSON(t *testing.T) {
|
|
||||||
// Test the JSON parsing by creating a temp credential file and using
|
|
||||||
// the claudeCredentialsJSON struct directly (white-box test).
|
|
||||||
jsonData := `{"claudeAiOauth":{"accessToken":"test-token","refreshToken":"test-refresh","expiresAt":1234567890,"subscriptionType":"pro"}}`
|
|
||||||
|
|
||||||
var cf claudeCredentialsJSON
|
|
||||||
if err := json.Unmarshal([]byte(jsonData), &cf); err != nil {
|
|
||||||
t.Fatalf("unmarshal: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cf.ClaudeAiOauth.AccessToken != "test-token" {
|
|
||||||
t.Errorf("AccessToken = %q, want test-token", cf.ClaudeAiOauth.AccessToken)
|
|
||||||
}
|
|
||||||
if cf.ClaudeAiOauth.RefreshToken != "test-refresh" {
|
|
||||||
t.Errorf("RefreshToken = %q, want test-refresh", cf.ClaudeAiOauth.RefreshToken)
|
|
||||||
}
|
|
||||||
if cf.ClaudeAiOauth.ExpiresAt != 1234567890 {
|
|
||||||
t.Errorf("ExpiresAt = %d, want 1234567890", cf.ClaudeAiOauth.ExpiresAt)
|
|
||||||
}
|
|
||||||
if cf.ClaudeAiOauth.SubscriptionType != "pro" {
|
|
||||||
t.Errorf("SubscriptionType = %q, want pro", cf.ClaudeAiOauth.SubscriptionType)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadDefaultCredentials_EmptyAccessToken(t *testing.T) {
|
|
||||||
// Verify that an empty access token in the JSON produces an error.
|
|
||||||
// We test the parsing struct and logic path.
|
|
||||||
jsonData := `{"claudeAiOauth":{"accessToken":"","refreshToken":"r","expiresAt":1}}`
|
|
||||||
|
|
||||||
var cf claudeCredentialsJSON
|
|
||||||
if err := json.Unmarshal([]byte(jsonData), &cf); err != nil {
|
|
||||||
t.Fatalf("unmarshal: %v", err)
|
|
||||||
}
|
|
||||||
if cf.ClaudeAiOauth.AccessToken != "" {
|
|
||||||
t.Errorf("expected empty access token")
|
|
||||||
}
|
|
||||||
// The actual LoadDefaultCredentials would return an error here.
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package embedded
|
||||||
|
|
||||||
|
import "github.com/rs/zerolog/log"
|
||||||
|
|
||||||
|
// logWriter bridges subprocess stdout/stderr to zerolog.
|
||||||
|
type logWriter struct {
|
||||||
|
level string
|
||||||
|
component string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *logWriter) Write(p []byte) (n int, err error) {
|
||||||
|
msg := string(p)
|
||||||
|
switch w.level {
|
||||||
|
case "error":
|
||||||
|
log.Error().Str("component", w.component).Msg(msg)
|
||||||
|
default:
|
||||||
|
log.Debug().Str("component", w.component).Msg(msg)
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
@@ -131,19 +131,3 @@ func (p *Perses) writeDashboardProvision() error {
|
|||||||
dashData, 0o644,
|
dashData, 0o644,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
type logWriter struct {
|
|
||||||
level string
|
|
||||||
component string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *logWriter) Write(p []byte) (n int, err error) {
|
|
||||||
msg := string(p)
|
|
||||||
switch w.level {
|
|
||||||
case "error":
|
|
||||||
log.Error().Str("component", w.component).Msg(msg)
|
|
||||||
default:
|
|
||||||
log.Debug().Str("component", w.component).Msg(msg)
|
|
||||||
}
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -13,17 +13,9 @@ import (
|
|||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gopkg.in/lumberjack.v2"
|
"gopkg.in/lumberjack.v2"
|
||||||
)
|
|
||||||
|
|
||||||
// Config holds logging configuration, mirrors config.LoggingConfig.
|
"github.com/fujin/anthropic-proxy/internal/config"
|
||||||
type Config struct {
|
)
|
||||||
Level string
|
|
||||||
File string
|
|
||||||
MaxSizeMB int
|
|
||||||
MaxBackups int
|
|
||||||
MaxAgeDays int
|
|
||||||
Compress bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup initializes the global zerolog logger.
|
// Setup initializes the global zerolog logger.
|
||||||
// - File set: JSON → lumberjack rotating file
|
// - File set: JSON → lumberjack rotating file
|
||||||
@@ -31,7 +23,7 @@ type Config struct {
|
|||||||
// - File empty + not TTY: JSON → stderr (for systemd journal)
|
// - File empty + not TTY: JSON → stderr (for systemd journal)
|
||||||
// Extra writers (e.g., OTLP log bridge) are added via io.MultiWriter so logs
|
// Extra writers (e.g., OTLP log bridge) are added via io.MultiWriter so logs
|
||||||
// are written to both the primary destination and any extra writers.
|
// are written to both the primary destination and any extra writers.
|
||||||
func Setup(cfg Config, extraWriters ...io.Writer) zerolog.Logger {
|
func Setup(cfg config.LoggingConfig, extraWriters ...io.Writer) zerolog.Logger {
|
||||||
// Parse log level
|
// Parse log level
|
||||||
level, err := zerolog.ParseLevel(cfg.Level)
|
level, err := zerolog.ParseLevel(cfg.Level)
|
||||||
if err != nil || cfg.Level == "" {
|
if err != nil || cfg.Level == "" {
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"github.com/fujin/anthropic-proxy/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRedactHeaders(t *testing.T) {
|
func TestRedactHeaders(t *testing.T) {
|
||||||
@@ -177,7 +179,7 @@ func TestSetup_WithFile(t *testing.T) {
|
|||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
logFile := filepath.Join(dir, "test.log")
|
logFile := filepath.Join(dir, "test.log")
|
||||||
|
|
||||||
logger := Setup(Config{
|
logger := Setup(config.LoggingConfig{
|
||||||
Level: "debug",
|
Level: "debug",
|
||||||
File: logFile,
|
File: logFile,
|
||||||
MaxSizeMB: 10,
|
MaxSizeMB: 10,
|
||||||
@@ -191,7 +193,7 @@ func TestSetup_WithFile(t *testing.T) {
|
|||||||
|
|
||||||
func TestSetup_WithoutFile(t *testing.T) {
|
func TestSetup_WithoutFile(t *testing.T) {
|
||||||
// File empty — should use console or stderr mode depending on TTY
|
// File empty — should use console or stderr mode depending on TTY
|
||||||
logger := Setup(Config{
|
logger := Setup(config.LoggingConfig{
|
||||||
Level: "warn",
|
Level: "warn",
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -201,13 +203,13 @@ func TestSetup_WithoutFile(t *testing.T) {
|
|||||||
|
|
||||||
func TestSetup_DefaultLevel(t *testing.T) {
|
func TestSetup_DefaultLevel(t *testing.T) {
|
||||||
// Empty level should default to info
|
// Empty level should default to info
|
||||||
logger := Setup(Config{})
|
logger := Setup(config.LoggingConfig{})
|
||||||
_ = logger // verify no panic
|
_ = logger // verify no panic
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetup_InvalidLevel(t *testing.T) {
|
func TestSetup_InvalidLevel(t *testing.T) {
|
||||||
// Invalid level should default to info
|
// Invalid level should default to info
|
||||||
logger := Setup(Config{Level: "not-a-level"})
|
logger := Setup(config.LoggingConfig{Level: "not-a-level"})
|
||||||
_ = logger // verify no panic
|
_ = logger // verify no panic
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,9 +11,13 @@ import (
|
|||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// fingerprintSalt is the fixed salt used by Claude Code for billing header
|
||||||
|
// fingerprint computation. Extracted from the Claude Code CLI source.
|
||||||
const fingerprintSalt = "59cf53e54c78"
|
const fingerprintSalt = "59cf53e54c78"
|
||||||
|
|
||||||
func computeFingerprint(firstUserMessage string, version string) string {
|
func computeFingerprint(firstUserMessage string, version string) string {
|
||||||
|
// UTF-16 character indices sampled from the first user message, matching
|
||||||
|
// the Claude Code CLI's fingerprinting algorithm.
|
||||||
indices := []int{4, 7, 20}
|
indices := []int{4, 7, 20}
|
||||||
runes := utf16.Encode([]rune(firstUserMessage))
|
runes := utf16.Encode([]rune(firstUserMessage))
|
||||||
var chars string
|
var chars string
|
||||||
|
|||||||
+92
-134
@@ -2,6 +2,7 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
@@ -18,6 +19,15 @@ import (
|
|||||||
"github.com/fujin/anthropic-proxy/internal/telemetry"
|
"github.com/fujin/anthropic-proxy/internal/telemetry"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// requestInfo bundles common request context passed to logging/telemetry helpers.
|
||||||
|
type requestInfo struct {
|
||||||
|
model string
|
||||||
|
stream bool
|
||||||
|
cred *auth.Credential
|
||||||
|
body []byte
|
||||||
|
originalBody []byte
|
||||||
|
}
|
||||||
|
|
||||||
func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func() *Sanitizer, tracker *ratelimit.Tracker) gin.HandlerFunc {
|
func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func() *Sanitizer, tracker *ratelimit.Tracker) gin.HandlerFunc {
|
||||||
upstream := NewUpstreamClient(profile)
|
upstream := NewUpstreamClient(profile)
|
||||||
|
|
||||||
@@ -61,6 +71,7 @@ func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, p
|
|||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
model := gjson.GetBytes(body, "model").String()
|
model := gjson.GetBytes(body, "model").String()
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
|
ri := requestInfo{model: model, stream: false, cred: cred, body: body, originalBody: originalBody}
|
||||||
|
|
||||||
telemetry.RequestBodySize.Record(ctx, int64(len(body)),
|
telemetry.RequestBodySize.Record(ctx, int64(len(body)),
|
||||||
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", false)))
|
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", false)))
|
||||||
@@ -69,85 +80,25 @@ func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, p
|
|||||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
recordConnectionError(ctx, err, ri, latencyMs)
|
||||||
Err(err).
|
|
||||||
Str("credential", cred.Email).
|
|
||||||
Str("model", model).
|
|
||||||
Bool("stream", false).
|
|
||||||
Str("request_body_original", string(originalBody)).
|
|
||||||
Str("request_body_sanitized", string(body)).
|
|
||||||
Int("request_body_size", len(body)).
|
|
||||||
Float64("latency_ms", latencyMs).
|
|
||||||
Msg("upstream connection error")
|
|
||||||
|
|
||||||
telemetry.UpstreamErrors.Add(ctx, 1,
|
|
||||||
metric.WithAttributes(
|
|
||||||
attribute.String("error_type", "connection"),
|
|
||||||
attribute.String("credential", cred.Email),
|
|
||||||
attribute.Int("status_code", http.StatusBadGateway),
|
|
||||||
))
|
|
||||||
telemetry.RequestCounter.Add(ctx, 1,
|
|
||||||
metric.WithAttributes(
|
|
||||||
attribute.String("model", model),
|
|
||||||
attribute.Bool("stream", false),
|
|
||||||
attribute.Int("status_code", http.StatusBadGateway),
|
|
||||||
))
|
|
||||||
telemetry.RequestDuration.Record(ctx, latencyMs,
|
|
||||||
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", false), attribute.Int("status_code", http.StatusBadGateway)))
|
|
||||||
|
|
||||||
c.JSON(http.StatusBadGateway, gin.H{"error": "upstream request failed"})
|
c.JSON(http.StatusBadGateway, gin.H{"error": "upstream request failed"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
attrs := []attribute.KeyValue{
|
recordRequestMetrics(ctx, ri, statusCode, latencyMs)
|
||||||
attribute.String("model", model),
|
|
||||||
attribute.Bool("stream", false),
|
|
||||||
attribute.Int("status_code", statusCode),
|
|
||||||
}
|
|
||||||
|
|
||||||
telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...))
|
|
||||||
telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...))
|
|
||||||
|
|
||||||
if statusCode >= 400 {
|
if statusCode >= 400 {
|
||||||
pool.MarkFailure(cred, statusCode)
|
pool.MarkFailure(cred, statusCode)
|
||||||
telemetry.CredentialCooldowns.Add(ctx, 1,
|
telemetry.CredentialCooldowns.Add(ctx, 1,
|
||||||
metric.WithAttributes(attribute.Int("status_code", statusCode)))
|
metric.WithAttributes(attribute.Int("status_code", statusCode)))
|
||||||
errorType := gjson.GetBytes(respBody, "error.type").String()
|
recordUpstreamError(ctx, statusCode, respBody, headers.Get("X-Request-Id"), latencyMs, ri, c.Request.Header)
|
||||||
errorMessage := gjson.GetBytes(respBody, "error.message").String()
|
|
||||||
log.Error().
|
|
||||||
Int("status", statusCode).
|
|
||||||
Str("error_type", errorType).
|
|
||||||
Str("error_message", errorMessage).
|
|
||||||
Str("response_body", string(respBody)).
|
|
||||||
Str("request_id", headers.Get("X-Request-Id")).
|
|
||||||
Float64("latency_ms", latencyMs).
|
|
||||||
Str("credential", cred.Email).
|
|
||||||
Str("model", model).
|
|
||||||
Bool("stream", false).
|
|
||||||
Str("request_body_original", string(originalBody)).
|
|
||||||
Str("request_body_sanitized", string(body)).
|
|
||||||
Int("request_body_size", len(body)).
|
|
||||||
Str("request_headers", logging.RedactHeaders(c.Request.Header)).
|
|
||||||
Msg("upstream error")
|
|
||||||
|
|
||||||
telemetry.UpstreamErrors.Add(ctx, 1,
|
|
||||||
metric.WithAttributes(
|
|
||||||
attribute.Int("status_code", statusCode),
|
|
||||||
attribute.String("error_type", errorType),
|
|
||||||
attribute.String("credential", cred.Email),
|
|
||||||
))
|
|
||||||
} else {
|
} else {
|
||||||
pool.MarkSuccess(cred)
|
pool.MarkSuccess(cred)
|
||||||
respBody = san.DesanitizeResponse(respBody)
|
respBody = san.DesanitizeResponse(respBody)
|
||||||
|
|
||||||
inputTokens := gjson.GetBytes(respBody, "usage.input_tokens").Int()
|
inputTokens := gjson.GetBytes(respBody, "usage.input_tokens").Int()
|
||||||
outputTokens := gjson.GetBytes(respBody, "usage.output_tokens").Int()
|
outputTokens := gjson.GetBytes(respBody, "usage.output_tokens").Int()
|
||||||
tokenAttrs := metric.WithAttributes(
|
recordTokenUsage(ctx, model, cred, inputTokens, outputTokens)
|
||||||
attribute.String("model", model),
|
|
||||||
attribute.String("credential", cred.Email),
|
|
||||||
)
|
|
||||||
telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs)
|
|
||||||
telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs)
|
|
||||||
if tracker != nil {
|
if tracker != nil {
|
||||||
tracker.UpdateFromHeaders(headers)
|
tracker.UpdateFromHeaders(headers)
|
||||||
}
|
}
|
||||||
@@ -174,6 +125,7 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool
|
|||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
model := gjson.GetBytes(body, "model").String()
|
model := gjson.GetBytes(body, "model").String()
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
|
ri := requestInfo{model: model, stream: true, cred: cred, body: body, originalBody: originalBody}
|
||||||
|
|
||||||
telemetry.StreamRequests.Add(ctx, 1, metric.WithAttributes(attribute.String("model", model)))
|
telemetry.StreamRequests.Add(ctx, 1, metric.WithAttributes(attribute.String("model", model)))
|
||||||
telemetry.RequestBodySize.Record(ctx, int64(len(body)),
|
telemetry.RequestBodySize.Record(ctx, int64(len(body)),
|
||||||
@@ -182,32 +134,7 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool
|
|||||||
resp, err := upstream.ExecuteStream(ctx, cred, body)
|
resp, err := upstream.ExecuteStream(ctx, cred, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||||
log.Error().
|
recordConnectionError(ctx, err, ri, latencyMs)
|
||||||
Err(err).
|
|
||||||
Str("credential", cred.Email).
|
|
||||||
Str("model", model).
|
|
||||||
Bool("stream", true).
|
|
||||||
Str("request_body_original", string(originalBody)).
|
|
||||||
Str("request_body_sanitized", string(body)).
|
|
||||||
Int("request_body_size", len(body)).
|
|
||||||
Float64("latency_ms", latencyMs).
|
|
||||||
Msg("upstream connection error")
|
|
||||||
|
|
||||||
telemetry.UpstreamErrors.Add(ctx, 1,
|
|
||||||
metric.WithAttributes(
|
|
||||||
attribute.String("error_type", "connection"),
|
|
||||||
attribute.String("credential", cred.Email),
|
|
||||||
attribute.Int("status_code", http.StatusBadGateway),
|
|
||||||
))
|
|
||||||
telemetry.RequestCounter.Add(ctx, 1,
|
|
||||||
metric.WithAttributes(
|
|
||||||
attribute.String("model", model),
|
|
||||||
attribute.Bool("stream", true),
|
|
||||||
attribute.Int("status_code", http.StatusBadGateway),
|
|
||||||
))
|
|
||||||
telemetry.RequestDuration.Record(ctx, latencyMs,
|
|
||||||
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", true), attribute.Int("status_code", http.StatusBadGateway)))
|
|
||||||
|
|
||||||
c.JSON(http.StatusBadGateway, gin.H{"error": "upstream stream request failed"})
|
c.JSON(http.StatusBadGateway, gin.H{"error": "upstream stream request failed"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -219,37 +146,8 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool
|
|||||||
metric.WithAttributes(attribute.Int("status_code", resp.StatusCode)))
|
metric.WithAttributes(attribute.Int("status_code", resp.StatusCode)))
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||||
errorType := gjson.GetBytes(respBody, "error.type").String()
|
recordRequestMetrics(ctx, ri, resp.StatusCode, latencyMs)
|
||||||
errorMessage := gjson.GetBytes(respBody, "error.message").String()
|
recordUpstreamError(ctx, resp.StatusCode, respBody, resp.Header.Get("X-Request-Id"), latencyMs, ri, c.Request.Header)
|
||||||
log.Error().
|
|
||||||
Int("status", resp.StatusCode).
|
|
||||||
Str("error_type", errorType).
|
|
||||||
Str("error_message", errorMessage).
|
|
||||||
Str("response_body", string(respBody)).
|
|
||||||
Str("request_id", resp.Header.Get("X-Request-Id")).
|
|
||||||
Float64("latency_ms", latencyMs).
|
|
||||||
Str("credential", cred.Email).
|
|
||||||
Str("model", model).
|
|
||||||
Bool("stream", true).
|
|
||||||
Str("request_body_original", string(originalBody)).
|
|
||||||
Str("request_body_sanitized", string(body)).
|
|
||||||
Int("request_body_size", len(body)).
|
|
||||||
Str("request_headers", logging.RedactHeaders(c.Request.Header)).
|
|
||||||
Msg("upstream error")
|
|
||||||
|
|
||||||
attrs := []attribute.KeyValue{
|
|
||||||
attribute.String("model", model),
|
|
||||||
attribute.Bool("stream", true),
|
|
||||||
attribute.Int("status_code", resp.StatusCode),
|
|
||||||
}
|
|
||||||
telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...))
|
|
||||||
telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...))
|
|
||||||
telemetry.UpstreamErrors.Add(ctx, 1,
|
|
||||||
metric.WithAttributes(
|
|
||||||
attribute.Int("status_code", resp.StatusCode),
|
|
||||||
attribute.String("error_type", errorType),
|
|
||||||
attribute.String("credential", cred.Email),
|
|
||||||
))
|
|
||||||
|
|
||||||
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody)
|
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody)
|
||||||
return
|
return
|
||||||
@@ -290,21 +188,10 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool
|
|||||||
}
|
}
|
||||||
|
|
||||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||||
attrs := []attribute.KeyValue{
|
recordRequestMetrics(ctx, ri, http.StatusOK, latencyMs)
|
||||||
attribute.String("model", model),
|
|
||||||
attribute.Bool("stream", true),
|
|
||||||
attribute.Int("status_code", http.StatusOK),
|
|
||||||
}
|
|
||||||
telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...))
|
|
||||||
telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...))
|
|
||||||
|
|
||||||
if inputTokens > 0 || outputTokens > 0 {
|
if inputTokens > 0 || outputTokens > 0 {
|
||||||
tokenAttrs := metric.WithAttributes(
|
recordTokenUsage(ctx, model, cred, inputTokens, outputTokens)
|
||||||
attribute.String("model", model),
|
|
||||||
attribute.String("credential", cred.Email),
|
|
||||||
)
|
|
||||||
telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs)
|
|
||||||
telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs)
|
|
||||||
if tracker != nil {
|
if tracker != nil {
|
||||||
tracker.UpdateFromHeaders(resp.Header)
|
tracker.UpdateFromHeaders(resp.Header)
|
||||||
}
|
}
|
||||||
@@ -322,3 +209,74 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool
|
|||||||
log.Error().Err(err).Msg("stream scan error")
|
log.Error().Err(err).Msg("stream scan error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// recordConnectionError logs and records metrics for upstream connection failures.
|
||||||
|
func recordConnectionError(ctx context.Context, err error, ri requestInfo, latencyMs float64) {
|
||||||
|
log.Error().
|
||||||
|
Err(err).
|
||||||
|
Str("credential", ri.cred.Email).
|
||||||
|
Str("model", ri.model).
|
||||||
|
Bool("stream", ri.stream).
|
||||||
|
Str("request_body_original", string(ri.originalBody)).
|
||||||
|
Str("request_body_sanitized", string(ri.body)).
|
||||||
|
Int("request_body_size", len(ri.body)).
|
||||||
|
Float64("latency_ms", latencyMs).
|
||||||
|
Msg("upstream connection error")
|
||||||
|
|
||||||
|
telemetry.UpstreamErrors.Add(ctx, 1,
|
||||||
|
metric.WithAttributes(
|
||||||
|
attribute.String("error_type", "connection"),
|
||||||
|
attribute.String("credential", ri.cred.Email),
|
||||||
|
attribute.Int("status_code", http.StatusBadGateway),
|
||||||
|
))
|
||||||
|
recordRequestMetrics(ctx, ri, http.StatusBadGateway, latencyMs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordUpstreamError logs and records metrics for upstream HTTP error responses.
|
||||||
|
func recordUpstreamError(ctx context.Context, statusCode int, respBody []byte, requestID string, latencyMs float64, ri requestInfo, requestHeaders http.Header) {
|
||||||
|
errorType := gjson.GetBytes(respBody, "error.type").String()
|
||||||
|
errorMessage := gjson.GetBytes(respBody, "error.message").String()
|
||||||
|
log.Error().
|
||||||
|
Int("status", statusCode).
|
||||||
|
Str("error_type", errorType).
|
||||||
|
Str("error_message", errorMessage).
|
||||||
|
Str("response_body", string(respBody)).
|
||||||
|
Str("request_id", requestID).
|
||||||
|
Float64("latency_ms", latencyMs).
|
||||||
|
Str("credential", ri.cred.Email).
|
||||||
|
Str("model", ri.model).
|
||||||
|
Bool("stream", ri.stream).
|
||||||
|
Str("request_body_original", string(ri.originalBody)).
|
||||||
|
Str("request_body_sanitized", string(ri.body)).
|
||||||
|
Int("request_body_size", len(ri.body)).
|
||||||
|
Str("request_headers", logging.RedactHeaders(requestHeaders)).
|
||||||
|
Msg("upstream error")
|
||||||
|
|
||||||
|
telemetry.UpstreamErrors.Add(ctx, 1,
|
||||||
|
metric.WithAttributes(
|
||||||
|
attribute.Int("status_code", statusCode),
|
||||||
|
attribute.String("error_type", errorType),
|
||||||
|
attribute.String("credential", ri.cred.Email),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordRequestMetrics records the request counter and duration histogram.
|
||||||
|
func recordRequestMetrics(ctx context.Context, ri requestInfo, statusCode int, latencyMs float64) {
|
||||||
|
attrs := []attribute.KeyValue{
|
||||||
|
attribute.String("model", ri.model),
|
||||||
|
attribute.Bool("stream", ri.stream),
|
||||||
|
attribute.Int("status_code", statusCode),
|
||||||
|
}
|
||||||
|
telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...))
|
||||||
|
telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordTokenUsage records token consumption metrics.
|
||||||
|
func recordTokenUsage(ctx context.Context, model string, cred *auth.Credential, inputTokens, outputTokens int64) {
|
||||||
|
tokenAttrs := metric.WithAttributes(
|
||||||
|
attribute.String("model", model),
|
||||||
|
attribute.String("credential", cred.Email),
|
||||||
|
)
|
||||||
|
telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs)
|
||||||
|
telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs)
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
|
|
||||||
@@ -11,10 +12,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Sanitizer struct {
|
type Sanitizer struct {
|
||||||
toolsForward map[string]string
|
toolsForward map[string]string
|
||||||
toolsReverse map[string]string
|
toolsReverse map[string]string
|
||||||
systemRules []config.ReplaceRule
|
systemRules []config.ReplaceRule
|
||||||
bodyRules []config.ReplaceRule
|
bodyRules []config.ReplaceRule
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSanitizer(cfg config.SanitizeConfig) *Sanitizer {
|
func NewSanitizer(cfg config.SanitizeConfig) *Sanitizer {
|
||||||
@@ -49,7 +50,11 @@ func (s *Sanitizer) DesanitizeResponse(body []byte) []byte {
|
|||||||
}
|
}
|
||||||
name := block.Get("name").String()
|
name := block.Get("name").String()
|
||||||
if orig, ok := s.toolsReverse[name]; ok {
|
if orig, ok := s.toolsReverse[name]; ok {
|
||||||
body, _ = sjson.SetBytes(body, "content."+strconv.Itoa(i)+".name", orig)
|
if b, err := sjson.SetBytes(body, "content."+strconv.Itoa(i)+".name", orig); err != nil {
|
||||||
|
log.Warn().Err(err).Str("tool", name).Msg("desanitize response: set name failed")
|
||||||
|
} else {
|
||||||
|
body = b
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return body
|
return body
|
||||||
@@ -64,8 +69,12 @@ func (s *Sanitizer) DesanitizeStreamEvent(line string) string {
|
|||||||
for _, path := range []string{"content_block.name", "delta.name"} {
|
for _, path := range []string{"content_block.name", "delta.name"} {
|
||||||
name := gjson.GetBytes(data, path).String()
|
name := gjson.GetBytes(data, path).String()
|
||||||
if orig, ok := s.toolsReverse[name]; ok {
|
if orig, ok := s.toolsReverse[name]; ok {
|
||||||
data, _ = sjson.SetBytes(data, path, orig)
|
if b, err := sjson.SetBytes(data, path, orig); err != nil {
|
||||||
changed = true
|
log.Warn().Err(err).Str("tool", name).Msg("desanitize stream event: set name failed")
|
||||||
|
} else {
|
||||||
|
data = b
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if changed {
|
if changed {
|
||||||
@@ -85,7 +94,11 @@ func (s *Sanitizer) renameTools(body []byte) []byte {
|
|||||||
for i, tool := range tools.Array() {
|
for i, tool := range tools.Array() {
|
||||||
name := tool.Get("name").String()
|
name := tool.Get("name").String()
|
||||||
if newName, ok := s.toolsForward[name]; ok {
|
if newName, ok := s.toolsForward[name]; ok {
|
||||||
body, _ = sjson.SetBytes(body, "tools."+strconv.Itoa(i)+".name", newName)
|
if b, err := sjson.SetBytes(body, "tools."+strconv.Itoa(i)+".name", newName); err != nil {
|
||||||
|
log.Warn().Err(err).Str("tool", name).Msg("rename tool failed")
|
||||||
|
} else {
|
||||||
|
body = b
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return body
|
return body
|
||||||
@@ -104,7 +117,11 @@ func (s *Sanitizer) replaceSystem(body []byte) []byte {
|
|||||||
for _, rule := range s.systemRules {
|
for _, rule := range s.systemRules {
|
||||||
text = strings.ReplaceAll(text, rule.Match, rule.Replace)
|
text = strings.ReplaceAll(text, rule.Match, rule.Replace)
|
||||||
}
|
}
|
||||||
body, _ = sjson.SetBytes(body, "system."+strconv.Itoa(i)+".text", text)
|
if b, err := sjson.SetBytes(body, "system."+strconv.Itoa(i)+".text", text); err != nil {
|
||||||
|
log.Warn().Err(err).Int("block", i).Msg("replace system text failed")
|
||||||
|
} else {
|
||||||
|
body = b
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|||||||
+53
-41
@@ -36,6 +36,21 @@ var skipHeaders = map[string]bool{
|
|||||||
"connection": true,
|
"connection": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const fakeJSONResponse = `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`
|
||||||
|
|
||||||
|
const fakeStreamResponse = "event: message_start\n" +
|
||||||
|
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_fake\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}\n\n" +
|
||||||
|
"event: content_block_start\n" +
|
||||||
|
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n" +
|
||||||
|
"event: content_block_delta\n" +
|
||||||
|
"data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"ok\"}}\n\n" +
|
||||||
|
"event: content_block_stop\n" +
|
||||||
|
"data: {\"type\":\"content_block_stop\",\"index\":0}\n\n" +
|
||||||
|
"event: message_delta\n" +
|
||||||
|
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":1}}\n\n" +
|
||||||
|
"event: message_stop\n" +
|
||||||
|
"data: {\"type\":\"message_stop\"}\n\n"
|
||||||
|
|
||||||
func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) {
|
func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) {
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -48,45 +63,7 @@ func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) {
|
|||||||
captured := make(chan struct{}, 1)
|
captured := make(chan struct{}, 1)
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/", sniffHandler(&mu, &profile, captured))
|
||||||
if r.Method == "HEAD" {
|
|
||||||
w.WriteHeader(200)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if r.Method != "POST" || !strings.Contains(r.URL.Path, "/v1/messages") {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(200)
|
|
||||||
fmt.Fprint(w, `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
body, _ := io.ReadAll(r.Body)
|
|
||||||
|
|
||||||
mu.Lock()
|
|
||||||
if profile == nil {
|
|
||||||
profile = extractProfile(r, body)
|
|
||||||
select {
|
|
||||||
case captured <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mu.Unlock()
|
|
||||||
|
|
||||||
if strings.Contains(string(body), `"stream":true`) {
|
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
w.WriteHeader(200)
|
|
||||||
fmt.Fprint(w, "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_fake\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}\n\n")
|
|
||||||
fmt.Fprint(w, "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n")
|
|
||||||
fmt.Fprint(w, "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"ok\"}}\n\n")
|
|
||||||
fmt.Fprint(w, "event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n")
|
|
||||||
fmt.Fprint(w, "event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":1}}\n\n")
|
|
||||||
fmt.Fprint(w, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
|
|
||||||
} else {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(200)
|
|
||||||
fmt.Fprint(w, `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
srv := &http.Server{Handler: mux}
|
srv := &http.Server{Handler: mux}
|
||||||
go srv.Serve(listener)
|
go srv.Serve(listener)
|
||||||
@@ -130,8 +107,44 @@ func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) {
|
|||||||
return profile, nil
|
return profile, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sniffHandler(mu *sync.Mutex, profile **SniffedProfile, captured chan<- struct{}) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method == "HEAD" {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Method != "POST" || !strings.Contains(r.URL.Path, "/v1/messages") {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(200)
|
||||||
|
fmt.Fprint(w, fakeJSONResponse)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
if *profile == nil {
|
||||||
|
*profile = extractProfile(r, body)
|
||||||
|
select {
|
||||||
|
case captured <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
if strings.Contains(string(body), `"stream":true`) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.WriteHeader(200)
|
||||||
|
fmt.Fprint(w, fakeStreamResponse)
|
||||||
|
} else {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(200)
|
||||||
|
fmt.Fprint(w, fakeJSONResponse)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func extractProfile(r *http.Request, body []byte) *SniffedProfile {
|
func extractProfile(r *http.Request, body []byte) *SniffedProfile {
|
||||||
// Capture raw headers preserving original casing.
|
|
||||||
var headers [][2]string
|
var headers [][2]string
|
||||||
for name, vals := range r.Header {
|
for name, vals := range r.Header {
|
||||||
if skipHeaders[strings.ToLower(name)] {
|
if skipHeaders[strings.ToLower(name)] {
|
||||||
@@ -142,7 +155,6 @@ func extractProfile(r *http.Request, body []byte) *SniffedProfile {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deduplicate and strip subscription-specific betas.
|
|
||||||
seen := map[string]bool{}
|
seen := map[string]bool{}
|
||||||
var deduped [][2]string
|
var deduped [][2]string
|
||||||
for _, h := range headers {
|
for _, h := range headers {
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import (
|
|||||||
|
|
||||||
"github.com/fujin/anthropic-proxy/internal/auth"
|
"github.com/fujin/anthropic-proxy/internal/auth"
|
||||||
"github.com/fujin/anthropic-proxy/internal/logging"
|
"github.com/fujin/anthropic-proxy/internal/logging"
|
||||||
|
"github.com/fujin/anthropic-proxy/internal/transport"
|
||||||
|
"github.com/fujin/anthropic-proxy/internal/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
const messagesURL = "https://api.anthropic.com/v1/messages?beta=true"
|
const messagesURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||||
@@ -27,7 +29,7 @@ func NewUpstreamClient(profile *SniffedProfile) *UpstreamClient {
|
|||||||
return &UpstreamClient{
|
return &UpstreamClient{
|
||||||
client: http.Client{
|
client: http.Client{
|
||||||
Timeout: 0,
|
Timeout: 0,
|
||||||
Transport: newUtlsRoundTripper(),
|
Transport: transport.NewUTLS(),
|
||||||
},
|
},
|
||||||
sessionID: uuid.New().String(),
|
sessionID: uuid.New().String(),
|
||||||
profile: profile,
|
profile: profile,
|
||||||
@@ -38,7 +40,7 @@ func (u *UpstreamClient) version() string {
|
|||||||
if u.profile != nil && u.profile.Version != "" {
|
if u.profile != nil && u.profile.Version != "" {
|
||||||
return u.profile.Version
|
return u.profile.Version
|
||||||
}
|
}
|
||||||
return "2.1.92"
|
return version.ClaudeCodeFallback
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyHeaders replays sniffed headers, substituting auth + per-request IDs + accept.
|
// applyHeaders replays sniffed headers, substituting auth + per-request IDs + accept.
|
||||||
|
|||||||
@@ -7,8 +7,13 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/fujin/anthropic-proxy/internal/transport"
|
||||||
|
"github.com/fujin/anthropic-proxy/internal/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var usageClient = transport.NewHTTPClient(10 * time.Second)
|
||||||
|
|
||||||
const usageURL = "https://api.anthropic.com/api/oauth/usage"
|
const usageURL = "https://api.anthropic.com/api/oauth/usage"
|
||||||
|
|
||||||
type RateLimit struct {
|
type RateLimit struct {
|
||||||
@@ -17,17 +22,17 @@ type RateLimit struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ExtraUsage struct {
|
type ExtraUsage struct {
|
||||||
IsEnabled bool `json:"is_enabled"`
|
IsEnabled bool `json:"is_enabled"`
|
||||||
MonthlyLimit *float64 `json:"monthly_limit"`
|
MonthlyLimit *float64 `json:"monthly_limit"`
|
||||||
UsedCredits *float64 `json:"used_credits"`
|
UsedCredits *float64 `json:"used_credits"`
|
||||||
Utilization *float64 `json:"utilization"`
|
Utilization *float64 `json:"utilization"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UsageResponse struct {
|
type UsageResponse struct {
|
||||||
FiveHour *RateLimit `json:"five_hour"`
|
FiveHour *RateLimit `json:"five_hour"`
|
||||||
SevenDay *RateLimit `json:"seven_day"`
|
SevenDay *RateLimit `json:"seven_day"`
|
||||||
SevenDaySonnet *RateLimit `json:"seven_day_sonnet"`
|
SevenDaySonnet *RateLimit `json:"seven_day_sonnet"`
|
||||||
ExtraUsage *ExtraUsage `json:"extra_usage"`
|
ExtraUsage *ExtraUsage `json:"extra_usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchUsage(ctx context.Context, token string) (*UsageResponse, error) {
|
func fetchUsage(ctx context.Context, token string) (*UsageResponse, error) {
|
||||||
@@ -41,9 +46,9 @@ func fetchUsage(ctx context.Context, token string) (*UsageResponse, error) {
|
|||||||
req.Header.Set("Authorization", "Bearer "+token)
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||||
req.Header.Set("User-Agent", "claude-cli/2.1.92")
|
req.Header.Set("User-Agent", "claude-cli/"+version.ClaudeCodeFallback)
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := usageClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("request: %w", err)
|
return nil, fmt.Errorf("request: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -138,10 +138,16 @@ func corsMiddleware() gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// authBypassPaths lists endpoints that do not require API key authentication.
|
||||||
|
var authBypassPaths = map[string]bool{
|
||||||
|
"/healthz": true,
|
||||||
|
"/reload": true,
|
||||||
|
"/metrics": true,
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) authMiddleware() gin.HandlerFunc {
|
func (s *Server) authMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
path := c.Request.URL.Path
|
if authBypassPaths[c.Request.URL.Path] {
|
||||||
if path == "/healthz" || path == "/reload" || path == "/metrics" {
|
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,29 +1,47 @@
|
|||||||
package proxy
|
// Package transport provides a shared uTLS HTTP/2 round-tripper with Chrome
|
||||||
|
// TLS fingerprinting and per-host connection pooling. Used by both the upstream
|
||||||
|
// proxy client and the OAuth token refresh client.
|
||||||
|
package transport
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
tls "github.com/refraction-networking/utls"
|
tls "github.com/refraction-networking/utls"
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type utlsRoundTripper struct {
|
// UTLS implements http.RoundTripper using uTLS (Chrome fingerprint) over HTTP/2.
|
||||||
|
// It maintains a per-host connection pool with coordination for concurrent
|
||||||
|
// requests to the same host.
|
||||||
|
type UTLS struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
connections map[string]*http2.ClientConn
|
connections map[string]*http2.ClientConn
|
||||||
pending map[string]*sync.Cond
|
pending map[string]*sync.Cond
|
||||||
|
dialTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUtlsRoundTripper() *utlsRoundTripper {
|
// NewUTLS creates a uTLS HTTP/2 round-tripper with a 10-second dial timeout.
|
||||||
return &utlsRoundTripper{
|
func NewUTLS() *UTLS {
|
||||||
|
return &UTLS{
|
||||||
connections: make(map[string]*http2.ClientConn),
|
connections: make(map[string]*http2.ClientConn),
|
||||||
pending: make(map[string]*sync.Cond),
|
pending: make(map[string]*sync.Cond),
|
||||||
|
dialTimeout: 10 * time.Second,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) {
|
// NewHTTPClient returns an http.Client using uTLS transport with the given
|
||||||
|
// request timeout. Pass 0 for no timeout (streaming).
|
||||||
|
func NewHTTPClient(timeout time.Duration) *http.Client {
|
||||||
|
return &http.Client{
|
||||||
|
Timeout: timeout,
|
||||||
|
Transport: NewUTLS(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UTLS) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
|
|
||||||
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
|
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
|
||||||
@@ -59,8 +77,8 @@ func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.Clie
|
|||||||
return h2Conn, nil
|
return h2Conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
|
func (t *UTLS) createConnection(host, addr string) (*http2.ClientConn, error) {
|
||||||
conn, err := net.Dial("tcp", addr)
|
conn, err := net.DialTimeout("tcp", addr, t.dialTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -83,14 +101,14 @@ func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientCon
|
|||||||
return h2Conn, nil
|
return h2Conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
// RoundTrip implements http.RoundTripper with uTLS Chrome fingerprinting.
|
||||||
|
func (t *UTLS) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
hostname := req.URL.Hostname()
|
hostname := req.URL.Hostname()
|
||||||
port := req.URL.Port()
|
port := req.URL.Port()
|
||||||
if port == "" {
|
if port == "" {
|
||||||
port = "443"
|
port = "443"
|
||||||
}
|
}
|
||||||
addr := net.JoinHostPort(hostname, port)
|
addr := net.JoinHostPort(hostname, port)
|
||||||
log.Debug().Str("addr", addr).Msg("uTLS round trip")
|
|
||||||
|
|
||||||
h2Conn, err := t.getOrCreateConnection(hostname, addr)
|
h2Conn, err := t.getOrCreateConnection(hostname, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -0,0 +1,78 @@
|
|||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewUTLS(t *testing.T) {
|
||||||
|
tr := NewUTLS()
|
||||||
|
if tr == nil {
|
||||||
|
t.Fatal("NewUTLS returned nil")
|
||||||
|
}
|
||||||
|
if tr.connections == nil {
|
||||||
|
t.Error("connections map is nil")
|
||||||
|
}
|
||||||
|
if tr.pending == nil {
|
||||||
|
t.Error("pending map is nil")
|
||||||
|
}
|
||||||
|
if tr.dialTimeout != 10*time.Second {
|
||||||
|
t.Errorf("dialTimeout = %v, want 10s", tr.dialTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHTTPClient(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
timeout time.Duration
|
||||||
|
}{
|
||||||
|
{"zero timeout (streaming)", 0},
|
||||||
|
{"15s timeout", 15 * time.Second},
|
||||||
|
{"30s timeout", 30 * time.Second},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := NewHTTPClient(tt.timeout)
|
||||||
|
if c == nil {
|
||||||
|
t.Fatal("NewHTTPClient returned nil")
|
||||||
|
}
|
||||||
|
if c.Timeout != tt.timeout {
|
||||||
|
t.Errorf("Timeout = %v, want %v", c.Timeout, tt.timeout)
|
||||||
|
}
|
||||||
|
if c.Transport == nil {
|
||||||
|
t.Error("Transport is nil")
|
||||||
|
}
|
||||||
|
if _, ok := c.Transport.(*UTLS); !ok {
|
||||||
|
t.Errorf("Transport type = %T, want *UTLS", c.Transport)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUTLS_ImplementsRoundTripper(t *testing.T) {
|
||||||
|
var _ http.RoundTripper = (*UTLS)(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUTLS_RoundTrip_InvalidHost(t *testing.T) {
|
||||||
|
tr := NewUTLS()
|
||||||
|
// Use a non-routable address to test dial timeout behavior
|
||||||
|
req, err := http.NewRequest("GET", "https://192.0.2.1:443/test", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest: %v", err)
|
||||||
|
}
|
||||||
|
_, err = tr.RoundTrip(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for non-routable address, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUTLS_ConnectionEviction(t *testing.T) {
|
||||||
|
tr := NewUTLS()
|
||||||
|
// Verify connections map starts empty
|
||||||
|
tr.mu.Lock()
|
||||||
|
if len(tr.connections) != 0 {
|
||||||
|
t.Errorf("initial connections = %d, want 0", len(tr.connections))
|
||||||
|
}
|
||||||
|
tr.mu.Unlock()
|
||||||
|
}
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
// Package version provides the fallback Claude Code client version used when
|
||||||
|
// no sniffed profile is available. This constant is shared between the upstream
|
||||||
|
// proxy client and the rate limit usage poller.
|
||||||
|
package version
|
||||||
|
|
||||||
|
// ClaudeCodeFallback is the Claude Code CLI version string used as a fallback
|
||||||
|
// when no real version is obtained from sniffing.
|
||||||
|
const ClaudeCodeFallback = "2.1.92"
|
||||||
@@ -21,6 +21,74 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func initCredential() (*auth.Credential, error) {
|
||||||
|
creds, err := auth.LoadDefaultCredentials()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cred *auth.Credential
|
||||||
|
if len(creds) > 0 {
|
||||||
|
cred = creds[0]
|
||||||
|
// If token is expired, try refresh first
|
||||||
|
if !cred.ExpiresAt.IsZero() && time.Now().After(cred.ExpiresAt) {
|
||||||
|
log.Info().Msg("token expired, attempting refresh")
|
||||||
|
refreshCtx, refreshCancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
refreshErr := auth.RefreshToken(refreshCtx, cred)
|
||||||
|
refreshCancel()
|
||||||
|
if refreshErr != nil {
|
||||||
|
log.Warn().Err(refreshErr).Msg("refresh failed, initiating login")
|
||||||
|
cred = nil // fall through to login
|
||||||
|
} else {
|
||||||
|
log.Info().Msg("token refreshed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cred == nil {
|
||||||
|
fi, statErr := os.Stdin.Stat()
|
||||||
|
if statErr == nil && (fi.Mode()&os.ModeCharDevice) == 0 {
|
||||||
|
return nil, fmt.Errorf("no valid credentials found; run the proxy interactively for initial login")
|
||||||
|
}
|
||||||
|
log.Info().Msg("no credentials found, starting OAuth login")
|
||||||
|
cred, err = auth.Login(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("login failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Str("credential", cred.Email).Msg("credential loaded")
|
||||||
|
return cred, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func initEmbedded(cfg *config.Config) (cleanup func(), err error) {
|
||||||
|
if !cfg.Telemetry.Embedded.Enabled {
|
||||||
|
return func() {}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cleanups []func()
|
||||||
|
|
||||||
|
vm := embedded.NewVM(cfg.Telemetry.Embedded, cfg.Port)
|
||||||
|
if err := vm.Start(); err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to start victoria-metrics")
|
||||||
|
} else {
|
||||||
|
cleanups = append(cleanups, vm.Stop)
|
||||||
|
}
|
||||||
|
|
||||||
|
perses := embedded.NewPerses(cfg.Telemetry.Embedded, cfg.Port)
|
||||||
|
if err := perses.Start(); err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to start perses")
|
||||||
|
} else {
|
||||||
|
cleanups = append(cleanups, perses.Stop)
|
||||||
|
}
|
||||||
|
|
||||||
|
return func() {
|
||||||
|
for i := len(cleanups) - 1; i >= 0; i-- {
|
||||||
|
cleanups[i]()
|
||||||
|
}
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func run() error {
|
func run() error {
|
||||||
cfg, err := config.Load("config.yaml")
|
cfg, err := config.Load("config.yaml")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -48,54 +116,13 @@ func run() error {
|
|||||||
extraWriters = append(extraWriters, logBridge)
|
extraWriters = append(extraWriters, logBridge)
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.Setup(logging.Config{
|
logging.Setup(cfg.Logging, extraWriters...)
|
||||||
Level: cfg.Logging.Level,
|
|
||||||
File: cfg.Logging.File,
|
|
||||||
MaxSizeMB: cfg.Logging.MaxSizeMB,
|
|
||||||
MaxBackups: cfg.Logging.MaxBackups,
|
|
||||||
MaxAgeDays: cfg.Logging.MaxAgeDays,
|
|
||||||
Compress: cfg.Logging.Compress,
|
|
||||||
}, extraWriters...)
|
|
||||||
|
|
||||||
// Load credentials from ~/.claude/.credentials.json
|
cred, err := initCredential()
|
||||||
creds, err := config.LoadDefaultCredentials()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("load credentials: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var cred *auth.Credential
|
|
||||||
if len(creds) > 0 {
|
|
||||||
cred = creds[0]
|
|
||||||
// If token is expired, try refresh first
|
|
||||||
if !cred.ExpiresAt.IsZero() && time.Now().After(cred.ExpiresAt) {
|
|
||||||
log.Info().Msg("token expired, attempting refresh")
|
|
||||||
refreshCtx, refreshCancel := context.WithTimeout(context.Background(), 15*time.Second)
|
|
||||||
refreshErr := auth.RefreshToken(refreshCtx, cred)
|
|
||||||
refreshCancel()
|
|
||||||
if refreshErr != nil {
|
|
||||||
log.Warn().Err(refreshErr).Msg("refresh failed, initiating login")
|
|
||||||
cred = nil // fall through to login
|
|
||||||
} else {
|
|
||||||
log.Info().Msg("token refreshed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if cred == nil {
|
|
||||||
// Non-TTY check: if stdin is not a terminal, can't do interactive login
|
|
||||||
fi, statErr := os.Stdin.Stat()
|
|
||||||
if statErr == nil && (fi.Mode()&os.ModeCharDevice) == 0 {
|
|
||||||
return fmt.Errorf("no valid credentials found; run the proxy interactively for initial login")
|
|
||||||
}
|
|
||||||
log.Info().Msg("no credentials found, starting OAuth login")
|
|
||||||
cred, err = auth.Login(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("login failed: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info().Str("credential", cred.Email).Msg("credential loaded")
|
|
||||||
|
|
||||||
credForTracker = cred
|
credForTracker = cred
|
||||||
|
|
||||||
pool := auth.NewPool([]*auth.Credential{cred})
|
pool := auth.NewPool([]*auth.Credential{cred})
|
||||||
@@ -116,24 +143,11 @@ func run() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start embedded observability stack (VM + Perses) if enabled
|
embeddedCleanup, err := initEmbedded(cfg)
|
||||||
var vm *embedded.VM
|
if err != nil {
|
||||||
var perses *embedded.Perses
|
return err
|
||||||
if cfg.Telemetry.Embedded.Enabled {
|
|
||||||
vm = embedded.NewVM(cfg.Telemetry.Embedded, cfg.Port)
|
|
||||||
if err := vm.Start(); err != nil {
|
|
||||||
log.Error().Err(err).Msg("failed to start victoria-metrics")
|
|
||||||
} else {
|
|
||||||
defer vm.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
perses = embedded.NewPerses(cfg.Telemetry.Embedded, cfg.Port)
|
|
||||||
if err := perses.Start(); err != nil {
|
|
||||||
log.Error().Err(err).Msg("failed to start perses")
|
|
||||||
} else {
|
|
||||||
defer perses.Stop()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
defer embeddedCleanup()
|
||||||
|
|
||||||
log.Info().Int("port", cfg.Port).Msg("starting server")
|
log.Info().Int("port", cfg.Port).Msg("starting server")
|
||||||
srv := server.New(cfg, pool, profile, tracker, metricsHandler)
|
srv := server.New(cfg, pool, profile, tracker, metricsHandler)
|
||||||
|
|||||||
Reference in New Issue
Block a user