refactor: modularize codebase — deduplicate, extract, clean up
- Unify duplicate uTLS transports into shared internal/transport package - Extract shared version constant into internal/version - Move LoadDefaultCredentials from config to auth (remove config→auth import) - Deduplicate handler.go: extract telemetry/error helpers (324→268 lines) - Break up main.go::run() into initCredential/initEmbedded - Eliminate logging.Config duplication (use config.LoggingConfig directly) - Extract logWriter to embedded/log.go, SSE fixtures to consts in sniff.go - Use uTLS client for usage polling (consistent TLS fingerprint) - Handle sjson.SetBytes errors in sanitize.go instead of silently swallowing - Document reverse-engineered magic values in billing.go - Unexport Credential.CooldownUntil (internal state) - Replace hardcoded auth bypass paths with map in server.go
This commit is contained in:
@@ -11,9 +11,13 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// fingerprintSalt is the fixed salt used by Claude Code for billing header
|
||||
// fingerprint computation. Extracted from the Claude Code CLI source.
|
||||
const fingerprintSalt = "59cf53e54c78"
|
||||
|
||||
func computeFingerprint(firstUserMessage string, version string) string {
|
||||
// UTF-16 character indices sampled from the first user message, matching
|
||||
// the Claude Code CLI's fingerprinting algorithm.
|
||||
indices := []int{4, 7, 20}
|
||||
runes := utf16.Encode([]rune(firstUserMessage))
|
||||
var chars string
|
||||
|
||||
+92
-134
@@ -2,6 +2,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
@@ -18,6 +19,15 @@ import (
|
||||
"github.com/fujin/anthropic-proxy/internal/telemetry"
|
||||
)
|
||||
|
||||
// requestInfo bundles common request context passed to logging/telemetry helpers.
|
||||
type requestInfo struct {
|
||||
model string
|
||||
stream bool
|
||||
cred *auth.Credential
|
||||
body []byte
|
||||
originalBody []byte
|
||||
}
|
||||
|
||||
func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func() *Sanitizer, tracker *ratelimit.Tracker) gin.HandlerFunc {
|
||||
upstream := NewUpstreamClient(profile)
|
||||
|
||||
@@ -61,6 +71,7 @@ func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, p
|
||||
startTime := time.Now()
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
ctx := c.Request.Context()
|
||||
ri := requestInfo{model: model, stream: false, cred: cred, body: body, originalBody: originalBody}
|
||||
|
||||
telemetry.RequestBodySize.Record(ctx, int64(len(body)),
|
||||
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", false)))
|
||||
@@ -69,85 +80,25 @@ func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, p
|
||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("credential", cred.Email).
|
||||
Str("model", model).
|
||||
Bool("stream", false).
|
||||
Str("request_body_original", string(originalBody)).
|
||||
Str("request_body_sanitized", string(body)).
|
||||
Int("request_body_size", len(body)).
|
||||
Float64("latency_ms", latencyMs).
|
||||
Msg("upstream connection error")
|
||||
|
||||
telemetry.UpstreamErrors.Add(ctx, 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("error_type", "connection"),
|
||||
attribute.String("credential", cred.Email),
|
||||
attribute.Int("status_code", http.StatusBadGateway),
|
||||
))
|
||||
telemetry.RequestCounter.Add(ctx, 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("model", model),
|
||||
attribute.Bool("stream", false),
|
||||
attribute.Int("status_code", http.StatusBadGateway),
|
||||
))
|
||||
telemetry.RequestDuration.Record(ctx, latencyMs,
|
||||
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", false), attribute.Int("status_code", http.StatusBadGateway)))
|
||||
|
||||
recordConnectionError(ctx, err, ri, latencyMs)
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "upstream request failed"})
|
||||
return
|
||||
}
|
||||
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("model", model),
|
||||
attribute.Bool("stream", false),
|
||||
attribute.Int("status_code", statusCode),
|
||||
}
|
||||
|
||||
telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...))
|
||||
telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...))
|
||||
recordRequestMetrics(ctx, ri, statusCode, latencyMs)
|
||||
|
||||
if statusCode >= 400 {
|
||||
pool.MarkFailure(cred, statusCode)
|
||||
telemetry.CredentialCooldowns.Add(ctx, 1,
|
||||
metric.WithAttributes(attribute.Int("status_code", statusCode)))
|
||||
errorType := gjson.GetBytes(respBody, "error.type").String()
|
||||
errorMessage := gjson.GetBytes(respBody, "error.message").String()
|
||||
log.Error().
|
||||
Int("status", statusCode).
|
||||
Str("error_type", errorType).
|
||||
Str("error_message", errorMessage).
|
||||
Str("response_body", string(respBody)).
|
||||
Str("request_id", headers.Get("X-Request-Id")).
|
||||
Float64("latency_ms", latencyMs).
|
||||
Str("credential", cred.Email).
|
||||
Str("model", model).
|
||||
Bool("stream", false).
|
||||
Str("request_body_original", string(originalBody)).
|
||||
Str("request_body_sanitized", string(body)).
|
||||
Int("request_body_size", len(body)).
|
||||
Str("request_headers", logging.RedactHeaders(c.Request.Header)).
|
||||
Msg("upstream error")
|
||||
|
||||
telemetry.UpstreamErrors.Add(ctx, 1,
|
||||
metric.WithAttributes(
|
||||
attribute.Int("status_code", statusCode),
|
||||
attribute.String("error_type", errorType),
|
||||
attribute.String("credential", cred.Email),
|
||||
))
|
||||
recordUpstreamError(ctx, statusCode, respBody, headers.Get("X-Request-Id"), latencyMs, ri, c.Request.Header)
|
||||
} else {
|
||||
pool.MarkSuccess(cred)
|
||||
respBody = san.DesanitizeResponse(respBody)
|
||||
|
||||
inputTokens := gjson.GetBytes(respBody, "usage.input_tokens").Int()
|
||||
outputTokens := gjson.GetBytes(respBody, "usage.output_tokens").Int()
|
||||
tokenAttrs := metric.WithAttributes(
|
||||
attribute.String("model", model),
|
||||
attribute.String("credential", cred.Email),
|
||||
)
|
||||
telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs)
|
||||
telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs)
|
||||
recordTokenUsage(ctx, model, cred, inputTokens, outputTokens)
|
||||
if tracker != nil {
|
||||
tracker.UpdateFromHeaders(headers)
|
||||
}
|
||||
@@ -174,6 +125,7 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool
|
||||
startTime := time.Now()
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
ctx := c.Request.Context()
|
||||
ri := requestInfo{model: model, stream: true, cred: cred, body: body, originalBody: originalBody}
|
||||
|
||||
telemetry.StreamRequests.Add(ctx, 1, metric.WithAttributes(attribute.String("model", model)))
|
||||
telemetry.RequestBodySize.Record(ctx, int64(len(body)),
|
||||
@@ -182,32 +134,7 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool
|
||||
resp, err := upstream.ExecuteStream(ctx, cred, body)
|
||||
if err != nil {
|
||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("credential", cred.Email).
|
||||
Str("model", model).
|
||||
Bool("stream", true).
|
||||
Str("request_body_original", string(originalBody)).
|
||||
Str("request_body_sanitized", string(body)).
|
||||
Int("request_body_size", len(body)).
|
||||
Float64("latency_ms", latencyMs).
|
||||
Msg("upstream connection error")
|
||||
|
||||
telemetry.UpstreamErrors.Add(ctx, 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("error_type", "connection"),
|
||||
attribute.String("credential", cred.Email),
|
||||
attribute.Int("status_code", http.StatusBadGateway),
|
||||
))
|
||||
telemetry.RequestCounter.Add(ctx, 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("model", model),
|
||||
attribute.Bool("stream", true),
|
||||
attribute.Int("status_code", http.StatusBadGateway),
|
||||
))
|
||||
telemetry.RequestDuration.Record(ctx, latencyMs,
|
||||
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", true), attribute.Int("status_code", http.StatusBadGateway)))
|
||||
|
||||
recordConnectionError(ctx, err, ri, latencyMs)
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "upstream stream request failed"})
|
||||
return
|
||||
}
|
||||
@@ -219,37 +146,8 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool
|
||||
metric.WithAttributes(attribute.Int("status_code", resp.StatusCode)))
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||
errorType := gjson.GetBytes(respBody, "error.type").String()
|
||||
errorMessage := gjson.GetBytes(respBody, "error.message").String()
|
||||
log.Error().
|
||||
Int("status", resp.StatusCode).
|
||||
Str("error_type", errorType).
|
||||
Str("error_message", errorMessage).
|
||||
Str("response_body", string(respBody)).
|
||||
Str("request_id", resp.Header.Get("X-Request-Id")).
|
||||
Float64("latency_ms", latencyMs).
|
||||
Str("credential", cred.Email).
|
||||
Str("model", model).
|
||||
Bool("stream", true).
|
||||
Str("request_body_original", string(originalBody)).
|
||||
Str("request_body_sanitized", string(body)).
|
||||
Int("request_body_size", len(body)).
|
||||
Str("request_headers", logging.RedactHeaders(c.Request.Header)).
|
||||
Msg("upstream error")
|
||||
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("model", model),
|
||||
attribute.Bool("stream", true),
|
||||
attribute.Int("status_code", resp.StatusCode),
|
||||
}
|
||||
telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...))
|
||||
telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...))
|
||||
telemetry.UpstreamErrors.Add(ctx, 1,
|
||||
metric.WithAttributes(
|
||||
attribute.Int("status_code", resp.StatusCode),
|
||||
attribute.String("error_type", errorType),
|
||||
attribute.String("credential", cred.Email),
|
||||
))
|
||||
recordRequestMetrics(ctx, ri, resp.StatusCode, latencyMs)
|
||||
recordUpstreamError(ctx, resp.StatusCode, respBody, resp.Header.Get("X-Request-Id"), latencyMs, ri, c.Request.Header)
|
||||
|
||||
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody)
|
||||
return
|
||||
@@ -290,21 +188,10 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool
|
||||
}
|
||||
|
||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("model", model),
|
||||
attribute.Bool("stream", true),
|
||||
attribute.Int("status_code", http.StatusOK),
|
||||
}
|
||||
telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...))
|
||||
telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...))
|
||||
recordRequestMetrics(ctx, ri, http.StatusOK, latencyMs)
|
||||
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
tokenAttrs := metric.WithAttributes(
|
||||
attribute.String("model", model),
|
||||
attribute.String("credential", cred.Email),
|
||||
)
|
||||
telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs)
|
||||
telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs)
|
||||
recordTokenUsage(ctx, model, cred, inputTokens, outputTokens)
|
||||
if tracker != nil {
|
||||
tracker.UpdateFromHeaders(resp.Header)
|
||||
}
|
||||
@@ -322,3 +209,74 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool
|
||||
log.Error().Err(err).Msg("stream scan error")
|
||||
}
|
||||
}
|
||||
|
||||
// recordConnectionError logs and records metrics for upstream connection failures.
|
||||
func recordConnectionError(ctx context.Context, err error, ri requestInfo, latencyMs float64) {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("credential", ri.cred.Email).
|
||||
Str("model", ri.model).
|
||||
Bool("stream", ri.stream).
|
||||
Str("request_body_original", string(ri.originalBody)).
|
||||
Str("request_body_sanitized", string(ri.body)).
|
||||
Int("request_body_size", len(ri.body)).
|
||||
Float64("latency_ms", latencyMs).
|
||||
Msg("upstream connection error")
|
||||
|
||||
telemetry.UpstreamErrors.Add(ctx, 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("error_type", "connection"),
|
||||
attribute.String("credential", ri.cred.Email),
|
||||
attribute.Int("status_code", http.StatusBadGateway),
|
||||
))
|
||||
recordRequestMetrics(ctx, ri, http.StatusBadGateway, latencyMs)
|
||||
}
|
||||
|
||||
// recordUpstreamError logs and records metrics for upstream HTTP error responses.
|
||||
func recordUpstreamError(ctx context.Context, statusCode int, respBody []byte, requestID string, latencyMs float64, ri requestInfo, requestHeaders http.Header) {
|
||||
errorType := gjson.GetBytes(respBody, "error.type").String()
|
||||
errorMessage := gjson.GetBytes(respBody, "error.message").String()
|
||||
log.Error().
|
||||
Int("status", statusCode).
|
||||
Str("error_type", errorType).
|
||||
Str("error_message", errorMessage).
|
||||
Str("response_body", string(respBody)).
|
||||
Str("request_id", requestID).
|
||||
Float64("latency_ms", latencyMs).
|
||||
Str("credential", ri.cred.Email).
|
||||
Str("model", ri.model).
|
||||
Bool("stream", ri.stream).
|
||||
Str("request_body_original", string(ri.originalBody)).
|
||||
Str("request_body_sanitized", string(ri.body)).
|
||||
Int("request_body_size", len(ri.body)).
|
||||
Str("request_headers", logging.RedactHeaders(requestHeaders)).
|
||||
Msg("upstream error")
|
||||
|
||||
telemetry.UpstreamErrors.Add(ctx, 1,
|
||||
metric.WithAttributes(
|
||||
attribute.Int("status_code", statusCode),
|
||||
attribute.String("error_type", errorType),
|
||||
attribute.String("credential", ri.cred.Email),
|
||||
))
|
||||
}
|
||||
|
||||
// recordRequestMetrics records the request counter and duration histogram.
|
||||
func recordRequestMetrics(ctx context.Context, ri requestInfo, statusCode int, latencyMs float64) {
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("model", ri.model),
|
||||
attribute.Bool("stream", ri.stream),
|
||||
attribute.Int("status_code", statusCode),
|
||||
}
|
||||
telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...))
|
||||
telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...))
|
||||
}
|
||||
|
||||
// recordTokenUsage records token consumption metrics.
|
||||
func recordTokenUsage(ctx context.Context, model string, cred *auth.Credential, inputTokens, outputTokens int64) {
|
||||
tokenAttrs := metric.WithAttributes(
|
||||
attribute.String("model", model),
|
||||
attribute.String("credential", cred.Email),
|
||||
)
|
||||
telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs)
|
||||
telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
@@ -11,10 +12,10 @@ import (
|
||||
)
|
||||
|
||||
type Sanitizer struct {
|
||||
toolsForward map[string]string
|
||||
toolsReverse map[string]string
|
||||
systemRules []config.ReplaceRule
|
||||
bodyRules []config.ReplaceRule
|
||||
toolsForward map[string]string
|
||||
toolsReverse map[string]string
|
||||
systemRules []config.ReplaceRule
|
||||
bodyRules []config.ReplaceRule
|
||||
}
|
||||
|
||||
func NewSanitizer(cfg config.SanitizeConfig) *Sanitizer {
|
||||
@@ -49,7 +50,11 @@ func (s *Sanitizer) DesanitizeResponse(body []byte) []byte {
|
||||
}
|
||||
name := block.Get("name").String()
|
||||
if orig, ok := s.toolsReverse[name]; ok {
|
||||
body, _ = sjson.SetBytes(body, "content."+strconv.Itoa(i)+".name", orig)
|
||||
if b, err := sjson.SetBytes(body, "content."+strconv.Itoa(i)+".name", orig); err != nil {
|
||||
log.Warn().Err(err).Str("tool", name).Msg("desanitize response: set name failed")
|
||||
} else {
|
||||
body = b
|
||||
}
|
||||
}
|
||||
}
|
||||
return body
|
||||
@@ -64,8 +69,12 @@ func (s *Sanitizer) DesanitizeStreamEvent(line string) string {
|
||||
for _, path := range []string{"content_block.name", "delta.name"} {
|
||||
name := gjson.GetBytes(data, path).String()
|
||||
if orig, ok := s.toolsReverse[name]; ok {
|
||||
data, _ = sjson.SetBytes(data, path, orig)
|
||||
changed = true
|
||||
if b, err := sjson.SetBytes(data, path, orig); err != nil {
|
||||
log.Warn().Err(err).Str("tool", name).Msg("desanitize stream event: set name failed")
|
||||
} else {
|
||||
data = b
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if changed {
|
||||
@@ -85,7 +94,11 @@ func (s *Sanitizer) renameTools(body []byte) []byte {
|
||||
for i, tool := range tools.Array() {
|
||||
name := tool.Get("name").String()
|
||||
if newName, ok := s.toolsForward[name]; ok {
|
||||
body, _ = sjson.SetBytes(body, "tools."+strconv.Itoa(i)+".name", newName)
|
||||
if b, err := sjson.SetBytes(body, "tools."+strconv.Itoa(i)+".name", newName); err != nil {
|
||||
log.Warn().Err(err).Str("tool", name).Msg("rename tool failed")
|
||||
} else {
|
||||
body = b
|
||||
}
|
||||
}
|
||||
}
|
||||
return body
|
||||
@@ -104,7 +117,11 @@ func (s *Sanitizer) replaceSystem(body []byte) []byte {
|
||||
for _, rule := range s.systemRules {
|
||||
text = strings.ReplaceAll(text, rule.Match, rule.Replace)
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "system."+strconv.Itoa(i)+".text", text)
|
||||
if b, err := sjson.SetBytes(body, "system."+strconv.Itoa(i)+".text", text); err != nil {
|
||||
log.Warn().Err(err).Int("block", i).Msg("replace system text failed")
|
||||
} else {
|
||||
body = b
|
||||
}
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
+53
-41
@@ -36,6 +36,21 @@ var skipHeaders = map[string]bool{
|
||||
"connection": true,
|
||||
}
|
||||
|
||||
const fakeJSONResponse = `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`
|
||||
|
||||
const fakeStreamResponse = "event: message_start\n" +
|
||||
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_fake\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}\n\n" +
|
||||
"event: content_block_start\n" +
|
||||
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n" +
|
||||
"event: content_block_delta\n" +
|
||||
"data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"ok\"}}\n\n" +
|
||||
"event: content_block_stop\n" +
|
||||
"data: {\"type\":\"content_block_stop\",\"index\":0}\n\n" +
|
||||
"event: message_delta\n" +
|
||||
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":1}}\n\n" +
|
||||
"event: message_stop\n" +
|
||||
"data: {\"type\":\"message_stop\"}\n\n"
|
||||
|
||||
func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
@@ -48,45 +63,7 @@ func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) {
|
||||
captured := make(chan struct{}, 1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "HEAD" {
|
||||
w.WriteHeader(200)
|
||||
return
|
||||
}
|
||||
if r.Method != "POST" || !strings.Contains(r.URL.Path, "/v1/messages") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`)
|
||||
return
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
|
||||
mu.Lock()
|
||||
if profile == nil {
|
||||
profile = extractProfile(r, body)
|
||||
select {
|
||||
case captured <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
if strings.Contains(string(body), `"stream":true`) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_fake\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}\n\n")
|
||||
fmt.Fprint(w, "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n")
|
||||
fmt.Fprint(w, "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"ok\"}}\n\n")
|
||||
fmt.Fprint(w, "event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n")
|
||||
fmt.Fprint(w, "event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":1}}\n\n")
|
||||
fmt.Fprint(w, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`)
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("/", sniffHandler(&mu, &profile, captured))
|
||||
|
||||
srv := &http.Server{Handler: mux}
|
||||
go srv.Serve(listener)
|
||||
@@ -130,8 +107,44 @@ func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) {
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
func sniffHandler(mu *sync.Mutex, profile **SniffedProfile, captured chan<- struct{}) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "HEAD" {
|
||||
w.WriteHeader(200)
|
||||
return
|
||||
}
|
||||
if r.Method != "POST" || !strings.Contains(r.URL.Path, "/v1/messages") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, fakeJSONResponse)
|
||||
return
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
|
||||
mu.Lock()
|
||||
if *profile == nil {
|
||||
*profile = extractProfile(r, body)
|
||||
select {
|
||||
case captured <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
if strings.Contains(string(body), `"stream":true`) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, fakeStreamResponse)
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, fakeJSONResponse)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func extractProfile(r *http.Request, body []byte) *SniffedProfile {
|
||||
// Capture raw headers preserving original casing.
|
||||
var headers [][2]string
|
||||
for name, vals := range r.Header {
|
||||
if skipHeaders[strings.ToLower(name)] {
|
||||
@@ -142,7 +155,6 @@ func extractProfile(r *http.Request, body []byte) *SniffedProfile {
|
||||
}
|
||||
}
|
||||
|
||||
// Deduplicate and strip subscription-specific betas.
|
||||
seen := map[string]bool{}
|
||||
var deduped [][2]string
|
||||
for _, h := range headers {
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
tls "github.com/refraction-networking/utls"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type utlsRoundTripper struct {
|
||||
mu sync.Mutex
|
||||
connections map[string]*http2.ClientConn
|
||||
pending map[string]*sync.Cond
|
||||
}
|
||||
|
||||
func newUtlsRoundTripper() *utlsRoundTripper {
|
||||
return &utlsRoundTripper{
|
||||
connections: make(map[string]*http2.ClientConn),
|
||||
pending: make(map[string]*sync.Cond),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
t.mu.Lock()
|
||||
|
||||
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
|
||||
t.mu.Unlock()
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
if cond, ok := t.pending[host]; ok {
|
||||
cond.Wait()
|
||||
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
|
||||
t.mu.Unlock()
|
||||
return h2Conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
cond := sync.NewCond(&t.mu)
|
||||
t.pending[host] = cond
|
||||
t.mu.Unlock()
|
||||
|
||||
h2Conn, err := t.createConnection(host, addr)
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
delete(t.pending, host)
|
||||
cond.Broadcast()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.connections[host] = h2Conn
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{ServerName: host}
|
||||
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloChrome_Auto)
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tr := &http2.Transport{}
|
||||
h2Conn, err := tr.NewClientConn(tlsConn)
|
||||
if err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
hostname := req.URL.Hostname()
|
||||
port := req.URL.Port()
|
||||
if port == "" {
|
||||
port = "443"
|
||||
}
|
||||
addr := net.JoinHostPort(hostname, port)
|
||||
log.Debug().Str("addr", addr).Msg("uTLS round trip")
|
||||
|
||||
h2Conn, err := t.getOrCreateConnection(hostname, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := h2Conn.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.mu.Lock()
|
||||
if cached, ok := t.connections[hostname]; ok && cached == h2Conn {
|
||||
delete(t.connections, hostname)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/auth"
|
||||
"github.com/fujin/anthropic-proxy/internal/logging"
|
||||
"github.com/fujin/anthropic-proxy/internal/transport"
|
||||
"github.com/fujin/anthropic-proxy/internal/version"
|
||||
)
|
||||
|
||||
const messagesURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
@@ -27,7 +29,7 @@ func NewUpstreamClient(profile *SniffedProfile) *UpstreamClient {
|
||||
return &UpstreamClient{
|
||||
client: http.Client{
|
||||
Timeout: 0,
|
||||
Transport: newUtlsRoundTripper(),
|
||||
Transport: transport.NewUTLS(),
|
||||
},
|
||||
sessionID: uuid.New().String(),
|
||||
profile: profile,
|
||||
@@ -38,7 +40,7 @@ func (u *UpstreamClient) version() string {
|
||||
if u.profile != nil && u.profile.Version != "" {
|
||||
return u.profile.Version
|
||||
}
|
||||
return "2.1.92"
|
||||
return version.ClaudeCodeFallback
|
||||
}
|
||||
|
||||
// applyHeaders replays sniffed headers, substituting auth + per-request IDs + accept.
|
||||
|
||||
Reference in New Issue
Block a user