package proxy import ( "bufio" "context" "io" "net/http" "time" "github.com/gin-gonic/gin" "github.com/rs/zerolog/log" "github.com/tidwall/gjson" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" "github.com/fujin/anthropic-proxy/internal/auth" "github.com/fujin/anthropic-proxy/internal/logging" "github.com/fujin/anthropic-proxy/internal/ratelimit" "github.com/fujin/anthropic-proxy/internal/telemetry" ) // requestInfo bundles common request context passed to logging/telemetry helpers. type requestInfo struct { model string stream bool cred *auth.Credential body []byte originalBody []byte } func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func() *Sanitizer, tracker *ratelimit.Tracker) gin.HandlerFunc { upstream := NewUpstreamClient(profile) return func(c *gin.Context) { body, err := io.ReadAll(c.Request.Body) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"}) return } originalBody := make([]byte, len(body)) copy(originalBody, body) log.Info(). Str("method", c.Request.Method). Str("path", c.Request.URL.Path). Int("body_size", len(body)). Str("model", gjson.GetBytes(body, "model").String()). Msg("incoming request") san := getSanitizer() body = san.SanitizeRequest(body) cred, err := pool.Pick() if err != nil { c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()}) return } isStream := gjson.GetBytes(body, "stream").Bool() if isStream { handleStream(c, upstream, san, pool, cred, body, originalBody, tracker) } else { handleNonStream(c, upstream, san, pool, cred, body, originalBody, tracker) } } } func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte, originalBody []byte, tracker *ratelimit.Tracker) { startTime := time.Now() model := gjson.GetBytes(body, "model").String() ctx := c.Request.Context() ri := requestInfo{model: model, stream: false, cred: cred, body: body, originalBody: originalBody} telemetry.RequestBodySize.Record(ctx, int64(len(body)), metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", false))) respBody, headers, statusCode, err := upstream.Execute(ctx, cred, body) latencyMs := float64(time.Since(startTime).Milliseconds()) if err != nil { recordConnectionError(ctx, err, ri, latencyMs) c.JSON(http.StatusBadGateway, gin.H{"error": "upstream request failed"}) return } recordRequestMetrics(ctx, ri, statusCode, latencyMs) if statusCode >= 400 { pool.MarkFailure(cred, statusCode) telemetry.CredentialCooldowns.Add(ctx, 1, metric.WithAttributes(attribute.Int("status_code", statusCode))) recordUpstreamError(ctx, statusCode, respBody, headers.Get("X-Request-Id"), latencyMs, ri, c.Request.Header) } else { pool.MarkSuccess(cred) respBody = san.DesanitizeResponse(respBody) inputTokens := gjson.GetBytes(respBody, "usage.input_tokens").Int() outputTokens := gjson.GetBytes(respBody, "usage.output_tokens").Int() recordTokenUsage(ctx, model, cred, inputTokens, outputTokens) if tracker != nil { tracker.UpdateFromHeaders(headers) } log.Info(). Int("status", statusCode). Float64("latency_ms", latencyMs). Str("model", model). Int64("input_tokens", inputTokens). Int64("output_tokens", outputTokens). Msg("request completed") } for _, h := range []string{"Content-Type", "X-Request-Id"} { if v := headers.Get(h); v != "" { c.Header(h, v) } } c.Data(statusCode, headers.Get("Content-Type"), respBody) } func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte, originalBody []byte, tracker *ratelimit.Tracker) { startTime := time.Now() model := gjson.GetBytes(body, "model").String() ctx := c.Request.Context() ri := requestInfo{model: model, stream: true, cred: cred, body: body, originalBody: originalBody} telemetry.StreamRequests.Add(ctx, 1, metric.WithAttributes(attribute.String("model", model))) telemetry.RequestBodySize.Record(ctx, int64(len(body)), metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", true))) resp, err := upstream.ExecuteStream(ctx, cred, body) if err != nil { latencyMs := float64(time.Since(startTime).Milliseconds()) recordConnectionError(ctx, err, ri, latencyMs) c.JSON(http.StatusBadGateway, gin.H{"error": "upstream stream request failed"}) return } defer resp.Body.Close() if resp.StatusCode >= 400 { pool.MarkFailure(cred, resp.StatusCode) telemetry.CredentialCooldowns.Add(ctx, 1, metric.WithAttributes(attribute.Int("status_code", resp.StatusCode))) respBody, _ := io.ReadAll(resp.Body) latencyMs := float64(time.Since(startTime).Milliseconds()) recordRequestMetrics(ctx, ri, resp.StatusCode, latencyMs) recordUpstreamError(ctx, resp.StatusCode, respBody, resp.Header.Get("X-Request-Id"), latencyMs, ri, c.Request.Header) c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody) return } pool.MarkSuccess(cred) c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Status(http.StatusOK) flusher, ok := c.Writer.(http.Flusher) if !ok { log.Error().Msg("response writer does not support flushing") c.JSON(http.StatusInternalServerError, gin.H{"error": "streaming not supported"}) return } var inputTokens, outputTokens int64 scanner := bufio.NewScanner(resp.Body) scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) for scanner.Scan() { line := san.DesanitizeStreamEvent(scanner.Text()) c.Writer.WriteString(line + "\n") flusher.Flush() if len(line) > 5 && line[:5] == "data:" { data := line[5:] eventType := gjson.Get(data, "type").String() switch eventType { case "message_start": inputTokens = gjson.Get(data, "message.usage.input_tokens").Int() case "message_delta": outputTokens = gjson.Get(data, "usage.output_tokens").Int() } } } latencyMs := float64(time.Since(startTime).Milliseconds()) recordRequestMetrics(ctx, ri, http.StatusOK, latencyMs) if inputTokens > 0 || outputTokens > 0 { recordTokenUsage(ctx, model, cred, inputTokens, outputTokens) if tracker != nil { tracker.UpdateFromHeaders(resp.Header) } } log.Info(). Float64("latency_ms", latencyMs). Str("model", model). Bool("stream", true). Int64("input_tokens", inputTokens). Int64("output_tokens", outputTokens). Msg("stream completed") if err := scanner.Err(); err != nil { log.Error().Err(err).Msg("stream scan error") } } // recordConnectionError logs and records metrics for upstream connection failures. func recordConnectionError(ctx context.Context, err error, ri requestInfo, latencyMs float64) { log.Error(). Err(err). Str("credential", ri.cred.Email). Str("model", ri.model). Bool("stream", ri.stream). Str("request_body_original", string(ri.originalBody)). Str("request_body_sanitized", string(ri.body)). Int("request_body_size", len(ri.body)). Float64("latency_ms", latencyMs). Msg("upstream connection error") telemetry.UpstreamErrors.Add(ctx, 1, metric.WithAttributes( attribute.String("error_type", "connection"), attribute.String("credential", ri.cred.Email), attribute.Int("status_code", http.StatusBadGateway), )) recordRequestMetrics(ctx, ri, http.StatusBadGateway, latencyMs) } // recordUpstreamError logs and records metrics for upstream HTTP error responses. func recordUpstreamError(ctx context.Context, statusCode int, respBody []byte, requestID string, latencyMs float64, ri requestInfo, requestHeaders http.Header) { errorType := gjson.GetBytes(respBody, "error.type").String() errorMessage := gjson.GetBytes(respBody, "error.message").String() log.Error(). Int("status", statusCode). Str("error_type", errorType). Str("error_message", errorMessage). Str("response_body", string(respBody)). Str("request_id", requestID). Float64("latency_ms", latencyMs). Str("credential", ri.cred.Email). Str("model", ri.model). Bool("stream", ri.stream). Str("request_body_original", string(ri.originalBody)). Str("request_body_sanitized", string(ri.body)). Int("request_body_size", len(ri.body)). Str("request_headers", logging.RedactHeaders(requestHeaders)). Msg("upstream error") telemetry.UpstreamErrors.Add(ctx, 1, metric.WithAttributes( attribute.Int("status_code", statusCode), attribute.String("error_type", errorType), attribute.String("credential", ri.cred.Email), )) } // recordRequestMetrics records the request counter and duration histogram. func recordRequestMetrics(ctx context.Context, ri requestInfo, statusCode int, latencyMs float64) { attrs := []attribute.KeyValue{ attribute.String("model", ri.model), attribute.Bool("stream", ri.stream), attribute.Int("status_code", statusCode), } telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...)) telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...)) } // recordTokenUsage records token consumption metrics. func recordTokenUsage(ctx context.Context, model string, cred *auth.Credential, inputTokens, outputTokens int64) { tokenAttrs := metric.WithAttributes( attribute.String("model", model), attribute.String("credential", cred.Email), ) telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs) telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs) }