0df28e9dd8
- Unify duplicate uTLS transports into shared internal/transport package - Extract shared version constant into internal/version - Move LoadDefaultCredentials from config to auth (remove config→auth import) - Deduplicate handler.go: extract telemetry/error helpers (324→268 lines) - Break up main.go::run() into initCredential/initEmbedded - Eliminate logging.Config duplication (use config.LoggingConfig directly) - Extract logWriter to embedded/log.go, SSE fixtures to consts in sniff.go - Use uTLS client for usage polling (consistent TLS fingerprint) - Handle sjson.SetBytes errors in sanitize.go instead of silently swallowing - Document reverse-engineered magic values in billing.go - Unexport Credential.CooldownUntil (internal state) - Replace hardcoded auth bypass paths with map in server.go
283 lines
9.3 KiB
Go
283 lines
9.3 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/tidwall/gjson"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/metric"
|
|
|
|
"github.com/fujin/anthropic-proxy/internal/auth"
|
|
"github.com/fujin/anthropic-proxy/internal/logging"
|
|
"github.com/fujin/anthropic-proxy/internal/ratelimit"
|
|
"github.com/fujin/anthropic-proxy/internal/telemetry"
|
|
)
|
|
|
|
// requestInfo bundles common request context passed to logging/telemetry helpers.
|
|
type requestInfo struct {
|
|
model string
|
|
stream bool
|
|
cred *auth.Credential
|
|
body []byte
|
|
originalBody []byte
|
|
}
|
|
|
|
func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func() *Sanitizer, tracker *ratelimit.Tracker) gin.HandlerFunc {
|
|
upstream := NewUpstreamClient(profile)
|
|
|
|
return func(c *gin.Context) {
|
|
body, err := io.ReadAll(c.Request.Body)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
|
|
return
|
|
}
|
|
|
|
originalBody := make([]byte, len(body))
|
|
copy(originalBody, body)
|
|
|
|
log.Info().
|
|
Str("method", c.Request.Method).
|
|
Str("path", c.Request.URL.Path).
|
|
Int("body_size", len(body)).
|
|
Str("model", gjson.GetBytes(body, "model").String()).
|
|
Msg("incoming request")
|
|
|
|
san := getSanitizer()
|
|
body = san.SanitizeRequest(body)
|
|
|
|
cred, err := pool.Pick()
|
|
if err != nil {
|
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
isStream := gjson.GetBytes(body, "stream").Bool()
|
|
|
|
if isStream {
|
|
handleStream(c, upstream, san, pool, cred, body, originalBody, tracker)
|
|
} else {
|
|
handleNonStream(c, upstream, san, pool, cred, body, originalBody, tracker)
|
|
}
|
|
}
|
|
}
|
|
|
|
func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte, originalBody []byte, tracker *ratelimit.Tracker) {
|
|
startTime := time.Now()
|
|
model := gjson.GetBytes(body, "model").String()
|
|
ctx := c.Request.Context()
|
|
ri := requestInfo{model: model, stream: false, cred: cred, body: body, originalBody: originalBody}
|
|
|
|
telemetry.RequestBodySize.Record(ctx, int64(len(body)),
|
|
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", false)))
|
|
|
|
respBody, headers, statusCode, err := upstream.Execute(ctx, cred, body)
|
|
latencyMs := float64(time.Since(startTime).Milliseconds())
|
|
|
|
if err != nil {
|
|
recordConnectionError(ctx, err, ri, latencyMs)
|
|
c.JSON(http.StatusBadGateway, gin.H{"error": "upstream request failed"})
|
|
return
|
|
}
|
|
|
|
recordRequestMetrics(ctx, ri, statusCode, latencyMs)
|
|
|
|
if statusCode >= 400 {
|
|
pool.MarkFailure(cred, statusCode)
|
|
telemetry.CredentialCooldowns.Add(ctx, 1,
|
|
metric.WithAttributes(attribute.Int("status_code", statusCode)))
|
|
recordUpstreamError(ctx, statusCode, respBody, headers.Get("X-Request-Id"), latencyMs, ri, c.Request.Header)
|
|
} else {
|
|
pool.MarkSuccess(cred)
|
|
respBody = san.DesanitizeResponse(respBody)
|
|
|
|
inputTokens := gjson.GetBytes(respBody, "usage.input_tokens").Int()
|
|
outputTokens := gjson.GetBytes(respBody, "usage.output_tokens").Int()
|
|
recordTokenUsage(ctx, model, cred, inputTokens, outputTokens)
|
|
if tracker != nil {
|
|
tracker.UpdateFromHeaders(headers)
|
|
}
|
|
|
|
log.Info().
|
|
Int("status", statusCode).
|
|
Float64("latency_ms", latencyMs).
|
|
Str("model", model).
|
|
Int64("input_tokens", inputTokens).
|
|
Int64("output_tokens", outputTokens).
|
|
Msg("request completed")
|
|
}
|
|
|
|
for _, h := range []string{"Content-Type", "X-Request-Id"} {
|
|
if v := headers.Get(h); v != "" {
|
|
c.Header(h, v)
|
|
}
|
|
}
|
|
|
|
c.Data(statusCode, headers.Get("Content-Type"), respBody)
|
|
}
|
|
|
|
func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte, originalBody []byte, tracker *ratelimit.Tracker) {
|
|
startTime := time.Now()
|
|
model := gjson.GetBytes(body, "model").String()
|
|
ctx := c.Request.Context()
|
|
ri := requestInfo{model: model, stream: true, cred: cred, body: body, originalBody: originalBody}
|
|
|
|
telemetry.StreamRequests.Add(ctx, 1, metric.WithAttributes(attribute.String("model", model)))
|
|
telemetry.RequestBodySize.Record(ctx, int64(len(body)),
|
|
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", true)))
|
|
|
|
resp, err := upstream.ExecuteStream(ctx, cred, body)
|
|
if err != nil {
|
|
latencyMs := float64(time.Since(startTime).Milliseconds())
|
|
recordConnectionError(ctx, err, ri, latencyMs)
|
|
c.JSON(http.StatusBadGateway, gin.H{"error": "upstream stream request failed"})
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode >= 400 {
|
|
pool.MarkFailure(cred, resp.StatusCode)
|
|
telemetry.CredentialCooldowns.Add(ctx, 1,
|
|
metric.WithAttributes(attribute.Int("status_code", resp.StatusCode)))
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
latencyMs := float64(time.Since(startTime).Milliseconds())
|
|
recordRequestMetrics(ctx, ri, resp.StatusCode, latencyMs)
|
|
recordUpstreamError(ctx, resp.StatusCode, respBody, resp.Header.Get("X-Request-Id"), latencyMs, ri, c.Request.Header)
|
|
|
|
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody)
|
|
return
|
|
}
|
|
|
|
pool.MarkSuccess(cred)
|
|
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
c.Status(http.StatusOK)
|
|
|
|
flusher, ok := c.Writer.(http.Flusher)
|
|
if !ok {
|
|
log.Error().Msg("response writer does not support flushing")
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "streaming not supported"})
|
|
return
|
|
}
|
|
|
|
var inputTokens, outputTokens int64
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
|
for scanner.Scan() {
|
|
line := san.DesanitizeStreamEvent(scanner.Text())
|
|
c.Writer.WriteString(line + "\n")
|
|
flusher.Flush()
|
|
|
|
if len(line) > 5 && line[:5] == "data:" {
|
|
data := line[5:]
|
|
eventType := gjson.Get(data, "type").String()
|
|
switch eventType {
|
|
case "message_start":
|
|
inputTokens = gjson.Get(data, "message.usage.input_tokens").Int()
|
|
case "message_delta":
|
|
outputTokens = gjson.Get(data, "usage.output_tokens").Int()
|
|
}
|
|
}
|
|
}
|
|
|
|
latencyMs := float64(time.Since(startTime).Milliseconds())
|
|
recordRequestMetrics(ctx, ri, http.StatusOK, latencyMs)
|
|
|
|
if inputTokens > 0 || outputTokens > 0 {
|
|
recordTokenUsage(ctx, model, cred, inputTokens, outputTokens)
|
|
if tracker != nil {
|
|
tracker.UpdateFromHeaders(resp.Header)
|
|
}
|
|
}
|
|
|
|
log.Info().
|
|
Float64("latency_ms", latencyMs).
|
|
Str("model", model).
|
|
Bool("stream", true).
|
|
Int64("input_tokens", inputTokens).
|
|
Int64("output_tokens", outputTokens).
|
|
Msg("stream completed")
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
log.Error().Err(err).Msg("stream scan error")
|
|
}
|
|
}
|
|
|
|
// recordConnectionError logs and records metrics for upstream connection failures.
|
|
func recordConnectionError(ctx context.Context, err error, ri requestInfo, latencyMs float64) {
|
|
log.Error().
|
|
Err(err).
|
|
Str("credential", ri.cred.Email).
|
|
Str("model", ri.model).
|
|
Bool("stream", ri.stream).
|
|
Str("request_body_original", string(ri.originalBody)).
|
|
Str("request_body_sanitized", string(ri.body)).
|
|
Int("request_body_size", len(ri.body)).
|
|
Float64("latency_ms", latencyMs).
|
|
Msg("upstream connection error")
|
|
|
|
telemetry.UpstreamErrors.Add(ctx, 1,
|
|
metric.WithAttributes(
|
|
attribute.String("error_type", "connection"),
|
|
attribute.String("credential", ri.cred.Email),
|
|
attribute.Int("status_code", http.StatusBadGateway),
|
|
))
|
|
recordRequestMetrics(ctx, ri, http.StatusBadGateway, latencyMs)
|
|
}
|
|
|
|
// recordUpstreamError logs and records metrics for upstream HTTP error responses.
|
|
func recordUpstreamError(ctx context.Context, statusCode int, respBody []byte, requestID string, latencyMs float64, ri requestInfo, requestHeaders http.Header) {
|
|
errorType := gjson.GetBytes(respBody, "error.type").String()
|
|
errorMessage := gjson.GetBytes(respBody, "error.message").String()
|
|
log.Error().
|
|
Int("status", statusCode).
|
|
Str("error_type", errorType).
|
|
Str("error_message", errorMessage).
|
|
Str("response_body", string(respBody)).
|
|
Str("request_id", requestID).
|
|
Float64("latency_ms", latencyMs).
|
|
Str("credential", ri.cred.Email).
|
|
Str("model", ri.model).
|
|
Bool("stream", ri.stream).
|
|
Str("request_body_original", string(ri.originalBody)).
|
|
Str("request_body_sanitized", string(ri.body)).
|
|
Int("request_body_size", len(ri.body)).
|
|
Str("request_headers", logging.RedactHeaders(requestHeaders)).
|
|
Msg("upstream error")
|
|
|
|
telemetry.UpstreamErrors.Add(ctx, 1,
|
|
metric.WithAttributes(
|
|
attribute.Int("status_code", statusCode),
|
|
attribute.String("error_type", errorType),
|
|
attribute.String("credential", ri.cred.Email),
|
|
))
|
|
}
|
|
|
|
// recordRequestMetrics records the request counter and duration histogram.
|
|
func recordRequestMetrics(ctx context.Context, ri requestInfo, statusCode int, latencyMs float64) {
|
|
attrs := []attribute.KeyValue{
|
|
attribute.String("model", ri.model),
|
|
attribute.Bool("stream", ri.stream),
|
|
attribute.Int("status_code", statusCode),
|
|
}
|
|
telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...))
|
|
telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...))
|
|
}
|
|
|
|
// recordTokenUsage records token consumption metrics.
|
|
func recordTokenUsage(ctx context.Context, model string, cred *auth.Credential, inputTokens, outputTokens int64) {
|
|
tokenAttrs := metric.WithAttributes(
|
|
attribute.String("model", model),
|
|
attribute.String("credential", cred.Email),
|
|
)
|
|
telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs)
|
|
telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs)
|
|
}
|