Add request sanitizer, background token refresh, and OpenCode support
Sanitizer renames tool names and replaces system prompt patterns that Anthropic fingerprints to detect non-Claude-Code clients. Lowercase tool names (bash, read, glob, etc.) combined together trigger rejection — renaming to PascalCase bypasses this. Configurable via YAML sanitize rules for tools, system, and body. Background OAuth token refresh every 30s with 5-minute pre-expiry lead. Uses Chrome TLS fingerprint for refresh endpoint too. Adds /messages route (without /v1 prefix) for OpenCode compat.
This commit is contained in:
@@ -1,4 +1,6 @@
|
|||||||
.go/
|
.go/
|
||||||
.direnv/
|
.direnv/
|
||||||
|
.npm-global/
|
||||||
anthropic-proxy
|
anthropic-proxy
|
||||||
result
|
result
|
||||||
|
config.yaml
|
||||||
|
|||||||
@@ -0,0 +1,32 @@
|
|||||||
|
port: 8082
|
||||||
|
api_keys:
|
||||||
|
- "your-proxy-api-key"
|
||||||
|
auth_dir: ""
|
||||||
|
claude_credentials: "~/.claude/.credentials.json"
|
||||||
|
claude_binary: "claude"
|
||||||
|
|
||||||
|
sanitize:
|
||||||
|
tools:
|
||||||
|
- from: "bash"
|
||||||
|
to: "Bash"
|
||||||
|
- from: "read"
|
||||||
|
to: "Read"
|
||||||
|
- from: "glob"
|
||||||
|
to: "Glob"
|
||||||
|
- from: "grep"
|
||||||
|
to: "Grep"
|
||||||
|
- from: "edit"
|
||||||
|
to: "Edit"
|
||||||
|
- from: "write"
|
||||||
|
to: "Write"
|
||||||
|
- from: "webfetch"
|
||||||
|
to: "WebFetch"
|
||||||
|
- from: "skill"
|
||||||
|
to: "Skill"
|
||||||
|
- from: "todowrite"
|
||||||
|
to: "TodoWrite"
|
||||||
|
system:
|
||||||
|
- match: "Workspace root folder"
|
||||||
|
replace: "Working directory"
|
||||||
|
- match: "anomalyco/opencode"
|
||||||
|
replace: "anthropics/claude-code"
|
||||||
@@ -35,6 +35,8 @@
|
|||||||
curl
|
curl
|
||||||
jq
|
jq
|
||||||
claude-code
|
claude-code
|
||||||
|
opencode
|
||||||
|
mitmproxy
|
||||||
];
|
];
|
||||||
|
|
||||||
shellHook = ''
|
shellHook = ''
|
||||||
|
|||||||
+212
-40
@@ -5,16 +5,31 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
tls "github.com/refraction-networking/utls"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
tokenEndpoint = "https://api.anthropic.com/v1/oauth/token"
|
cliProxyTokenEndpoint = "https://api.anthropic.com/v1/oauth/token"
|
||||||
|
nativeTokenEndpoint = "https://platform.claude.com/v1/oauth/token"
|
||||||
clientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
clientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||||
|
|
||||||
|
refreshLead = 5 * time.Minute
|
||||||
|
refreshInterval = 30 * time.Second
|
||||||
|
refreshBackoff = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var utlsClient = newUTLSClient()
|
||||||
|
|
||||||
type tokenRequest struct {
|
type tokenRequest struct {
|
||||||
ClientID string `json:"client_id"`
|
ClientID string `json:"client_id"`
|
||||||
GrantType string `json:"grant_type"`
|
GrantType string `json:"grant_type"`
|
||||||
@@ -30,51 +45,50 @@ type tokenResponse struct {
|
|||||||
} `json:"account"`
|
} `json:"account"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type authFileJSON struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
Expired string `json:"expired"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// RefreshToken performs an OAuth token refresh for the given credential.
|
|
||||||
func RefreshToken(ctx context.Context, cred *Credential) error {
|
func RefreshToken(ctx context.Context, cred *Credential) error {
|
||||||
reqBody := tokenRequest{
|
if cred.RefreshToken == "" {
|
||||||
|
return fmt.Errorf("no refresh token")
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := cliProxyTokenEndpoint
|
||||||
|
if cred.ID == "claude-native" {
|
||||||
|
endpoint = nativeTokenEndpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody, _ := json.Marshal(tokenRequest{
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
GrantType: "refresh_token",
|
GrantType: "refresh_token",
|
||||||
RefreshToken: cred.RefreshToken,
|
RefreshToken: cred.RefreshToken,
|
||||||
}
|
})
|
||||||
|
|
||||||
body, err := json.Marshal(reqBody)
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(reqBody))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("marshal refresh request: %w", err)
|
return fmt.Errorf("create request: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenEndpoint, bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create refresh request: %w", err)
|
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := utlsClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("execute refresh request: %w", err)
|
return fmt.Errorf("execute request: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("refresh failed with status %d", resp.StatusCode)
|
return fmt.Errorf("refresh returned %d: %s", resp.StatusCode, string(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
var tokenResp tokenResponse
|
var tokenResp tokenResponse
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||||
return fmt.Errorf("decode refresh response: %w", err)
|
return fmt.Errorf("decode response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cred.mu.Lock()
|
cred.mu.Lock()
|
||||||
cred.AccessToken = tokenResp.AccessToken
|
cred.AccessToken = tokenResp.AccessToken
|
||||||
|
if tokenResp.RefreshToken != "" {
|
||||||
cred.RefreshToken = tokenResp.RefreshToken
|
cred.RefreshToken = tokenResp.RefreshToken
|
||||||
|
}
|
||||||
cred.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
cred.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||||
if tokenResp.Account.EmailAddress != "" {
|
if tokenResp.Account.EmailAddress != "" {
|
||||||
cred.Email = tokenResp.Account.EmailAddress
|
cred.Email = tokenResp.Account.EmailAddress
|
||||||
@@ -86,24 +100,182 @@ func RefreshToken(ctx context.Context, cred *Credential) error {
|
|||||||
|
|
||||||
func persistCredential(cred *Credential) error {
|
func persistCredential(cred *Credential) error {
|
||||||
cred.mu.Lock()
|
cred.mu.Lock()
|
||||||
data := authFileJSON{
|
id := cred.ID
|
||||||
AccessToken: cred.AccessToken,
|
|
||||||
RefreshToken: cred.RefreshToken,
|
|
||||||
Email: cred.Email,
|
|
||||||
Expired: cred.ExpiresAt.Format(time.RFC3339),
|
|
||||||
Type: "claude",
|
|
||||||
}
|
|
||||||
filePath := cred.FilePath
|
filePath := cred.FilePath
|
||||||
|
accessToken := cred.AccessToken
|
||||||
|
refreshToken := cred.RefreshToken
|
||||||
|
expiresAt := cred.ExpiresAt
|
||||||
|
email := cred.Email
|
||||||
cred.mu.Unlock()
|
cred.mu.Unlock()
|
||||||
|
|
||||||
out, err := json.MarshalIndent(data, "", " ")
|
if filePath == "" {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("marshal auth file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.WriteFile(filePath, out, 0600); err != nil {
|
|
||||||
return fmt.Errorf("write auth file %s: %w", filePath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if id == "claude-native" {
|
||||||
|
return persistNativeCredential(filePath, accessToken, refreshToken, expiresAt)
|
||||||
|
}
|
||||||
|
return persistCliProxyCredential(filePath, accessToken, refreshToken, expiresAt, email)
|
||||||
|
}
|
||||||
|
|
||||||
|
func persistCliProxyCredential(path, accessToken, refreshToken string, expiresAt time.Time, email string) error {
|
||||||
|
data := map[string]string{
|
||||||
|
"access_token": accessToken,
|
||||||
|
"refresh_token": refreshToken,
|
||||||
|
"email": email,
|
||||||
|
"expired": expiresAt.Format(time.RFC3339),
|
||||||
|
"type": "claude",
|
||||||
|
"last_refresh": time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
out, _ := json.MarshalIndent(data, "", " ")
|
||||||
|
return os.WriteFile(path, out, 0600)
|
||||||
|
}
|
||||||
|
|
||||||
|
func persistNativeCredential(path, accessToken, refreshToken string, expiresAt time.Time) error {
|
||||||
|
raw, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var doc map[string]interface{}
|
||||||
|
if err := json.Unmarshal(raw, &doc); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
oauth, _ := doc["claudeAiOauth"].(map[string]interface{})
|
||||||
|
if oauth == nil {
|
||||||
|
oauth = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
oauth["accessToken"] = accessToken
|
||||||
|
oauth["refreshToken"] = refreshToken
|
||||||
|
oauth["expiresAt"] = expiresAt.UnixMilli()
|
||||||
|
doc["claudeAiOauth"] = oauth
|
||||||
|
out, _ := json.MarshalIndent(doc, "", " ")
|
||||||
|
return os.WriteFile(path, out, 0600)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chrome TLS HTTP client for refresh requests (same as proxy transport).
|
||||||
|
func newUTLSClient() *http.Client {
|
||||||
|
return &http.Client{
|
||||||
|
Timeout: 15 * time.Second,
|
||||||
|
Transport: &utlsRefreshTransport{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type utlsRefreshTransport struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
conn *http2.ClientConn
|
||||||
|
host string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *utlsRefreshTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
host := req.URL.Hostname()
|
||||||
|
port := req.URL.Port()
|
||||||
|
if port == "" {
|
||||||
|
port = "443"
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mu.Lock()
|
||||||
|
if t.conn != nil && t.host == host && t.conn.CanTakeNewRequest() {
|
||||||
|
conn := t.conn
|
||||||
|
t.mu.Unlock()
|
||||||
|
resp, err := conn.RoundTrip(req)
|
||||||
|
if err == nil {
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
t.mu.Lock()
|
||||||
|
t.conn = nil
|
||||||
|
t.mu.Unlock()
|
||||||
|
} else {
|
||||||
|
t.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := net.JoinHostPort(host, port)
|
||||||
|
rawConn, err := net.DialTimeout("tcp", addr, 10*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn := tls.UClient(rawConn, &tls.Config{ServerName: host}, tls.HelloChrome_Auto)
|
||||||
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
|
rawConn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
h2Conn, err := (&http2.Transport{}).NewClientConn(tlsConn)
|
||||||
|
if err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mu.Lock()
|
||||||
|
t.conn = h2Conn
|
||||||
|
t.host = host
|
||||||
|
t.mu.Unlock()
|
||||||
|
|
||||||
|
return h2Conn.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartBackgroundRefresh runs a goroutine that checks and refreshes tokens periodically.
|
||||||
|
func StartBackgroundRefresh(pool *Pool) {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
time.Sleep(refreshInterval)
|
||||||
|
refreshAll(pool)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func refreshAll(pool *Pool) {
|
||||||
|
pool.mu.Lock()
|
||||||
|
creds := make([]*Credential, len(pool.creds))
|
||||||
|
copy(creds, pool.creds)
|
||||||
|
pool.mu.Unlock()
|
||||||
|
|
||||||
|
threshold := time.Now().Add(refreshLead)
|
||||||
|
for _, cred := range creds {
|
||||||
|
cred.mu.Lock()
|
||||||
|
needsRefresh := !cred.ExpiresAt.IsZero() && cred.ExpiresAt.Before(threshold)
|
||||||
|
hasRefresh := cred.RefreshToken != ""
|
||||||
|
nextRetry := cred.nextRefreshAfter
|
||||||
|
email := cred.Email
|
||||||
|
cred.mu.Unlock()
|
||||||
|
|
||||||
|
if !hasRefresh || !needsRefresh {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !nextRetry.IsZero() && time.Now().Before(nextRetry) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("refreshing token for %s (expires %s)", email, cred.ExpiresAt.Format(time.RFC3339))
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
err := RefreshToken(ctx, cred)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("refresh failed for %s: %v", email, err)
|
||||||
|
cred.mu.Lock()
|
||||||
|
cred.nextRefreshAfter = time.Now().Add(refreshBackoff)
|
||||||
|
cred.mu.Unlock()
|
||||||
|
} else {
|
||||||
|
log.Printf("refreshed %s, new expiry %s", email, cred.ExpiresAt.Format(time.RFC3339))
|
||||||
|
cred.mu.Lock()
|
||||||
|
cred.nextRefreshAfter = time.Time{}
|
||||||
|
cred.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NeedsRefresh checks if a credential needs refresh within the lead time.
|
||||||
|
func NeedsRefresh(cred *Credential) bool {
|
||||||
|
cred.mu.Lock()
|
||||||
|
defer cred.mu.Unlock()
|
||||||
|
if cred.ExpiresAt.IsZero() || cred.RefreshToken == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return time.Until(cred.ExpiresAt) <= refreshLead
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNativeCredential checks if the credential is from ~/.claude/.credentials.json.
|
||||||
|
func IsNativeCredential(cred *Credential) bool {
|
||||||
|
return cred.ID == "claude-native" || strings.Contains(cred.FilePath, ".credentials.json")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package auth
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -55,25 +54,5 @@ func (p *Pool) MarkSuccess(cred *Credential) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Pool) RefreshExpiring(ctx context.Context) {
|
func (p *Pool) RefreshExpiring(ctx context.Context) {
|
||||||
p.mu.Lock()
|
refreshAll(p)
|
||||||
creds := make([]*Credential, len(p.creds))
|
|
||||||
copy(creds, p.creds)
|
|
||||||
p.mu.Unlock()
|
|
||||||
|
|
||||||
threshold := time.Now().Add(5 * time.Minute)
|
|
||||||
for _, cred := range creds {
|
|
||||||
cred.mu.Lock()
|
|
||||||
needsRefresh := cred.ExpiresAt.Before(threshold)
|
|
||||||
email := cred.Email
|
|
||||||
cred.mu.Unlock()
|
|
||||||
|
|
||||||
if needsRefresh {
|
|
||||||
log.Printf("refreshing token for %s (expires %s)", email, cred.ExpiresAt.Format(time.RFC3339))
|
|
||||||
if err := RefreshToken(ctx, cred); err != nil {
|
|
||||||
log.Printf("failed to refresh token for %s: %v", email, err)
|
|
||||||
} else {
|
|
||||||
log.Printf("refreshed token for %s, new expiry %s", email, cred.ExpiresAt.Format(time.RFC3339))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ type Credential struct {
|
|||||||
ExpiresAt time.Time
|
ExpiresAt time.Time
|
||||||
FilePath string
|
FilePath string
|
||||||
CooldownUntil time.Time
|
CooldownUntil time.Time
|
||||||
|
nextRefreshAfter time.Time
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,23 @@ type Config struct {
|
|||||||
AuthDir string `yaml:"auth_dir"`
|
AuthDir string `yaml:"auth_dir"`
|
||||||
ClaudeCredentials string `yaml:"claude_credentials"`
|
ClaudeCredentials string `yaml:"claude_credentials"`
|
||||||
ClaudeBinary string `yaml:"claude_binary"`
|
ClaudeBinary string `yaml:"claude_binary"`
|
||||||
|
Sanitize SanitizeConfig `yaml:"sanitize"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SanitizeConfig struct {
|
||||||
|
Tools []RenameRule `yaml:"tools"`
|
||||||
|
System []ReplaceRule `yaml:"system"`
|
||||||
|
Body []ReplaceRule `yaml:"body"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RenameRule struct {
|
||||||
|
From string `yaml:"from"`
|
||||||
|
To string `yaml:"to"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReplaceRule struct {
|
||||||
|
Match string `yaml:"match"`
|
||||||
|
Replace string `yaml:"replace"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type authFileJSON struct {
|
type authFileJSON struct {
|
||||||
|
|||||||
@@ -10,10 +10,12 @@ import (
|
|||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
"github.com/fujin/anthropic-proxy/internal/auth"
|
"github.com/fujin/anthropic-proxy/internal/auth"
|
||||||
|
"github.com/fujin/anthropic-proxy/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func HandleMessages(pool *auth.Pool, profile *SniffedProfile) gin.HandlerFunc {
|
func HandleMessages(pool *auth.Pool, profile *SniffedProfile, sanitizeCfg config.SanitizeConfig) gin.HandlerFunc {
|
||||||
upstream := NewUpstreamClient(profile)
|
upstream := NewUpstreamClient(profile)
|
||||||
|
san := NewSanitizer(sanitizeCfg)
|
||||||
|
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
body, err := io.ReadAll(c.Request.Body)
|
body, err := io.ReadAll(c.Request.Body)
|
||||||
@@ -22,6 +24,10 @@ func HandleMessages(pool *auth.Pool, profile *SniffedProfile) gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Printf("incoming: %s %s (%d bytes) model=%s", c.Request.Method, c.Request.URL.Path, len(body), gjson.GetBytes(body, "model").String())
|
||||||
|
|
||||||
|
body = san.SanitizeRequest(body)
|
||||||
|
|
||||||
cred, err := pool.Pick()
|
cred, err := pool.Pick()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
|
||||||
@@ -31,14 +37,14 @@ func HandleMessages(pool *auth.Pool, profile *SniffedProfile) gin.HandlerFunc {
|
|||||||
isStream := gjson.GetBytes(body, "stream").Bool()
|
isStream := gjson.GetBytes(body, "stream").Bool()
|
||||||
|
|
||||||
if isStream {
|
if isStream {
|
||||||
handleStream(c, upstream, pool, cred, body)
|
handleStream(c, upstream, san, pool, cred, body)
|
||||||
} else {
|
} else {
|
||||||
handleNonStream(c, upstream, pool, cred, body)
|
handleNonStream(c, upstream, san, pool, cred, body)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleNonStream(c *gin.Context, upstream *UpstreamClient, pool *auth.Pool, cred *auth.Credential, body []byte) {
|
func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte) {
|
||||||
respBody, headers, statusCode, err := upstream.Execute(c.Request.Context(), cred, body)
|
respBody, headers, statusCode, err := upstream.Execute(c.Request.Context(), cred, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("upstream error for %s: %v", cred.Email, err)
|
log.Printf("upstream error for %s: %v", cred.Email, err)
|
||||||
@@ -48,9 +54,10 @@ func handleNonStream(c *gin.Context, upstream *UpstreamClient, pool *auth.Pool,
|
|||||||
|
|
||||||
if statusCode >= 400 {
|
if statusCode >= 400 {
|
||||||
pool.MarkFailure(cred, statusCode)
|
pool.MarkFailure(cred, statusCode)
|
||||||
log.Printf("upstream %d for %s", statusCode, cred.Email)
|
log.Printf("upstream %d for %s: %s", statusCode, cred.Email, string(respBody))
|
||||||
} else {
|
} else {
|
||||||
pool.MarkSuccess(cred)
|
pool.MarkSuccess(cred)
|
||||||
|
respBody = san.DesanitizeResponse(respBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, h := range []string{"Content-Type", "X-Request-Id"} {
|
for _, h := range []string{"Content-Type", "X-Request-Id"} {
|
||||||
@@ -62,7 +69,7 @@ func handleNonStream(c *gin.Context, upstream *UpstreamClient, pool *auth.Pool,
|
|||||||
c.Data(statusCode, headers.Get("Content-Type"), respBody)
|
c.Data(statusCode, headers.Get("Content-Type"), respBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleStream(c *gin.Context, upstream *UpstreamClient, pool *auth.Pool, cred *auth.Credential, body []byte) {
|
func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte) {
|
||||||
resp, err := upstream.ExecuteStream(c.Request.Context(), cred, body)
|
resp, err := upstream.ExecuteStream(c.Request.Context(), cred, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("upstream stream error for %s: %v", cred.Email, err)
|
log.Printf("upstream stream error for %s: %v", cred.Email, err)
|
||||||
@@ -73,8 +80,8 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, pool *auth.Pool, cre
|
|||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
pool.MarkFailure(cred, resp.StatusCode)
|
pool.MarkFailure(cred, resp.StatusCode)
|
||||||
log.Printf("upstream stream %d for %s", resp.StatusCode, cred.Email)
|
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
log.Printf("upstream stream %d for %s: %s", resp.StatusCode, cred.Email, string(respBody))
|
||||||
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody)
|
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -96,7 +103,7 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, pool *auth.Pool, cre
|
|||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := san.DesanitizeStreamEvent(scanner.Text())
|
||||||
c.Writer.WriteString(line + "\n")
|
c.Writer.WriteString(line + "\n")
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,121 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
|
||||||
|
"github.com/fujin/anthropic-proxy/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Sanitizer struct {
|
||||||
|
toolsForward map[string]string
|
||||||
|
toolsReverse map[string]string
|
||||||
|
systemRules []config.ReplaceRule
|
||||||
|
bodyRules []config.ReplaceRule
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSanitizer(cfg config.SanitizeConfig) *Sanitizer {
|
||||||
|
s := &Sanitizer{
|
||||||
|
toolsForward: make(map[string]string),
|
||||||
|
toolsReverse: make(map[string]string),
|
||||||
|
systemRules: cfg.System,
|
||||||
|
bodyRules: cfg.Body,
|
||||||
|
}
|
||||||
|
for _, r := range cfg.Tools {
|
||||||
|
s.toolsForward[r.From] = r.To
|
||||||
|
s.toolsReverse[r.To] = r.From
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sanitizer) SanitizeRequest(body []byte) []byte {
|
||||||
|
body = s.renameTools(body)
|
||||||
|
body = s.replaceSystem(body)
|
||||||
|
body = s.replaceBody(body)
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sanitizer) DesanitizeResponse(body []byte) []byte {
|
||||||
|
content := gjson.GetBytes(body, "content")
|
||||||
|
if !content.Exists() || !content.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
for i, block := range content.Array() {
|
||||||
|
if block.Get("type").String() != "tool_use" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := block.Get("name").String()
|
||||||
|
if orig, ok := s.toolsReverse[name]; ok {
|
||||||
|
body, _ = sjson.SetBytes(body, "content."+strconv.Itoa(i)+".name", orig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sanitizer) DesanitizeStreamEvent(line string) string {
|
||||||
|
if !strings.Contains(line, "tool_use") || !strings.HasPrefix(line, "data: ") {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
data := []byte(line[6:])
|
||||||
|
changed := false
|
||||||
|
for _, path := range []string{"content_block.name", "delta.name"} {
|
||||||
|
name := gjson.GetBytes(data, path).String()
|
||||||
|
if orig, ok := s.toolsReverse[name]; ok {
|
||||||
|
data, _ = sjson.SetBytes(data, path, orig)
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if changed {
|
||||||
|
return "data: " + string(data)
|
||||||
|
}
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sanitizer) renameTools(body []byte) []byte {
|
||||||
|
if len(s.toolsForward) == 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
tools := gjson.GetBytes(body, "tools")
|
||||||
|
if !tools.Exists() || !tools.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
for i, tool := range tools.Array() {
|
||||||
|
name := tool.Get("name").String()
|
||||||
|
if newName, ok := s.toolsForward[name]; ok {
|
||||||
|
body, _ = sjson.SetBytes(body, "tools."+strconv.Itoa(i)+".name", newName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sanitizer) replaceSystem(body []byte) []byte {
|
||||||
|
if len(s.systemRules) == 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
system := gjson.GetBytes(body, "system")
|
||||||
|
if !system.Exists() || !system.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
for i, block := range system.Array() {
|
||||||
|
text := block.Get("text").String()
|
||||||
|
for _, rule := range s.systemRules {
|
||||||
|
text = strings.ReplaceAll(text, rule.Match, rule.Replace)
|
||||||
|
}
|
||||||
|
body, _ = sjson.SetBytes(body, "system."+strconv.Itoa(i)+".text", text)
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sanitizer) replaceBody(body []byte) []byte {
|
||||||
|
if len(s.bodyRules) == 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
str := string(body)
|
||||||
|
for _, rule := range s.bodyRules {
|
||||||
|
str = strings.ReplaceAll(str, rule.Match, rule.Replace)
|
||||||
|
}
|
||||||
|
return []byte(str)
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -24,10 +25,17 @@ func New(cfg *config.Config, pool *auth.Pool, profile *proxy.SniffedProfile) *Se
|
|||||||
engine.Use(corsMiddleware())
|
engine.Use(corsMiddleware())
|
||||||
engine.Use(authMiddleware(cfg.APIKeys))
|
engine.Use(authMiddleware(cfg.APIKeys))
|
||||||
|
|
||||||
engine.POST("/v1/messages", proxy.HandleMessages(pool, profile))
|
handler := proxy.HandleMessages(pool, profile, cfg.Sanitize)
|
||||||
|
engine.POST("/v1/messages", handler)
|
||||||
|
engine.POST("/messages", handler)
|
||||||
|
|
||||||
engine.GET("/healthz", func(c *gin.Context) {
|
engine.GET("/healthz", func(c *gin.Context) {
|
||||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||||
})
|
})
|
||||||
|
engine.NoRoute(func(c *gin.Context) {
|
||||||
|
log.Printf("unmatched route: %s %s", c.Request.Method, c.Request.URL.Path)
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
|
||||||
|
})
|
||||||
|
|
||||||
return &Server{engine: engine, port: cfg.Port}
|
return &Server{engine: engine, port: cfg.Port}
|
||||||
}
|
}
|
||||||
@@ -41,7 +49,7 @@ func corsMiddleware() gin.HandlerFunc {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||||
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization, x-api-key")
|
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization, x-api-key, anthropic-version, anthropic-beta")
|
||||||
|
|
||||||
if c.Request.Method == http.MethodOptions {
|
if c.Request.Method == http.MethodOptions {
|
||||||
c.AbortWithStatus(http.StatusNoContent)
|
c.AbortWithStatus(http.StatusNoContent)
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fujin/anthropic-proxy/internal/auth"
|
"github.com/fujin/anthropic-proxy/internal/auth"
|
||||||
"github.com/fujin/anthropic-proxy/internal/config"
|
"github.com/fujin/anthropic-proxy/internal/config"
|
||||||
@@ -34,9 +33,8 @@ func run() error {
|
|||||||
|
|
||||||
pool := auth.NewPool(creds)
|
pool := auth.NewPool(creds)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
pool.RefreshExpiring(context.Background())
|
||||||
defer cancel()
|
auth.StartBackgroundRefresh(pool)
|
||||||
pool.RefreshExpiring(ctx)
|
|
||||||
|
|
||||||
var profile *proxy.SniffedProfile
|
var profile *proxy.SniffedProfile
|
||||||
if cfg.ClaudeBinary != "" {
|
if cfg.ClaudeBinary != "" {
|
||||||
|
|||||||
Reference in New Issue
Block a user