package proxy import ( "bufio" "io" "net/http" "time" "github.com/gin-gonic/gin" "github.com/rs/zerolog/log" "github.com/tidwall/gjson" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" "github.com/fujin/anthropic-proxy/internal/auth" "github.com/fujin/anthropic-proxy/internal/logging" "github.com/fujin/anthropic-proxy/internal/ratelimit" "github.com/fujin/anthropic-proxy/internal/telemetry" ) func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func() *Sanitizer, tracker *ratelimit.Tracker) gin.HandlerFunc { upstream := NewUpstreamClient(profile) return func(c *gin.Context) { body, err := io.ReadAll(c.Request.Body) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"}) return } originalBody := make([]byte, len(body)) copy(originalBody, body) log.Info(). Str("method", c.Request.Method). Str("path", c.Request.URL.Path). Int("body_size", len(body)). Str("model", gjson.GetBytes(body, "model").String()). Msg("incoming request") san := getSanitizer() body = san.SanitizeRequest(body) cred, err := pool.Pick() if err != nil { c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()}) return } isStream := gjson.GetBytes(body, "stream").Bool() if isStream { handleStream(c, upstream, san, pool, cred, body, originalBody, tracker) } else { handleNonStream(c, upstream, san, pool, cred, body, originalBody, tracker) } } } func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte, originalBody []byte, tracker *ratelimit.Tracker) { startTime := time.Now() model := gjson.GetBytes(body, "model").String() ctx := c.Request.Context() telemetry.RequestBodySize.Record(ctx, int64(len(body)), metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", false))) respBody, headers, statusCode, err := upstream.Execute(ctx, cred, body) latencyMs := float64(time.Since(startTime).Milliseconds()) if err != nil { log.Error(). Err(err). Str("credential", cred.Email). Str("model", model). Bool("stream", false). Str("request_body_original", string(originalBody)). Str("request_body_sanitized", string(body)). Int("request_body_size", len(body)). Float64("latency_ms", latencyMs). Msg("upstream connection error") telemetry.UpstreamErrors.Add(ctx, 1, metric.WithAttributes( attribute.String("error_type", "connection"), attribute.String("credential", cred.Email), attribute.Int("status_code", http.StatusBadGateway), )) telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes( attribute.String("model", model), attribute.Bool("stream", false), attribute.Int("status_code", http.StatusBadGateway), )) telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", false), attribute.Int("status_code", http.StatusBadGateway))) c.JSON(http.StatusBadGateway, gin.H{"error": "upstream request failed"}) return } attrs := []attribute.KeyValue{ attribute.String("model", model), attribute.Bool("stream", false), attribute.Int("status_code", statusCode), } telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...)) telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...)) if statusCode >= 400 { pool.MarkFailure(cred, statusCode) telemetry.CredentialCooldowns.Add(ctx, 1, metric.WithAttributes(attribute.Int("status_code", statusCode))) errorType := gjson.GetBytes(respBody, "error.type").String() errorMessage := gjson.GetBytes(respBody, "error.message").String() log.Error(). Int("status", statusCode). Str("error_type", errorType). Str("error_message", errorMessage). Str("response_body", string(respBody)). Str("request_id", headers.Get("X-Request-Id")). Float64("latency_ms", latencyMs). Str("credential", cred.Email). Str("model", model). Bool("stream", false). Str("request_body_original", string(originalBody)). Str("request_body_sanitized", string(body)). Int("request_body_size", len(body)). Str("request_headers", logging.RedactHeaders(c.Request.Header)). Msg("upstream error") telemetry.UpstreamErrors.Add(ctx, 1, metric.WithAttributes( attribute.Int("status_code", statusCode), attribute.String("error_type", errorType), attribute.String("credential", cred.Email), )) } else { pool.MarkSuccess(cred) respBody = san.DesanitizeResponse(respBody) inputTokens := gjson.GetBytes(respBody, "usage.input_tokens").Int() outputTokens := gjson.GetBytes(respBody, "usage.output_tokens").Int() tokenAttrs := metric.WithAttributes( attribute.String("model", model), attribute.String("credential", cred.Email), ) telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs) telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs) if tracker != nil { tracker.RecordTokens(inputTokens, outputTokens) tracker.UpdateFromHeaders(headers) } log.Info(). Int("status", statusCode). Float64("latency_ms", latencyMs). Str("model", model). Int64("input_tokens", inputTokens). Int64("output_tokens", outputTokens). Msg("request completed") } for _, h := range []string{"Content-Type", "X-Request-Id"} { if v := headers.Get(h); v != "" { c.Header(h, v) } } c.Data(statusCode, headers.Get("Content-Type"), respBody) } func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte, originalBody []byte, tracker *ratelimit.Tracker) { startTime := time.Now() model := gjson.GetBytes(body, "model").String() ctx := c.Request.Context() telemetry.StreamRequests.Add(ctx, 1, metric.WithAttributes(attribute.String("model", model))) telemetry.RequestBodySize.Record(ctx, int64(len(body)), metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", true))) resp, err := upstream.ExecuteStream(ctx, cred, body) if err != nil { latencyMs := float64(time.Since(startTime).Milliseconds()) log.Error(). Err(err). Str("credential", cred.Email). Str("model", model). Bool("stream", true). Str("request_body_original", string(originalBody)). Str("request_body_sanitized", string(body)). Int("request_body_size", len(body)). Float64("latency_ms", latencyMs). Msg("upstream connection error") telemetry.UpstreamErrors.Add(ctx, 1, metric.WithAttributes( attribute.String("error_type", "connection"), attribute.String("credential", cred.Email), attribute.Int("status_code", http.StatusBadGateway), )) telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes( attribute.String("model", model), attribute.Bool("stream", true), attribute.Int("status_code", http.StatusBadGateway), )) telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", true), attribute.Int("status_code", http.StatusBadGateway))) c.JSON(http.StatusBadGateway, gin.H{"error": "upstream stream request failed"}) return } defer resp.Body.Close() if resp.StatusCode >= 400 { pool.MarkFailure(cred, resp.StatusCode) telemetry.CredentialCooldowns.Add(ctx, 1, metric.WithAttributes(attribute.Int("status_code", resp.StatusCode))) respBody, _ := io.ReadAll(resp.Body) latencyMs := float64(time.Since(startTime).Milliseconds()) errorType := gjson.GetBytes(respBody, "error.type").String() errorMessage := gjson.GetBytes(respBody, "error.message").String() log.Error(). Int("status", resp.StatusCode). Str("error_type", errorType). Str("error_message", errorMessage). Str("response_body", string(respBody)). Str("request_id", resp.Header.Get("X-Request-Id")). Float64("latency_ms", latencyMs). Str("credential", cred.Email). Str("model", model). Bool("stream", true). Str("request_body_original", string(originalBody)). Str("request_body_sanitized", string(body)). Int("request_body_size", len(body)). Str("request_headers", logging.RedactHeaders(c.Request.Header)). Msg("upstream error") attrs := []attribute.KeyValue{ attribute.String("model", model), attribute.Bool("stream", true), attribute.Int("status_code", resp.StatusCode), } telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...)) telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...)) telemetry.UpstreamErrors.Add(ctx, 1, metric.WithAttributes( attribute.Int("status_code", resp.StatusCode), attribute.String("error_type", errorType), attribute.String("credential", cred.Email), )) c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody) return } pool.MarkSuccess(cred) c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Status(http.StatusOK) flusher, ok := c.Writer.(http.Flusher) if !ok { log.Error().Msg("response writer does not support flushing") c.JSON(http.StatusInternalServerError, gin.H{"error": "streaming not supported"}) return } var inputTokens, outputTokens int64 scanner := bufio.NewScanner(resp.Body) scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) for scanner.Scan() { line := san.DesanitizeStreamEvent(scanner.Text()) c.Writer.WriteString(line + "\n") flusher.Flush() // Extract token usage from message_delta event if len(line) > 5 && line[:5] == "data:" { data := line[5:] if gjson.Get(data, "type").String() == "message_delta" { inputTokens = gjson.Get(data, "usage.input_tokens").Int() outputTokens = gjson.Get(data, "usage.output_tokens").Int() } } } latencyMs := float64(time.Since(startTime).Milliseconds()) attrs := []attribute.KeyValue{ attribute.String("model", model), attribute.Bool("stream", true), attribute.Int("status_code", http.StatusOK), } telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...)) telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...)) if inputTokens > 0 || outputTokens > 0 { tokenAttrs := metric.WithAttributes( attribute.String("model", model), attribute.String("credential", cred.Email), ) telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs) telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs) if tracker != nil { tracker.RecordTokens(inputTokens, outputTokens) tracker.UpdateFromHeaders(resp.Header) } } log.Info(). Float64("latency_ms", latencyMs). Str("model", model). Bool("stream", true). Int64("input_tokens", inputTokens). Int64("output_tokens", outputTokens). Msg("stream completed") if err := scanner.Err(); err != nil { log.Error().Err(err).Msg("stream scan error") } }