Files
anthropic-proxy/internal/proxy/handler.go
T
2026-04-14 13:50:34 +02:00

325 lines
11 KiB
Go

package proxy
import (
"bufio"
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"github.com/tidwall/gjson"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"github.com/fujin/anthropic-proxy/internal/auth"
"github.com/fujin/anthropic-proxy/internal/logging"
"github.com/fujin/anthropic-proxy/internal/ratelimit"
"github.com/fujin/anthropic-proxy/internal/telemetry"
)
func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func() *Sanitizer, tracker *ratelimit.Tracker) gin.HandlerFunc {
upstream := NewUpstreamClient(profile)
return func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
return
}
originalBody := make([]byte, len(body))
copy(originalBody, body)
log.Info().
Str("method", c.Request.Method).
Str("path", c.Request.URL.Path).
Int("body_size", len(body)).
Str("model", gjson.GetBytes(body, "model").String()).
Msg("incoming request")
san := getSanitizer()
body = san.SanitizeRequest(body)
cred, err := pool.Pick()
if err != nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
return
}
isStream := gjson.GetBytes(body, "stream").Bool()
if isStream {
handleStream(c, upstream, san, pool, cred, body, originalBody, tracker)
} else {
handleNonStream(c, upstream, san, pool, cred, body, originalBody, tracker)
}
}
}
func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte, originalBody []byte, tracker *ratelimit.Tracker) {
startTime := time.Now()
model := gjson.GetBytes(body, "model").String()
ctx := c.Request.Context()
telemetry.RequestBodySize.Record(ctx, int64(len(body)),
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", false)))
respBody, headers, statusCode, err := upstream.Execute(ctx, cred, body)
latencyMs := float64(time.Since(startTime).Milliseconds())
if err != nil {
log.Error().
Err(err).
Str("credential", cred.Email).
Str("model", model).
Bool("stream", false).
Str("request_body_original", string(originalBody)).
Str("request_body_sanitized", string(body)).
Int("request_body_size", len(body)).
Float64("latency_ms", latencyMs).
Msg("upstream connection error")
telemetry.UpstreamErrors.Add(ctx, 1,
metric.WithAttributes(
attribute.String("error_type", "connection"),
attribute.String("credential", cred.Email),
attribute.Int("status_code", http.StatusBadGateway),
))
telemetry.RequestCounter.Add(ctx, 1,
metric.WithAttributes(
attribute.String("model", model),
attribute.Bool("stream", false),
attribute.Int("status_code", http.StatusBadGateway),
))
telemetry.RequestDuration.Record(ctx, latencyMs,
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", false), attribute.Int("status_code", http.StatusBadGateway)))
c.JSON(http.StatusBadGateway, gin.H{"error": "upstream request failed"})
return
}
attrs := []attribute.KeyValue{
attribute.String("model", model),
attribute.Bool("stream", false),
attribute.Int("status_code", statusCode),
}
telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...))
telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...))
if statusCode >= 400 {
pool.MarkFailure(cred, statusCode)
telemetry.CredentialCooldowns.Add(ctx, 1,
metric.WithAttributes(attribute.Int("status_code", statusCode)))
errorType := gjson.GetBytes(respBody, "error.type").String()
errorMessage := gjson.GetBytes(respBody, "error.message").String()
log.Error().
Int("status", statusCode).
Str("error_type", errorType).
Str("error_message", errorMessage).
Str("response_body", string(respBody)).
Str("request_id", headers.Get("X-Request-Id")).
Float64("latency_ms", latencyMs).
Str("credential", cred.Email).
Str("model", model).
Bool("stream", false).
Str("request_body_original", string(originalBody)).
Str("request_body_sanitized", string(body)).
Int("request_body_size", len(body)).
Str("request_headers", logging.RedactHeaders(c.Request.Header)).
Msg("upstream error")
telemetry.UpstreamErrors.Add(ctx, 1,
metric.WithAttributes(
attribute.Int("status_code", statusCode),
attribute.String("error_type", errorType),
attribute.String("credential", cred.Email),
))
} else {
pool.MarkSuccess(cred)
respBody = san.DesanitizeResponse(respBody)
inputTokens := gjson.GetBytes(respBody, "usage.input_tokens").Int()
outputTokens := gjson.GetBytes(respBody, "usage.output_tokens").Int()
tokenAttrs := metric.WithAttributes(
attribute.String("model", model),
attribute.String("credential", cred.Email),
)
telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs)
telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs)
if tracker != nil {
tracker.RecordTokens(inputTokens, outputTokens)
tracker.UpdateFromHeaders(headers)
}
log.Info().
Int("status", statusCode).
Float64("latency_ms", latencyMs).
Str("model", model).
Int64("input_tokens", inputTokens).
Int64("output_tokens", outputTokens).
Msg("request completed")
}
for _, h := range []string{"Content-Type", "X-Request-Id"} {
if v := headers.Get(h); v != "" {
c.Header(h, v)
}
}
c.Data(statusCode, headers.Get("Content-Type"), respBody)
}
func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte, originalBody []byte, tracker *ratelimit.Tracker) {
startTime := time.Now()
model := gjson.GetBytes(body, "model").String()
ctx := c.Request.Context()
telemetry.StreamRequests.Add(ctx, 1, metric.WithAttributes(attribute.String("model", model)))
telemetry.RequestBodySize.Record(ctx, int64(len(body)),
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", true)))
resp, err := upstream.ExecuteStream(ctx, cred, body)
if err != nil {
latencyMs := float64(time.Since(startTime).Milliseconds())
log.Error().
Err(err).
Str("credential", cred.Email).
Str("model", model).
Bool("stream", true).
Str("request_body_original", string(originalBody)).
Str("request_body_sanitized", string(body)).
Int("request_body_size", len(body)).
Float64("latency_ms", latencyMs).
Msg("upstream connection error")
telemetry.UpstreamErrors.Add(ctx, 1,
metric.WithAttributes(
attribute.String("error_type", "connection"),
attribute.String("credential", cred.Email),
attribute.Int("status_code", http.StatusBadGateway),
))
telemetry.RequestCounter.Add(ctx, 1,
metric.WithAttributes(
attribute.String("model", model),
attribute.Bool("stream", true),
attribute.Int("status_code", http.StatusBadGateway),
))
telemetry.RequestDuration.Record(ctx, latencyMs,
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", true), attribute.Int("status_code", http.StatusBadGateway)))
c.JSON(http.StatusBadGateway, gin.H{"error": "upstream stream request failed"})
return
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
pool.MarkFailure(cred, resp.StatusCode)
telemetry.CredentialCooldowns.Add(ctx, 1,
metric.WithAttributes(attribute.Int("status_code", resp.StatusCode)))
respBody, _ := io.ReadAll(resp.Body)
latencyMs := float64(time.Since(startTime).Milliseconds())
errorType := gjson.GetBytes(respBody, "error.type").String()
errorMessage := gjson.GetBytes(respBody, "error.message").String()
log.Error().
Int("status", resp.StatusCode).
Str("error_type", errorType).
Str("error_message", errorMessage).
Str("response_body", string(respBody)).
Str("request_id", resp.Header.Get("X-Request-Id")).
Float64("latency_ms", latencyMs).
Str("credential", cred.Email).
Str("model", model).
Bool("stream", true).
Str("request_body_original", string(originalBody)).
Str("request_body_sanitized", string(body)).
Int("request_body_size", len(body)).
Str("request_headers", logging.RedactHeaders(c.Request.Header)).
Msg("upstream error")
attrs := []attribute.KeyValue{
attribute.String("model", model),
attribute.Bool("stream", true),
attribute.Int("status_code", resp.StatusCode),
}
telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...))
telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...))
telemetry.UpstreamErrors.Add(ctx, 1,
metric.WithAttributes(
attribute.Int("status_code", resp.StatusCode),
attribute.String("error_type", errorType),
attribute.String("credential", cred.Email),
))
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody)
return
}
pool.MarkSuccess(cred)
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Status(http.StatusOK)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
log.Error().Msg("response writer does not support flushing")
c.JSON(http.StatusInternalServerError, gin.H{"error": "streaming not supported"})
return
}
var inputTokens, outputTokens int64
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := san.DesanitizeStreamEvent(scanner.Text())
c.Writer.WriteString(line + "\n")
flusher.Flush()
// Extract token usage from message_delta event
if len(line) > 5 && line[:5] == "data:" {
data := line[5:]
if gjson.Get(data, "type").String() == "message_delta" {
inputTokens = gjson.Get(data, "usage.input_tokens").Int()
outputTokens = gjson.Get(data, "usage.output_tokens").Int()
}
}
}
latencyMs := float64(time.Since(startTime).Milliseconds())
attrs := []attribute.KeyValue{
attribute.String("model", model),
attribute.Bool("stream", true),
attribute.Int("status_code", http.StatusOK),
}
telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...))
telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...))
if inputTokens > 0 || outputTokens > 0 {
tokenAttrs := metric.WithAttributes(
attribute.String("model", model),
attribute.String("credential", cred.Email),
)
telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs)
telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs)
if tracker != nil {
tracker.RecordTokens(inputTokens, outputTokens)
tracker.UpdateFromHeaders(resp.Header)
}
}
log.Info().
Float64("latency_ms", latencyMs).
Str("model", model).
Bool("stream", true).
Int64("input_tokens", inputTokens).
Int64("output_tokens", outputTokens).
Msg("stream completed")
if err := scanner.Err(); err != nil {
log.Error().Err(err).Msg("stream scan error")
}
}