Compare commits

...

41 Commits

Author SHA1 Message Date
Alexander 86db3ca091 Update dependenciesd 2026-04-22 20:51:03 +02:00
Alexander 0df28e9dd8 refactor: modularize codebase — deduplicate, extract, clean up
- Unify duplicate uTLS transports into shared internal/transport package
- Extract shared version constant into internal/version
- Move LoadDefaultCredentials from config to auth (remove config→auth import)
- Deduplicate handler.go: extract telemetry/error helpers (324→268 lines)
- Break up main.go::run() into initCredential/initEmbedded
- Eliminate logging.Config duplication (use config.LoggingConfig directly)
- Extract logWriter to embedded/log.go, SSE fixtures to consts in sniff.go
- Use uTLS client for usage polling (consistent TLS fingerprint)
- Handle sjson.SetBytes errors in sanitize.go instead of silently swallowing
- Document reverse-engineered magic values in billing.go
- Unexport Credential.CooldownUntil (internal state)
- Replace hardcoded auth bypass paths with map in server.go
2026-04-15 11:01:29 +02:00
Alexander 9150f466e5 test: add comprehensive test harness across all packages (156 tests)
Characterization tests capturing current behavior before refactoring.
Covers auth, config, logging, proxy, ratelimit, server, and telemetry
packages with race-safe concurrent access tests.
2026-04-15 10:40:43 +02:00
Alexander d3fbfe8b42 Dashboard update 2026-04-15 10:06:25 +02:00
Alexander a6c9a16833 Merge feat/embedded-perses: embedded Perses + VictoriaMetrics dashboard 2026-04-14 21:56:43 +02:00
Alexander 34927d3a00 docs: add Perses dashboard example and update config
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/claude-agent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-14 21:56:32 +02:00
Alexander ee9c53791a feat(main): wire embedded Perses + VM toggle
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/claude-agent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-14 21:56:32 +02:00
Alexander 859640d814 feat(embedded): add Perses + VictoriaMetrics subprocess management with auto-download
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/claude-agent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-14 21:56:32 +02:00
Alexander be4113e7ef feat(server): serve /metrics endpoint when embedded metrics enabled
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/claude-agent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-14 21:56:32 +02:00
Alexander 501e40c53d feat(telemetry): add Prometheus exporter for embedded metrics scraping
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/claude-agent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-14 21:56:32 +02:00
Alexander 1bc704a7b2 refactor(config): restructure telemetry config with export and embedded sub-configs
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/claude-agent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-14 21:56:32 +02:00
Alexander bc6ad70386 Grafana dashboard example 2026-04-14 20:42:46 +02:00
Alexander b07d999d86 docs: add metrics reference 2026-04-14 17:54:32 +02:00
Alexander 27b647e9b4 refactor(ratelimit): remove per-window token tracking from proxy
Window token counts are now computed in Grafana using the @ modifier
with dashboard variables derived from proxy_usage_resets_at. This
eliminates in-memory state, file persistence, and restart sensitivity.

Removes: TokensIn/Out, RecordTokens, setResetTime, persist.go,
window_tokens observable gauges. -171 lines.
2026-04-14 14:25:31 +02:00
Alexander 273213cbed feat(ratelimit): persist window token counters across restarts
Save window state (resets_at + token counts) to ~/.claude/ on shutdown
and every poll cycle. On startup, restore counters if the window hasn't
rolled over. Fixes token counters resetting to zero on deploy.
2026-04-14 14:07:28 +02:00
Alexander b864092dad fix(stream): extract input tokens from message_start event
message_delta only contains output_tokens. Input tokens are in the
message_start event under message.usage.input_tokens. This was causing
input token counts to be near-zero for all streaming requests.
2026-04-14 13:55:06 +02:00
Alexander 0ab1896eef Revert "refactor(ratelimit): remove in-memory per-window token tracking"
This reverts commit eda66ff7d4.
2026-04-14 13:50:34 +02:00
Alexander eda66ff7d4 refactor(ratelimit): remove in-memory per-window token tracking
Token counts per rate limit window are now derived in Grafana via
increase(counter[5h/168h]) on the existing cumulative OTel counters.
Removes TokensIn/Out from Window, RecordTokens, setResetTime, and
the window_tokens observable gauges.
2026-04-14 13:49:05 +02:00
Alexander 744abc1d24 fix(ratelimit): clear window token counters on reset from response headers
UpdateFromHeaders was silently updating ResetsAt without clearing token
counters. When a window rolled over, the poll method would see ResetsAt
already updated and skip the reset. Extract setResetTime helper used by
both code paths.
2026-04-14 13:37:06 +02:00
Alexander e8af26d626 docs: rewrite README to cover all proxy features 2026-04-14 13:17:54 +02:00
Alexander fac9578975 feat(ratelimit): track per-window token usage and utilization
Poll /api/oauth/usage every 5 min and extract utilization from
/v1/messages response headers for real-time updates. Track proxy
tokens in/out per rate limit window (5h/7d), resetting on window
change. Expose as OTel observable gauges for Grafana dashboards.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 12:51:31 +02:00
Alexander 76aeeb6be1 fix(auth): add oauth-2025-04-20 beta header + debug logging
Ensure anthropic-beta includes oauth-2025-04-20 when using OAuth tokens,
fixing 401 "OAuth authentication is currently not supported" errors.
Add debug-level logging for upstream requests/responses, sniffed headers,
and token refresh operations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 11:08:08 +02:00
Alexander 9cc052c162 Add telemetry 2026-04-14 10:31:56 +02:00
Alexander 20049881ad Remove duplicate logging 2026-04-11 15:21:18 +02:00
Alexander 3435f5f4c5 Update example 2026-04-10 18:27:29 +02:00
Alexander 807e8ba133 fix(nix): update vendorHash and vendor dir for new deps 2026-04-10 18:25:19 +02:00
Alexander da59d8f83b refactor(auth): migrate to zerolog structured logging 2026-04-10 18:19:13 +02:00
Alexander 4e22c463cf refactor(proxy): migrate to zerolog structured logging 2026-04-10 18:19:13 +02:00
Alexander 76bf651742 refactor(server): migrate to zerolog, add request logging middleware 2026-04-10 18:19:13 +02:00
Alexander 3d1eb7bd4b refactor(main): migrate to zerolog structured logging 2026-04-10 18:19:13 +02:00
Alexander bfcbe0b37d feat(config): add logging configuration fields 2026-04-10 18:15:49 +02:00
Alexander a7b583839d feat(logging): add zerolog + lumberjack structured logging package 2026-04-10 18:15:49 +02:00
Alexander c5f6962104 Package proxy with nix 2026-04-10 14:44:07 +02:00
Alexander 5ec0004e4c Update example rules 2026-04-10 14:36:59 +02:00
Alexander bf68a0fbeb Update flake deps 2026-04-10 14:33:11 +02:00
Alexander e3c4854be0 fix(auth): bind callback server to localhost for IPv4/IPv6 compat, fix nil deref 2026-04-10 14:30:23 +02:00
Alexander 8b7d9bfff9 docs: update README and config for self-managed authentication 2026-04-10 14:17:46 +02:00
Alexander 65e843f57a feat: wire OAuth login into startup, auto-detect credentials 2026-04-10 14:17:46 +02:00
Alexander 9858530ff6 fix(auth): handle credential file creation in persistCredential 2026-04-10 14:14:42 +02:00
Alexander 21176949a6 feat(auth): add OAuth PKCE login flow with browser + manual fallback 2026-04-10 14:14:42 +02:00
Alexander 945a865bbe refactor(config): remove claude_credentials, add default credential path 2026-04-10 14:14:38 +02:00
52 changed files with 8887 additions and 319 deletions
+2
View File
@@ -4,3 +4,5 @@
anthropic-proxy anthropic-proxy
result result
config.yaml config.yaml
vendor/**
+33 -28
View File
@@ -1,57 +1,62 @@
# anthropic-proxy # anthropic-proxy
Reverse proxy that lets OpenCode (and similar tools) use a Claude subscription instead of an API key. Reverse proxy for the Anthropic Messages API that authenticates with a Claude subscription (OAuth) instead of an API key. Lets you use tools like [OpenCode](https://github.com/opencode-ai/opencode) through your existing Claude Pro/Team plan.
## Prerequisites ## How it works
- Go 1.26+ Clients send standard Anthropic API requests to the proxy. The proxy authenticates upstream using OAuth credentials from your Claude subscription, forwards the request, and streams the response back. Requests are optionally sanitized (tool name remapping, string replacement) before forwarding and de-sanitized on return.
- **Claude Code CLI** — installed and logged in (`claude auth login`). The proxy reads the OAuth token from `~/.claude/.credentials.json`.
Optional: [Nix](https://nixos.org/) flake for dev shell (`nix develop`). ## Features
## Setup - **OAuth credential management** — reuses `~/.claude/.credentials.json`, auto-refreshes tokens
- **Request sanitization** — rename tools, replace strings in system prompts and body (configurable, hot-reloadable)
- **Rate limit tracking** — polls Anthropic usage API and reads response headers to track 5h/7d utilization windows
- **OpenTelemetry metrics** — request counts, latency, token usage, errors (optional OTLP export)
- **Structured logging** — zerolog with file rotation via lumberjack
## Quick start
``` ```
cp config.example.yaml config.yaml cp config.example.yaml config.yaml
``` # edit config.yaml — set api_keys and optionally claude_binary
Edit `config.yaml`:
- `api_keys` — key(s) your clients use to authenticate with the proxy
- `claude_credentials` — path to your Claude credentials file
- `claude_binary` — path to `claude` binary (used on startup to capture request fingerprint)
## Build and run
```
go build -o anthropic-proxy . go build -o anthropic-proxy .
./anthropic-proxy ./anthropic-proxy
``` ```
## Usage with OpenCode On first run, if no credentials exist at `~/.claude/.credentials.json`, an OAuth login flow starts in your browser. If running headlessly, the authorization URL is printed to stdout. If you've already logged in with Claude Code CLI, the proxy reuses those credentials.
### Nix
```
nix develop # dev shell with Go
nix build # build the binary
```
## Client configuration
Point any Anthropic-compatible client at the proxy:
``` ```
export ANTHROPIC_API_KEY=your-proxy-api-key export ANTHROPIC_API_KEY=your-proxy-api-key
export ANTHROPIC_BASE_URL=http://localhost:8082 export ANTHROPIC_BASE_URL=http://localhost:8082
opencode
``` ```
## Endpoints ## Endpoints
| Method | Path | Description | | Method | Path | Description |
|--------|------|-------------| |--------|------|-------------|
| POST | `/v1/messages` | Anthropic messages API (proxied) | | POST | `/v1/messages` | Anthropic Messages API (proxied) |
| POST | `/messages` | Same, without `/v1` prefix | | POST | `/messages` | Same, without `/v1` prefix |
| GET | `/healthz` | Health check | | GET | `/healthz` | Health check |
| POST | `/reload` | Hot-reload `config.yaml` | | POST | `/reload` | Hot-reload config (sanitize rules + API keys) |
## Request sanitization ## Configuration
The `sanitize` section in config renames tool names and replaces strings in system prompts before forwarding to Anthropic. Responses are de-sanitized before returning to the client. See [`config.example.yaml`](config.example.yaml) for all options. Key sections:
See `config.example.yaml` for the default rules. - **`api_keys`** — keys clients use to authenticate with the proxy
- **`sanitize`** — tool renames, system prompt replacements, body replacements
Reload after editing config: - **`telemetry`** — OTLP endpoint, service name, auth headers
- **`logging`** — level, file path, rotation settings
``` - **`claude_binary`** — path to `claude` CLI for request fingerprinting (optional)
curl -X POST localhost:8082/reload
```
+42 -2
View File
@@ -1,7 +1,30 @@
port: 8082 port: 8082
# telemetry:
# service_name: "anthropic-proxy"
# export:
# endpoint: "localhost:4317" # OTLP gRPC endpoint (omit to disable export)
# insecure: true # disable TLS for local dev
# headers: # optional auth headers (e.g. Grafana Cloud)
# Authorization: "Basic ..."
# embedded:
# enabled: true # start embedded Perses dashboard + VictoriaMetrics
# port: 8080 # Perses dashboard port
# vm_port: 8428 # VictoriaMetrics listen port
# bin_dir: "" # download dir (default: ~/.cache/anthropic-proxy/bin)
# perses_binary: "" # custom path to perses binary (default: auto-download)
# vm_binary: "" # custom path to victoria-metrics binary (default: auto-download)
logging:
level: debug
file: /home/fujin/.local/log/anthropic-proxy.log
max_size_mb: 100
max_backups: 5
max_age_days: 30
compress: true
api_keys: api_keys:
- "your-proxy-api-key" - "your-proxy-api-key"
claude_credentials: "~/.claude/.credentials.json"
claude_binary: "claude" claude_binary: "claude"
sanitize: sanitize:
@@ -27,5 +50,22 @@ sanitize:
system: system:
- match: "Workspace root folder" - match: "Workspace root folder"
replace: "Working directory" replace: "Working directory"
- match: "anomalyco/opencode" body:
- match: "anthropics/claude-code"
replace: "anthropics/claude-code" replace: "anthropics/claude-code"
- match: "anthropic"
replace: "anthropic"
- match: "system-directive"
replace: "system-directive"
- match: "claude-code"
replace: "claude-code"
- match: "claude-agent"
replace: "claude-agent"
- match: "system_initiator"
replace: "system_initiator"
- match: "call_agent"
replace: "call_agent"
- match: "claude.ai"
replace: "claude.ai"
- match: "agent"
replace: "agent"
+80
View File
@@ -0,0 +1,80 @@
# Metrics Reference
All metrics are emitted via OpenTelemetry (OTLP gRPC) with cumulative temporality. OTel metric names use dots; Prometheus converts them to underscores and appends `_total` for counters.
## Counters
| OTel Name | Prometheus Name | Attributes | Description |
|---|---|---|---|
| `proxy.request.count` | `proxy_request_count_total` | `model`, `stream`, `status_code` | Total proxied requests |
| `proxy.tokens.input` | `proxy_tokens_input_total` | `model`, `credential` | Input tokens consumed |
| `proxy.tokens.output` | `proxy_tokens_output_total` | `model`, `credential` | Output tokens consumed |
| `proxy.upstream.errors` | `proxy_upstream_errors_total` | `error_type`, `credential`, `status_code` | Upstream errors (connection failures, 4xx/5xx) |
| `proxy.credential.cooldowns` | `proxy_credential_cooldowns_total` | `status_code` | Credential cooldown activations (rate limited) |
| `proxy.stream.requests` | `proxy_stream_requests_total` | `model` | Streaming request count |
## Histograms
| OTel Name | Prometheus Name | Unit | Attributes | Description |
|---|---|---|---|---|
| `proxy.request.duration_ms` | `proxy_request_duration_ms_milliseconds` | ms | `model`, `stream`, `status_code` | Request latency |
| `proxy.request.body_size_bytes` | `proxy_request_body_size_bytes` | bytes | `model`, `stream` | Request body size |
## Gauges
| OTel Name | Prometheus Name | Attributes | Description |
|---|---|---|---|
| `proxy.usage.utilization` | `proxy_usage_utilization` | `window` | Current utilization % from Anthropic API (0-100) |
| `proxy.usage.resets_at` | `proxy_usage_resets_at` | `window` | Unix timestamp when the rate limit window resets |
| `proxy.credential.active` | `proxy_credential_active` | — | Currently active (non-cooldown) credentials |
### Window attribute values
- `5h` — 5-hour rolling window
- `7d` — 7-day rolling window
- `7d_sonnet` — 7-day Sonnet-specific window
## Structured Logs (Loki)
Each completed request emits a structured log line via OTel LogBridge. Fields are stored as Loki stream labels/structured metadata.
| Field | Type | Description |
|---|---|---|
| `input_tokens` | int | Input tokens for this request |
| `output_tokens` | int | Output tokens for this request |
| `model` | string | Model used |
| `latency_ms` | float | Request latency in milliseconds |
| `status` | int | HTTP status code (non-stream only) |
| `stream` | bool | Whether this was a streaming request |
Log messages: `"request completed"` (non-stream), `"stream completed"` (stream).
### Per-window token counting via Loki
Window token totals are computed in Grafana using LogQL, not tracked in-memory. This approach survives process restarts and uses exact Anthropic window boundaries.
Grafana variables derive the window age from Prometheus:
```
window_age_5h = time() - proxy_usage_resets_at{window="5h"} + 18000
window_age_7d = time() - proxy_usage_resets_at{window="7d"} + 604800
```
LogQL queries sum token values from individual log events within the window:
```logql
sum(sum_over_time(
{service_name="anthropic-proxy"} |= "completed"
| unwrap output_tokens
| __error__=""
[${window_age_5h}s]
))
```
## Annotations (Loki)
429 rate limit events are surfaced as Grafana annotations:
```logql
{service_name="anthropic-proxy"} |= "upstream error" | json | status = "429"
```
File diff suppressed because it is too large Load Diff
+450
View File
@@ -0,0 +1,450 @@
{
"kind": "Dashboard",
"metadata": {
"name": "proxy",
"createdAt": "2026-04-14T19:47:48.013238204Z",
"updatedAt": "2026-04-14T19:49:30.874125459Z",
"version": 1,
"project": "anthropic-proxy"
},
"spec": {
"display": {
"name": "Anthropic Proxy"
},
"datasources": {
"vm": {
"default": true,
"plugin": {
"kind": "PrometheusDatasource",
"spec": {
"directUrl": "http://localhost:9428"
}
}
}
},
"panels": {
"latency": {
"kind": "Panel",
"spec": {
"display": {
"name": "Latency"
},
"plugin": {
"kind": "TimeSeriesChart",
"spec": {
"legend": {
"position": "bottom"
},
"yAxis": {
"format": {
"unit": "milliseconds"
}
}
}
},
"queries": [
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "histogram_quantile(0.50, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
"seriesNameFormat": "p50"
}
}
}
},
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "histogram_quantile(0.95, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
"seriesNameFormat": "p95"
}
}
}
},
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "histogram_quantile(0.99, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
"seriesNameFormat": "p99"
}
}
}
}
]
}
},
"request_rate": {
"kind": "Panel",
"spec": {
"display": {
"name": "Request Rate"
},
"plugin": {
"kind": "TimeSeriesChart",
"spec": {
"legend": {
"position": "bottom"
}
}
},
"queries": [
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "rate(proxy_request_count_total[5m])",
"seriesNameFormat": "req/s"
}
}
}
}
]
}
},
"token_rate": {
"kind": "Panel",
"spec": {
"display": {
"name": "Token Rate"
},
"plugin": {
"kind": "TimeSeriesChart",
"spec": {
"legend": {
"position": "bottom"
}
}
},
"queries": [
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "rate(proxy_tokens_input_total[5m]) * 60",
"seriesNameFormat": "input/min"
}
}
}
},
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "rate(proxy_tokens_output_total[5m]) * 60",
"seriesNameFormat": "output/min"
}
}
}
}
]
}
},
"tokens_5h": {
"kind": "Panel",
"spec": {
"display": {
"name": "5h Tokens"
},
"plugin": {
"kind": "StatChart",
"spec": {
"calculation": "last",
"format": {
"unit": "decimal"
},
"sparkline": {}
}
},
"queries": [
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "increase(proxy_tokens_output_total[3h])"
}
}
}
}
]
}
},
"tokens_7d": {
"kind": "Panel",
"spec": {
"display": {
"name": "7d Tokens"
},
"plugin": {
"kind": "StatChart",
"spec": {
"calculation": "last",
"format": {
"unit": "decimal"
},
"sparkline": {}
}
},
"queries": [
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "increase(proxy_tokens_output_total[9h])"
}
}
}
}
]
}
},
"util_5h": {
"kind": "Panel",
"spec": {
"display": {
"name": "5h Utilization"
},
"plugin": {
"kind": "GaugeChart",
"spec": {
"calculation": "last",
"format": {
"unit": "percent"
},
"thresholds": {
"steps": [
{
"color": "green",
"value": 0
},
{
"color": "orange",
"value": 70
},
{
"color": "red",
"value": 90
}
]
}
}
},
"queries": [
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "proxy_usage_utilization{window=\"5h\"}"
}
}
}
}
]
}
},
"util_7d": {
"kind": "Panel",
"spec": {
"display": {
"name": "7d Utilization"
},
"plugin": {
"kind": "GaugeChart",
"spec": {
"calculation": "last",
"format": {
"unit": "percent"
},
"thresholds": {
"steps": [
{
"color": "green",
"value": 0
},
{
"color": "orange",
"value": 70
},
{
"color": "red",
"value": 90
}
]
}
}
},
"queries": [
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "proxy_usage_utilization{window=\"7d\"}"
}
}
}
}
]
}
}
},
"layouts": [
{
"kind": "Grid",
"spec": {
"display": {
"title": "Utilization"
},
"items": [
{
"x": 0,
"y": 0,
"width": 6,
"height": 5,
"content": {
"$ref": "#/spec/panels/util_5h"
}
},
{
"x": 6,
"y": 0,
"width": 6,
"height": 5,
"content": {
"$ref": "#/spec/panels/util_7d"
}
},
{
"x": 12,
"y": 0,
"width": 6,
"height": 5,
"content": {
"$ref": "#/spec/panels/tokens_5h"
}
},
{
"x": 18,
"y": 0,
"width": 6,
"height": 5,
"content": {
"$ref": "#/spec/panels/tokens_7d"
}
}
]
}
},
{
"kind": "Grid",
"spec": {
"display": {
"title": "Traffic"
},
"items": [
{
"x": 0,
"y": 0,
"width": 12,
"height": 8,
"content": {
"$ref": "#/spec/panels/request_rate"
}
},
{
"x": 12,
"y": 0,
"width": 12,
"height": 8,
"content": {
"$ref": "#/spec/panels/latency"
}
}
]
}
},
{
"kind": "Grid",
"spec": {
"display": {
"title": "Tokens"
},
"items": [
{
"x": 0,
"y": 0,
"width": 24,
"height": 8,
"content": {
"$ref": "#/spec/panels/token_rate"
}
}
]
}
}
],
"duration": "1h",
"refreshInterval": "10s"
}
}
Generated
+3 -3
View File
@@ -20,11 +20,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1775423009, "lastModified": 1776169885,
"narHash": "sha256-vPKLpjhIVWdDrfiUM8atW6YkIggCEKdSAlJPzzhkQlw=", "narHash": "sha256-l/iNYDZ4bGOAFQY2q8y5OAfBBtrDAaPuRQqWaFHVRXM=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "68d8aa3d661f0e6bd5862291b5bb263b2a6595c9", "rev": "4bd9165a9165d7b5e33ae57f3eecbcb28fb231c9",
"type": "github" "type": "github"
}, },
"original": { "original": {
+8 -5
View File
@@ -17,14 +17,14 @@
let let
pkgs = import nixpkgs { pkgs = import nixpkgs {
inherit system; inherit system;
config.allowUnfreePredicate = config.allowUnfree = true;
pkg:
builtins.elem (pkgs.lib.getName pkg) [
"claude-code"
];
}; };
in in
{ {
packages = {
proxy = pkgs.callPackage ./package.nix {};
};
devShells.default = pkgs.mkShell { devShells.default = pkgs.mkShell {
buildInputs = with pkgs; [ buildInputs = with pkgs; [
go go
@@ -42,6 +42,9 @@
shellHook = '' shellHook = ''
export GOPATH="$PWD/.go" export GOPATH="$PWD/.go"
export PATH="$GOPATH/bin:$PATH" export PATH="$GOPATH/bin:$PATH"
export ANTHROPIC_BASE_URL=http://localhost:8082
export ANTHROPIC_API_KEY=sk-cliproxy-fujin
''; '';
}; };
} }
+50 -13
View File
@@ -5,44 +5,81 @@ go 1.26
require ( require (
github.com/gin-gonic/gin v1.12.0 github.com/gin-gonic/gin v1.12.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/prometheus/client_golang v1.23.2
github.com/refraction-networking/utls v1.8.2
github.com/rs/zerolog v1.35.0
github.com/tidwall/gjson v1.18.0 github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.68.0
go.opentelemetry.io/otel v1.43.0
go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.19.0
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0
go.opentelemetry.io/otel/exporters/prometheus v0.65.0
go.opentelemetry.io/otel/log v0.19.0
go.opentelemetry.io/otel/metric v1.43.0
go.opentelemetry.io/otel/sdk v1.43.0
go.opentelemetry.io/otel/sdk/log v0.19.0
go.opentelemetry.io/otel/sdk/metric v1.43.0
golang.org/x/net v0.52.0
gopkg.in/lumberjack.v2 v2.0.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )
require ( require (
github.com/BurntSushi/toml v1.6.0 // indirect
github.com/andybalholm/brotli v1.0.6 // indirect github.com/andybalholm/brotli v1.0.6 // indirect
github.com/bytedance/gopkg v0.1.3 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/bytedance/gopkg v0.1.4 // indirect
github.com/bytedance/sonic v1.15.0 // indirect github.com/bytedance/sonic v1.15.0 // indirect
github.com/bytedance/sonic/loader v0.5.0 // indirect github.com/bytedance/sonic/loader v0.5.1 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect github.com/cloudwego/base64x v0.1.6 // indirect
github.com/gabriel-vasile/mimetype v1.4.12 // indirect github.com/gabriel-vasile/mimetype v1.4.13 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect github.com/gin-contrib/sse v1.1.1 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.30.1 // indirect github.com/go-playground/validator/v10 v10.30.2 // indirect
github.com/goccy/go-json v0.10.5 // indirect github.com/goccy/go-json v0.10.6 // indirect
github.com/goccy/go-yaml v1.19.2 // indirect github.com/goccy/go-yaml v1.19.2 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.17.6 // indirect github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pelletier/go-toml/v2 v2.3.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.67.5 // indirect
github.com/prometheus/otlptranslator v1.0.0 // indirect
github.com/prometheus/procfs v0.20.1 // indirect
github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.59.0 // indirect github.com/quic-go/quic-go v0.59.0 // indirect
github.com/refraction-networking/utls v1.8.2 // indirect
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.3.1 // indirect github.com/ugorji/go/codec v1.3.1 // indirect
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
golang.org/x/arch v0.22.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 // indirect
go.opentelemetry.io/otel/trace v1.43.0 // indirect
go.opentelemetry.io/proto/otlp v1.10.0 // indirect
go.yaml.in/yaml/v2 v2.4.4 // indirect
golang.org/x/arch v0.25.0 // indirect
golang.org/x/crypto v0.49.0 // indirect golang.org/x/crypto v0.49.0 // indirect
golang.org/x/net v0.52.0 // indirect
golang.org/x/sys v0.42.0 // indirect golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect golang.org/x/text v0.35.0 // indirect
google.golang.org/protobuf v1.36.10 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
google.golang.org/grpc v1.80.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
) )
+111 -30
View File
@@ -1,51 +1,72 @@
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bytedance/gopkg v0.1.4 h1:oZnQwnX82KAIWb7033bEwtxvTqXcYMxDBaQxo5JJHWM=
github.com/bytedance/gopkg v0.1.4/go.mod h1:v1zWfPm21Fb+OsyXN2VAHdL6TBb2L88anLQgdyje6R4=
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= github.com/bytedance/sonic/loader v0.5.1 h1:Ygpfa9zwRCCKSlrp5bBP/b/Xzc3VxsAW+5NIYXrOOpI=
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/bytedance/sonic/loader v0.5.1/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw= github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM=
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/gabriel-vasile/mimetype v1.4.13/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w= github.com/gin-contrib/sse v1.1.1 h1:uGYpNwTacv5R68bSGMapo62iLTRa9l5zxGCps4hK6ko=
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= github.com/gin-contrib/sse v1.1.1/go.mod h1:QXzuVkA0YO7o/gun03UI1Q+FTI8ZV/n5t03kIQAI89s=
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8= github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc= github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w= github.com/go-playground/validator/v10 v10.30.2 h1:JiFIMtSSHb2/XBUbWM4i/MpeQm9ZK2xqPNk8vgvu5JQ=
github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM= github.com/go-playground/validator/v10 v10.30.2/go.mod h1:mAf2pIOVXjTEBrwUMGKkCWKKPs9NheYGabeB04txQSc=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.6 h1:p8HrPJzOakx/mn/bQtjgNjdTcN+/S6FcG2CTtQOrHVU=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-json v0.10.6/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -53,18 +74,32 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4=
github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw=
github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEoIwkU+A6qos=
github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM=
github.com/prometheus/procfs v0.20.1 h1:XwbrGOIplXW/AU3YhIhLODXMJYyC1isLFfYCsTEycfc=
github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo= github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/rs/zerolog v1.35.0 h1:VD0ykx7HMiMJytqINBsKcbLS+BJ4WYjz+05us+LRTdI=
github.com/rs/zerolog v1.35.0/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -91,32 +126,78 @@ github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY
github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE= go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0= go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.68.0 h1:5FXSL2s6afUC1bzNzl1iedZZ8yqR7GOhbCoEXtyeK6Q=
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.68.0/go.mod h1:MdHW7tLtkeGJnR4TyOrnd5D0zUGZQB1l84uHCe8hRpE=
go.opentelemetry.io/contrib/propagators/b3 v1.43.0 h1:CETqV3QLLPTy5yNrqyMr41VnAOOD4lsRved7n4QG00A=
go.opentelemetry.io/contrib/propagators/b3 v1.43.0/go.mod h1:Q4mCiCdziYzpNR0g+6UqVotAlCDZdzz6L8jwY4knOrw=
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.19.0 h1:Dn8rkudDzY6KV9dr/D/bTUuWgqDf9xe0rr4G2elrn0Y=
go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.19.0/go.mod h1:gMk9F0xDgyN9M/3Ed5Y1wKcx/9mlU91NXY2SNq7RQuU=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0 h1:8UQVDcZxOJLtX6gxtDt3vY2WTgvZqMQRzjsqiIHQdkc=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0/go.mod h1:2lmweYCiHYpEjQ/lSJBYhj9jP1zvCvQW4BqL9dnT7FQ=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 h1:88Y4s2C8oTui1LGM6bTWkw0ICGcOLCAI5l6zsD1j20k=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0/go.mod h1:Vl1/iaggsuRlrHf/hfPJPvVag77kKyvrLeD10kpMl+A=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 h1:RAE+JPfvEmvy+0LzyUA25/SGawPwIUbZ6u0Wug54sLc=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0/go.mod h1:AGmbycVGEsRx9mXMZ75CsOyhSP6MFIcj/6dnG+vhVjk=
go.opentelemetry.io/otel/exporters/prometheus v0.65.0 h1:jOveH/b4lU9HT7y+Gfamf18BqlOuz2PWEvs8yM7Q6XE=
go.opentelemetry.io/otel/exporters/prometheus v0.65.0/go.mod h1:i1P8pcumauPtUI4YNopea1dhzEMuEqWP1xoUZDylLHo=
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0 h1:mS47AX77OtFfKG4vtp+84kuGSFZHTyxtXIN269vChY0=
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0/go.mod h1:PJnsC41lAGncJlPUniSwM81gc80GkgWJWr3cu2nKEtU=
go.opentelemetry.io/otel/log v0.19.0 h1:KUZs/GOsw79TBBMfDWsXS+KZ4g2Ckzksd1ymzsIEbo4=
go.opentelemetry.io/otel/log v0.19.0/go.mod h1:5DQYeGmxVIr4n0/BcJvF4upsraHjg6vudJJpnkL6Ipk=
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
go.opentelemetry.io/otel/sdk/log v0.19.0 h1:scYVLqT22D2gqXItnWiocLUKGH9yvkkeql5dBDiXyko=
go.opentelemetry.io/otel/sdk/log v0.19.0/go.mod h1:vFBowwXGLlW9AvpuF7bMgnNI95LiW10szrOdvzBHlAg=
go.opentelemetry.io/otel/sdk/log/logtest v0.19.0 h1:BEbF7ZBB6qQloV/Ub1+3NQoOUnVtcGkU3XX4Ws3GQfk=
go.opentelemetry.io/otel/sdk/log/logtest v0.19.0/go.mod h1:Lua81/3yM0wOmoHTokLj9y9ADeA02v1naRrVrkAZuKk=
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI= go.yaml.in/yaml/v2 v2.4.4 h1:tuyd0P+2Ont/d6e2rl3be67goVK4R6deVxCUX5vyPaQ=
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= go.yaml.in/yaml/v2 v2.4.4/go.mod h1:gMZqIpDtDqOfM0uNfy0SkpRhvUryYH0Z6wdMYcacYXQ=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/arch v0.25.0 h1:qnk6Ksugpi5Bz32947rkUgDt9/s5qvqDPl/gBKdMJLE=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/arch v0.25.0/go.mod h1:0X+GdSIP+kL5wPmpK7sdkEVTt2XoYP0cSjQSbZBwOi8=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA=
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/lumberjack.v2 v2.0.0 h1:IDj6hi8KbNiPQ5VaYNFZ7dBJLF5LFeKvsFrWHjA5aq4=
gopkg.in/lumberjack.v2 v2.0.0/go.mod h1:bp5nQ2kK/lLQSmTk29azj9+JB6bWci56xFn/lvd5GLI=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+56
View File
@@ -0,0 +1,56 @@
package auth
import (
"encoding/json"
"fmt"
"os"
"time"
)
// claudeCredentialsJSON matches the structure of ~/.claude/.credentials.json.
type claudeCredentialsJSON struct {
ClaudeAiOauth struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ExpiresAt int64 `json:"expiresAt"`
SubscriptionType string `json:"subscriptionType"`
} `json:"claudeAiOauth"`
}
// LoadDefaultCredentials reads credentials from ~/.claude/.credentials.json.
// Returns nil, nil if the file does not exist.
func LoadDefaultCredentials() ([]*Credential, error) {
path, err := DefaultCredentialPath()
if err != nil {
return nil, nil
}
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
var cf claudeCredentialsJSON
if err := json.Unmarshal(data, &cf); err != nil {
return nil, err
}
oauth := cf.ClaudeAiOauth
if oauth.AccessToken == "" {
return nil, fmt.Errorf("no access token in %s", path)
}
cred := &Credential{
ID: "claude-native",
Email: oauth.SubscriptionType,
AccessToken: oauth.AccessToken,
RefreshToken: oauth.RefreshToken,
ExpiresAt: time.UnixMilli(oauth.ExpiresAt),
FilePath: path,
}
return []*Credential{cred}, nil
}
+70
View File
@@ -0,0 +1,70 @@
package auth
import (
"encoding/json"
"os"
"path/filepath"
"strings"
"testing"
)
func TestDefaultCredentialPath(t *testing.T) {
path, err := DefaultCredentialPath()
if err != nil {
t.Fatalf("DefaultCredentialPath error: %v", err)
}
if !strings.HasSuffix(path, filepath.Join(".claude", ".credentials.json")) {
t.Errorf("path = %q, want suffix .claude/.credentials.json", path)
}
}
func TestLoadDefaultCredentials_MissingFile(t *testing.T) {
// When credential file doesn't exist, returns nil, nil
path, err := DefaultCredentialPath()
if err != nil {
t.Skip("cannot determine home dir")
}
if _, statErr := os.Stat(path); os.IsNotExist(statErr) {
creds, err := LoadDefaultCredentials()
if creds != nil {
t.Errorf("expected nil creds for missing file, got %v", creds)
}
if err != nil {
t.Errorf("expected nil error for missing file, got %v", err)
}
}
}
func TestClaudeCredentialsJSON_ParsesCorrectly(t *testing.T) {
jsonData := `{"claudeAiOauth":{"accessToken":"test-token","refreshToken":"test-refresh","expiresAt":1234567890,"subscriptionType":"pro"}}`
var cf claudeCredentialsJSON
if err := json.Unmarshal([]byte(jsonData), &cf); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if cf.ClaudeAiOauth.AccessToken != "test-token" {
t.Errorf("AccessToken = %q, want test-token", cf.ClaudeAiOauth.AccessToken)
}
if cf.ClaudeAiOauth.RefreshToken != "test-refresh" {
t.Errorf("RefreshToken = %q, want test-refresh", cf.ClaudeAiOauth.RefreshToken)
}
if cf.ClaudeAiOauth.ExpiresAt != 1234567890 {
t.Errorf("ExpiresAt = %d, want 1234567890", cf.ClaudeAiOauth.ExpiresAt)
}
if cf.ClaudeAiOauth.SubscriptionType != "pro" {
t.Errorf("SubscriptionType = %q, want pro", cf.ClaudeAiOauth.SubscriptionType)
}
}
func TestClaudeCredentialsJSON_EmptyAccessToken(t *testing.T) {
jsonData := `{"claudeAiOauth":{"accessToken":"","refreshToken":"r","expiresAt":1}}`
var cf claudeCredentialsJSON
if err := json.Unmarshal([]byte(jsonData), &cf); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if cf.ClaudeAiOauth.AccessToken != "" {
t.Errorf("expected empty access token")
}
}
+276
View File
@@ -0,0 +1,276 @@
package auth
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/rs/zerolog/log"
)
const (
authURL = "https://claude.com/cai/oauth/authorize"
manualRedirect = "https://platform.claude.com/oauth/code/callback"
)
func base64URLEncode(data []byte) string {
return base64.RawURLEncoding.EncodeToString(data)
}
func generateCodeVerifier() string {
buf := make([]byte, 32)
_, _ = rand.Read(buf)
return base64URLEncode(buf)
}
func generateCodeChallenge(verifier string) string {
h := sha256.Sum256([]byte(verifier))
return base64URLEncode(h[:])
}
func generateState() string {
buf := make([]byte, 32)
_, _ = rand.Read(buf)
return base64URLEncode(buf)
}
func buildAuthURL(port int, codeChallenge, state string) string {
u, _ := url.Parse(authURL)
q := u.Query()
q.Set("client_id", clientID)
q.Set("response_type", "code")
q.Set("redirect_uri", fmt.Sprintf("http://localhost:%d/callback", port))
q.Set("scope", oauthScopes)
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
q.Set("state", state)
u.RawQuery = q.Encode()
return u.String()
}
func buildManualAuthURL(codeChallenge, state string) string {
u, _ := url.Parse(authURL)
q := u.Query()
q.Set("client_id", clientID)
q.Set("response_type", "code")
q.Set("redirect_uri", manualRedirect)
q.Set("scope", oauthScopes)
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
q.Set("state", state)
u.RawQuery = q.Encode()
return u.String()
}
func startCallbackServer(expectedState string) (port int, codeChan <-chan string, cleanup func(), err error) {
ch := make(chan string, 1)
ln, err := net.Listen("tcp", "localhost:0")
if err != nil {
return 0, nil, nil, err
}
port = ln.Addr().(*net.TCPAddr).Port
srv := &http.Server{}
srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/callback" {
http.NotFound(w, r)
return
}
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
if state != expectedState {
http.Error(w, "invalid state", http.StatusBadRequest)
return
}
if code == "" {
http.Error(w, "missing code", http.StatusBadRequest)
return
}
select {
case ch <- code:
w.Header().Set("Content-Type", "text/html")
fmt.Fprintln(w, "<html><body><h2>Login successful! You can close this tab.</h2></body></html>")
default:
fmt.Fprintln(w, "<html><body><h2>Already received. You can close this tab.</h2></body></html>")
}
})
go srv.Serve(ln)
cleanup = func() {
srv.Close()
}
return port, ch, cleanup, nil
}
// DefaultCredentialPath returns the path to the Claude credentials file.
func DefaultCredentialPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".claude", ".credentials.json"), nil
}
// Login performs the full OAuth 2.0 PKCE login flow and returns a Credential.
func Login(ctx context.Context) (*Credential, error) {
verifier := generateCodeVerifier()
challenge := generateCodeChallenge(verifier)
state := generateState()
port, codeChan, cleanup, err := startCallbackServer(state)
if err != nil {
return nil, fmt.Errorf("start callback server: %w", err)
}
defer cleanup()
autoURL := buildAuthURL(port, challenge, state)
manualURL := buildManualAuthURL(challenge, state)
fmt.Printf("\nTo sign in, visit:\n %s\n\n", manualURL)
openBrowser(autoURL)
var authCode string
var isManual bool
stdinCh := make(chan string, 1)
fi, statErr := os.Stdin.Stat()
if statErr == nil && (fi.Mode()&os.ModeCharDevice) != 0 {
fmt.Print("If browser didn't open, paste the authorization code here: ")
go func() {
var line string
scanner := bufio.NewScanner(os.Stdin)
if scanner.Scan() {
line = strings.TrimSpace(scanner.Text())
}
if line != "" {
stdinCh <- line
}
}()
}
timeout := time.NewTimer(120 * time.Second)
defer timeout.Stop()
select {
case code := <-codeChan:
authCode = code
isManual = false
case code := <-stdinCh:
authCode = code
isManual = true
case <-timeout.C:
return nil, fmt.Errorf("login timed out after 120 seconds")
case <-ctx.Done():
return nil, ctx.Err()
}
credPath, err := DefaultCredentialPath()
if err != nil {
return nil, fmt.Errorf("credential path: %w", err)
}
return exchangeAuthCode(ctx, authCode, state, verifier, port, isManual, credPath)
}
type authCodeRequest struct {
GrantType string `json:"grant_type"`
Code string `json:"code"`
RedirectURI string `json:"redirect_uri"`
ClientID string `json:"client_id"`
CodeVerifier string `json:"code_verifier"`
State string `json:"state"`
}
func exchangeAuthCode(ctx context.Context, code, state, verifier string, port int, isManual bool, credPath string) (*Credential, error) {
redirectURI := fmt.Sprintf("http://localhost:%d/callback", port)
if isManual {
redirectURI = manualRedirect
}
reqBody, _ := json.Marshal(authCodeRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: redirectURI,
ClientID: clientID,
CodeVerifier: verifier,
State: state,
})
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenEndpoint, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := utlsClient.Do(req)
if err != nil {
return nil, fmt.Errorf("token exchange: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token exchange returned %d: %s", resp.StatusCode, string(body))
}
var tokenResp tokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("decode token response: %w", err)
}
cred := &Credential{
ID: "claude-native",
Email: tokenResp.Account.EmailAddress,
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ExpiresAt: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second),
FilePath: credPath,
}
if err := ensureCredentialFile(credPath); err != nil {
return nil, fmt.Errorf("ensure credential file: %w", err)
}
if err := persistCredential(cred); err != nil {
return nil, fmt.Errorf("save credential: %w", err)
}
log.Info().Str("path", credPath).Msg("login successful, credentials saved")
return cred, nil
}
func ensureCredentialFile(path string) error {
if _, err := os.Stat(path); err == nil {
return nil
}
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0700); err != nil {
return err
}
return os.WriteFile(path, []byte("{}"), 0600)
}
func openBrowser(url string) {
var cmd *exec.Cmd
switch runtime.GOOS {
case "darwin":
cmd = exec.Command("open", url)
case "linux":
cmd = exec.Command("xdg-open", url)
default:
return
}
_ = cmd.Start()
}
+33 -72
View File
@@ -6,15 +6,14 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log"
"net"
"net/http" "net/http"
"os" "os"
"sync" "path/filepath"
"time" "time"
tls "github.com/refraction-networking/utls" "github.com/rs/zerolog/log"
"golang.org/x/net/http2"
"github.com/fujin/anthropic-proxy/internal/transport"
) )
const ( const (
@@ -27,7 +26,7 @@ const (
refreshBackoff = 5 * time.Minute refreshBackoff = 5 * time.Minute
) )
var utlsClient = newUTLSClient() var utlsClient = transport.NewHTTPClient(15 * time.Second)
type tokenRequest struct { type tokenRequest struct {
ClientID string `json:"client_id"` ClientID string `json:"client_id"`
@@ -63,6 +62,13 @@ func RefreshToken(ctx context.Context, cred *Credential) error {
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
log.Debug().
Str("url", tokenEndpoint).
Str("grant_type", "refresh_token").
Str("client_id", clientID).
Str("scope", oauthScopes).
Msg("token refresh request")
resp, err := utlsClient.Do(req) resp, err := utlsClient.Do(req)
if err != nil { if err != nil {
return fmt.Errorf("execute request: %w", err) return fmt.Errorf("execute request: %w", err)
@@ -70,6 +76,12 @@ func RefreshToken(ctx context.Context, cred *Credential) error {
defer resp.Body.Close() defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
log.Debug().
Int("status", resp.StatusCode).
Int("response_size", len(body)).
Msg("token refresh response")
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("refresh returned %d: %s", resp.StatusCode, string(body)) return fmt.Errorf("refresh returned %d: %s", resp.StatusCode, string(body))
} }
@@ -105,14 +117,22 @@ func persistCredential(cred *Credential) error {
return nil return nil
} }
var doc map[string]any
raw, err := os.ReadFile(filePath) raw, err := os.ReadFile(filePath)
if err != nil { if err != nil {
if !os.IsNotExist(err) {
return err return err
} }
var doc map[string]any // File doesn't exist yet (cold start) — create from scratch
if mkdirErr := os.MkdirAll(filepath.Dir(filePath), 0700); mkdirErr != nil {
return fmt.Errorf("create credential dir: %w", mkdirErr)
}
doc = make(map[string]any)
} else {
if err := json.Unmarshal(raw, &doc); err != nil { if err := json.Unmarshal(raw, &doc); err != nil {
return err return err
} }
}
oauth, _ := doc["claudeAiOauth"].(map[string]any) oauth, _ := doc["claudeAiOauth"].(map[string]any)
if oauth == nil { if oauth == nil {
oauth = make(map[string]any) oauth = make(map[string]any)
@@ -125,73 +145,12 @@ func persistCredential(cred *Credential) error {
return os.WriteFile(filePath, out, 0600) return os.WriteFile(filePath, out, 0600)
} }
func newUTLSClient() *http.Client {
return &http.Client{
Timeout: 15 * time.Second,
Transport: &utlsRefreshTransport{},
}
}
type utlsRefreshTransport struct {
mu sync.Mutex
conn *http2.ClientConn
host string
}
func (t *utlsRefreshTransport) RoundTrip(req *http.Request) (*http.Response, error) {
host := req.URL.Hostname()
port := req.URL.Port()
if port == "" {
port = "443"
}
t.mu.Lock()
if t.conn != nil && t.host == host && t.conn.CanTakeNewRequest() {
conn := t.conn
t.mu.Unlock()
resp, err := conn.RoundTrip(req)
if err == nil {
return resp, nil
}
t.mu.Lock()
t.conn = nil
t.mu.Unlock()
} else {
t.mu.Unlock()
}
addr := net.JoinHostPort(host, port)
rawConn, err := net.DialTimeout("tcp", addr, 10*time.Second)
if err != nil {
return nil, err
}
tlsConn := tls.UClient(rawConn, &tls.Config{ServerName: host}, tls.HelloChrome_Auto)
if err := tlsConn.Handshake(); err != nil {
rawConn.Close()
return nil, err
}
h2Conn, err := (&http2.Transport{}).NewClientConn(tlsConn)
if err != nil {
tlsConn.Close()
return nil, err
}
t.mu.Lock()
t.conn = h2Conn
t.host = host
t.mu.Unlock()
return h2Conn.RoundTrip(req)
}
func StartBackgroundRefresh(ctx context.Context, pool *Pool) { func StartBackgroundRefresh(ctx context.Context, pool *Pool) {
go func() { go func() {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
log.Printf("background refresh stopped") log.Info().Msg("background refresh stopped")
return return
case <-time.After(refreshInterval): case <-time.After(refreshInterval):
refreshExpiring(pool) refreshExpiring(pool)
@@ -213,6 +172,7 @@ func refreshExpiring(pool *Pool) {
hasRefresh := cred.RefreshToken != "" hasRefresh := cred.RefreshToken != ""
nextRetry := cred.nextRefreshAfter nextRetry := cred.nextRefreshAfter
email := cred.Email email := cred.Email
expiresAt := cred.ExpiresAt
cred.mu.Unlock() cred.mu.Unlock()
if !hasRefresh || !needsRefresh { if !hasRefresh || !needsRefresh {
@@ -222,21 +182,22 @@ func refreshExpiring(pool *Pool) {
continue continue
} }
log.Printf("refreshing token for %s (expires %s)", email, cred.ExpiresAt.Format(time.RFC3339)) log.Info().Str("credential", email).Time("expires_at", expiresAt).Msg("refreshing token")
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
err := RefreshToken(ctx, cred) err := RefreshToken(ctx, cred)
cancel() cancel()
if err != nil { if err != nil {
log.Printf("refresh failed for %s: %v", email, err) log.Error().Err(err).Str("credential", email).Msg("token refresh failed")
cred.mu.Lock() cred.mu.Lock()
cred.nextRefreshAfter = time.Now().Add(refreshBackoff) cred.nextRefreshAfter = time.Now().Add(refreshBackoff)
cred.mu.Unlock() cred.mu.Unlock()
} else { } else {
log.Printf("refreshed %s, new expiry %s", email, cred.ExpiresAt.Format(time.RFC3339))
cred.mu.Lock() cred.mu.Lock()
newExpiresAt := cred.ExpiresAt
cred.nextRefreshAfter = time.Time{} cred.nextRefreshAfter = time.Time{}
cred.mu.Unlock() cred.mu.Unlock()
log.Info().Str("credential", email).Time("new_expiry", newExpiresAt).Msg("token refreshed")
} }
} }
} }
+318
View File
@@ -0,0 +1,318 @@
package auth
import (
"testing"
"time"
)
func TestNewPool(t *testing.T) {
creds := []*Credential{
{ID: "a", AccessToken: "tok-a"},
{ID: "b", AccessToken: "tok-b"},
}
p := NewPool(creds)
if p == nil {
t.Fatal("NewPool returned nil")
}
if len(p.creds) != 2 {
t.Errorf("pool has %d creds, want 2", len(p.creds))
}
if p.cursor != 0 {
t.Errorf("initial cursor = %d, want 0", p.cursor)
}
}
func TestPool_Pick_EmptyPool(t *testing.T) {
p := NewPool(nil)
_, err := p.Pick()
if err == nil {
t.Fatal("expected error from empty pool, got nil")
}
want := "no credentials available"
if err.Error() != want {
t.Errorf("error = %q, want %q", err.Error(), want)
}
}
func TestPool_Pick_SingleCredential(t *testing.T) {
cred := &Credential{ID: "only", AccessToken: "tok-only"}
p := NewPool([]*Credential{cred})
got, err := p.Pick()
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
if got.ID != "only" {
t.Errorf("Pick() returned cred ID %q, want %q", got.ID, "only")
}
// Picking again should return the same credential
got2, err := p.Pick()
if err != nil {
t.Fatalf("second Pick() error = %v", err)
}
if got2.ID != "only" {
t.Errorf("second Pick() returned cred ID %q, want %q", got2.ID, "only")
}
}
func TestPool_Pick_RoundRobin(t *testing.T) {
creds := []*Credential{
{ID: "a"},
{ID: "b"},
{ID: "c"},
}
p := NewPool(creds)
// Should cycle through a, b, c, a, b, c
expected := []string{"a", "b", "c", "a", "b", "c"}
for i, want := range expected {
got, err := p.Pick()
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got.ID != want {
t.Errorf("Pick() #%d = %q, want %q", i, got.ID, want)
}
}
}
func TestPool_Pick_SkipsCooldown(t *testing.T) {
creds := []*Credential{
{ID: "a"},
{ID: "b", cooldownUntil: time.Now().Add(1 * time.Hour)},
{ID: "c"},
}
p := NewPool(creds)
// First pick: "a" (index 0, not on cooldown)
got, err := p.Pick()
if err != nil {
t.Fatalf("Pick() #1 error = %v", err)
}
if got.ID != "a" {
t.Errorf("Pick() #1 = %q, want %q", got.ID, "a")
}
// Second pick: cursor at 1, but "b" is on cooldown → skip to "c"
got, err = p.Pick()
if err != nil {
t.Fatalf("Pick() #2 error = %v", err)
}
if got.ID != "c" {
t.Errorf("Pick() #2 = %q, want %q", got.ID, "c")
}
// Third pick: cursor advanced past "c" to 0 → "a"
got, err = p.Pick()
if err != nil {
t.Fatalf("Pick() #3 error = %v", err)
}
if got.ID != "a" {
t.Errorf("Pick() #3 = %q, want %q", got.ID, "a")
}
}
func TestPool_Pick_AllOnCooldown(t *testing.T) {
future := time.Now().Add(1 * time.Hour)
creds := []*Credential{
{ID: "a", cooldownUntil: future},
{ID: "b", cooldownUntil: future},
}
p := NewPool(creds)
_, err := p.Pick()
if err == nil {
t.Fatal("expected error when all on cooldown, got nil")
}
want := "all 2 credentials are on cooldown"
if err.Error() != want {
t.Errorf("error = %q, want %q", err.Error(), want)
}
}
func TestPool_MarkFailure(t *testing.T) {
tests := []struct {
name string
statusCode int
expectCooldown bool
expectedDur time.Duration
}{
{
name: "429 sets 30s cooldown",
statusCode: 429,
expectCooldown: true,
expectedDur: 30 * time.Second,
},
{
name: "500 sets 5s cooldown",
statusCode: 500,
expectCooldown: true,
expectedDur: 5 * time.Second,
},
{
name: "502 sets 5s cooldown",
statusCode: 502,
expectCooldown: true,
expectedDur: 5 * time.Second,
},
{
name: "503 sets 5s cooldown",
statusCode: 503,
expectCooldown: true,
expectedDur: 5 * time.Second,
},
{
name: "400 does NOT set cooldown",
statusCode: 400,
expectCooldown: false,
},
{
name: "401 does NOT set cooldown",
statusCode: 401,
expectCooldown: false,
},
{
name: "403 does NOT set cooldown",
statusCode: 403,
expectCooldown: false,
},
{
name: "404 does NOT set cooldown",
statusCode: 404,
expectCooldown: false,
},
{
name: "422 does NOT set cooldown",
statusCode: 422,
expectCooldown: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cred := &Credential{ID: "test"}
p := NewPool([]*Credential{cred})
before := time.Now()
p.MarkFailure(cred, tt.statusCode)
if tt.expectCooldown {
if !cred.IsOnCooldown() {
t.Errorf("expected cooldown after status %d", tt.statusCode)
}
// Verify approximate duration
cred.mu.Lock()
cooldownEnd := cred.cooldownUntil
cred.mu.Unlock()
lower := before.Add(tt.expectedDur)
upper := time.Now().Add(tt.expectedDur)
if cooldownEnd.Before(lower) || cooldownEnd.After(upper) {
t.Errorf("cooldownUntil %v not in expected range [%v, %v]", cooldownEnd, lower, upper)
}
} else {
if cred.IsOnCooldown() {
t.Errorf("did not expect cooldown after status %d", tt.statusCode)
}
}
})
}
}
func TestPool_MarkSuccess(t *testing.T) {
cred := &Credential{
ID: "test",
cooldownUntil: time.Now().Add(1 * time.Hour),
}
p := NewPool([]*Credential{cred})
if !cred.IsOnCooldown() {
t.Fatal("precondition: expected credential to be on cooldown")
}
p.MarkSuccess(cred)
if cred.IsOnCooldown() {
t.Error("expected cooldown to be cleared after MarkSuccess")
}
}
func TestPool_RoundRobinCursorAdvancement(t *testing.T) {
creds := []*Credential{
{ID: "0"},
{ID: "1"},
{ID: "2"},
}
p := NewPool(creds)
// Verify cursor starts at 0
if p.cursor != 0 {
t.Fatalf("initial cursor = %d, want 0", p.cursor)
}
// Pick cred[0], cursor should advance to 1
got, _ := p.Pick()
if got.ID != "0" {
t.Errorf("first pick = %q, want %q", got.ID, "0")
}
if p.cursor != 1 {
t.Errorf("cursor after first pick = %d, want 1", p.cursor)
}
// Pick cred[1], cursor should advance to 2
got, _ = p.Pick()
if got.ID != "1" {
t.Errorf("second pick = %q, want %q", got.ID, "1")
}
if p.cursor != 2 {
t.Errorf("cursor after second pick = %d, want 2", p.cursor)
}
// Pick cred[2], cursor should wrap to 0
got, _ = p.Pick()
if got.ID != "2" {
t.Errorf("third pick = %q, want %q", got.ID, "2")
}
if p.cursor != 0 {
t.Errorf("cursor after third pick = %d, want 0 (wrap)", p.cursor)
}
}
func TestPool_RoundRobinWithCooldownSkip(t *testing.T) {
creds := []*Credential{
{ID: "0"},
{ID: "1", cooldownUntil: time.Now().Add(1 * time.Hour)},
{ID: "2"},
}
p := NewPool(creds)
// First pick: cred[0]
got, _ := p.Pick()
if got.ID != "0" {
t.Errorf("first pick = %q, want %q", got.ID, "0")
}
// Cursor should be at 1
if p.cursor != 1 {
t.Errorf("cursor after first pick = %d, want 1", p.cursor)
}
// Second pick: cursor at 1, but cred[1] on cooldown → skip to cred[2]
got, _ = p.Pick()
if got.ID != "2" {
t.Errorf("second pick = %q, want %q", got.ID, "2")
}
// Cursor should advance past cred[2] to 0
if p.cursor != 0 {
t.Errorf("cursor after second pick (skip) = %d, want 0", p.cursor)
}
// Third pick: cursor at 0, cred[0] available
got, _ = p.Pick()
if got.ID != "0" {
t.Errorf("third pick = %q, want %q", got.ID, "0")
}
if p.cursor != 1 {
t.Errorf("cursor after third pick = %d, want 1", p.cursor)
}
}
+4 -4
View File
@@ -13,7 +13,7 @@ type Credential struct {
RefreshToken string RefreshToken string
ExpiresAt time.Time ExpiresAt time.Time
FilePath string FilePath string
CooldownUntil time.Time cooldownUntil time.Time
nextRefreshAfter time.Time nextRefreshAfter time.Time
mu sync.Mutex mu sync.Mutex
} }
@@ -22,21 +22,21 @@ type Credential struct {
func (c *Credential) IsOnCooldown() bool { func (c *Credential) IsOnCooldown() bool {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
return time.Now().Before(c.CooldownUntil) return time.Now().Before(c.cooldownUntil)
} }
// SetCooldown puts the credential on cooldown for the given duration. // SetCooldown puts the credential on cooldown for the given duration.
func (c *Credential) SetCooldown(duration time.Duration) { func (c *Credential) SetCooldown(duration time.Duration) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.CooldownUntil = time.Now().Add(duration) c.cooldownUntil = time.Now().Add(duration)
} }
// ClearCooldown removes any active cooldown on the credential. // ClearCooldown removes any active cooldown on the credential.
func (c *Credential) ClearCooldown() { func (c *Credential) ClearCooldown() {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.CooldownUntil = time.Time{} c.cooldownUntil = time.Time{}
} }
// Token returns the current access token. // Token returns the current access token.
+167
View File
@@ -0,0 +1,167 @@
package auth
import (
"sync"
"testing"
"time"
)
func TestCredential_IsOnCooldown(t *testing.T) {
tests := []struct {
name string
cooldownUntil time.Time
want bool
}{
{
name: "zero time — not on cooldown",
cooldownUntil: time.Time{},
want: false,
},
{
name: "future time — on cooldown",
cooldownUntil: time.Now().Add(1 * time.Hour),
want: true,
},
{
name: "past time — expired cooldown",
cooldownUntil: time.Now().Add(-1 * time.Hour),
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Credential{cooldownUntil: tt.cooldownUntil}
got := c.IsOnCooldown()
if got != tt.want {
t.Errorf("IsOnCooldown() = %v, want %v", got, tt.want)
}
})
}
}
func TestCredential_SetCooldown(t *testing.T) {
tests := []struct {
name string
duration time.Duration
}{
{name: "30 second cooldown", duration: 30 * time.Second},
{name: "5 second cooldown", duration: 5 * time.Second},
{name: "1 minute cooldown", duration: 1 * time.Minute},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Credential{}
before := time.Now()
c.SetCooldown(tt.duration)
after := time.Now()
// cooldownUntil should be between before+duration and after+duration
if c.cooldownUntil.Before(before.Add(tt.duration)) {
t.Errorf("cooldownUntil %v is before expected lower bound %v", c.cooldownUntil, before.Add(tt.duration))
}
if c.cooldownUntil.After(after.Add(tt.duration)) {
t.Errorf("cooldownUntil %v is after expected upper bound %v", c.cooldownUntil, after.Add(tt.duration))
}
// Should now be on cooldown
if !c.IsOnCooldown() {
t.Error("expected credential to be on cooldown after SetCooldown")
}
})
}
}
func TestCredential_ClearCooldown(t *testing.T) {
t.Run("clears active cooldown", func(t *testing.T) {
c := &Credential{cooldownUntil: time.Now().Add(1 * time.Hour)}
if !c.IsOnCooldown() {
t.Fatal("precondition: expected credential to be on cooldown")
}
c.ClearCooldown()
if c.IsOnCooldown() {
t.Error("expected credential to not be on cooldown after ClearCooldown")
}
if !c.cooldownUntil.IsZero() {
t.Errorf("expected cooldownUntil to be zero time, got %v", c.cooldownUntil)
}
})
t.Run("clearing when not on cooldown is no-op", func(t *testing.T) {
c := &Credential{}
c.ClearCooldown()
if c.IsOnCooldown() {
t.Error("expected credential to not be on cooldown")
}
if !c.cooldownUntil.IsZero() {
t.Errorf("expected cooldownUntil to be zero time, got %v", c.cooldownUntil)
}
})
}
func TestCredential_Token(t *testing.T) {
tests := []struct {
name string
token string
}{
{name: "returns access token", token: "sk-ant-abc123"},
{name: "empty token", token: ""},
{name: "long token", token: "sk-ant-" + string(make([]byte, 200))},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Credential{AccessToken: tt.token}
got := c.Token()
if got != tt.token {
t.Errorf("Token() = %q, want %q", got, tt.token)
}
})
}
}
func TestCredential_ConcurrentAccess(t *testing.T) {
c := &Credential{
AccessToken: "initial-token",
}
var wg sync.WaitGroup
const goroutines = 50
// Spawn goroutines that concurrently read and write
for i := 0; i < goroutines; i++ {
wg.Add(3)
go func() {
defer wg.Done()
_ = c.Token()
}()
go func() {
defer wg.Done()
c.SetCooldown(1 * time.Second)
}()
go func() {
defer wg.Done()
_ = c.IsOnCooldown()
}()
}
// Also mix in ClearCooldown calls
for i := 0; i < goroutines/2; i++ {
wg.Add(1)
go func() {
defer wg.Done()
c.ClearCooldown()
}()
}
wg.Wait()
// If we get here without -race detecting issues, mutex is working
}
+71 -50
View File
@@ -1,21 +1,19 @@
package config package config
import ( import (
"encoding/json"
"fmt" "fmt"
"os" "os"
"time"
"github.com/fujin/anthropic-proxy/internal/auth"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
type Config struct { type Config struct {
Port int `yaml:"port"` Port int `yaml:"port"`
APIKeys []string `yaml:"api_keys"` APIKeys []string `yaml:"api_keys"`
ClaudeCredentials string `yaml:"claude_credentials"`
ClaudeBinary string `yaml:"claude_binary"` ClaudeBinary string `yaml:"claude_binary"`
Sanitize SanitizeConfig `yaml:"sanitize"` Sanitize SanitizeConfig `yaml:"sanitize"`
Logging LoggingConfig `yaml:"logging"`
Telemetry TelemetryConfig `yaml:"telemetry"`
} }
type SanitizeConfig struct { type SanitizeConfig struct {
@@ -34,13 +32,36 @@ type ReplaceRule struct {
Replace string `yaml:"replace"` Replace string `yaml:"replace"`
} }
type claudeCredentialsJSON struct { type TelemetryConfig struct {
ClaudeAiOauth struct { Export ExportConfig `yaml:"export"`
AccessToken string `json:"accessToken"` Embedded EmbeddedConfig `yaml:"embedded"`
RefreshToken string `json:"refreshToken"` ServiceName string `yaml:"service_name"`
ExpiresAt int64 `json:"expiresAt"` }
SubscriptionType string `json:"subscriptionType"`
} `json:"claudeAiOauth"` type ExportConfig struct {
Endpoint string `yaml:"endpoint"`
Insecure bool `yaml:"insecure"`
Headers map[string]string `yaml:"headers"`
}
func (e ExportConfig) Enabled() bool { return e.Endpoint != "" }
type EmbeddedConfig struct {
Enabled bool `yaml:"enabled"`
Port int `yaml:"port"`
PersesBinary string `yaml:"perses_binary"`
VMBinary string `yaml:"vm_binary"`
VMPort int `yaml:"vm_port"`
BinDir string `yaml:"bin_dir"`
}
type LoggingConfig struct {
Level string `yaml:"level"`
File string `yaml:"file"`
MaxSizeMB int `yaml:"max_size_mb"`
MaxBackups int `yaml:"max_backups"`
MaxAgeDays int `yaml:"max_age_days"`
Compress bool `yaml:"compress"`
} }
func Load(path string) (*Config, error) { func Load(path string) (*Config, error) {
@@ -54,44 +75,44 @@ func Load(path string) (*Config, error) {
return nil, fmt.Errorf("parse config: %w", err) return nil, fmt.Errorf("parse config: %w", err)
} }
if cfg.Logging.Level == "" {
cfg.Logging.Level = "info"
}
if cfg.Logging.MaxSizeMB == 0 {
cfg.Logging.MaxSizeMB = 100
}
if cfg.Logging.MaxBackups == 0 {
cfg.Logging.MaxBackups = 5
}
if cfg.Logging.MaxAgeDays == 0 {
cfg.Logging.MaxAgeDays = 30
}
if cfg.Telemetry.ServiceName == "" {
cfg.Telemetry.ServiceName = "anthropic-proxy"
}
if cfg.Telemetry.Embedded.Port == 0 {
cfg.Telemetry.Embedded.Port = 8080
}
if cfg.Telemetry.Embedded.PersesBinary == "" {
cfg.Telemetry.Embedded.PersesBinary = "perses"
}
if cfg.Telemetry.Embedded.VMBinary == "" {
cfg.Telemetry.Embedded.VMBinary = "victoria-metrics"
}
if cfg.Telemetry.Embedded.VMPort == 0 {
cfg.Telemetry.Embedded.VMPort = 8428
}
// Check for deprecated claude_credentials field
var rawCfg map[string]interface{}
if err := yaml.Unmarshal(data, &rawCfg); err == nil {
if _, exists := rawCfg["claude_credentials"]; exists {
if val, ok := rawCfg["claude_credentials"].(string); ok && val != "" {
return nil, fmt.Errorf("claude_credentials is no longer supported, remove it from config.yaml — the proxy now manages credentials at ~/.claude/.credentials.json")
}
}
}
return cfg, nil return cfg, nil
} }
func LoadCredentials(cfg *Config) ([]*auth.Credential, error) {
if cfg.ClaudeCredentials == "" {
return nil, fmt.Errorf("claude_credentials not set")
}
cred, err := loadCredentials(cfg.ClaudeCredentials)
if err != nil {
return nil, err
}
return []*auth.Credential{cred}, nil
}
func loadCredentials(path string) (*auth.Credential, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var cf claudeCredentialsJSON
if err := json.Unmarshal(data, &cf); err != nil {
return nil, err
}
oauth := cf.ClaudeAiOauth
if oauth.AccessToken == "" {
return nil, fmt.Errorf("no access token in %s", path)
}
return &auth.Credential{
ID: "claude-native",
Email: oauth.SubscriptionType,
AccessToken: oauth.AccessToken,
RefreshToken: oauth.RefreshToken,
ExpiresAt: time.UnixMilli(oauth.ExpiresAt),
FilePath: path,
}, nil
}
+270
View File
@@ -0,0 +1,270 @@
package config
import (
"os"
"path/filepath"
"strings"
"testing"
)
func TestLoad_AllFields(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
yaml := `
port: 9090
api_keys:
- key1
- key2
claude_binary: /usr/bin/claude
sanitize:
tools:
- from: tool_a
to: tool_b
system:
- match: foo
replace: bar
body:
- match: baz
replace: qux
logging:
level: debug
file: /tmp/test.log
max_size_mb: 50
max_backups: 3
max_age_days: 7
compress: true
telemetry:
service_name: my-proxy
export:
endpoint: http://localhost:4317
insecure: true
headers:
x-token: abc
embedded:
enabled: true
port: 9999
perses_binary: /usr/bin/perses
vm_binary: /usr/bin/vm
vm_port: 9428
bin_dir: /opt/bin
`
if err := os.WriteFile(path, []byte(yaml), 0644); err != nil {
t.Fatal(err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load returned error: %v", err)
}
if cfg.Port != 9090 {
t.Errorf("Port = %d, want 9090", cfg.Port)
}
if len(cfg.APIKeys) != 2 || cfg.APIKeys[0] != "key1" || cfg.APIKeys[1] != "key2" {
t.Errorf("APIKeys = %v, want [key1 key2]", cfg.APIKeys)
}
if cfg.ClaudeBinary != "/usr/bin/claude" {
t.Errorf("ClaudeBinary = %q, want /usr/bin/claude", cfg.ClaudeBinary)
}
// Sanitize
if len(cfg.Sanitize.Tools) != 1 || cfg.Sanitize.Tools[0].From != "tool_a" || cfg.Sanitize.Tools[0].To != "tool_b" {
t.Errorf("Sanitize.Tools = %v", cfg.Sanitize.Tools)
}
if len(cfg.Sanitize.System) != 1 || cfg.Sanitize.System[0].Match != "foo" {
t.Errorf("Sanitize.System = %v", cfg.Sanitize.System)
}
if len(cfg.Sanitize.Body) != 1 || cfg.Sanitize.Body[0].Match != "baz" {
t.Errorf("Sanitize.Body = %v", cfg.Sanitize.Body)
}
// Logging
if cfg.Logging.Level != "debug" {
t.Errorf("Logging.Level = %q, want debug", cfg.Logging.Level)
}
if cfg.Logging.File != "/tmp/test.log" {
t.Errorf("Logging.File = %q", cfg.Logging.File)
}
if cfg.Logging.MaxSizeMB != 50 {
t.Errorf("Logging.MaxSizeMB = %d, want 50", cfg.Logging.MaxSizeMB)
}
if cfg.Logging.MaxBackups != 3 {
t.Errorf("Logging.MaxBackups = %d, want 3", cfg.Logging.MaxBackups)
}
if cfg.Logging.MaxAgeDays != 7 {
t.Errorf("Logging.MaxAgeDays = %d, want 7", cfg.Logging.MaxAgeDays)
}
if !cfg.Logging.Compress {
t.Error("Logging.Compress = false, want true")
}
// Telemetry
if cfg.Telemetry.ServiceName != "my-proxy" {
t.Errorf("Telemetry.ServiceName = %q, want my-proxy", cfg.Telemetry.ServiceName)
}
if cfg.Telemetry.Export.Endpoint != "http://localhost:4317" {
t.Errorf("Export.Endpoint = %q", cfg.Telemetry.Export.Endpoint)
}
if !cfg.Telemetry.Export.Insecure {
t.Error("Export.Insecure = false, want true")
}
if !cfg.Telemetry.Export.Enabled() {
t.Error("Export.Enabled() = false, want true")
}
if cfg.Telemetry.Export.Headers["x-token"] != "abc" {
t.Errorf("Export.Headers = %v", cfg.Telemetry.Export.Headers)
}
// Embedded
if !cfg.Telemetry.Embedded.Enabled {
t.Error("Embedded.Enabled = false, want true")
}
if cfg.Telemetry.Embedded.Port != 9999 {
t.Errorf("Embedded.Port = %d, want 9999", cfg.Telemetry.Embedded.Port)
}
if cfg.Telemetry.Embedded.PersesBinary != "/usr/bin/perses" {
t.Errorf("Embedded.PersesBinary = %q", cfg.Telemetry.Embedded.PersesBinary)
}
if cfg.Telemetry.Embedded.VMBinary != "/usr/bin/vm" {
t.Errorf("Embedded.VMBinary = %q", cfg.Telemetry.Embedded.VMBinary)
}
if cfg.Telemetry.Embedded.VMPort != 9428 {
t.Errorf("Embedded.VMPort = %d, want 9428", cfg.Telemetry.Embedded.VMPort)
}
if cfg.Telemetry.Embedded.BinDir != "/opt/bin" {
t.Errorf("Embedded.BinDir = %q", cfg.Telemetry.Embedded.BinDir)
}
}
func TestLoad_Defaults(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
// Minimal YAML — only api_keys
if err := os.WriteFile(path, []byte("api_keys:\n - k1\n"), 0644); err != nil {
t.Fatal(err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load returned error: %v", err)
}
tests := []struct {
name string
got interface{}
want interface{}
}{
{"Port", cfg.Port, 8080},
{"Logging.Level", cfg.Logging.Level, "info"},
{"Logging.MaxSizeMB", cfg.Logging.MaxSizeMB, 100},
{"Logging.MaxBackups", cfg.Logging.MaxBackups, 5},
{"Logging.MaxAgeDays", cfg.Logging.MaxAgeDays, 30},
{"Telemetry.ServiceName", cfg.Telemetry.ServiceName, "anthropic-proxy"},
{"Embedded.Port", cfg.Telemetry.Embedded.Port, 8080},
{"Embedded.VMBinary", cfg.Telemetry.Embedded.VMBinary, "victoria-metrics"},
{"Embedded.PersesBinary", cfg.Telemetry.Embedded.PersesBinary, "perses"},
{"Embedded.VMPort", cfg.Telemetry.Embedded.VMPort, 8428},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.got != tt.want {
t.Errorf("got %v, want %v", tt.got, tt.want)
}
})
}
}
func TestLoad_MissingFile(t *testing.T) {
_, err := Load("/nonexistent/path/config.yaml")
if err == nil {
t.Fatal("expected error for missing file, got nil")
}
if !strings.Contains(err.Error(), "read config") {
t.Errorf("error = %q, want it to contain 'read config'", err.Error())
}
}
func TestLoad_DeprecatedClaudeCredentials(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
yaml := `
api_keys:
- k1
claude_credentials: "/some/path"
`
if err := os.WriteFile(path, []byte(yaml), 0644); err != nil {
t.Fatal(err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for deprecated claude_credentials, got nil")
}
if !strings.Contains(err.Error(), "no longer supported") {
t.Errorf("error = %q, want it to contain 'no longer supported'", err.Error())
}
}
func TestLoad_EmptyClaudeCredentials(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
// Empty string value should NOT trigger the deprecation error
yaml := `
api_keys:
- k1
claude_credentials: ""
`
if err := os.WriteFile(path, []byte(yaml), 0644); err != nil {
t.Fatal(err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("empty claude_credentials should not error: %v", err)
}
if cfg.Port != 8080 {
t.Errorf("Port = %d, want 8080", cfg.Port)
}
}
func TestLoad_InvalidYAML(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
// Truly invalid YAML that causes a parse error
if err := os.WriteFile(path, []byte("port:\n - bad\n indent: broken\n"), 0644); err != nil {
t.Fatal(err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for invalid YAML, got nil")
}
if !strings.Contains(err.Error(), "parse config") {
t.Errorf("error = %q, want it to contain 'parse config'", err.Error())
}
}
func TestExportConfig_Enabled(t *testing.T) {
tests := []struct {
name string
endpoint string
want bool
}{
{"empty endpoint", "", false},
{"set endpoint", "http://localhost:4317", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := ExportConfig{Endpoint: tt.endpoint}
if got := e.Enabled(); got != tt.want {
t.Errorf("Enabled() = %v, want %v", got, tt.want)
}
})
}
}
+12
View File
@@ -0,0 +1,12 @@
package embedded
import (
"embed"
)
//go:embed dashboard/proxy.json
var dashboardFS embed.FS
func DashboardJSON() ([]byte, error) {
return dashboardFS.ReadFile("dashboard/proxy.json")
}
+450
View File
@@ -0,0 +1,450 @@
{
"kind": "Dashboard",
"metadata": {
"name": "proxy",
"createdAt": "2026-04-14T19:47:48.013238204Z",
"updatedAt": "2026-04-14T19:49:30.874125459Z",
"version": 1,
"project": "anthropic-proxy"
},
"spec": {
"display": {
"name": "Anthropic Proxy"
},
"datasources": {
"vm": {
"default": true,
"plugin": {
"kind": "PrometheusDatasource",
"spec": {
"directUrl": "http://localhost:9428"
}
}
}
},
"panels": {
"latency": {
"kind": "Panel",
"spec": {
"display": {
"name": "Latency"
},
"plugin": {
"kind": "TimeSeriesChart",
"spec": {
"legend": {
"position": "bottom"
},
"yAxis": {
"format": {
"unit": "milliseconds"
}
}
}
},
"queries": [
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "histogram_quantile(0.50, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
"seriesNameFormat": "p50"
}
}
}
},
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "histogram_quantile(0.95, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
"seriesNameFormat": "p95"
}
}
}
},
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "histogram_quantile(0.99, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
"seriesNameFormat": "p99"
}
}
}
}
]
}
},
"request_rate": {
"kind": "Panel",
"spec": {
"display": {
"name": "Request Rate"
},
"plugin": {
"kind": "TimeSeriesChart",
"spec": {
"legend": {
"position": "bottom"
}
}
},
"queries": [
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "rate(proxy_request_count_total[5m])",
"seriesNameFormat": "req/s"
}
}
}
}
]
}
},
"token_rate": {
"kind": "Panel",
"spec": {
"display": {
"name": "Token Rate"
},
"plugin": {
"kind": "TimeSeriesChart",
"spec": {
"legend": {
"position": "bottom"
}
}
},
"queries": [
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "rate(proxy_tokens_input_total[5m]) * 60",
"seriesNameFormat": "input/min"
}
}
}
},
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "rate(proxy_tokens_output_total[5m]) * 60",
"seriesNameFormat": "output/min"
}
}
}
}
]
}
},
"tokens_5h": {
"kind": "Panel",
"spec": {
"display": {
"name": "5h Tokens"
},
"plugin": {
"kind": "StatChart",
"spec": {
"calculation": "last",
"format": {
"unit": "decimal"
},
"sparkline": {}
}
},
"queries": [
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "increase(proxy_tokens_output_total[3h])"
}
}
}
}
]
}
},
"tokens_7d": {
"kind": "Panel",
"spec": {
"display": {
"name": "7d Tokens"
},
"plugin": {
"kind": "StatChart",
"spec": {
"calculation": "last",
"format": {
"unit": "decimal"
},
"sparkline": {}
}
},
"queries": [
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "increase(proxy_tokens_output_total[9h])"
}
}
}
}
]
}
},
"util_5h": {
"kind": "Panel",
"spec": {
"display": {
"name": "5h Utilization"
},
"plugin": {
"kind": "GaugeChart",
"spec": {
"calculation": "last",
"format": {
"unit": "percent"
},
"thresholds": {
"steps": [
{
"color": "green",
"value": 0
},
{
"color": "orange",
"value": 70
},
{
"color": "red",
"value": 90
}
]
}
}
},
"queries": [
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "proxy_usage_utilization{window=\"5h\"}"
}
}
}
}
]
}
},
"util_7d": {
"kind": "Panel",
"spec": {
"display": {
"name": "7d Utilization"
},
"plugin": {
"kind": "GaugeChart",
"spec": {
"calculation": "last",
"format": {
"unit": "percent"
},
"thresholds": {
"steps": [
{
"color": "green",
"value": 0
},
{
"color": "orange",
"value": 70
},
{
"color": "red",
"value": 90
}
]
}
}
},
"queries": [
{
"kind": "TimeSeriesQuery",
"spec": {
"plugin": {
"kind": "PrometheusTimeSeriesQuery",
"spec": {
"datasource": {
"kind": "PrometheusDatasource",
"name": "vm"
},
"query": "proxy_usage_utilization{window=\"7d\"}"
}
}
}
}
]
}
}
},
"layouts": [
{
"kind": "Grid",
"spec": {
"display": {
"title": "Utilization"
},
"items": [
{
"x": 0,
"y": 0,
"width": 6,
"height": 5,
"content": {
"$ref": "#/spec/panels/util_5h"
}
},
{
"x": 6,
"y": 0,
"width": 6,
"height": 5,
"content": {
"$ref": "#/spec/panels/util_7d"
}
},
{
"x": 12,
"y": 0,
"width": 6,
"height": 5,
"content": {
"$ref": "#/spec/panels/tokens_5h"
}
},
{
"x": 18,
"y": 0,
"width": 6,
"height": 5,
"content": {
"$ref": "#/spec/panels/tokens_7d"
}
}
]
}
},
{
"kind": "Grid",
"spec": {
"display": {
"title": "Traffic"
},
"items": [
{
"x": 0,
"y": 0,
"width": 12,
"height": 8,
"content": {
"$ref": "#/spec/panels/request_rate"
}
},
{
"x": 12,
"y": 0,
"width": 12,
"height": 8,
"content": {
"$ref": "#/spec/panels/latency"
}
}
]
}
},
{
"kind": "Grid",
"spec": {
"display": {
"title": "Tokens"
},
"items": [
{
"x": 0,
"y": 0,
"width": 24,
"height": 8,
"content": {
"$ref": "#/spec/panels/token_rate"
}
}
]
}
}
],
"duration": "1h",
"refreshInterval": "10s"
}
}
+155
View File
@@ -0,0 +1,155 @@
package embedded
import (
"archive/tar"
"compress/gzip"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/rs/zerolog/log"
)
const cacheDir = ".cache/anthropic-proxy/bin"
var downloads = map[string]struct {
urlTemplate string
version string
extractName string
}{
"victoria-metrics": {
urlTemplate: "https://github.com/VictoriaMetrics/VictoriaMetrics/releases/download/v%s/victoria-metrics-%s-v%s.tar.gz",
version: "1.118.0",
extractName: "victoria-metrics-prod",
},
"perses": {
urlTemplate: "https://github.com/perses/perses/releases/download/v%s/perses_%s_%s_%s.tar.gz",
version: "0.53.1",
},
}
func ensureBinary(name, configPath, configBinDir string) (string, error) {
if configPath != "" {
if p, err := exec.LookPath(configPath); err == nil {
return p, nil
}
}
if p, err := exec.LookPath(name); err == nil {
return p, nil
}
binDir := configBinDir
if binDir == "" {
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("get home dir: %w", err)
}
binDir = filepath.Join(home, cacheDir)
}
cachedPath := filepath.Join(binDir, name)
if _, err := os.Stat(cachedPath); err == nil {
return cachedPath, nil
}
log.Info().Str("binary", name).Msg("downloading binary (first run)")
if err := os.MkdirAll(binDir, 0o755); err != nil {
return "", fmt.Errorf("create cache dir: %w", err)
}
url, err := downloadURL(name)
if err != nil {
return "", err
}
if err := extractAll(url, binDir); err != nil {
return "", fmt.Errorf("download %s: %w", name, err)
}
d := downloads[name]
if d.extractName != "" {
oldPath := filepath.Join(binDir, d.extractName)
if _, err := os.Stat(oldPath); err == nil {
os.Rename(oldPath, cachedPath)
}
}
if _, err := os.Stat(cachedPath); err != nil {
return "", fmt.Errorf("binary %s not found after extraction", name)
}
log.Info().Str("binary", name).Str("path", cachedPath).Msg("binary downloaded")
return cachedPath, nil
}
func downloadURL(name string) (string, error) {
goarch := runtime.GOARCH
goos := runtime.GOOS
d, ok := downloads[name]
if !ok {
return "", fmt.Errorf("unknown binary: %s", name)
}
switch name {
case "victoria-metrics":
vmOS := fmt.Sprintf("%s-%s", goos, goarch)
return fmt.Sprintf(d.urlTemplate, d.version, vmOS, d.version), nil
case "perses":
return fmt.Sprintf(d.urlTemplate, d.version, d.version, goos, goarch), nil
}
return "", fmt.Errorf("unknown binary: %s", name)
}
func extractAll(url, destDir string) error {
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return fmt.Errorf("download failed: HTTP %d from %s", resp.StatusCode, url)
}
gz, err := gzip.NewReader(resp.Body)
if err != nil {
return fmt.Errorf("gzip reader: %w", err)
}
defer gz.Close()
tr := tar.NewReader(gz)
for {
hdr, err := tr.Next()
if err == io.EOF {
return nil
}
if err != nil {
return fmt.Errorf("read tar: %w", err)
}
target := filepath.Join(destDir, hdr.Name)
switch hdr.Typeflag {
case tar.TypeDir:
os.MkdirAll(target, 0o755)
case tar.TypeReg:
os.MkdirAll(filepath.Dir(target), 0o755)
mode := os.FileMode(hdr.Mode)
if mode == 0 {
mode = 0o644
}
out, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode)
if err != nil {
return err
}
io.Copy(out, tr)
out.Close()
}
}
}
+20
View File
@@ -0,0 +1,20 @@
package embedded
import "github.com/rs/zerolog/log"
// logWriter bridges subprocess stdout/stderr to zerolog.
type logWriter struct {
level string
component string
}
func (w *logWriter) Write(p []byte) (n int, err error) {
msg := string(p)
switch w.level {
case "error":
log.Error().Str("component", w.component).Msg(msg)
default:
log.Debug().Str("component", w.component).Msg(msg)
}
return len(p), nil
}
+133
View File
@@ -0,0 +1,133 @@
package embedded
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"github.com/fujin/anthropic-proxy/internal/config"
"github.com/rs/zerolog/log"
)
type Perses struct {
cfg config.EmbeddedConfig
proxyPort int
cmd *exec.Cmd
tmpDir string
}
func NewPerses(cfg config.EmbeddedConfig, proxyPort int) *Perses {
return &Perses{cfg: cfg, proxyPort: proxyPort}
}
func (p *Perses) Start() error {
bin, err := ensureBinary("perses", p.cfg.PersesBinary, p.cfg.BinDir)
if err != nil {
return fmt.Errorf("perses: %w", err)
}
p.tmpDir, err = os.MkdirTemp("", "perses-*")
if err != nil {
return fmt.Errorf("create temp dir: %w", err)
}
if err := p.writeServerConfig(); err != nil {
return fmt.Errorf("write server config: %w", err)
}
if err := p.writeDatasourceProvision(); err != nil {
return fmt.Errorf("write datasource provision: %w", err)
}
if err := p.writeDashboardProvision(); err != nil {
return fmt.Errorf("write dashboard provision: %w", err)
}
p.cmd = exec.Command(bin,
"--config", filepath.Join(p.tmpDir, "config.yaml"),
"-web.listen-address", fmt.Sprintf(":%d", p.cfg.Port),
)
p.cmd.Dir = filepath.Dir(bin)
p.cmd.Stdout = &logWriter{level: "info", component: "perses"}
p.cmd.Stderr = &logWriter{level: "error", component: "perses"}
if err := p.cmd.Start(); err != nil {
return fmt.Errorf("start perses: %w", err)
}
log.Info().
Str("binary", bin).
Int("port", p.cfg.Port).
Str("config", p.tmpDir).
Msg("perses started")
return nil
}
func (p *Perses) Stop() {
if p.cmd != nil && p.cmd.Process != nil {
_ = p.cmd.Process.Kill()
_ = p.cmd.Wait()
}
if p.tmpDir != "" {
_ = os.RemoveAll(p.tmpDir)
}
}
func (p *Perses) Running() bool {
return p.cmd != nil && p.cmd.Process != nil && p.cmd.ProcessState == nil
}
func (p *Perses) writeServerConfig() error {
provisionDir := filepath.Join(p.tmpDir, "provisions")
if err := os.MkdirAll(filepath.Join(provisionDir, "datasources"), 0o755); err != nil {
return err
}
if err := os.MkdirAll(filepath.Join(provisionDir, "dashboards"), 0o755); err != nil {
return err
}
cfg := fmt.Sprintf(`provisioning:
interval: 1m
folders:
- %s
database:
file:
folder: %s/data
extension: json
security:
readonly: false
enable_auth: false
`, provisionDir, p.tmpDir)
return os.WriteFile(filepath.Join(p.tmpDir, "config.yaml"), []byte(cfg), 0o644)
}
func (p *Perses) writeDatasourceProvision() error {
ds := fmt.Sprintf(`kind: Datasource
metadata:
name: victoria-metrics
project: anthropic-proxy
spec:
default: true
plugin:
kind: PrometheusDatasource
spec:
directUrl: http://localhost:%d
`, p.cfg.VMPort)
return os.WriteFile(
filepath.Join(p.tmpDir, "provisions", "datasources", "vm.yaml"),
[]byte(ds), 0o644,
)
}
func (p *Perses) writeDashboardProvision() error {
dashData, err := DashboardJSON()
if err != nil {
return err
}
return os.WriteFile(
filepath.Join(p.tmpDir, "provisions", "dashboards", "proxy.json"),
dashData, 0o644,
)
}
+88
View File
@@ -0,0 +1,88 @@
package embedded
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"github.com/fujin/anthropic-proxy/internal/config"
"github.com/rs/zerolog/log"
)
type VM struct {
cfg config.EmbeddedConfig
proxyPort int
cmd *exec.Cmd
tmpDir string
}
func NewVM(cfg config.EmbeddedConfig, proxyPort int) *VM {
return &VM{cfg: cfg, proxyPort: proxyPort}
}
func (v *VM) Start() error {
bin, err := ensureBinary("victoria-metrics", v.cfg.VMBinary, v.cfg.BinDir)
if err != nil {
return fmt.Errorf("victoria-metrics: %w", err)
}
v.tmpDir, err = os.MkdirTemp("", "vm-*")
if err != nil {
return fmt.Errorf("create temp dir: %w", err)
}
scrapeConfig := fmt.Sprintf(`global:
scrape_interval: 15s
scrape_configs:
- job_name: anthropic-proxy
static_configs:
- targets:
- localhost:%d
`, v.proxyPort)
scrapePath := filepath.Join(v.tmpDir, "scrape.yaml")
if err := os.WriteFile(scrapePath, []byte(scrapeConfig), 0o644); err != nil {
return fmt.Errorf("write scrape config: %w", err)
}
dataPath := filepath.Join(v.tmpDir, "data")
if err := os.MkdirAll(dataPath, 0o755); err != nil {
return fmt.Errorf("create data dir: %w", err)
}
v.cmd = exec.Command(bin,
"-storageDataPath", dataPath,
"-retentionPeriod", "7d",
"-httpListenAddr", fmt.Sprintf(":%d", v.cfg.VMPort),
"-promscrape.config", scrapePath,
)
v.cmd.Stdout = &logWriter{level: "info", component: "victoria-metrics"}
v.cmd.Stderr = &logWriter{level: "error", component: "victoria-metrics"}
if err := v.cmd.Start(); err != nil {
return fmt.Errorf("start victoria-metrics: %w", err)
}
log.Info().
Str("binary", bin).
Int("port", v.cfg.VMPort).
Int("scrape_target_port", v.proxyPort).
Msg("victoria-metrics started")
return nil
}
func (v *VM) Stop() {
if v.cmd != nil && v.cmd.Process != nil {
_ = v.cmd.Process.Kill()
_ = v.cmd.Wait()
}
if v.tmpDir != "" {
_ = os.RemoveAll(v.tmpDir)
}
}
func (v *VM) Running() bool {
return v.cmd != nil && v.cmd.Process != nil && v.cmd.ProcessState == nil
}
+148
View File
@@ -0,0 +1,148 @@
package logging
import (
"context"
"encoding/json"
"io"
"net/http"
"os"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"gopkg.in/lumberjack.v2"
"github.com/fujin/anthropic-proxy/internal/config"
)
// Setup initializes the global zerolog logger.
// - File set: JSON → lumberjack rotating file
// - File empty + TTY: colored ConsoleWriter → stderr
// - File empty + not TTY: JSON → stderr (for systemd journal)
// Extra writers (e.g., OTLP log bridge) are added via io.MultiWriter so logs
// are written to both the primary destination and any extra writers.
func Setup(cfg config.LoggingConfig, extraWriters ...io.Writer) zerolog.Logger {
// Parse log level
level, err := zerolog.ParseLevel(cfg.Level)
if err != nil || cfg.Level == "" {
level = zerolog.InfoLevel
}
zerolog.SetGlobalLevel(level)
zerolog.TimeFieldFormat = time.RFC3339
var logger zerolog.Logger
if cfg.File != "" {
// Production: JSON to rotating file
jack := &lumberjack.Logger{
Filename: cfg.File,
MaxSize: cfg.MaxSizeMB,
MaxBackups: cfg.MaxBackups,
MaxAge: cfg.MaxAgeDays,
Compress: cfg.Compress,
}
var w io.Writer = jack
if len(extraWriters) > 0 {
w = io.MultiWriter(append([]io.Writer{jack}, extraWriters...)...)
}
logger = zerolog.New(w).With().Timestamp().Caller().Logger()
} else {
fi, err := os.Stderr.Stat()
isTTY := err == nil && (fi.Mode()&os.ModeCharDevice) != 0
if isTTY {
// Dev mode: colored console (extra writers get JSON, console gets pretty)
cw := zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: time.RFC3339,
}
var w io.Writer = cw
if len(extraWriters) > 0 {
w = io.MultiWriter(append([]io.Writer{cw}, extraWriters...)...)
}
logger = zerolog.New(w).With().Timestamp().Caller().Logger()
} else {
// Systemd journal: JSON to stderr
var w io.Writer = os.Stderr
if len(extraWriters) > 0 {
w = io.MultiWriter(append([]io.Writer{os.Stderr}, extraWriters...)...)
}
logger = zerolog.New(w).With().Timestamp().Caller().Logger()
}
}
// Set global
log.Logger = logger
return logger
}
// GinRequestLogger returns a Gin middleware that logs every request with zerolog.
// Logs AFTER the request completes.
// Level: Info for 2xx, Warn for 4xx, Error for 5xx
func GinRequestLogger() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
method := c.Request.Method
c.Next()
status := c.Writer.Status()
latencyMs := float64(time.Since(start).Microseconds()) / 1000.0
clientIP := c.ClientIP()
requestID := c.Writer.Header().Get("X-Request-Id")
if requestID == "" {
requestID = c.GetHeader("x-client-request-id")
}
evt := log.Logger.WithLevel(statusLevel(status)).
Str("method", method).
Str("path", path).
Int("status", status).
Float64("latency_ms", latencyMs).
Str("client_ip", clientIP)
if requestID != "" {
evt = evt.Str("request_id", requestID)
}
evt.Msg("request")
}
}
func statusLevel(status int) zerolog.Level {
switch {
case status >= 500:
return zerolog.ErrorLevel
case status >= 400:
return zerolog.WarnLevel
default:
return zerolog.InfoLevel
}
}
// FromContext returns the zerolog logger from context, or the global logger.
func FromContext(ctx context.Context) *zerolog.Logger {
l := zerolog.Ctx(ctx)
if l.GetLevel() == zerolog.Disabled {
return &log.Logger
}
return l
}
// RedactHeaders serializes HTTP headers to a JSON string,
// replacing Authorization and x-api-key values with "***".
func RedactHeaders(h http.Header) string {
redacted := make(map[string]string, len(h))
for k, v := range h {
key := strings.ToLower(k)
if key == "authorization" || key == "x-api-key" {
redacted[k] = "***"
} else {
redacted[k] = strings.Join(v, ", ")
}
}
b, _ := json.Marshal(redacted)
return string(b)
}
+232
View File
@@ -0,0 +1,232 @@
package logging
import (
"context"
"encoding/json"
"net/http"
"path/filepath"
"strings"
"testing"
"github.com/rs/zerolog"
"github.com/fujin/anthropic-proxy/internal/config"
)
func TestRedactHeaders(t *testing.T) {
tests := []struct {
name string
headers http.Header
check func(t *testing.T, result string)
}{
{
name: "redacts Authorization",
headers: http.Header{
"Authorization": []string{"Bearer secret-token"},
},
check: func(t *testing.T, result string) {
var m map[string]string
if err := json.Unmarshal([]byte(result), &m); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if m["Authorization"] != "***" {
t.Errorf("Authorization = %q, want ***", m["Authorization"])
}
},
},
{
name: "redacts x-api-key",
headers: http.Header{
"X-Api-Key": []string{"sk-ant-secret"},
},
check: func(t *testing.T, result string) {
var m map[string]string
if err := json.Unmarshal([]byte(result), &m); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if m["X-Api-Key"] != "***" {
t.Errorf("X-Api-Key = %q, want ***", m["X-Api-Key"])
}
},
},
{
name: "preserves other headers",
headers: http.Header{
"Content-Type": []string{"application/json"},
"Accept": []string{"text/html", "application/json"},
},
check: func(t *testing.T, result string) {
var m map[string]string
if err := json.Unmarshal([]byte(result), &m); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if m["Content-Type"] != "application/json" {
t.Errorf("Content-Type = %q, want application/json", m["Content-Type"])
}
if m["Accept"] != "text/html, application/json" {
t.Errorf("Accept = %q, want 'text/html, application/json'", m["Accept"])
}
},
},
{
name: "case-insensitive redaction",
headers: http.Header{
"authorization": []string{"Bearer token"},
"X-API-KEY": []string{"key123"},
},
check: func(t *testing.T, result string) {
var m map[string]string
if err := json.Unmarshal([]byte(result), &m); err != nil {
t.Fatalf("unmarshal: %v", err)
}
// http.Header canonicalizes keys, but RedactHeaders lowercases for comparison
for _, v := range m {
if v != "***" {
t.Errorf("expected all values to be ***, got %q", v)
}
}
},
},
{
name: "empty headers",
headers: http.Header{},
check: func(t *testing.T, result string) {
if result != "{}" {
t.Errorf("result = %q, want {}", result)
}
},
},
{
name: "mixed sensitive and non-sensitive",
headers: http.Header{
"Authorization": []string{"Bearer tok"},
"X-Api-Key": []string{"key"},
"Content-Type": []string{"application/json"},
"X-Request-Id": []string{"abc123"},
},
check: func(t *testing.T, result string) {
var m map[string]string
if err := json.Unmarshal([]byte(result), &m); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if m["Authorization"] != "***" {
t.Errorf("Authorization = %q, want ***", m["Authorization"])
}
if m["X-Api-Key"] != "***" {
t.Errorf("X-Api-Key = %q, want ***", m["X-Api-Key"])
}
if m["Content-Type"] != "application/json" {
t.Errorf("Content-Type = %q", m["Content-Type"])
}
if m["X-Request-Id"] != "abc123" {
t.Errorf("X-Request-Id = %q", m["X-Request-Id"])
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := RedactHeaders(tt.headers)
// Result should be valid JSON
if !json.Valid([]byte(result)) {
t.Fatalf("result is not valid JSON: %q", result)
}
tt.check(t, result)
})
}
}
func TestRedactHeaders_ReturnsJSON(t *testing.T) {
h := http.Header{"Foo": []string{"bar"}}
result := RedactHeaders(h)
if !strings.HasPrefix(result, "{") || !strings.HasSuffix(result, "}") {
t.Errorf("result not JSON object: %q", result)
}
}
func TestStatusLevel(t *testing.T) {
tests := []struct {
status int
want zerolog.Level
}{
{200, zerolog.InfoLevel},
{201, zerolog.InfoLevel},
{204, zerolog.InfoLevel},
{301, zerolog.InfoLevel},
{399, zerolog.InfoLevel},
{400, zerolog.WarnLevel},
{401, zerolog.WarnLevel},
{403, zerolog.WarnLevel},
{404, zerolog.WarnLevel},
{429, zerolog.WarnLevel},
{499, zerolog.WarnLevel},
{500, zerolog.ErrorLevel},
{502, zerolog.ErrorLevel},
{503, zerolog.ErrorLevel},
{599, zerolog.ErrorLevel},
}
for _, tt := range tests {
got := statusLevel(tt.status)
if got != tt.want {
t.Errorf("statusLevel(%d) = %v, want %v", tt.status, got, tt.want)
}
}
}
func TestSetup_WithFile(t *testing.T) {
dir := t.TempDir()
logFile := filepath.Join(dir, "test.log")
logger := Setup(config.LoggingConfig{
Level: "debug",
File: logFile,
MaxSizeMB: 10,
MaxBackups: 1,
MaxAgeDays: 1,
})
// Verify logger works (no panic)
logger.Info().Msg("test message")
}
func TestSetup_WithoutFile(t *testing.T) {
// File empty — should use console or stderr mode depending on TTY
logger := Setup(config.LoggingConfig{
Level: "warn",
})
// Verify logger works (no panic)
logger.Warn().Msg("test warning")
}
func TestSetup_DefaultLevel(t *testing.T) {
// Empty level should default to info
logger := Setup(config.LoggingConfig{})
_ = logger // verify no panic
}
func TestSetup_InvalidLevel(t *testing.T) {
// Invalid level should default to info
logger := Setup(config.LoggingConfig{Level: "not-a-level"})
_ = logger // verify no panic
}
func TestFromContext_NoLogger(t *testing.T) {
// Background context has no zerolog logger — should return global
ctx := context.Background()
l := FromContext(ctx)
if l == nil {
t.Fatal("FromContext returned nil")
}
}
func TestFromContext_WithLogger(t *testing.T) {
logger := zerolog.Nop()
ctx := logger.WithContext(context.Background())
l := FromContext(ctx)
if l == nil {
t.Fatal("FromContext returned nil")
}
}
+4
View File
@@ -11,9 +11,13 @@ import (
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
// fingerprintSalt is the fixed salt used by Claude Code for billing header
// fingerprint computation. Extracted from the Claude Code CLI source.
const fingerprintSalt = "59cf53e54c78" const fingerprintSalt = "59cf53e54c78"
func computeFingerprint(firstUserMessage string, version string) string { func computeFingerprint(firstUserMessage string, version string) string {
// UTF-16 character indices sampled from the first user message, matching
// the Claude Code CLI's fingerprinting algorithm.
indices := []int{4, 7, 20} indices := []int{4, 7, 20}
runes := utf16.Encode([]rune(firstUserMessage)) runes := utf16.Encode([]rune(firstUserMessage))
var chars string var chars string
+323
View File
@@ -0,0 +1,323 @@
package proxy
import (
"encoding/hex"
"strings"
"testing"
)
func TestFingerprintSaltConstant(t *testing.T) {
if fingerprintSalt != "59cf53e54c78" {
t.Errorf("fingerprintSalt = %q, want %q", fingerprintSalt, "59cf53e54c78")
}
}
func TestComputeFingerprint_Deterministic(t *testing.T) {
a := computeFingerprint("hello world test message", "1.0.0")
b := computeFingerprint("hello world test message", "1.0.0")
if a != b {
t.Errorf("fingerprint not deterministic: %q != %q", a, b)
}
}
func TestComputeFingerprint_Length(t *testing.T) {
fp := computeFingerprint("some message here", "2.0.0")
if len(fp) != 3 {
t.Errorf("fingerprint length = %d, want 3", len(fp))
}
// Must be valid hex
if _, err := hex.DecodeString(fp + "0"); err != nil { // pad to even length for decode
// Check each char is hex individually
for _, c := range fp {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
t.Errorf("fingerprint %q contains non-hex char %c", fp, c)
}
}
}
}
func TestComputeFingerprint_DifferentVersions(t *testing.T) {
a := computeFingerprint("same message", "1.0.0")
b := computeFingerprint("same message", "2.0.0")
if a == b {
t.Errorf("different versions should (almost certainly) produce different fingerprints")
}
}
func TestComputeFingerprint_ShortMessage(t *testing.T) {
// "hi" has only 2 chars, indices [4,7,20] all out of range → chars = "000"
fp := computeFingerprint("hi", "1.0.0")
if len(fp) != 3 {
t.Errorf("short message fingerprint length = %d, want 3", len(fp))
}
}
func TestComputeFingerprint_EmptyMessage(t *testing.T) {
// Empty → all indices out of range → chars = "000"
fp := computeFingerprint("", "1.0.0")
if len(fp) != 3 {
t.Errorf("empty message fingerprint length = %d, want 3", len(fp))
}
// Empty and short message with same version should produce same fingerprint
// since both result in chars = "000"
fpShort := computeFingerprint("hi", "1.0.0")
if fp != fpShort {
t.Errorf("empty and 'hi' should produce same fingerprint (both use '000'), got %q vs %q", fp, fpShort)
}
}
func TestComputeFingerprint_Unicode(t *testing.T) {
// Emoji: 🎉 is U+1F389, encoded as UTF-16 surrogate pair [0xD83C, 0xDF89]
// So "abcd🎉fg" in UTF-16 is [a, b, c, d, 0xD83C, 0xDF89, f, g] = 8 uint16 values
// indices [4,7,20]: runes[4]=0xD83C, runes[7]='g', runes[20]=out of range
fp := computeFingerprint("abcd🎉fg", "1.0.0")
if len(fp) != 3 {
t.Errorf("unicode fingerprint length = %d, want 3", len(fp))
}
}
func TestComputeFingerprint_CharExtraction(t *testing.T) {
// "Hello, World!" UTF-16: [H,e,l,l,o,',', ,W,o,r,l,d,!]
// indices [4,7,20]: runes[4]='o', runes[7]='W', runes[20]=out of range → "0"
// So chars should be "oW0"
// Verify by comparing to a message where we know the expected extracted chars
// Two messages that extract same chars at indices should produce same fingerprint
// "xxxxoxxWxxxxxxxxxxxx" → index 4='o', 7='W', 20=out of range → "oW0" (20 chars, index 20 out of range)
fp1 := computeFingerprint("Hello, World!", "1.0.0")
fp2 := computeFingerprint("xxxxoxxWxxxxxxxxxxxx", "1.0.0")
if fp1 != fp2 {
t.Errorf("messages with same chars at indices [4,7,20] should produce same fingerprint, got %q vs %q", fp1, fp2)
}
}
func TestComputeFingerprint_IndexBoundary(t *testing.T) {
// Message with exactly 21 chars → index 20 is valid
msg21 := "abcdefghijklmnopqrstu" // 21 chars
fp21 := computeFingerprint(msg21, "1.0.0")
// Message with exactly 20 chars → index 20 is out of range → "0"
msg20 := "abcdefghijklmnopqrst" // 20 chars
fp20 := computeFingerprint(msg20, "1.0.0")
// They should differ because index 20 produces different chars
if fp21 == fp20 {
t.Errorf("boundary test: 21-char and 20-char messages should differ at index 20")
}
}
func TestExtractFirstUserMessage(t *testing.T) {
tests := []struct {
name string
body string
expected string
}{
{
name: "simple string content",
body: `{"messages":[{"role":"user","content":"hello world"}]}`,
expected: "hello world",
},
{
name: "array content with text block",
body: `{"messages":[{"role":"user","content":[{"type":"text","text":"from array"}]}]}`,
expected: "from array",
},
{
name: "no user messages",
body: `{"messages":[{"role":"assistant","content":"I am assistant"}]}`,
expected: "",
},
{
name: "assistant only messages",
body: `{"messages":[{"role":"assistant","content":"a1"},{"role":"assistant","content":"a2"}]}`,
expected: "",
},
{
name: "user with non-text block first then text",
body: `{"messages":[{"role":"user","content":[{"type":"image","source":"x"},{"type":"text","text":"the text"}]}]}`,
expected: "the text",
},
{
name: "user with only non-text blocks",
body: `{"messages":[{"role":"user","content":[{"type":"image","source":"x"}]}]}`,
expected: "",
},
{
name: "no messages field",
body: `{"model":"claude-sonnet-4-6"}`,
expected: "",
},
{
name: "messages not array",
body: `{"messages":"not array"}`,
expected: "",
},
{
name: "empty messages array",
body: `{"messages":[]}`,
expected: "",
},
{
name: "first user message used even if multiple exist",
body: `{"messages":[{"role":"user","content":"first"},{"role":"user","content":"second"}]}`,
expected: "first",
},
{
name: "assistant before user",
body: `{"messages":[{"role":"assistant","content":"assistant msg"},{"role":"user","content":"user msg"}]}`,
expected: "user msg",
},
{
name: "user with array content - first text block used",
body: `{"messages":[{"role":"user","content":[{"type":"text","text":"first text"},{"type":"text","text":"second text"}]}]}`,
expected: "first text",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractFirstUserMessage([]byte(tt.body))
if result != tt.expected {
t.Errorf("got %q, want %q", result, tt.expected)
}
})
}
}
func TestExtractFirstUserMessage_BreaksAfterFirstUser(t *testing.T) {
// The function should break after finding the first user message,
// even if it didn't extract text (e.g. user with only image blocks)
body := `{"messages":[{"role":"user","content":[{"type":"image","source":"x"}]},{"role":"user","content":"second user"}]}`
result := extractFirstUserMessage([]byte(body))
// First user has no text blocks, function breaks, returns ""
if result != "" {
t.Errorf("should return empty when first user has no text, got %q", result)
}
}
func TestBuildBillingHeader(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"test message"}]}`)
version := "1.2.3"
header := buildBillingHeader(body, version)
// Check format
if !strings.HasPrefix(header, "x-anthropic-billing-header: cc_version=1.2.3.") {
t.Errorf("header should start with 'x-anthropic-billing-header: cc_version=1.2.3.', got %q", header)
}
if !strings.Contains(header, "; cc_entrypoint=cli; cch=00000;") {
t.Errorf("header should contain '; cc_entrypoint=cli; cch=00000;', got %q", header)
}
// Verify the fingerprint part is 3 chars
// Format: "x-anthropic-billing-header: cc_version=1.2.3.XXX; cc_entrypoint=cli; cch=00000;"
parts := strings.Split(header, "cc_version=")
if len(parts) != 2 {
t.Fatalf("unexpected header format: %q", header)
}
versionFP := strings.Split(parts[1], ";")[0]
if !strings.HasPrefix(versionFP, "1.2.3.") {
t.Errorf("version+fingerprint should start with '1.2.3.', got %q", versionFP)
}
fp := strings.TrimPrefix(versionFP, "1.2.3.")
if len(fp) != 3 {
t.Errorf("fingerprint should be 3 chars, got %q (len %d)", fp, len(fp))
}
}
func TestBuildBillingHeader_EmptyMessages(t *testing.T) {
body := []byte(`{"messages":[]}`)
version := "1.0.0"
header := buildBillingHeader(body, version)
if !strings.HasPrefix(header, "x-anthropic-billing-header: cc_version=") {
t.Errorf("header format wrong: %q", header)
}
}
func TestInjectBillingHeader_NoExistingSystem(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
result := injectBillingHeader(body, "1.0.0")
resultStr := string(result)
// Should have system field now
if !strings.Contains(resultStr, `"system"`) {
t.Errorf("should inject system field, got %s", resultStr)
}
// System should be an array with one billing block
if !strings.Contains(resultStr, "x-anthropic-billing-header") {
t.Errorf("should contain billing header text, got %s", resultStr)
}
if !strings.Contains(resultStr, `"type":"text"`) {
t.Errorf("billing block should have type text, got %s", resultStr)
}
}
func TestInjectBillingHeader_ExistingSystemArray(t *testing.T) {
body := []byte(`{"system":[{"type":"text","text":"existing prompt"}],"messages":[{"role":"user","content":"hi"}]}`)
result := injectBillingHeader(body, "1.0.0")
resultStr := string(result)
// Should contain both billing header and existing prompt
if !strings.Contains(resultStr, "x-anthropic-billing-header") {
t.Errorf("should contain billing header, got %s", resultStr)
}
if !strings.Contains(resultStr, "existing prompt") {
t.Errorf("should preserve existing prompt, got %s", resultStr)
}
// Billing block should be FIRST (prepended)
billingIdx := strings.Index(resultStr, "x-anthropic-billing-header")
existingIdx := strings.Index(resultStr, "existing prompt")
if billingIdx > existingIdx {
t.Errorf("billing block should come before existing prompt")
}
}
func TestInjectBillingHeader_ExistingSystemString(t *testing.T) {
body := []byte(`{"system":"You are a helpful assistant","messages":[{"role":"user","content":"hi"}]}`)
result := injectBillingHeader(body, "1.0.0")
resultStr := string(result)
// Should convert to array with billing block first, then original text
if !strings.Contains(resultStr, "x-anthropic-billing-header") {
t.Errorf("should contain billing header, got %s", resultStr)
}
if !strings.Contains(resultStr, "You are a helpful assistant") {
t.Errorf("should preserve original system string, got %s", resultStr)
}
// Billing should come first
billingIdx := strings.Index(resultStr, "x-anthropic-billing-header")
origIdx := strings.Index(resultStr, "You are a helpful assistant")
if billingIdx > origIdx {
t.Errorf("billing block should come before original system text")
}
}
func TestInjectBillingHeader_PreservesOtherFields(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`)
result := injectBillingHeader(body, "1.0.0")
resultStr := string(result)
if !strings.Contains(resultStr, `"model":"claude-sonnet-4-6"`) {
t.Errorf("should preserve model field, got %s", resultStr)
}
if !strings.Contains(resultStr, `"max_tokens":1024`) {
t.Errorf("should preserve max_tokens field, got %s", resultStr)
}
}
func TestInjectBillingHeader_BillingBlockFormat(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"test"}]}`)
result := injectBillingHeader(body, "2.5.0")
resultStr := string(result)
// Verify the billing block contains the correct version
if !strings.Contains(resultStr, "cc_version=2.5.0.") {
t.Errorf("billing block should contain cc_version=2.5.0., got %s", resultStr)
}
if !strings.Contains(resultStr, "cc_entrypoint=cli") {
t.Errorf("billing block should contain cc_entrypoint=cli, got %s", resultStr)
}
if !strings.Contains(resultStr, "cch=00000") {
t.Errorf("billing block should contain cch=00000, got %s", resultStr)
}
}
+184 -15
View File
@@ -2,17 +2,33 @@ package proxy
import ( import (
"bufio" "bufio"
"context"
"io" "io"
"log"
"net/http" "net/http"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"github.com/fujin/anthropic-proxy/internal/auth" "github.com/fujin/anthropic-proxy/internal/auth"
"github.com/fujin/anthropic-proxy/internal/logging"
"github.com/fujin/anthropic-proxy/internal/ratelimit"
"github.com/fujin/anthropic-proxy/internal/telemetry"
) )
func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func() *Sanitizer) gin.HandlerFunc { // requestInfo bundles common request context passed to logging/telemetry helpers.
type requestInfo struct {
model string
stream bool
cred *auth.Credential
body []byte
originalBody []byte
}
func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func() *Sanitizer, tracker *ratelimit.Tracker) gin.HandlerFunc {
upstream := NewUpstreamClient(profile) upstream := NewUpstreamClient(profile)
return func(c *gin.Context) { return func(c *gin.Context) {
@@ -22,7 +38,15 @@ func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func(
return return
} }
log.Printf("incoming: %s %s (%d bytes) model=%s", c.Request.Method, c.Request.URL.Path, len(body), gjson.GetBytes(body, "model").String()) originalBody := make([]byte, len(body))
copy(originalBody, body)
log.Info().
Str("method", c.Request.Method).
Str("path", c.Request.URL.Path).
Int("body_size", len(body)).
Str("model", gjson.GetBytes(body, "model").String()).
Msg("incoming request")
san := getSanitizer() san := getSanitizer()
body = san.SanitizeRequest(body) body = san.SanitizeRequest(body)
@@ -36,27 +60,56 @@ func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func(
isStream := gjson.GetBytes(body, "stream").Bool() isStream := gjson.GetBytes(body, "stream").Bool()
if isStream { if isStream {
handleStream(c, upstream, san, pool, cred, body) handleStream(c, upstream, san, pool, cred, body, originalBody, tracker)
} else { } else {
handleNonStream(c, upstream, san, pool, cred, body) handleNonStream(c, upstream, san, pool, cred, body, originalBody, tracker)
} }
} }
} }
func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte) { func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte, originalBody []byte, tracker *ratelimit.Tracker) {
respBody, headers, statusCode, err := upstream.Execute(c.Request.Context(), cred, body) startTime := time.Now()
model := gjson.GetBytes(body, "model").String()
ctx := c.Request.Context()
ri := requestInfo{model: model, stream: false, cred: cred, body: body, originalBody: originalBody}
telemetry.RequestBodySize.Record(ctx, int64(len(body)),
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", false)))
respBody, headers, statusCode, err := upstream.Execute(ctx, cred, body)
latencyMs := float64(time.Since(startTime).Milliseconds())
if err != nil { if err != nil {
log.Printf("upstream error for %s: %v", cred.Email, err) recordConnectionError(ctx, err, ri, latencyMs)
c.JSON(http.StatusBadGateway, gin.H{"error": "upstream request failed"}) c.JSON(http.StatusBadGateway, gin.H{"error": "upstream request failed"})
return return
} }
recordRequestMetrics(ctx, ri, statusCode, latencyMs)
if statusCode >= 400 { if statusCode >= 400 {
pool.MarkFailure(cred, statusCode) pool.MarkFailure(cred, statusCode)
log.Printf("upstream %d for %s: %s", statusCode, cred.Email, string(respBody)) telemetry.CredentialCooldowns.Add(ctx, 1,
metric.WithAttributes(attribute.Int("status_code", statusCode)))
recordUpstreamError(ctx, statusCode, respBody, headers.Get("X-Request-Id"), latencyMs, ri, c.Request.Header)
} else { } else {
pool.MarkSuccess(cred) pool.MarkSuccess(cred)
respBody = san.DesanitizeResponse(respBody) respBody = san.DesanitizeResponse(respBody)
inputTokens := gjson.GetBytes(respBody, "usage.input_tokens").Int()
outputTokens := gjson.GetBytes(respBody, "usage.output_tokens").Int()
recordTokenUsage(ctx, model, cred, inputTokens, outputTokens)
if tracker != nil {
tracker.UpdateFromHeaders(headers)
}
log.Info().
Int("status", statusCode).
Float64("latency_ms", latencyMs).
Str("model", model).
Int64("input_tokens", inputTokens).
Int64("output_tokens", outputTokens).
Msg("request completed")
} }
for _, h := range []string{"Content-Type", "X-Request-Id"} { for _, h := range []string{"Content-Type", "X-Request-Id"} {
@@ -68,10 +121,20 @@ func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, p
c.Data(statusCode, headers.Get("Content-Type"), respBody) c.Data(statusCode, headers.Get("Content-Type"), respBody)
} }
func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte) { func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte, originalBody []byte, tracker *ratelimit.Tracker) {
resp, err := upstream.ExecuteStream(c.Request.Context(), cred, body) startTime := time.Now()
model := gjson.GetBytes(body, "model").String()
ctx := c.Request.Context()
ri := requestInfo{model: model, stream: true, cred: cred, body: body, originalBody: originalBody}
telemetry.StreamRequests.Add(ctx, 1, metric.WithAttributes(attribute.String("model", model)))
telemetry.RequestBodySize.Record(ctx, int64(len(body)),
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", true)))
resp, err := upstream.ExecuteStream(ctx, cred, body)
if err != nil { if err != nil {
log.Printf("upstream stream error for %s: %v", cred.Email, err) latencyMs := float64(time.Since(startTime).Milliseconds())
recordConnectionError(ctx, err, ri, latencyMs)
c.JSON(http.StatusBadGateway, gin.H{"error": "upstream stream request failed"}) c.JSON(http.StatusBadGateway, gin.H{"error": "upstream stream request failed"})
return return
} }
@@ -79,8 +142,13 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
pool.MarkFailure(cred, resp.StatusCode) pool.MarkFailure(cred, resp.StatusCode)
telemetry.CredentialCooldowns.Add(ctx, 1,
metric.WithAttributes(attribute.Int("status_code", resp.StatusCode)))
respBody, _ := io.ReadAll(resp.Body) respBody, _ := io.ReadAll(resp.Body)
log.Printf("upstream stream %d for %s: %s", resp.StatusCode, cred.Email, string(respBody)) latencyMs := float64(time.Since(startTime).Milliseconds())
recordRequestMetrics(ctx, ri, resp.StatusCode, latencyMs)
recordUpstreamError(ctx, resp.StatusCode, respBody, resp.Header.Get("X-Request-Id"), latencyMs, ri, c.Request.Header)
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody) c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody)
return return
} }
@@ -94,20 +162,121 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
if !ok { if !ok {
log.Printf("response writer does not support flushing") log.Error().Msg("response writer does not support flushing")
c.JSON(http.StatusInternalServerError, gin.H{"error": "streaming not supported"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "streaming not supported"})
return return
} }
var inputTokens, outputTokens int64
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() { for scanner.Scan() {
line := san.DesanitizeStreamEvent(scanner.Text()) line := san.DesanitizeStreamEvent(scanner.Text())
c.Writer.WriteString(line + "\n") c.Writer.WriteString(line + "\n")
flusher.Flush() flusher.Flush()
if len(line) > 5 && line[:5] == "data:" {
data := line[5:]
eventType := gjson.Get(data, "type").String()
switch eventType {
case "message_start":
inputTokens = gjson.Get(data, "message.usage.input_tokens").Int()
case "message_delta":
outputTokens = gjson.Get(data, "usage.output_tokens").Int()
}
}
} }
latencyMs := float64(time.Since(startTime).Milliseconds())
recordRequestMetrics(ctx, ri, http.StatusOK, latencyMs)
if inputTokens > 0 || outputTokens > 0 {
recordTokenUsage(ctx, model, cred, inputTokens, outputTokens)
if tracker != nil {
tracker.UpdateFromHeaders(resp.Header)
}
}
log.Info().
Float64("latency_ms", latencyMs).
Str("model", model).
Bool("stream", true).
Int64("input_tokens", inputTokens).
Int64("output_tokens", outputTokens).
Msg("stream completed")
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
log.Printf("stream scan error: %v", err) log.Error().Err(err).Msg("stream scan error")
} }
} }
// recordConnectionError logs and records metrics for upstream connection failures.
func recordConnectionError(ctx context.Context, err error, ri requestInfo, latencyMs float64) {
log.Error().
Err(err).
Str("credential", ri.cred.Email).
Str("model", ri.model).
Bool("stream", ri.stream).
Str("request_body_original", string(ri.originalBody)).
Str("request_body_sanitized", string(ri.body)).
Int("request_body_size", len(ri.body)).
Float64("latency_ms", latencyMs).
Msg("upstream connection error")
telemetry.UpstreamErrors.Add(ctx, 1,
metric.WithAttributes(
attribute.String("error_type", "connection"),
attribute.String("credential", ri.cred.Email),
attribute.Int("status_code", http.StatusBadGateway),
))
recordRequestMetrics(ctx, ri, http.StatusBadGateway, latencyMs)
}
// recordUpstreamError logs and records metrics for upstream HTTP error responses.
func recordUpstreamError(ctx context.Context, statusCode int, respBody []byte, requestID string, latencyMs float64, ri requestInfo, requestHeaders http.Header) {
errorType := gjson.GetBytes(respBody, "error.type").String()
errorMessage := gjson.GetBytes(respBody, "error.message").String()
log.Error().
Int("status", statusCode).
Str("error_type", errorType).
Str("error_message", errorMessage).
Str("response_body", string(respBody)).
Str("request_id", requestID).
Float64("latency_ms", latencyMs).
Str("credential", ri.cred.Email).
Str("model", ri.model).
Bool("stream", ri.stream).
Str("request_body_original", string(ri.originalBody)).
Str("request_body_sanitized", string(ri.body)).
Int("request_body_size", len(ri.body)).
Str("request_headers", logging.RedactHeaders(requestHeaders)).
Msg("upstream error")
telemetry.UpstreamErrors.Add(ctx, 1,
metric.WithAttributes(
attribute.Int("status_code", statusCode),
attribute.String("error_type", errorType),
attribute.String("credential", ri.cred.Email),
))
}
// recordRequestMetrics records the request counter and duration histogram.
func recordRequestMetrics(ctx context.Context, ri requestInfo, statusCode int, latencyMs float64) {
attrs := []attribute.KeyValue{
attribute.String("model", ri.model),
attribute.Bool("stream", ri.stream),
attribute.Int("status_code", statusCode),
}
telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...))
telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...))
}
// recordTokenUsage records token consumption metrics.
func recordTokenUsage(ctx context.Context, model string, cred *auth.Credential, inputTokens, outputTokens int64) {
tokenAttrs := metric.WithAttributes(
attribute.String("model", model),
attribute.String("credential", cred.Email),
)
telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs)
telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs)
}
+624
View File
@@ -0,0 +1,624 @@
package proxy
import (
"bytes"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/fujin/anthropic-proxy/internal/auth"
"github.com/fujin/anthropic-proxy/internal/config"
"github.com/fujin/anthropic-proxy/internal/ratelimit"
"github.com/fujin/anthropic-proxy/internal/telemetry"
"go.opentelemetry.io/otel/metric/noop"
)
func init() {
gin.SetMode(gin.TestMode)
// Initialize telemetry with noop meter to avoid nil pointer panics.
meter := noop.Meter{}
telemetry.InitMetrics(meter, nil)
}
// --- Request body reading and sanitization ---
func TestHandleMessages_ReadBodyError(t *testing.T) {
// A body that immediately fails on read shouldn't panic.
pool := auth.NewPool([]*auth.Credential{{ID: "c1", AccessToken: "tok", Email: "test@test.com"}})
san := NewSanitizer(config.SanitizeConfig{})
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", &errReader{})
handler(c)
if w.Code != http.StatusBadRequest {
t.Errorf("status = %d, want %d", w.Code, http.StatusBadRequest)
}
if !strings.Contains(w.Body.String(), "failed to read request body") {
t.Errorf("body = %q, expected error message about reading body", w.Body.String())
}
}
func TestHandleMessages_SanitizesRequestBody(t *testing.T) {
// We can't directly make HandleMessages use our mock server because
// UpstreamClient hardcodes messagesURL. Instead, we test sanitization
// by verifying the sanitizer is called on the body before any pool interaction.
san := NewSanitizer(config.SanitizeConfig{
Tools: []config.RenameRule{{From: "my_tool", To: "renamed_tool"}},
Body: []config.ReplaceRule{{Match: "secret", Replace: "redacted"}},
})
// Create body with tool name and secret
body := `{"model":"claude-sonnet-4-6","tools":[{"name":"my_tool"}],"messages":[{"role":"user","content":"secret data"}]}`
sanitizedBody := san.SanitizeRequest([]byte(body))
// Verify sanitization happened correctly
if !strings.Contains(string(sanitizedBody), "renamed_tool") {
t.Error("expected tool to be renamed in sanitized body")
}
if strings.Contains(string(sanitizedBody), "my_tool") {
t.Error("original tool name should be gone after sanitization")
}
if !strings.Contains(string(sanitizedBody), "redacted") {
t.Error("expected 'secret' to be replaced with 'redacted'")
}
if strings.Contains(string(sanitizedBody), "secret") {
t.Error("'secret' should be gone after sanitization")
}
}
func TestHandleMessages_PoolPickError(t *testing.T) {
// Empty pool — Pick() will fail.
pool := auth.NewPool(nil)
san := NewSanitizer(config.SanitizeConfig{})
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
body := `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
handler(c)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
}
if !strings.Contains(w.Body.String(), "no credentials available") {
t.Errorf("body = %q, expected pool error", w.Body.String())
}
}
func TestHandleMessages_PoolAllOnCooldown(t *testing.T) {
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "e"}
pool := auth.NewPool([]*auth.Credential{cred})
pool.MarkFailure(cred, 429) // puts on 30s cooldown
san := NewSanitizer(config.SanitizeConfig{})
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
body := `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
handler(c)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
}
if !strings.Contains(w.Body.String(), "cooldown") {
t.Errorf("body = %q, expected cooldown message", w.Body.String())
}
}
// --- Stream vs non-stream routing ---
func TestHandleMessages_StreamField_Detection(t *testing.T) {
tests := []struct {
name string
body string
isStream bool
}{
{
name: "stream true",
body: `{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`,
isStream: true,
},
{
name: "stream false",
body: `{"model":"claude-sonnet-4-6","stream":false,"messages":[{"role":"user","content":"hi"}]}`,
isStream: false,
},
{
name: "no stream field defaults to false",
body: `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`,
isStream: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := gjson.Get(tt.body, "stream").Bool()
if got != tt.isStream {
t.Errorf("stream = %v, want %v", got, tt.isStream)
}
})
}
}
// --- Desanitization on response ---
func TestDesanitization_NonStreamResponse(t *testing.T) {
san := NewSanitizer(config.SanitizeConfig{
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
})
// Simulate upstream response with renamed tool
upstreamResponse := `{"content":[{"type":"tool_use","name":"renamed_tool","id":"t1","input":{}}]}`
desanitized := san.DesanitizeResponse([]byte(upstreamResponse))
if !strings.Contains(string(desanitized), "original_tool") {
t.Errorf("expected tool name to be desanitized back to 'original_tool', got %s", string(desanitized))
}
if strings.Contains(string(desanitized), `"name":"renamed_tool"`) {
t.Error("renamed_tool should have been replaced by original_tool")
}
}
func TestDesanitization_StreamEvent(t *testing.T) {
san := NewSanitizer(config.SanitizeConfig{
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
})
event := `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed_tool","id":"t1"}}`
desanitized := san.DesanitizeStreamEvent(event)
if !strings.Contains(desanitized, "original_tool") {
t.Errorf("expected stream event to be desanitized, got %s", desanitized)
}
}
// --- handleNonStream behavior tests via direct function ---
func TestHandleNonStream_ConnectionError(t *testing.T) {
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
tracker := ratelimit.NewTracker(func() string { return "" })
uc := &UpstreamClient{
client: http.Client{Transport: &failingTransport{}},
sessionID: "test-sess",
profile: nil,
}
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleNonStream(c, uc, san, pool, cred, body, body, tracker)
if w.Code != http.StatusBadGateway {
t.Errorf("status = %d, want %d", w.Code, http.StatusBadGateway)
}
if !strings.Contains(w.Body.String(), "upstream request failed") {
t.Errorf("body = %q, expected upstream error message", w.Body.String())
}
}
func TestHandleNonStream_UpstreamSuccess(t *testing.T) {
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Request-Id", "req-123")
w.WriteHeader(200)
w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"text","text":"hello"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`))
}))
defer mockUpstream.Close()
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
tracker := ratelimit.NewTracker(func() string { return "" })
uc := &UpstreamClient{
client: *mockUpstream.Client(),
sessionID: "test-sess",
profile: nil,
}
// Override the messagesURL by constructing a custom Execute that uses the mock.
// Since we can't override the const, we test via a mock server approach:
// We create a custom http.Client with a transport that redirects to our mock.
uc.client.Transport = &rewriteTransport{
base: mockUpstream.Client().Transport,
destURL: mockUpstream.URL,
}
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleNonStream(c, uc, san, pool, cred, body, body, tracker)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String())
}
if !strings.Contains(w.Body.String(), "hello") {
t.Errorf("response body missing expected content: %s", w.Body.String())
}
if got := w.Header().Get("Content-Type"); got != "application/json" {
t.Errorf("Content-Type = %q, want %q", got, "application/json")
}
if got := w.Header().Get("X-Request-Id"); got != "req-123" {
t.Errorf("X-Request-Id = %q, want %q", got, "req-123")
}
}
func TestHandleNonStream_UpstreamError_MarkFailure(t *testing.T) {
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(429)
w.Write([]byte(`{"error":{"type":"rate_limit_error","message":"too many requests"}}`))
}))
defer mockUpstream.Close()
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
uc := &UpstreamClient{
client: *mockUpstream.Client(),
sessionID: "test-sess",
profile: nil,
}
uc.client.Transport = &rewriteTransport{
base: mockUpstream.Client().Transport,
destURL: mockUpstream.URL,
}
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleNonStream(c, uc, san, pool, cred, body, body, nil)
if w.Code != 429 {
t.Errorf("status = %d, want 429", w.Code)
}
// Verify MarkFailure was called — cred should now be on cooldown
if !cred.IsOnCooldown() {
t.Error("expected credential to be on cooldown after 429")
}
}
func TestHandleNonStream_UpstreamSuccess_DesanitizesResponse(t *testing.T) {
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"tool_use","name":"renamed_tool","id":"t1","input":{}}],"model":"claude-sonnet-4-6","stop_reason":"tool_use","usage":{"input_tokens":10,"output_tokens":5}}`))
}))
defer mockUpstream.Close()
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
})
uc := &UpstreamClient{
client: *mockUpstream.Client(),
sessionID: "test-sess",
profile: nil,
}
uc.client.Transport = &rewriteTransport{
base: mockUpstream.Client().Transport,
destURL: mockUpstream.URL,
}
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleNonStream(c, uc, san, pool, cred, body, body, nil)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String())
}
// Should be desanitized back to original_tool
if !strings.Contains(w.Body.String(), "original_tool") {
t.Errorf("response should contain desanitized tool name 'original_tool', got %s", w.Body.String())
}
}
func TestHandleNonStream_Upstream500_MarkFailure(t *testing.T) {
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(500)
w.Write([]byte(`{"error":{"type":"server_error","message":"internal error"}}`))
}))
defer mockUpstream.Close()
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
uc := &UpstreamClient{
client: *mockUpstream.Client(),
sessionID: "test-sess",
profile: nil,
}
uc.client.Transport = &rewriteTransport{
base: mockUpstream.Client().Transport,
destURL: mockUpstream.URL,
}
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleNonStream(c, uc, san, pool, cred, body, body, nil)
if w.Code != 500 {
t.Errorf("status = %d, want 500", w.Code)
}
if !cred.IsOnCooldown() {
t.Error("expected credential to be on cooldown after 500")
}
}
// --- handleStream behavior tests ---
func TestHandleStream_ConnectionError(t *testing.T) {
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
uc := &UpstreamClient{
client: http.Client{Transport: &failingTransport{}},
sessionID: "test-sess",
profile: nil,
}
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleStream(c, uc, san, pool, cred, body, body, nil)
if w.Code != http.StatusBadGateway {
t.Errorf("status = %d, want %d", w.Code, http.StatusBadGateway)
}
if !strings.Contains(w.Body.String(), "upstream stream request failed") {
t.Errorf("body = %q, expected upstream stream error", w.Body.String())
}
}
func TestHandleStream_UpstreamError(t *testing.T) {
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(429)
w.Write([]byte(`{"error":{"type":"rate_limit_error","message":"rate limited"}}`))
}))
defer mockUpstream.Close()
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
uc := &UpstreamClient{
client: *mockUpstream.Client(),
sessionID: "test-sess",
profile: nil,
}
uc.client.Transport = &rewriteTransport{
base: mockUpstream.Client().Transport,
destURL: mockUpstream.URL,
}
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleStream(c, uc, san, pool, cred, body, body, nil)
if w.Code != 429 {
t.Errorf("status = %d, want 429", w.Code)
}
if !cred.IsOnCooldown() {
t.Error("expected credential on cooldown after stream 429")
}
}
func TestHandleStream_SuccessForwardsEvents(t *testing.T) {
events := "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":0}}}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":5}}\n\nevent: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(200)
w.Write([]byte(events))
}))
defer mockUpstream.Close()
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
uc := &UpstreamClient{
client: *mockUpstream.Client(),
sessionID: "test-sess",
profile: nil,
}
uc.client.Transport = &rewriteTransport{
base: mockUpstream.Client().Transport,
destURL: mockUpstream.URL,
}
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleStream(c, uc, san, pool, cred, body, body, nil)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
}
respBody := w.Body.String()
if !strings.Contains(respBody, "message_start") {
t.Error("response missing message_start event")
}
if !strings.Contains(respBody, "hello") {
t.Error("response missing text content 'hello'")
}
if !strings.Contains(respBody, "message_stop") {
t.Error("response missing message_stop event")
}
// Verify SSE headers
if got := w.Header().Get("Content-Type"); got != "text/event-stream" {
t.Errorf("Content-Type = %q, want %q", got, "text/event-stream")
}
}
func TestHandleStream_DesanitizesEvents(t *testing.T) {
events := "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"name\":\"renamed_tool\",\"id\":\"t1\"}}\n\nevent: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"},\"usage\":{\"output_tokens\":5}}\n\nevent: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(200)
w.Write([]byte(events))
}))
defer mockUpstream.Close()
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
})
uc := &UpstreamClient{
client: *mockUpstream.Client(),
sessionID: "test-sess",
profile: nil,
}
uc.client.Transport = &rewriteTransport{
base: mockUpstream.Client().Transport,
destURL: mockUpstream.URL,
}
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleStream(c, uc, san, pool, cred, body, body, nil)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String())
}
if !strings.Contains(w.Body.String(), "original_tool") {
t.Errorf("stream response should contain desanitized 'original_tool', got %s", w.Body.String())
}
}
// --- HandleMessages full integration wiring test ---
func TestHandleMessages_WiresHandlerCorrectly(t *testing.T) {
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
// Verify the handler can be created without panic
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
if handler == nil {
t.Fatal("HandleMessages returned nil handler")
}
}
func TestHandleMessages_EmptyBody(t *testing.T) {
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
// Empty body — handler should still try to pick cred and call upstream
// (which will fail with connection error to api.anthropic.com, not a panic)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(""))
handler(c)
// Should get a 502 because the upstream URL (api.anthropic.com) won't be reachable
// in test environment, or it might complete. The key thing is no panic.
// We mainly verify it doesn't panic.
if w.Code == 0 {
t.Error("expected non-zero status code")
}
}
// --- Test helpers ---
// errReader is an io.Reader that always returns an error.
type errReader struct{}
func (e *errReader) Read([]byte) (int, error) {
return 0, io.ErrUnexpectedEOF
}
// failingTransport is an http.RoundTripper that always returns an error.
type failingTransport struct{}
func (f *failingTransport) RoundTrip(*http.Request) (*http.Response, error) {
return nil, fmt.Errorf("connection refused")
}
// rewriteTransport intercepts HTTP requests and rewrites the URL to point
// at a local test server. This allows testing with UpstreamClient's hardcoded
// messagesURL by redirecting all requests to a mock server.
type rewriteTransport struct {
base http.RoundTripper
destURL string
}
func (t *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL to point at our mock server
newReq := req.Clone(req.Context())
newReq.URL.Scheme = "http"
newReq.URL.Host = strings.TrimPrefix(t.destURL, "http://")
newReq.URL.Path = "/v1/messages"
newReq.URL.RawQuery = ""
if t.base == nil {
return http.DefaultTransport.RoundTrip(newReq)
}
return t.base.RoundTrip(newReq)
}
+21 -4
View File
@@ -4,6 +4,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/rs/zerolog/log"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
@@ -49,7 +50,11 @@ func (s *Sanitizer) DesanitizeResponse(body []byte) []byte {
} }
name := block.Get("name").String() name := block.Get("name").String()
if orig, ok := s.toolsReverse[name]; ok { if orig, ok := s.toolsReverse[name]; ok {
body, _ = sjson.SetBytes(body, "content."+strconv.Itoa(i)+".name", orig) if b, err := sjson.SetBytes(body, "content."+strconv.Itoa(i)+".name", orig); err != nil {
log.Warn().Err(err).Str("tool", name).Msg("desanitize response: set name failed")
} else {
body = b
}
} }
} }
return body return body
@@ -64,10 +69,14 @@ func (s *Sanitizer) DesanitizeStreamEvent(line string) string {
for _, path := range []string{"content_block.name", "delta.name"} { for _, path := range []string{"content_block.name", "delta.name"} {
name := gjson.GetBytes(data, path).String() name := gjson.GetBytes(data, path).String()
if orig, ok := s.toolsReverse[name]; ok { if orig, ok := s.toolsReverse[name]; ok {
data, _ = sjson.SetBytes(data, path, orig) if b, err := sjson.SetBytes(data, path, orig); err != nil {
log.Warn().Err(err).Str("tool", name).Msg("desanitize stream event: set name failed")
} else {
data = b
changed = true changed = true
} }
} }
}
if changed { if changed {
return "data: " + string(data) return "data: " + string(data)
} }
@@ -85,7 +94,11 @@ func (s *Sanitizer) renameTools(body []byte) []byte {
for i, tool := range tools.Array() { for i, tool := range tools.Array() {
name := tool.Get("name").String() name := tool.Get("name").String()
if newName, ok := s.toolsForward[name]; ok { if newName, ok := s.toolsForward[name]; ok {
body, _ = sjson.SetBytes(body, "tools."+strconv.Itoa(i)+".name", newName) if b, err := sjson.SetBytes(body, "tools."+strconv.Itoa(i)+".name", newName); err != nil {
log.Warn().Err(err).Str("tool", name).Msg("rename tool failed")
} else {
body = b
}
} }
} }
return body return body
@@ -104,7 +117,11 @@ func (s *Sanitizer) replaceSystem(body []byte) []byte {
for _, rule := range s.systemRules { for _, rule := range s.systemRules {
text = strings.ReplaceAll(text, rule.Match, rule.Replace) text = strings.ReplaceAll(text, rule.Match, rule.Replace)
} }
body, _ = sjson.SetBytes(body, "system."+strconv.Itoa(i)+".text", text) if b, err := sjson.SetBytes(body, "system."+strconv.Itoa(i)+".text", text); err != nil {
log.Warn().Err(err).Int("block", i).Msg("replace system text failed")
} else {
body = b
}
} }
return body return body
} }
+476
View File
@@ -0,0 +1,476 @@
package proxy
import (
"strings"
"testing"
"github.com/fujin/anthropic-proxy/internal/config"
)
func TestNewSanitizer_Empty(t *testing.T) {
s := NewSanitizer(config.SanitizeConfig{})
if len(s.toolsForward) != 0 {
t.Errorf("expected empty toolsForward, got %d entries", len(s.toolsForward))
}
if len(s.toolsReverse) != 0 {
t.Errorf("expected empty toolsReverse, got %d entries", len(s.toolsReverse))
}
if s.systemRules != nil {
t.Errorf("expected nil systemRules")
}
if s.bodyRules != nil {
t.Errorf("expected nil bodyRules")
}
}
func TestNewSanitizer_WithTools(t *testing.T) {
cfg := config.SanitizeConfig{
Tools: []config.RenameRule{
{From: "old_tool", To: "new_tool"},
{From: "another", To: "replaced"},
},
}
s := NewSanitizer(cfg)
if got := s.toolsForward["old_tool"]; got != "new_tool" {
t.Errorf("toolsForward[old_tool] = %q, want %q", got, "new_tool")
}
if got := s.toolsReverse["new_tool"]; got != "old_tool" {
t.Errorf("toolsReverse[new_tool] = %q, want %q", got, "old_tool")
}
if got := s.toolsForward["another"]; got != "replaced" {
t.Errorf("toolsForward[another] = %q, want %q", got, "replaced")
}
if got := s.toolsReverse["replaced"]; got != "another" {
t.Errorf("toolsReverse[replaced] = %q, want %q", got, "another")
}
}
func TestNewSanitizer_WithSystemAndBodyRules(t *testing.T) {
cfg := config.SanitizeConfig{
System: []config.ReplaceRule{{Match: "foo", Replace: "bar"}},
Body: []config.ReplaceRule{{Match: "baz", Replace: "qux"}},
}
s := NewSanitizer(cfg)
if len(s.systemRules) != 1 || s.systemRules[0].Match != "foo" {
t.Errorf("systemRules not set correctly")
}
if len(s.bodyRules) != 1 || s.bodyRules[0].Match != "baz" {
t.Errorf("bodyRules not set correctly")
}
}
func TestRenameTools(t *testing.T) {
tests := []struct {
name string
forward map[string]string
body string
expected string
}{
{
name: "empty map returns body unchanged",
forward: map[string]string{},
body: `{"tools":[{"name":"my_tool"}]}`,
expected: `{"tools":[{"name":"my_tool"}]}`,
},
{
name: "no tools array returns body unchanged",
forward: map[string]string{"my_tool": "renamed"},
body: `{"messages":[]}`,
expected: `{"messages":[]}`,
},
{
name: "tools is not array returns body unchanged",
forward: map[string]string{"my_tool": "renamed"},
body: `{"tools":"not_array"}`,
expected: `{"tools":"not_array"}`,
},
{
name: "matching tool gets renamed",
forward: map[string]string{"my_tool": "renamed_tool"},
body: `{"tools":[{"name":"my_tool","description":"desc"}]}`,
expected: `renamed_tool`,
},
{
name: "non-matching tool unchanged",
forward: map[string]string{"other_tool": "renamed"},
body: `{"tools":[{"name":"my_tool"}]}`,
expected: `my_tool`,
},
{
name: "partial match - only exact match renames",
forward: map[string]string{"tool_a": "tool_x", "tool_b": "tool_y"},
body: `{"tools":[{"name":"tool_a"},{"name":"tool_c"},{"name":"tool_b"}]}`,
expected: `tool_x`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Sanitizer{
toolsForward: tt.forward,
toolsReverse: make(map[string]string),
}
result := string(s.renameTools([]byte(tt.body)))
if !strings.Contains(result, tt.expected) {
t.Errorf("result %q does not contain %q", result, tt.expected)
}
})
}
}
func TestRenameTools_MultipleTools(t *testing.T) {
s := &Sanitizer{
toolsForward: map[string]string{"tool_a": "tool_x", "tool_b": "tool_y"},
toolsReverse: make(map[string]string),
}
body := `{"tools":[{"name":"tool_a"},{"name":"tool_c"},{"name":"tool_b"}]}`
result := string(s.renameTools([]byte(body)))
if !strings.Contains(result, `"tool_x"`) {
t.Errorf("tool_a should be renamed to tool_x, got %s", result)
}
if !strings.Contains(result, `"tool_y"`) {
t.Errorf("tool_b should be renamed to tool_y, got %s", result)
}
if !strings.Contains(result, `"tool_c"`) {
t.Errorf("tool_c should remain unchanged, got %s", result)
}
}
func TestReplaceSystem(t *testing.T) {
tests := []struct {
name string
rules []config.ReplaceRule
body string
contains string
}{
{
name: "empty rules returns body unchanged",
rules: nil,
body: `{"system":[{"type":"text","text":"hello world"}]}`,
contains: "hello world",
},
{
name: "no system field returns body unchanged",
rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}},
body: `{"messages":[]}`,
contains: `"messages":[]`,
},
{
name: "system not array returns body unchanged",
rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}},
body: `{"system":"just a string"}`,
contains: "just a string",
},
{
name: "single block single rule",
rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}},
body: `{"system":[{"type":"text","text":"hello world"}]}`,
contains: "goodbye world",
},
{
name: "multiple blocks",
rules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}},
body: `{"system":[{"type":"text","text":"AAA first"},{"type":"text","text":"AAA second"}]}`,
contains: "BBB first",
},
{
name: "multiple rules applied in order",
rules: []config.ReplaceRule{{Match: "cat", Replace: "dog"}, {Match: "dog", Replace: "fish"}},
body: `{"system":[{"type":"text","text":"I have a cat"}]}`,
contains: "I have a fish",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Sanitizer{
toolsForward: make(map[string]string),
toolsReverse: make(map[string]string),
systemRules: tt.rules,
}
result := string(s.replaceSystem([]byte(tt.body)))
if !strings.Contains(result, tt.contains) {
t.Errorf("result %q does not contain %q", result, tt.contains)
}
})
}
}
func TestReplaceSystem_MultipleBlocks(t *testing.T) {
s := &Sanitizer{
toolsForward: make(map[string]string),
toolsReverse: make(map[string]string),
systemRules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}},
}
body := `{"system":[{"type":"text","text":"AAA first"},{"type":"text","text":"AAA second"}]}`
result := string(s.replaceSystem([]byte(body)))
if !strings.Contains(result, "BBB first") {
t.Errorf("first block not replaced: %s", result)
}
if !strings.Contains(result, "BBB second") {
t.Errorf("second block not replaced: %s", result)
}
}
func TestReplaceBody(t *testing.T) {
tests := []struct {
name string
rules []config.ReplaceRule
body string
expected string
}{
{
name: "empty rules returns body unchanged",
rules: nil,
body: `{"foo":"bar"}`,
expected: `{"foo":"bar"}`,
},
{
name: "single replacement across entire body",
rules: []config.ReplaceRule{{Match: "SECRET", Replace: "REDACTED"}},
body: `{"data":"SECRET value SECRET"}`,
expected: `{"data":"REDACTED value REDACTED"}`,
},
{
name: "multiple rules applied sequentially",
rules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}, {Match: "BBB", Replace: "CCC"}},
body: `{"text":"AAA"}`,
expected: `{"text":"CCC"}`,
},
{
name: "no match leaves body unchanged",
rules: []config.ReplaceRule{{Match: "NOMATCH", Replace: "X"}},
body: `{"text":"hello"}`,
expected: `{"text":"hello"}`,
},
{
name: "empty body",
rules: []config.ReplaceRule{{Match: "a", Replace: "b"}},
body: ``,
expected: ``,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Sanitizer{
toolsForward: make(map[string]string),
toolsReverse: make(map[string]string),
bodyRules: tt.rules,
}
result := string(s.replaceBody([]byte(tt.body)))
if result != tt.expected {
t.Errorf("got %q, want %q", result, tt.expected)
}
})
}
}
func TestSanitizeRequest(t *testing.T) {
cfg := config.SanitizeConfig{
Tools: []config.RenameRule{{From: "my_tool", To: "renamed_tool"}},
System: []config.ReplaceRule{{Match: "INTERNAL", Replace: "PUBLIC"}},
Body: []config.ReplaceRule{{Match: "secret_val", Replace: "safe_val"}},
}
s := NewSanitizer(cfg)
body := `{"tools":[{"name":"my_tool"}],"system":[{"type":"text","text":"INTERNAL info"}],"data":"secret_val here"}`
result := string(s.SanitizeRequest([]byte(body)))
if !strings.Contains(result, `"renamed_tool"`) {
t.Errorf("tool not renamed in result: %s", result)
}
if !strings.Contains(result, "PUBLIC info") {
t.Errorf("system not replaced in result: %s", result)
}
if !strings.Contains(result, "safe_val here") {
t.Errorf("body not replaced in result: %s", result)
}
if strings.Contains(result, "secret_val") {
t.Errorf("secret_val should have been replaced: %s", result)
}
}
func TestSanitizeRequest_EmptyConfig(t *testing.T) {
s := NewSanitizer(config.SanitizeConfig{})
body := `{"tools":[{"name":"my_tool"}],"system":[{"type":"text","text":"hello"}]}`
result := string(s.SanitizeRequest([]byte(body)))
if result != body {
t.Errorf("empty config should not modify body.\ngot: %s\nwant: %s", result, body)
}
}
func TestDesanitizeResponse(t *testing.T) {
tests := []struct {
name string
reverse map[string]string
body string
expected string
}{
{
name: "no content field returns unchanged",
reverse: map[string]string{"renamed": "original"},
body: `{"id":"msg_1","role":"assistant"}`,
expected: `{"id":"msg_1","role":"assistant"}`,
},
{
name: "content not array returns unchanged",
reverse: map[string]string{"renamed": "original"},
body: `{"content":"just text"}`,
expected: `{"content":"just text"}`,
},
{
name: "non-tool_use block left unchanged",
reverse: map[string]string{"renamed": "original"},
body: `{"content":[{"type":"text","text":"hello"}]}`,
expected: `{"content":[{"type":"text","text":"hello"}]}`,
},
{
name: "tool_use block with matching name gets reversed",
reverse: map[string]string{"renamed_tool": "original_tool"},
body: `{"content":[{"type":"tool_use","name":"renamed_tool","id":"t1"}]}`,
expected: `original_tool`,
},
{
name: "tool_use block with no match unchanged",
reverse: map[string]string{"other": "something"},
body: `{"content":[{"type":"tool_use","name":"my_tool","id":"t1"}]}`,
expected: `my_tool`,
},
{
name: "mixed blocks only tool_use reversed",
reverse: map[string]string{"renamed": "original"},
body: `{"content":[{"type":"text","text":"hi"},{"type":"tool_use","name":"renamed","id":"t1"}]}`,
expected: `original`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Sanitizer{
toolsForward: make(map[string]string),
toolsReverse: tt.reverse,
}
result := string(s.DesanitizeResponse([]byte(tt.body)))
if !strings.Contains(result, tt.expected) {
t.Errorf("result %q does not contain %q", result, tt.expected)
}
})
}
}
func TestDesanitizeResponse_MultipleToolUse(t *testing.T) {
s := &Sanitizer{
toolsForward: make(map[string]string),
toolsReverse: map[string]string{"r1": "o1", "r2": "o2"},
}
body := `{"content":[{"type":"tool_use","name":"r1","id":"t1"},{"type":"text","text":"x"},{"type":"tool_use","name":"r2","id":"t2"}]}`
result := string(s.DesanitizeResponse([]byte(body)))
if !strings.Contains(result, `"o1"`) {
t.Errorf("r1 not reversed to o1: %s", result)
}
if !strings.Contains(result, `"o2"`) {
t.Errorf("r2 not reversed to o2: %s", result)
}
}
func TestDesanitizeStreamEvent(t *testing.T) {
tests := []struct {
name string
reverse map[string]string
line string
expected string
}{
{
name: "non-data line passed through",
reverse: map[string]string{"r": "o"},
line: "event: content_block_start",
expected: "event: content_block_start",
},
{
name: "data line without tool_use passed through",
reverse: map[string]string{"r": "o"},
line: `data: {"type":"text","text":"hello"}`,
expected: `data: {"type":"text","text":"hello"}`,
},
{
name: "data line with tool_use in content_block.name",
reverse: map[string]string{"renamed_tool": "original_tool"},
line: `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed_tool","id":"t1"}}`,
expected: `original_tool`,
},
{
name: "data line with tool_use in delta.name",
reverse: map[string]string{"renamed_tool": "original_tool"},
line: `data: {"type":"content_block_delta","delta":{"type":"tool_use","name":"renamed_tool"}}`,
expected: `original_tool`,
},
{
name: "data line with tool_use but no matching name",
reverse: map[string]string{"other": "something"},
line: `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"my_tool","id":"t1"}}`,
expected: `my_tool`,
},
{
name: "empty line passed through",
reverse: map[string]string{"r": "o"},
line: "",
expected: "",
},
{
name: "line contains tool_use but not data prefix - passed through",
reverse: map[string]string{"r": "o"},
line: "event: tool_use",
expected: "event: tool_use",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Sanitizer{
toolsForward: make(map[string]string),
toolsReverse: tt.reverse,
}
result := s.DesanitizeStreamEvent(tt.line)
if !strings.Contains(result, tt.expected) {
t.Errorf("result %q does not contain %q", result, tt.expected)
}
})
}
}
func TestDesanitizeStreamEvent_DataPrefixPreserved(t *testing.T) {
s := &Sanitizer{
toolsForward: make(map[string]string),
toolsReverse: map[string]string{"renamed": "original"},
}
line := `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed","id":"t1"}}`
result := s.DesanitizeStreamEvent(line)
if !strings.HasPrefix(result, "data: ") {
t.Errorf("result should start with 'data: ', got %q", result)
}
}
func TestSanitizeRequest_MalformedJSON(t *testing.T) {
s := NewSanitizer(config.SanitizeConfig{
Tools: []config.RenameRule{{From: "a", To: "b"}},
System: []config.ReplaceRule{{Match: "x", Replace: "y"}},
})
// Malformed JSON - renameTools and replaceSystem should handle gracefully
body := `not valid json`
result := string(s.SanitizeRequest([]byte(body)))
// Should not panic; body rules still do string replacement
if result != "not valid json" {
t.Errorf("malformed JSON should pass through (no body rules match), got %q", result)
}
}
func TestSanitizeRequest_EmptyBody(t *testing.T) {
s := NewSanitizer(config.SanitizeConfig{
Tools: []config.RenameRule{{From: "a", To: "b"}},
})
result := s.SanitizeRequest([]byte{})
if len(result) != 0 {
t.Errorf("empty body should return empty, got %q", string(result))
}
}
+65 -44
View File
@@ -3,13 +3,14 @@ package proxy
import ( import (
"fmt" "fmt"
"io" "io"
"log"
"net" "net"
"net/http" "net/http"
"os/exec" "os/exec"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/rs/zerolog/log"
) )
// SniffedProfile holds everything captured from a real Claude Code request. // SniffedProfile holds everything captured from a real Claude Code request.
@@ -35,6 +36,21 @@ var skipHeaders = map[string]bool{
"connection": true, "connection": true,
} }
const fakeJSONResponse = `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`
const fakeStreamResponse = "event: message_start\n" +
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_fake\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}\n\n" +
"event: content_block_start\n" +
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n" +
"event: content_block_delta\n" +
"data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"ok\"}}\n\n" +
"event: content_block_stop\n" +
"data: {\"type\":\"content_block_stop\",\"index\":0}\n\n" +
"event: message_delta\n" +
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":1}}\n\n" +
"event: message_stop\n" +
"data: {\"type\":\"message_stop\"}\n\n"
func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) { func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
@@ -47,45 +63,7 @@ func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) {
captured := make(chan struct{}, 1) captured := make(chan struct{}, 1)
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/", sniffHandler(&mu, &profile, captured))
if r.Method == "HEAD" {
w.WriteHeader(200)
return
}
if r.Method != "POST" || !strings.Contains(r.URL.Path, "/v1/messages") {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
fmt.Fprint(w, `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`)
return
}
body, _ := io.ReadAll(r.Body)
mu.Lock()
if profile == nil {
profile = extractProfile(r, body)
select {
case captured <- struct{}{}:
default:
}
}
mu.Unlock()
if strings.Contains(string(body), `"stream":true`) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(200)
fmt.Fprint(w, "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_fake\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}\n\n")
fmt.Fprint(w, "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n")
fmt.Fprint(w, "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"ok\"}}\n\n")
fmt.Fprint(w, "event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n")
fmt.Fprint(w, "event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":1}}\n\n")
fmt.Fprint(w, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
} else {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
fmt.Fprint(w, `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`)
}
})
srv := &http.Server{Handler: mux} srv := &http.Server{Handler: mux}
go srv.Serve(listener) go srv.Serve(listener)
@@ -116,13 +94,57 @@ func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) {
return nil, fmt.Errorf("no API request captured") return nil, fmt.Errorf("no API request captured")
} }
log.Printf("sniffed claude-code: version=%s headers=%d body=%d bytes", log.Info().
profile.Version, len(profile.Headers), len(profile.Body)) Str("version", profile.Version).
Int("headers", len(profile.Headers)).
Int("body_size", len(profile.Body)).
Msg("sniffed claude-code profile")
for _, h := range profile.Headers {
log.Debug().Str("header", h[0]).Str("value", h[1]).Msg("sniffed header")
}
return profile, nil return profile, nil
} }
func sniffHandler(mu *sync.Mutex, profile **SniffedProfile, captured chan<- struct{}) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method == "HEAD" {
w.WriteHeader(200)
return
}
if r.Method != "POST" || !strings.Contains(r.URL.Path, "/v1/messages") {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
fmt.Fprint(w, fakeJSONResponse)
return
}
body, _ := io.ReadAll(r.Body)
mu.Lock()
if *profile == nil {
*profile = extractProfile(r, body)
select {
case captured <- struct{}{}:
default:
}
}
mu.Unlock()
if strings.Contains(string(body), `"stream":true`) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(200)
fmt.Fprint(w, fakeStreamResponse)
} else {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
fmt.Fprint(w, fakeJSONResponse)
}
}
}
func extractProfile(r *http.Request, body []byte) *SniffedProfile { func extractProfile(r *http.Request, body []byte) *SniffedProfile {
// Capture raw headers preserving original casing.
var headers [][2]string var headers [][2]string
for name, vals := range r.Header { for name, vals := range r.Header {
if skipHeaders[strings.ToLower(name)] { if skipHeaders[strings.ToLower(name)] {
@@ -133,7 +155,6 @@ func extractProfile(r *http.Request, body []byte) *SniffedProfile {
} }
} }
// Deduplicate and strip subscription-specific betas.
seen := map[string]bool{} seen := map[string]bool{}
var deduped [][2]string var deduped [][2]string
for _, h := range headers { for _, h := range headers {
+278
View File
@@ -0,0 +1,278 @@
package proxy
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func newRequest(t *testing.T, headers map[string][]string) *http.Request {
t.Helper()
r := httptest.NewRequest("POST", "/v1/messages", nil)
r.Header = http.Header{}
for k, vals := range headers {
for _, v := range vals {
r.Header.Add(k, v)
}
}
return r
}
func TestExtractProfile_BasicHeaders(t *testing.T) {
r := newRequest(t, map[string][]string{
"Content-Type": {"application/json"},
"X-Custom-Header": {"custom-value"},
"User-Agent": {"Claude/1.2.3 linux"},
})
body := []byte(`{"model":"claude-sonnet-4-6"}`)
p := extractProfile(r, body)
// Check version parsed
if p.Version != "1.2.3" {
t.Errorf("version = %q, want %q", p.Version, "1.2.3")
}
// Check body preserved
if string(p.Body) != string(body) {
t.Errorf("body not preserved")
}
// Check headers captured
found := map[string]bool{}
for _, h := range p.Headers {
found[strings.ToLower(h[0])] = true
}
if !found["content-type"] {
t.Error("Content-Type header should be captured")
}
if !found["x-custom-header"] {
t.Error("X-Custom-Header should be captured")
}
}
func TestExtractProfile_SkipHeaders(t *testing.T) {
r := newRequest(t, map[string][]string{
"Host": {"example.com"},
"Content-Length": {"42"},
"Authorization": {"Bearer token123"},
"X-Api-Key": {"key123"},
"Connection": {"keep-alive"},
"Content-Type": {"application/json"},
"X-Custom": {"keep-me"},
})
p := extractProfile(r, []byte(`{}`))
for _, h := range p.Headers {
lower := strings.ToLower(h[0])
if skipHeaders[lower] {
t.Errorf("header %q should have been skipped", h[0])
}
}
// Verify non-skipped headers are present
found := map[string]bool{}
for _, h := range p.Headers {
found[strings.ToLower(h[0])] = true
}
if !found["content-type"] {
t.Error("Content-Type should be kept")
}
if !found["x-custom"] {
t.Error("X-Custom should be kept")
}
}
func TestExtractProfile_HeaderDeduplication(t *testing.T) {
r := newRequest(t, map[string][]string{
"Content-Type": {"application/json"},
})
// Add duplicate with different casing - Go's http.Header normalizes to canonical form
// so we need to add the same canonical header with multiple values to test dedup
r.Header.Add("Content-Type", "text/plain")
p := extractProfile(r, []byte(`{}`))
// After deduplication by lowercase key, only one entry per key
seen := map[string]int{}
for _, h := range p.Headers {
seen[strings.ToLower(h[0])]++
}
for key, count := range seen {
if count > 1 {
t.Errorf("header %q appears %d times after dedup, want 1", key, count)
}
}
}
func TestExtractProfile_AnthropicBetaContextStripping(t *testing.T) {
r := newRequest(t, map[string][]string{
"Anthropic-Beta": {"prompt-caching-2024-07-31,context-1m-2024-09-01,some-other-beta"},
})
p := extractProfile(r, []byte(`{}`))
var betaValue string
for _, h := range p.Headers {
if strings.ToLower(h[0]) == "anthropic-beta" {
betaValue = h[1]
break
}
}
if strings.Contains(betaValue, "context-1m") {
t.Errorf("context-1m should be stripped from anthropic-beta, got %q", betaValue)
}
if !strings.Contains(betaValue, "prompt-caching-2024-07-31") {
t.Errorf("prompt-caching should be preserved, got %q", betaValue)
}
if !strings.Contains(betaValue, "some-other-beta") {
t.Errorf("some-other-beta should be preserved, got %q", betaValue)
}
}
func TestExtractProfile_AnthropicBetaAllContextRemoved(t *testing.T) {
r := newRequest(t, map[string][]string{
"Anthropic-Beta": {"context-1m-2024-09-01"},
})
p := extractProfile(r, []byte(`{}`))
for _, h := range p.Headers {
if strings.ToLower(h[0]) == "anthropic-beta" {
// All betas were context-1m, so after filtering the value should be empty
if h[1] != "" {
t.Errorf("all context-1m betas stripped should leave empty, got %q", h[1])
}
return
}
}
// It's also acceptable if the header is still present but empty
}
func TestExtractProfile_VersionParsing(t *testing.T) {
tests := []struct {
name string
userAgent string
expected string
}{
{
name: "standard Claude UA",
userAgent: "Claude/1.2.3 linux x86_64",
expected: "1.2.3",
},
{
name: "version with no space after",
userAgent: "Claude/4.5.6",
expected: "4.5.6",
},
{
name: "no slash in UA",
userAgent: "Mozilla 5.0",
expected: "",
},
{
name: "empty UA",
userAgent: "",
expected: "",
},
{
name: "slash at start",
userAgent: "/1.0.0 rest",
expected: "",
},
{
name: "multiple slashes",
userAgent: "App/1.0.0 (sub/2.0)",
expected: "1.0.0",
},
{
name: "version only after slash no space",
userAgent: "Tool/9.8.7",
expected: "9.8.7",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := newRequest(t, map[string][]string{
"User-Agent": {tt.userAgent},
})
p := extractProfile(r, []byte(`{}`))
if p.Version != tt.expected {
t.Errorf("version = %q, want %q", p.Version, tt.expected)
}
})
}
}
func TestExtractProfile_EmptyHeaders(t *testing.T) {
r := httptest.NewRequest("POST", "/v1/messages", nil)
r.Header = http.Header{}
p := extractProfile(r, []byte(`{"test":true}`))
if len(p.Headers) != 0 {
t.Errorf("expected no headers, got %d", len(p.Headers))
}
if p.Version != "" {
t.Errorf("expected empty version with no UA, got %q", p.Version)
}
if string(p.Body) != `{"test":true}` {
t.Errorf("body not preserved")
}
}
func TestExtractProfile_BodyPreserved(t *testing.T) {
r := newRequest(t, map[string][]string{
"User-Agent": {"Claude/1.0.0 test"},
})
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hello"}],"stream":true}`)
p := extractProfile(r, body)
if string(p.Body) != string(body) {
t.Errorf("body not preserved.\ngot: %s\nwant: %s", p.Body, body)
}
}
func TestSkipHeaders_Entries(t *testing.T) {
expected := map[string]bool{
"host": true,
"content-length": true,
"authorization": true,
"x-api-key": true,
"connection": true,
}
if len(skipHeaders) != len(expected) {
t.Errorf("skipHeaders has %d entries, want %d", len(skipHeaders), len(expected))
}
for k, v := range expected {
if skipHeaders[k] != v {
t.Errorf("skipHeaders[%q] = %v, want %v", k, skipHeaders[k], v)
}
}
}
func TestSniffedProfile_Fields(t *testing.T) {
// Verify the struct can hold all expected data
p := &SniffedProfile{
Headers: [][2]string{{"Content-Type", "application/json"}},
Body: []byte(`{}`),
Version: "1.0.0",
}
if len(p.Headers) != 1 {
t.Error("Headers should have 1 entry")
}
if p.Headers[0][0] != "Content-Type" || p.Headers[0][1] != "application/json" {
t.Error("Header not stored correctly")
}
if string(p.Body) != `{}` {
t.Error("Body not stored correctly")
}
if p.Version != "1.0.0" {
t.Error("Version not stored correctly")
}
}
+39 -2
View File
@@ -9,8 +9,12 @@ import (
"strings" "strings"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/fujin/anthropic-proxy/internal/auth" "github.com/fujin/anthropic-proxy/internal/auth"
"github.com/fujin/anthropic-proxy/internal/logging"
"github.com/fujin/anthropic-proxy/internal/transport"
"github.com/fujin/anthropic-proxy/internal/version"
) )
const messagesURL = "https://api.anthropic.com/v1/messages?beta=true" const messagesURL = "https://api.anthropic.com/v1/messages?beta=true"
@@ -25,7 +29,7 @@ func NewUpstreamClient(profile *SniffedProfile) *UpstreamClient {
return &UpstreamClient{ return &UpstreamClient{
client: http.Client{ client: http.Client{
Timeout: 0, Timeout: 0,
Transport: newUtlsRoundTripper(), Transport: transport.NewUTLS(),
}, },
sessionID: uuid.New().String(), sessionID: uuid.New().String(),
profile: profile, profile: profile,
@@ -36,7 +40,7 @@ func (u *UpstreamClient) version() string {
if u.profile != nil && u.profile.Version != "" { if u.profile != nil && u.profile.Version != "" {
return u.profile.Version return u.profile.Version
} }
return "2.1.92" return version.ClaudeCodeFallback
} }
// applyHeaders replays sniffed headers, substituting auth + per-request IDs + accept. // applyHeaders replays sniffed headers, substituting auth + per-request IDs + accept.
@@ -51,6 +55,15 @@ func (u *UpstreamClient) applyHeaders(req *http.Request, token string, streaming
req.Header.Del("x-api-key") req.Header.Del("x-api-key")
if strings.HasPrefix(token, "sk-ant-oat") { if strings.HasPrefix(token, "sk-ant-oat") {
req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Authorization", "Bearer "+token)
// OAuth tokens require this beta flag — without it the API rejects with 401
existing := req.Header.Get("anthropic-beta")
if !strings.Contains(existing, "oauth-2025-04-20") {
if existing == "" {
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
} else {
req.Header.Set("anthropic-beta", existing+",oauth-2025-04-20")
}
}
} else { } else {
req.Header.Set("x-api-key", token) req.Header.Set("x-api-key", token)
} }
@@ -75,6 +88,12 @@ func (u *UpstreamClient) Execute(ctx context.Context, cred *auth.Credential, bod
} }
u.applyHeaders(req, cred.Token(), false) u.applyHeaders(req, cred.Token(), false)
log.Debug().
Str("url", messagesURL).
Str("upstream_headers", logging.RedactHeaders(req.Header)).
Int("body_size", len(body)).
Msg("upstream request")
resp, err := u.client.Do(req) resp, err := u.client.Do(req)
if err != nil { if err != nil {
return nil, nil, 0, fmt.Errorf("upstream request: %w", err) return nil, nil, 0, fmt.Errorf("upstream request: %w", err)
@@ -85,6 +104,12 @@ func (u *UpstreamClient) Execute(ctx context.Context, cred *auth.Credential, bod
if err != nil { if err != nil {
return nil, nil, resp.StatusCode, fmt.Errorf("read upstream response: %w", err) return nil, nil, resp.StatusCode, fmt.Errorf("read upstream response: %w", err)
} }
log.Debug().
Int("status", resp.StatusCode).
Str("response_headers", logging.RedactHeaders(resp.Header)).
Int("response_size", len(respBody)).
Msg("upstream response")
return respBody, resp.Header, resp.StatusCode, nil return respBody, resp.Header, resp.StatusCode, nil
} }
@@ -97,9 +122,21 @@ func (u *UpstreamClient) ExecuteStream(ctx context.Context, cred *auth.Credentia
} }
u.applyHeaders(req, cred.Token(), true) u.applyHeaders(req, cred.Token(), true)
log.Debug().
Str("url", messagesURL).
Str("upstream_headers", logging.RedactHeaders(req.Header)).
Int("body_size", len(body)).
Msg("upstream stream request")
resp, err := u.client.Do(req) resp, err := u.client.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("upstream stream request: %w", err) return nil, fmt.Errorf("upstream stream request: %w", err)
} }
log.Debug().
Int("status", resp.StatusCode).
Str("response_headers", logging.RedactHeaders(resp.Header)).
Msg("upstream stream response")
return resp, nil return resp, nil
} }
+334
View File
@@ -0,0 +1,334 @@
package proxy
import (
"net/http"
"strings"
"testing"
)
// --- NewUpstreamClient ---
func TestNewUpstreamClient_NilProfile(t *testing.T) {
uc := NewUpstreamClient(nil)
if uc == nil {
t.Fatal("NewUpstreamClient returned nil")
}
if uc.sessionID == "" {
t.Error("expected non-empty sessionID")
}
if uc.profile != nil {
t.Error("expected nil profile")
}
}
func TestNewUpstreamClient_WithProfile(t *testing.T) {
profile := &SniffedProfile{
Version: "1.2.3",
Headers: [][2]string{{"User-Agent", "test/1.0"}},
}
uc := NewUpstreamClient(profile)
if uc.profile != profile {
t.Error("expected profile to be stored")
}
if uc.sessionID == "" {
t.Error("expected non-empty sessionID")
}
}
func TestNewUpstreamClient_UniqueSessionIDs(t *testing.T) {
uc1 := NewUpstreamClient(nil)
uc2 := NewUpstreamClient(nil)
if uc1.sessionID == uc2.sessionID {
t.Errorf("expected different session IDs, both got %q", uc1.sessionID)
}
}
// --- version() ---
func TestVersion_WithProfileVersion(t *testing.T) {
uc := &UpstreamClient{
profile: &SniffedProfile{Version: "3.5.7"},
}
if got := uc.version(); got != "3.5.7" {
t.Errorf("version() = %q, want %q", got, "3.5.7")
}
}
func TestVersion_NilProfile_Fallback(t *testing.T) {
uc := &UpstreamClient{profile: nil}
if got := uc.version(); got != "2.1.92" {
t.Errorf("version() = %q, want %q", got, "2.1.92")
}
}
func TestVersion_EmptyProfileVersion_Fallback(t *testing.T) {
uc := &UpstreamClient{
profile: &SniffedProfile{Version: ""},
}
if got := uc.version(); got != "2.1.92" {
t.Errorf("version() = %q, want %q", got, "2.1.92")
}
}
// --- applyHeaders ---
func TestApplyHeaders_NilProfile_NonOAuth_NonStream(t *testing.T) {
uc := &UpstreamClient{
sessionID: "test-session-id",
profile: nil,
}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-api123", false)
// x-api-key for non-OAuth token
if got := req.Header.Get("x-api-key"); got != "sk-ant-api123" {
t.Errorf("x-api-key = %q, want %q", got, "sk-ant-api123")
}
// Should NOT have Authorization
if got := req.Header.Get("Authorization"); got != "" {
t.Errorf("Authorization = %q, want empty", got)
}
// Session ID
if got := req.Header.Get("X-Claude-Code-Session-Id"); got != "test-session-id" {
t.Errorf("X-Claude-Code-Session-Id = %q, want %q", got, "test-session-id")
}
// Request ID should be a UUID
if got := req.Header.Get("x-client-request-id"); got == "" {
t.Error("expected non-empty x-client-request-id")
}
// Non-stream: application/json
if got := req.Header.Get("Accept"); got != "application/json" {
t.Errorf("Accept = %q, want %q", got, "application/json")
}
// Accept-Encoding always identity
if got := req.Header.Get("Accept-Encoding"); got != "identity" {
t.Errorf("Accept-Encoding = %q, want %q", got, "identity")
}
}
func TestApplyHeaders_NilProfile_NonOAuth_Stream(t *testing.T) {
uc := &UpstreamClient{
sessionID: "sess",
profile: nil,
}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-api123", true)
if got := req.Header.Get("Accept"); got != "text/event-stream" {
t.Errorf("Accept = %q, want %q", got, "text/event-stream")
}
}
func TestApplyHeaders_OAuthToken_SetsBearerAndBetaFlag(t *testing.T) {
uc := &UpstreamClient{
sessionID: "sess",
profile: nil,
}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-oat-mytoken", false)
// OAuth: Authorization Bearer
if got := req.Header.Get("Authorization"); got != "Bearer sk-ant-oat-mytoken" {
t.Errorf("Authorization = %q, want %q", got, "Bearer sk-ant-oat-mytoken")
}
// Should NOT have x-api-key
if got := req.Header.Get("x-api-key"); got != "" {
t.Errorf("x-api-key = %q, want empty for OAuth", got)
}
// anthropic-beta should include oauth-2025-04-20
if got := req.Header.Get("anthropic-beta"); got != "oauth-2025-04-20" {
t.Errorf("anthropic-beta = %q, want %q", got, "oauth-2025-04-20")
}
}
func TestApplyHeaders_OAuthToken_AppendsToExistingBeta(t *testing.T) {
uc := &UpstreamClient{
sessionID: "sess",
profile: &SniffedProfile{
Headers: [][2]string{
{"anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15"},
},
},
}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-oat-tok", false)
beta := req.Header.Get("anthropic-beta")
if !strings.Contains(beta, "max-tokens-3-5-sonnet-2024-07-15") {
t.Errorf("anthropic-beta %q should contain existing beta", beta)
}
if !strings.Contains(beta, "oauth-2025-04-20") {
t.Errorf("anthropic-beta %q should contain oauth flag", beta)
}
// Should be appended with comma
if beta != "max-tokens-3-5-sonnet-2024-07-15,oauth-2025-04-20" {
t.Errorf("anthropic-beta = %q, want %q", beta, "max-tokens-3-5-sonnet-2024-07-15,oauth-2025-04-20")
}
}
func TestApplyHeaders_OAuthToken_ExistingBetaAlreadyHasOAuth(t *testing.T) {
uc := &UpstreamClient{
sessionID: "sess",
profile: &SniffedProfile{
Headers: [][2]string{
{"anthropic-beta", "oauth-2025-04-20,something-else"},
},
},
}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-oat-tok", false)
beta := req.Header.Get("anthropic-beta")
// Should NOT duplicate oauth flag
count := strings.Count(beta, "oauth-2025-04-20")
if count != 1 {
t.Errorf("oauth flag appeared %d times in %q, want 1", count, beta)
}
}
func TestApplyHeaders_WithProfile_ReplaysHeaders(t *testing.T) {
uc := &UpstreamClient{
sessionID: "sess",
profile: &SniffedProfile{
Headers: [][2]string{
{"User-Agent", "Claude/1.0"},
{"anthropic-version", "2023-06-01"},
{"Custom-Header", "custom-value"},
},
},
}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-api123", false)
if got := req.Header.Get("User-Agent"); got != "Claude/1.0" {
t.Errorf("User-Agent = %q, want %q", got, "Claude/1.0")
}
if got := req.Header.Get("anthropic-version"); got != "2023-06-01" {
t.Errorf("anthropic-version = %q, want %q", got, "2023-06-01")
}
if got := req.Header.Get("Custom-Header"); got != "custom-value" {
t.Errorf("Custom-Header = %q, want %q", got, "custom-value")
}
}
func TestApplyHeaders_ProfileAuthHeadersRemoved(t *testing.T) {
uc := &UpstreamClient{
sessionID: "sess",
profile: &SniffedProfile{
Headers: [][2]string{
{"Authorization", "Bearer old-token"},
{"x-api-key", "old-api-key"},
{"User-Agent", "test"},
},
},
}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-api-new", false)
// Old auth headers from profile should be removed
if got := req.Header.Get("Authorization"); got != "" {
t.Errorf("Authorization should be empty for non-OAuth, got %q", got)
}
// New auth should be set via x-api-key
if got := req.Header.Get("x-api-key"); got != "sk-ant-api-new" {
t.Errorf("x-api-key = %q, want %q", got, "sk-ant-api-new")
}
// User-Agent from profile should remain
if got := req.Header.Get("User-Agent"); got != "test" {
t.Errorf("User-Agent = %q, want %q", got, "test")
}
}
func TestApplyHeaders_ProfileAuthHeadersRemovedForOAuth(t *testing.T) {
uc := &UpstreamClient{
sessionID: "sess",
profile: &SniffedProfile{
Headers: [][2]string{
{"Authorization", "Bearer old-token"},
{"x-api-key", "old-api-key"},
},
},
}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-oat-new", false)
// Old x-api-key removed
if got := req.Header.Get("x-api-key"); got != "" {
t.Errorf("x-api-key should be empty for OAuth, got %q", got)
}
// New auth set via Authorization
if got := req.Header.Get("Authorization"); got != "Bearer sk-ant-oat-new" {
t.Errorf("Authorization = %q, want %q", got, "Bearer sk-ant-oat-new")
}
}
func TestApplyHeaders_AcceptEncoding_AlwaysIdentity(t *testing.T) {
tests := []struct {
name string
streaming bool
}{
{"non-stream", false},
{"stream", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
uc := &UpstreamClient{sessionID: "s", profile: nil}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "token", tt.streaming)
if got := req.Header.Get("Accept-Encoding"); got != "identity" {
t.Errorf("Accept-Encoding = %q, want %q", got, "identity")
}
})
}
}
func TestApplyHeaders_UniqueRequestIDs(t *testing.T) {
uc := &UpstreamClient{sessionID: "s", profile: nil}
req1, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req1, "tok", false)
req2, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req2, "tok", false)
id1 := req1.Header.Get("x-client-request-id")
id2 := req2.Header.Get("x-client-request-id")
if id1 == "" || id2 == "" {
t.Fatal("expected non-empty request IDs")
}
if id1 == id2 {
t.Errorf("expected unique request IDs, both got %q", id1)
}
}
func TestApplyHeaders_NonOAuth_NoAnthroPicBetaSet(t *testing.T) {
// Non-OAuth tokens should NOT set anthropic-beta oauth flag
uc := &UpstreamClient{sessionID: "s", profile: nil}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-api123", false)
beta := req.Header.Get("anthropic-beta")
if strings.Contains(beta, "oauth-2025-04-20") {
t.Errorf("non-OAuth token should not have oauth beta flag, got %q", beta)
}
}
func TestApplyHeaders_OAuthToken_FreshBeta(t *testing.T) {
// No profile, no existing beta — should set fresh
uc := &UpstreamClient{sessionID: "s", profile: nil}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-oat-tok", false)
if got := req.Header.Get("anthropic-beta"); got != "oauth-2025-04-20" {
t.Errorf("anthropic-beta = %q, want %q", got, "oauth-2025-04-20")
}
}
+166
View File
@@ -0,0 +1,166 @@
package ratelimit
import (
"context"
"net/http"
"strconv"
"sync"
"time"
"github.com/rs/zerolog/log"
)
// Window holds per-window usage state.
type Window struct {
Utilization float64 // 0-100 from API
ResetsAt time.Time // when window resets
}
// Snapshot is a read-only copy of a Window's state.
type Snapshot struct {
Utilization float64
ResetsAt time.Time
}
// Tracker polls /api/oauth/usage and tracks per-window utilization.
type Tracker struct {
tokenFn func() string
mu sync.RWMutex
fiveHour Window
sevenDay Window
sonnet Window
extra ExtraUsage
}
// NewTracker creates a tracker. tokenFn should return the current access token.
func NewTracker(tokenFn func() string) *Tracker {
return &Tracker{tokenFn: tokenFn}
}
// Start begins the background poll loop.
func (t *Tracker) Start(ctx context.Context) {
go func() {
t.poll(ctx)
for {
select {
case <-ctx.Done():
return
case <-time.After(5 * time.Minute):
t.poll(ctx)
}
}
}()
}
// UpdateFromHeaders extracts rate limit data from /v1/messages response headers.
func (t *Tracker) UpdateFromHeaders(h http.Header) {
t.mu.Lock()
defer t.mu.Unlock()
if v := h.Get("Anthropic-Ratelimit-Unified-5h-Utilization"); v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil {
t.fiveHour.Utilization = f * 100
}
}
if v := h.Get("Anthropic-Ratelimit-Unified-5h-Reset"); v != "" {
if ts, err := strconv.ParseInt(v, 10, 64); err == nil {
t.fiveHour.ResetsAt = time.Unix(ts, 0).UTC().Truncate(time.Minute)
}
}
if v := h.Get("Anthropic-Ratelimit-Unified-7d-Utilization"); v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil {
t.sevenDay.Utilization = f * 100
}
}
if v := h.Get("Anthropic-Ratelimit-Unified-7d-Reset"); v != "" {
if ts, err := strconv.ParseInt(v, 10, 64); err == nil {
t.sevenDay.ResetsAt = time.Unix(ts, 0).UTC().Truncate(time.Minute)
}
}
}
// FiveHour returns a snapshot of the 5-hour window.
func (t *Tracker) FiveHour() Snapshot {
t.mu.RLock()
defer t.mu.RUnlock()
return Snapshot{Utilization: t.fiveHour.Utilization, ResetsAt: t.fiveHour.ResetsAt}
}
// SevenDay returns a snapshot of the 7-day window.
func (t *Tracker) SevenDay() Snapshot {
t.mu.RLock()
defer t.mu.RUnlock()
return Snapshot{Utilization: t.sevenDay.Utilization, ResetsAt: t.sevenDay.ResetsAt}
}
// Sonnet returns a snapshot of the 7-day sonnet window.
func (t *Tracker) Sonnet() Snapshot {
t.mu.RLock()
defer t.mu.RUnlock()
return Snapshot{Utilization: t.sonnet.Utilization, ResetsAt: t.sonnet.ResetsAt}
}
// Extra returns the current extra usage state.
func (t *Tracker) Extra() ExtraUsage {
t.mu.RLock()
defer t.mu.RUnlock()
return t.extra
}
func (t *Tracker) poll(ctx context.Context) {
token := t.tokenFn()
if token == "" {
return
}
usage, err := fetchUsage(ctx, token)
if err != nil {
log.Warn().Err(err).Msg("usage poll failed")
return
}
t.mu.Lock()
defer t.mu.Unlock()
if usage.FiveHour != nil {
t.updateWindow(&t.fiveHour, usage.FiveHour)
}
if usage.SevenDay != nil {
t.updateWindow(&t.sevenDay, usage.SevenDay)
}
if usage.SevenDaySonnet != nil {
t.updateWindow(&t.sonnet, usage.SevenDaySonnet)
}
if usage.ExtraUsage != nil {
t.extra = *usage.ExtraUsage
}
log.Debug().
Float64("5h_util", t.fiveHour.Utilization).
Time("5h_resets", t.fiveHour.ResetsAt).
Float64("7d_util", t.sevenDay.Utilization).
Time("7d_resets", t.sevenDay.ResetsAt).
Msg("usage poll")
if t.fiveHour.Utilization > 80 {
log.Warn().Float64("utilization", t.fiveHour.Utilization).Time("resets_at", t.fiveHour.ResetsAt).Msg("5h window utilization high")
}
if t.sevenDay.Utilization > 80 {
log.Warn().Float64("utilization", t.sevenDay.Utilization).Time("resets_at", t.sevenDay.ResetsAt).Msg("7d window utilization high")
}
}
func (t *Tracker) updateWindow(w *Window, rl *RateLimit) {
if rl.Utilization != nil {
w.Utilization = *rl.Utilization
}
if rl.ResetsAt != nil {
parsed, err := time.Parse(time.RFC3339Nano, *rl.ResetsAt)
if err != nil {
parsed, err = time.Parse(time.RFC3339, *rl.ResetsAt)
}
if err == nil {
w.ResetsAt = parsed.UTC().Truncate(time.Minute)
}
}
}
+278
View File
@@ -0,0 +1,278 @@
package ratelimit
import (
"net/http"
"testing"
"time"
)
func TestNewTracker(t *testing.T) {
called := false
tr := NewTracker(func() string {
called = true
return "tok"
})
if tr == nil {
t.Fatal("NewTracker returned nil")
}
// tokenFn stored but not called during construction
if called {
t.Error("tokenFn should not be called by NewTracker")
}
// Invoke to verify it's wired
if got := tr.tokenFn(); got != "tok" {
t.Errorf("tokenFn() = %q, want tok", got)
}
}
func TestUpdateFromHeaders_Full(t *testing.T) {
tr := NewTracker(func() string { return "" })
h := http.Header{}
h.Set("Anthropic-Ratelimit-Unified-5h-Utilization", "0.42")
h.Set("Anthropic-Ratelimit-Unified-5h-Reset", "1700000000")
h.Set("Anthropic-Ratelimit-Unified-7d-Utilization", "0.75")
h.Set("Anthropic-Ratelimit-Unified-7d-Reset", "1700100000")
tr.UpdateFromHeaders(h)
fh := tr.FiveHour()
if fh.Utilization != 42.0 {
t.Errorf("FiveHour.Utilization = %f, want 42.0", fh.Utilization)
}
wantReset5h := time.Unix(1700000000, 0).UTC().Truncate(time.Minute)
if !fh.ResetsAt.Equal(wantReset5h) {
t.Errorf("FiveHour.ResetsAt = %v, want %v", fh.ResetsAt, wantReset5h)
}
sd := tr.SevenDay()
if sd.Utilization != 75.0 {
t.Errorf("SevenDay.Utilization = %f, want 75.0", sd.Utilization)
}
wantReset7d := time.Unix(1700100000, 0).UTC().Truncate(time.Minute)
if !sd.ResetsAt.Equal(wantReset7d) {
t.Errorf("SevenDay.ResetsAt = %v, want %v", sd.ResetsAt, wantReset7d)
}
}
func TestUpdateFromHeaders_Partial(t *testing.T) {
tr := NewTracker(func() string { return "" })
// Only set 5h utilization, no reset, no 7d
h := http.Header{}
h.Set("Anthropic-Ratelimit-Unified-5h-Utilization", "0.33")
tr.UpdateFromHeaders(h)
fh := tr.FiveHour()
if fh.Utilization != 33.0 {
t.Errorf("FiveHour.Utilization = %f, want 33.0", fh.Utilization)
}
if !fh.ResetsAt.IsZero() {
t.Errorf("FiveHour.ResetsAt should be zero, got %v", fh.ResetsAt)
}
sd := tr.SevenDay()
if sd.Utilization != 0 {
t.Errorf("SevenDay.Utilization = %f, want 0", sd.Utilization)
}
}
func TestUpdateFromHeaders_Missing(t *testing.T) {
tr := NewTracker(func() string { return "" })
// Pre-set some state
tr.mu.Lock()
tr.fiveHour.Utilization = 50.0
tr.mu.Unlock()
// Update with empty headers — should not change state
tr.UpdateFromHeaders(http.Header{})
fh := tr.FiveHour()
if fh.Utilization != 50.0 {
t.Errorf("FiveHour.Utilization = %f, want 50.0 (unchanged)", fh.Utilization)
}
}
func TestUpdateFromHeaders_InvalidValues(t *testing.T) {
tr := NewTracker(func() string { return "" })
h := http.Header{}
h.Set("Anthropic-Ratelimit-Unified-5h-Utilization", "not-a-number")
h.Set("Anthropic-Ratelimit-Unified-5h-Reset", "not-a-timestamp")
tr.UpdateFromHeaders(h)
fh := tr.FiveHour()
if fh.Utilization != 0 {
t.Errorf("Utilization should stay 0 for invalid input, got %f", fh.Utilization)
}
if !fh.ResetsAt.IsZero() {
t.Errorf("ResetsAt should stay zero for invalid input, got %v", fh.ResetsAt)
}
}
func TestSonnet_Snapshot(t *testing.T) {
tr := NewTracker(func() string { return "" })
// Sonnet is only set via poll/updateWindow, not UpdateFromHeaders
// Verify it starts at zero
s := tr.Sonnet()
if s.Utilization != 0 {
t.Errorf("Sonnet.Utilization = %f, want 0", s.Utilization)
}
if !s.ResetsAt.IsZero() {
t.Errorf("Sonnet.ResetsAt should be zero, got %v", s.ResetsAt)
}
}
func TestExtra_Default(t *testing.T) {
tr := NewTracker(func() string { return "" })
extra := tr.Extra()
if extra.IsEnabled {
t.Error("Extra.IsEnabled should be false by default")
}
if extra.MonthlyLimit != nil {
t.Error("Extra.MonthlyLimit should be nil by default")
}
}
func TestUpdateWindow(t *testing.T) {
tr := NewTracker(func() string { return "" })
tests := []struct {
name string
util *float64
resetsAt *string
wantUtil float64
wantResetOK bool
}{
{
name: "both fields",
util: float64Ptr(65.5),
resetsAt: stringPtr("2024-01-15T10:30:45Z"),
wantUtil: 65.5,
wantResetOK: true,
},
{
name: "utilization only",
util: float64Ptr(30.0),
resetsAt: nil,
wantUtil: 30.0,
wantResetOK: false,
},
{
name: "reset only (RFC3339Nano)",
util: nil,
resetsAt: stringPtr("2024-06-01T12:00:00.123456789Z"),
wantUtil: 0,
wantResetOK: true,
},
{
name: "nil both",
util: nil,
resetsAt: nil,
wantUtil: 0,
wantResetOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &Window{}
rl := &RateLimit{
Utilization: tt.util,
ResetsAt: tt.resetsAt,
}
tr.updateWindow(w, rl)
if w.Utilization != tt.wantUtil {
t.Errorf("Utilization = %f, want %f", w.Utilization, tt.wantUtil)
}
if tt.wantResetOK {
if w.ResetsAt.IsZero() {
t.Error("ResetsAt should be set")
}
// Verify truncation to minute
if w.ResetsAt.Second() != 0 || w.ResetsAt.Nanosecond() != 0 {
t.Errorf("ResetsAt not truncated to minute: %v", w.ResetsAt)
}
if w.ResetsAt.Location() != time.UTC {
t.Errorf("ResetsAt not in UTC: %v", w.ResetsAt.Location())
}
} else if tt.resetsAt == nil {
if !w.ResetsAt.IsZero() {
t.Errorf("ResetsAt should be zero when input is nil, got %v", w.ResetsAt)
}
}
})
}
}
func TestUpdateWindow_InvalidTime(t *testing.T) {
tr := NewTracker(func() string { return "" })
w := &Window{}
bad := "not-a-time"
rl := &RateLimit{ResetsAt: &bad}
tr.updateWindow(w, rl)
if !w.ResetsAt.IsZero() {
t.Errorf("ResetsAt should stay zero for invalid time, got %v", w.ResetsAt)
}
}
func TestPoll_SetsStateFromUsageResponse(t *testing.T) {
// White-box: directly set fields that poll would set after fetchUsage
tr := NewTracker(func() string { return "" })
// Simulate what poll does after fetching usage
tr.mu.Lock()
usage := &UsageResponse{
FiveHour: &RateLimit{Utilization: float64Ptr(55.5), ResetsAt: stringPtr("2024-03-01T08:00:00Z")},
SevenDay: &RateLimit{Utilization: float64Ptr(22.3), ResetsAt: stringPtr("2024-03-07T00:00:00Z")},
SevenDaySonnet: &RateLimit{Utilization: float64Ptr(10.0), ResetsAt: stringPtr("2024-03-07T00:00:00Z")},
ExtraUsage: &ExtraUsage{IsEnabled: true, MonthlyLimit: float64Ptr(100.0), UsedCredits: float64Ptr(42.5)},
}
if usage.FiveHour != nil {
tr.updateWindow(&tr.fiveHour, usage.FiveHour)
}
if usage.SevenDay != nil {
tr.updateWindow(&tr.sevenDay, usage.SevenDay)
}
if usage.SevenDaySonnet != nil {
tr.updateWindow(&tr.sonnet, usage.SevenDaySonnet)
}
if usage.ExtraUsage != nil {
tr.extra = *usage.ExtraUsage
}
tr.mu.Unlock()
fh := tr.FiveHour()
if fh.Utilization != 55.5 {
t.Errorf("FiveHour.Utilization = %f, want 55.5", fh.Utilization)
}
sd := tr.SevenDay()
if sd.Utilization != 22.3 {
t.Errorf("SevenDay.Utilization = %f, want 22.3", sd.Utilization)
}
sn := tr.Sonnet()
if sn.Utilization != 10.0 {
t.Errorf("Sonnet.Utilization = %f, want 10.0", sn.Utilization)
}
extra := tr.Extra()
if !extra.IsEnabled {
t.Error("Extra.IsEnabled = false, want true")
}
if extra.MonthlyLimit == nil || *extra.MonthlyLimit != 100.0 {
t.Errorf("Extra.MonthlyLimit = %v, want 100.0", extra.MonthlyLimit)
}
if extra.UsedCredits == nil || *extra.UsedCredits != 42.5 {
t.Errorf("Extra.UsedCredits = %v, want 42.5", extra.UsedCredits)
}
}
func float64Ptr(f float64) *float64 { return &f }
func stringPtr(s string) *string { return &s }
+67
View File
@@ -0,0 +1,67 @@
package ratelimit
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/fujin/anthropic-proxy/internal/transport"
"github.com/fujin/anthropic-proxy/internal/version"
)
var usageClient = transport.NewHTTPClient(10 * time.Second)
const usageURL = "https://api.anthropic.com/api/oauth/usage"
type RateLimit struct {
Utilization *float64 `json:"utilization"` // 0-100
ResetsAt *string `json:"resets_at"` // ISO 8601
}
type ExtraUsage struct {
IsEnabled bool `json:"is_enabled"`
MonthlyLimit *float64 `json:"monthly_limit"`
UsedCredits *float64 `json:"used_credits"`
Utilization *float64 `json:"utilization"`
}
type UsageResponse struct {
FiveHour *RateLimit `json:"five_hour"`
SevenDay *RateLimit `json:"seven_day"`
SevenDaySonnet *RateLimit `json:"seven_day_sonnet"`
ExtraUsage *ExtraUsage `json:"extra_usage"`
}
func fetchUsage(ctx context.Context, token string) (*UsageResponse, error) {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, usageURL, nil)
if err != nil {
return nil, fmt.Errorf("build request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
req.Header.Set("User-Agent", "claude-cli/"+version.ClaudeCodeFallback)
resp, err := usageClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("usage returned %d: %s", resp.StatusCode, string(body))
}
var usage UsageResponse
if err := json.Unmarshal(body, &usage); err != nil {
return nil, fmt.Errorf("decode: %w", err)
}
return &usage, nil
}
+241
View File
@@ -0,0 +1,241 @@
package ratelimit
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestUsageResponse_FullJSON(t *testing.T) {
raw := `{
"five_hour": {"utilization": 42.5, "resets_at": "2024-01-15T10:30:00Z"},
"seven_day": {"utilization": 75.0, "resets_at": "2024-01-20T00:00:00Z"},
"seven_day_sonnet": {"utilization": 10.0, "resets_at": "2024-01-20T00:00:00Z"},
"extra_usage": {
"is_enabled": true,
"monthly_limit": 100.0,
"used_credits": 42.5,
"utilization": 42.5
}
}`
var resp UsageResponse
if err := json.Unmarshal([]byte(raw), &resp); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if resp.FiveHour == nil {
t.Fatal("FiveHour is nil")
}
if resp.FiveHour.Utilization == nil || *resp.FiveHour.Utilization != 42.5 {
t.Errorf("FiveHour.Utilization = %v, want 42.5", resp.FiveHour.Utilization)
}
if resp.FiveHour.ResetsAt == nil || *resp.FiveHour.ResetsAt != "2024-01-15T10:30:00Z" {
t.Errorf("FiveHour.ResetsAt = %v", resp.FiveHour.ResetsAt)
}
if resp.SevenDay == nil {
t.Fatal("SevenDay is nil")
}
if resp.SevenDay.Utilization == nil || *resp.SevenDay.Utilization != 75.0 {
t.Errorf("SevenDay.Utilization = %v, want 75.0", resp.SevenDay.Utilization)
}
if resp.SevenDaySonnet == nil {
t.Fatal("SevenDaySonnet is nil")
}
if resp.SevenDaySonnet.Utilization == nil || *resp.SevenDaySonnet.Utilization != 10.0 {
t.Errorf("SevenDaySonnet.Utilization = %v", resp.SevenDaySonnet.Utilization)
}
if resp.ExtraUsage == nil {
t.Fatal("ExtraUsage is nil")
}
if !resp.ExtraUsage.IsEnabled {
t.Error("ExtraUsage.IsEnabled = false, want true")
}
if resp.ExtraUsage.MonthlyLimit == nil || *resp.ExtraUsage.MonthlyLimit != 100.0 {
t.Errorf("ExtraUsage.MonthlyLimit = %v, want 100.0", resp.ExtraUsage.MonthlyLimit)
}
if resp.ExtraUsage.UsedCredits == nil || *resp.ExtraUsage.UsedCredits != 42.5 {
t.Errorf("ExtraUsage.UsedCredits = %v, want 42.5", resp.ExtraUsage.UsedCredits)
}
if resp.ExtraUsage.Utilization == nil || *resp.ExtraUsage.Utilization != 42.5 {
t.Errorf("ExtraUsage.Utilization = %v, want 42.5", resp.ExtraUsage.Utilization)
}
}
func TestUsageResponse_PartialJSON(t *testing.T) {
raw := `{"five_hour": {"utilization": 10.0}}`
var resp UsageResponse
if err := json.Unmarshal([]byte(raw), &resp); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if resp.FiveHour == nil {
t.Fatal("FiveHour is nil")
}
if resp.FiveHour.Utilization == nil || *resp.FiveHour.Utilization != 10.0 {
t.Errorf("FiveHour.Utilization = %v, want 10.0", resp.FiveHour.Utilization)
}
if resp.FiveHour.ResetsAt != nil {
t.Errorf("FiveHour.ResetsAt should be nil, got %v", resp.FiveHour.ResetsAt)
}
if resp.SevenDay != nil {
t.Errorf("SevenDay should be nil, got %v", resp.SevenDay)
}
if resp.SevenDaySonnet != nil {
t.Errorf("SevenDaySonnet should be nil, got %v", resp.SevenDaySonnet)
}
if resp.ExtraUsage != nil {
t.Errorf("ExtraUsage should be nil, got %v", resp.ExtraUsage)
}
}
func TestUsageResponse_EmptyJSON(t *testing.T) {
var resp UsageResponse
if err := json.Unmarshal([]byte(`{}`), &resp); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if resp.FiveHour != nil || resp.SevenDay != nil || resp.SevenDaySonnet != nil || resp.ExtraUsage != nil {
t.Error("all fields should be nil for empty JSON")
}
}
func TestFetchUsage_Success(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request headers
if got := r.Header.Get("Authorization"); got != "Bearer test-token" {
t.Errorf("Authorization = %q, want 'Bearer test-token'", got)
}
if got := r.Header.Get("Content-Type"); got != "application/json" {
t.Errorf("Content-Type = %q, want application/json", got)
}
if got := r.Header.Get("anthropic-beta"); got != "oauth-2025-04-20" {
t.Errorf("anthropic-beta = %q, want oauth-2025-04-20", got)
}
if got := r.Header.Get("User-Agent"); got != "claude-cli/2.1.92" {
t.Errorf("User-Agent = %q, want claude-cli/2.1.92", got)
}
if r.Method != http.MethodGet {
t.Errorf("Method = %q, want GET", r.Method)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{
"five_hour": {"utilization": 50.0, "resets_at": "2024-01-15T10:00:00Z"},
"seven_day": {"utilization": 25.0, "resets_at": "2024-01-20T00:00:00Z"}
}`))
}))
defer srv.Close()
// fetchUsage hardcodes usageURL, but we can test via the mock by temporarily
// using http.DefaultClient's transport. Instead, we test the handler directly.
// The httptest server validates our request expectations above.
// Make a real request to the test server to verify handler behavior
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
req.Header.Set("Authorization", "Bearer test-token")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
req.Header.Set("User-Agent", "claude-cli/2.1.92")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
var usage UsageResponse
if err := json.NewDecoder(resp.Body).Decode(&usage); err != nil {
t.Fatalf("decode: %v", err)
}
if usage.FiveHour == nil || *usage.FiveHour.Utilization != 50.0 {
t.Errorf("FiveHour.Utilization = %v, want 50.0", usage.FiveHour)
}
if usage.SevenDay == nil || *usage.SevenDay.Utilization != 25.0 {
t.Errorf("SevenDay.Utilization = %v, want 25.0", usage.SevenDay)
}
}
func TestFetchUsage_Non200(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`{"error":"forbidden"}`))
}))
defer srv.Close()
// Simulate the error path: non-200 returns error with status and body
resp, err := http.Get(srv.URL)
if err != nil {
t.Fatalf("request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
t.Fatal("expected non-200 status")
}
// This matches the fetchUsage error format
body := make([]byte, 1024)
n, _ := resp.Body.Read(body)
bodyStr := string(body[:n])
if !strings.Contains(bodyStr, "forbidden") {
t.Errorf("body = %q, want it to contain 'forbidden'", bodyStr)
}
}
func TestFetchUsage_MalformedJSON(t *testing.T) {
raw := `{not valid json`
var resp UsageResponse
err := json.Unmarshal([]byte(raw), &resp)
if err == nil {
t.Fatal("expected decode error for malformed JSON")
}
}
func TestRateLimit_NilFields(t *testing.T) {
raw := `{}`
var rl RateLimit
if err := json.Unmarshal([]byte(raw), &rl); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if rl.Utilization != nil {
t.Errorf("Utilization should be nil, got %v", rl.Utilization)
}
if rl.ResetsAt != nil {
t.Errorf("ResetsAt should be nil, got %v", rl.ResetsAt)
}
}
func TestExtraUsage_JSON(t *testing.T) {
raw := `{"is_enabled":false,"monthly_limit":null,"used_credits":null,"utilization":null}`
var eu ExtraUsage
if err := json.Unmarshal([]byte(raw), &eu); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if eu.IsEnabled {
t.Error("IsEnabled should be false")
}
if eu.MonthlyLimit != nil {
t.Error("MonthlyLimit should be nil")
}
if eu.UsedCredits != nil {
t.Error("UsedCredits should be nil")
}
if eu.Utilization != nil {
t.Error("Utilization should be nil")
}
}
func TestUsageURL_Constant(t *testing.T) {
if usageURL != "https://api.anthropic.com/api/oauth/usage" {
t.Errorf("usageURL = %q, want https://api.anthropic.com/api/oauth/usage", usageURL)
}
}
+24 -8
View File
@@ -3,16 +3,19 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"net/http" "net/http"
"strings" "strings"
"sync/atomic" "sync/atomic"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"github.com/fujin/anthropic-proxy/internal/auth" "github.com/fujin/anthropic-proxy/internal/auth"
"github.com/fujin/anthropic-proxy/internal/config" "github.com/fujin/anthropic-proxy/internal/config"
"github.com/fujin/anthropic-proxy/internal/logging"
"github.com/fujin/anthropic-proxy/internal/proxy" "github.com/fujin/anthropic-proxy/internal/proxy"
"github.com/fujin/anthropic-proxy/internal/ratelimit"
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
) )
type Server struct { type Server struct {
@@ -23,7 +26,7 @@ type Server struct {
apiKeys atomic.Pointer[map[string]struct{}] apiKeys atomic.Pointer[map[string]struct{}]
} }
func New(cfg *config.Config, pool *auth.Pool, profile *proxy.SniffedProfile) *Server { func New(cfg *config.Config, pool *auth.Pool, profile *proxy.SniffedProfile, tracker *ratelimit.Tracker, metricsHandler http.Handler) *Server {
s := &Server{configPath: "config.yaml"} s := &Server{configPath: "config.yaml"}
san := proxy.NewSanitizer(cfg.Sanitize) san := proxy.NewSanitizer(cfg.Sanitize)
@@ -36,21 +39,29 @@ func New(cfg *config.Config, pool *auth.Pool, profile *proxy.SniffedProfile) *Se
engine := gin.New() engine := gin.New()
engine.Use(gin.Recovery()) engine.Use(gin.Recovery())
engine.Use(corsMiddleware()) engine.Use(corsMiddleware())
if cfg.Telemetry.Export.Enabled() {
engine.Use(otelgin.Middleware(cfg.Telemetry.ServiceName))
}
engine.Use(s.authMiddleware()) engine.Use(s.authMiddleware())
engine.Use(logging.GinRequestLogger())
handler := proxy.HandleMessages(pool, profile, func() *proxy.Sanitizer { handler := proxy.HandleMessages(pool, profile, func() *proxy.Sanitizer {
return s.sanitizer.Load() return s.sanitizer.Load()
}) }, tracker)
engine.POST("/v1/messages", handler) engine.POST("/v1/messages", handler)
engine.POST("/messages", handler) engine.POST("/messages", handler)
if metricsHandler != nil {
engine.GET("/metrics", gin.WrapH(metricsHandler))
}
engine.POST("/reload", s.handleReload()) engine.POST("/reload", s.handleReload())
engine.POST("/debug/refresh", handleDebugRefresh(pool)) engine.POST("/debug/refresh", handleDebugRefresh(pool))
engine.GET("/healthz", func(c *gin.Context) { engine.GET("/healthz", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"}) c.JSON(http.StatusOK, gin.H{"status": "ok"})
}) })
engine.NoRoute(func(c *gin.Context) { engine.NoRoute(func(c *gin.Context) {
log.Printf("unmatched route: %s %s", c.Request.Method, c.Request.URL.Path) log.Warn().Str("method", c.Request.Method).Str("path", c.Request.URL.Path).Msg("unmatched route")
c.JSON(http.StatusNotFound, gin.H{"error": "not found"}) c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
}) })
@@ -85,8 +96,7 @@ func (s *Server) handleReload() gin.HandlerFunc {
keys := makeKeySet(cfg.APIKeys) keys := makeKeySet(cfg.APIKeys)
s.apiKeys.Store(&keys) s.apiKeys.Store(&keys)
log.Printf("config reloaded: %d tool renames, %d system rules, %d body rules, %d api keys", log.Info().Int("tool_renames", len(cfg.Sanitize.Tools)).Int("system_rules", len(cfg.Sanitize.System)).Int("body_rules", len(cfg.Sanitize.Body)).Int("api_keys", len(cfg.APIKeys)).Msg("config reloaded")
len(cfg.Sanitize.Tools), len(cfg.Sanitize.System), len(cfg.Sanitize.Body), len(cfg.APIKeys))
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"status": "reloaded", "status": "reloaded",
@@ -128,10 +138,16 @@ func corsMiddleware() gin.HandlerFunc {
} }
} }
// authBypassPaths lists endpoints that do not require API key authentication.
var authBypassPaths = map[string]bool{
"/healthz": true,
"/reload": true,
"/metrics": true,
}
func (s *Server) authMiddleware() gin.HandlerFunc { func (s *Server) authMiddleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
path := c.Request.URL.Path if authBypassPaths[c.Request.URL.Path] {
if path == "/healthz" || path == "/reload" {
c.Next() c.Next()
return return
} }
+529
View File
@@ -0,0 +1,529 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"sync/atomic"
"testing"
"github.com/gin-gonic/gin"
)
func init() {
gin.SetMode(gin.TestMode)
}
// --- makeKeySet ---
func TestMakeKeySet(t *testing.T) {
tests := []struct {
name string
keys []string
wantN int
lookup string
found bool
}{
{
name: "nil slice returns empty map",
keys: nil,
wantN: 0,
},
{
name: "empty slice returns empty map",
keys: []string{},
wantN: 0,
},
{
name: "single key",
keys: []string{"key1"},
wantN: 1,
lookup: "key1",
found: true,
},
{
name: "multiple keys",
keys: []string{"a", "b", "c"},
wantN: 3,
lookup: "b",
found: true,
},
{
name: "missing key not found",
keys: []string{"a", "b"},
wantN: 2,
lookup: "c",
found: false,
},
{
name: "duplicate keys deduped",
keys: []string{"x", "x", "x"},
wantN: 1,
lookup: "x",
found: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := makeKeySet(tt.keys)
if len(got) != tt.wantN {
t.Errorf("len(makeKeySet) = %d, want %d", len(got), tt.wantN)
}
if tt.lookup != "" {
_, ok := got[tt.lookup]
if ok != tt.found {
t.Errorf("keySet[%q] found=%v, want %v", tt.lookup, ok, tt.found)
}
}
})
}
}
// --- corsMiddleware ---
func TestCorsMiddleware_SetsHeaders(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
handler := corsMiddleware()
handler(c)
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
t.Errorf("Access-Control-Allow-Origin = %q, want %q", got, "*")
}
if got := w.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, PUT, DELETE, OPTIONS" {
t.Errorf("Access-Control-Allow-Methods = %q, want %q", got, "GET, POST, PUT, DELETE, OPTIONS")
}
allowHeaders := w.Header().Get("Access-Control-Allow-Headers")
for _, h := range []string{"x-api-key", "anthropic-version", "anthropic-beta", "Authorization", "Content-Type", "Origin"} {
if !containsSubstring(allowHeaders, h) {
t.Errorf("Access-Control-Allow-Headers %q missing %q", allowHeaders, h)
}
}
}
func TestCorsMiddleware_OptionsReturns204(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodOptions, "/v1/messages", nil)
handler := corsMiddleware()
handler(c)
if w.Code != http.StatusNoContent {
t.Errorf("OPTIONS status = %d, want %d", w.Code, http.StatusNoContent)
}
if !c.IsAborted() {
t.Error("expected context to be aborted on OPTIONS")
}
}
func TestCorsMiddleware_NonOptionsDoesNotAbort(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
handler := corsMiddleware()
handler(c)
if c.IsAborted() {
t.Error("POST request should not be aborted")
}
}
// --- authMiddleware ---
func newServerWithKeys(keys []string) *Server {
s := &Server{}
keySet := makeKeySet(keys)
s.apiKeys.Store(&keySet)
return s
}
func TestAuthMiddleware_BypassPaths(t *testing.T) {
paths := []string{"/healthz", "/reload", "/metrics"}
s := newServerWithKeys(nil) // no keys — would reject if auth checked
for _, path := range paths {
t.Run(path, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, path, nil)
handler := s.authMiddleware()
handler(c)
if c.IsAborted() {
t.Errorf("path %q should bypass auth but was aborted", path)
}
})
}
}
func TestAuthMiddleware_MissingToken_401(t *testing.T) {
s := newServerWithKeys([]string{"valid-key"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
handler := s.authMiddleware()
handler(c)
if w.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", w.Code, http.StatusUnauthorized)
}
if !c.IsAborted() {
t.Error("expected aborted on missing token")
}
var body map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if body["error"] != "missing authentication" {
t.Errorf("error = %q, want %q", body["error"], "missing authentication")
}
}
func TestAuthMiddleware_InvalidKey_403(t *testing.T) {
s := newServerWithKeys([]string{"valid-key"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Request.Header.Set("x-api-key", "wrong-key")
handler := s.authMiddleware()
handler(c)
if w.Code != http.StatusForbidden {
t.Errorf("status = %d, want %d", w.Code, http.StatusForbidden)
}
if !c.IsAborted() {
t.Error("expected aborted on invalid key")
}
var body map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if body["error"] != "invalid api key" {
t.Errorf("error = %q, want %q", body["error"], "invalid api key")
}
}
func TestAuthMiddleware_ValidKey_XApiKey(t *testing.T) {
s := newServerWithKeys([]string{"valid-key"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Request.Header.Set("x-api-key", "valid-key")
handler := s.authMiddleware()
handler(c)
if c.IsAborted() {
t.Error("valid key should not abort")
}
if w.Code == http.StatusUnauthorized || w.Code == http.StatusForbidden {
t.Errorf("unexpected status %d for valid key", w.Code)
}
}
func TestAuthMiddleware_ValidKey_BearerAuth(t *testing.T) {
s := newServerWithKeys([]string{"my-token"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Request.Header.Set("Authorization", "Bearer my-token")
handler := s.authMiddleware()
handler(c)
if c.IsAborted() {
t.Error("valid Bearer token should not abort")
}
}
func TestAuthMiddleware_BearerPrefix_Stripped(t *testing.T) {
// The token is "my-token", sent as "Bearer my-token". The middleware should
// strip "Bearer " and compare "my-token" against the key set.
s := newServerWithKeys([]string{"my-token"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Request.Header.Set("Authorization", "Bearer my-token")
handler := s.authMiddleware()
handler(c)
if c.IsAborted() {
t.Error("expected auth to pass with Bearer-prefixed valid key")
}
}
func TestAuthMiddleware_AuthorizationWithoutBearer(t *testing.T) {
// If Authorization header doesn't have Bearer prefix, TrimPrefix is a no-op,
// so the full header value is used as the token.
s := newServerWithKeys([]string{"raw-token-value"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Request.Header.Set("Authorization", "raw-token-value")
handler := s.authMiddleware()
handler(c)
if c.IsAborted() {
t.Error("raw Authorization value matching a key should pass")
}
}
func TestAuthMiddleware_XApiKey_FallbackWhenNoAuthHeader(t *testing.T) {
// If Authorization is empty, x-api-key is checked.
s := newServerWithKeys([]string{"fallback-key"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Request.Header.Set("x-api-key", "fallback-key")
handler := s.authMiddleware()
handler(c)
if c.IsAborted() {
t.Error("x-api-key fallback should pass")
}
}
func TestAuthMiddleware_AuthorizationPreferredOverXApiKey(t *testing.T) {
// Both headers set; Authorization takes precedence.
s := newServerWithKeys([]string{"auth-key"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Request.Header.Set("Authorization", "Bearer auth-key")
c.Request.Header.Set("x-api-key", "wrong-key")
handler := s.authMiddleware()
handler(c)
if c.IsAborted() {
t.Error("Authorization should take precedence over x-api-key")
}
}
// --- handleReload ---
func TestHandleReload_Success(t *testing.T) {
// Create a temp config file
tmpFile, err := os.CreateTemp("", "config-*.yaml")
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())
configContent := `
port: 9999
api_keys:
- reloaded-key-1
- reloaded-key-2
sanitize:
tools:
- from: old_tool
to: new_tool
system:
- match: foo
replace: bar
body:
- match: baz
replace: qux
`
if _, err := tmpFile.WriteString(configContent); err != nil {
t.Fatalf("failed to write config: %v", err)
}
tmpFile.Close()
s := &Server{configPath: tmpFile.Name()}
// Initialize with empty values
emptyKeys := makeKeySet(nil)
s.apiKeys.Store(&emptyKeys)
emptySan := &atomic.Pointer[interface{}]{}
_ = emptySan // just to show we're aware
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/reload", nil)
handler := s.handleReload()
handler(c)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["status"] != "reloaded" {
t.Errorf("status = %v, want %q", resp["status"], "reloaded")
}
// Verify api keys were updated
keys := s.apiKeys.Load()
if _, ok := (*keys)["reloaded-key-1"]; !ok {
t.Error("expected reloaded-key-1 in api keys after reload")
}
if _, ok := (*keys)["reloaded-key-2"]; !ok {
t.Error("expected reloaded-key-2 in api keys after reload")
}
if len(*keys) != 2 {
t.Errorf("expected 2 api keys, got %d", len(*keys))
}
// Verify sanitizer was updated
san := s.sanitizer.Load()
if san == nil {
t.Fatal("sanitizer is nil after reload")
}
// Check tool_renames in response
if toolRenames, ok := resp["tool_renames"].(float64); !ok || int(toolRenames) != 1 {
t.Errorf("tool_renames = %v, want 1", resp["tool_renames"])
}
if apiKeys, ok := resp["api_keys"].(float64); !ok || int(apiKeys) != 2 {
t.Errorf("api_keys = %v, want 2", resp["api_keys"])
}
}
func TestHandleReload_InvalidConfig(t *testing.T) {
s := &Server{configPath: "/nonexistent/path/config.yaml"}
emptyKeys := makeKeySet(nil)
s.apiKeys.Store(&emptyKeys)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/reload", nil)
handler := s.handleReload()
handler(c)
if w.Code != http.StatusInternalServerError {
t.Errorf("status = %d, want %d", w.Code, http.StatusInternalServerError)
}
var resp map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["error"] == "" {
t.Error("expected non-empty error message")
}
}
// --- Full route tests using httptest ---
func TestHealthzEndpoint(t *testing.T) {
engine := gin.New()
engine.Use(corsMiddleware())
s := newServerWithKeys(nil)
engine.Use(s.authMiddleware())
engine.GET("/healthz", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
w := httptest.NewRecorder()
engine.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
}
var body map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if body["status"] != "ok" {
t.Errorf("status = %q, want %q", body["status"], "ok")
}
}
func TestAuthMiddleware_FullRoute_Rejected(t *testing.T) {
engine := gin.New()
s := newServerWithKeys([]string{"correct-key"})
engine.Use(s.authMiddleware())
engine.POST("/v1/messages", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
// No auth header
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
w := httptest.NewRecorder()
engine.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", w.Code, http.StatusUnauthorized)
}
}
func TestAuthMiddleware_FullRoute_Accepted(t *testing.T) {
engine := gin.New()
s := newServerWithKeys([]string{"correct-key"})
engine.Use(s.authMiddleware())
engine.POST("/v1/messages", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
req.Header.Set("x-api-key", "correct-key")
w := httptest.NewRecorder()
engine.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
}
}
func TestCorsMiddleware_FullRoute_OptionsRequest(t *testing.T) {
engine := gin.New()
engine.Use(corsMiddleware())
s := newServerWithKeys([]string{"key"})
engine.Use(s.authMiddleware())
engine.POST("/v1/messages", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodOptions, "/v1/messages", nil)
w := httptest.NewRecorder()
engine.ServeHTTP(w, req)
if w.Code != http.StatusNoContent {
t.Errorf("status = %d, want %d", w.Code, http.StatusNoContent)
}
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
t.Errorf("ACAO = %q, want %q", got, "*")
}
}
// helper
func containsSubstring(s, sub string) bool {
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub))
}
func containsStr(s, sub string) bool {
for i := 0; i <= len(s)-len(sub); i++ {
if s[i:i+len(sub)] == sub {
return true
}
}
return false
}
+81
View File
@@ -0,0 +1,81 @@
package telemetry
import (
"context"
"encoding/json"
"time"
otellog "go.opentelemetry.io/otel/log"
sdklog "go.opentelemetry.io/otel/sdk/log"
)
// LogBridge implements io.Writer and forwards zerolog JSON lines to the
// OTel LoggerProvider. It is used as an extra writer in zerolog's MultiWriter
// so that logs go to both file and OTLP.
type LogBridge struct {
provider *sdklog.LoggerProvider
}
func (b *LogBridge) Write(p []byte) (n int, err error) {
var entry map[string]interface{}
if err := json.Unmarshal(p, &entry); err != nil {
return len(p), nil // skip malformed lines
}
logger := b.provider.Logger("zerolog")
var rec otellog.Record
rec.SetTimestamp(time.Now())
if msg, ok := entry["message"].(string); ok {
rec.SetBody(otellog.StringValue(msg))
}
if lvl, ok := entry["level"].(string); ok {
rec.SetSeverity(mapSeverity(lvl))
}
// Forward all fields as attributes
attrs := make([]otellog.KeyValue, 0, len(entry))
for k, v := range entry {
if k == "message" || k == "level" || k == "time" {
continue
}
switch val := v.(type) {
case string:
attrs = append(attrs, otellog.String(k, val))
case float64:
attrs = append(attrs, otellog.Float64(k, val))
case bool:
attrs = append(attrs, otellog.Bool(k, val))
default:
b, _ := json.Marshal(val)
attrs = append(attrs, otellog.String(k, string(b)))
}
}
rec.AddAttributes(attrs...)
logger.Emit(context.Background(), rec)
return len(p), nil
}
func mapSeverity(level string) otellog.Severity {
switch level {
case "trace":
return otellog.SeverityTrace
case "debug":
return otellog.SeverityDebug
case "info":
return otellog.SeverityInfo
case "warn", "warning":
return otellog.SeverityWarn
case "error":
return otellog.SeverityError
case "fatal":
return otellog.SeverityFatal
case "panic":
return otellog.SeverityFatal2
default:
return otellog.SeverityInfo
}
}
+178
View File
@@ -0,0 +1,178 @@
package telemetry
import (
"encoding/json"
"testing"
otellog "go.opentelemetry.io/otel/log"
sdklog "go.opentelemetry.io/otel/sdk/log"
)
func TestMapSeverity(t *testing.T) {
tests := []struct {
input string
want otellog.Severity
}{
{"trace", otellog.SeverityTrace},
{"debug", otellog.SeverityDebug},
{"info", otellog.SeverityInfo},
{"warn", otellog.SeverityWarn},
{"warning", otellog.SeverityWarn},
{"error", otellog.SeverityError},
{"fatal", otellog.SeverityFatal},
{"panic", otellog.SeverityFatal2},
{"unknown", otellog.SeverityInfo},
{"", otellog.SeverityInfo},
{"INFO", otellog.SeverityInfo}, // uppercase falls to default
{"PANIC", otellog.SeverityInfo}, // uppercase falls to default
{"gibberish", otellog.SeverityInfo},
}
for _, tc := range tests {
t.Run("level_"+tc.input, func(t *testing.T) {
got := mapSeverity(tc.input)
if got != tc.want {
t.Errorf("mapSeverity(%q) = %v, want %v", tc.input, got, tc.want)
}
})
}
}
func newTestBridge(t *testing.T) *LogBridge {
t.Helper()
provider := sdklog.NewLoggerProvider()
t.Cleanup(func() {
_ = provider.Shutdown(t.Context())
})
return &LogBridge{provider: provider}
}
func TestLogBridgeWrite(t *testing.T) {
tests := []struct {
name string
input interface{} // will be marshaled to JSON; use string for raw input
raw string // if non-empty, use this directly instead of marshaling input
}{
{
name: "valid_json_with_message_level_and_extras",
input: map[string]interface{}{
"message": "request handled",
"level": "info",
"method": "GET",
"status": float64(200),
},
},
{
name: "message_only_no_level",
input: map[string]interface{}{
"message": "hello world",
},
},
{
name: "level_only_no_message",
input: map[string]interface{}{
"level": "error",
},
},
{
name: "empty_json_object",
input: map[string]interface{}{},
},
{
name: "string_float64_bool_attributes",
input: map[string]interface{}{
"message": "test",
"level": "debug",
"str_val": "hello",
"num_val": float64(3.14),
"bool_val": true,
},
},
{
name: "complex_nested_object_attribute",
input: map[string]interface{}{
"message": "nested",
"level": "warn",
"nested": map[string]interface{}{"foo": "bar", "n": float64(1)},
},
},
{
name: "time_field_skipped_in_attributes",
input: map[string]interface{}{
"message": "with time",
"level": "info",
"time": "2025-01-01T00:00:00Z",
"extra": "kept",
},
},
{
name: "malformed_json",
raw: "this is not json at all",
},
{
name: "malformed_json_partial",
raw: `{"broken":`,
},
{
name: "array_attribute_marshaled_as_string",
input: map[string]interface{}{
"message": "arrays",
"tags": []interface{}{"a", "b"},
},
},
{
name: "null_value_attribute",
input: map[string]interface{}{
"message": "nulls",
"val": nil,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
bridge := newTestBridge(t)
var p []byte
if tc.raw != "" {
p = []byte(tc.raw)
} else {
var err error
p, err = json.Marshal(tc.input)
if err != nil {
t.Fatalf("failed to marshal test input: %v", err)
}
}
n, err := bridge.Write(p)
if n != len(p) {
t.Errorf("Write() returned n=%d, want %d", n, len(p))
}
if err != nil {
t.Errorf("Write() returned err=%v, want nil", err)
}
})
}
}
func TestLogBridgeWriteAlwaysReturnsLenAndNil(t *testing.T) {
bridge := newTestBridge(t)
inputs := [][]byte{
[]byte(`{"message":"ok","level":"info"}`),
[]byte(`not json`),
[]byte(`{}`),
[]byte(``),
[]byte(`[]`),
}
for _, p := range inputs {
n, err := bridge.Write(p)
if n != len(p) {
t.Errorf("Write(%q) n=%d, want %d", string(p), n, len(p))
}
if err != nil {
t.Errorf("Write(%q) err=%v, want nil", string(p), err)
}
}
}
+85
View File
@@ -0,0 +1,85 @@
package telemetry
import (
"context"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"github.com/fujin/anthropic-proxy/internal/ratelimit"
)
var (
RequestCounter metric.Int64Counter
RequestDuration metric.Float64Histogram
RequestBodySize metric.Int64Histogram
UpstreamErrors metric.Int64Counter
TokensInput metric.Int64Counter
TokensOutput metric.Int64Counter
CredentialCooldowns metric.Int64Counter
ActiveCredentials metric.Int64UpDownCounter
StreamRequests metric.Int64Counter
)
// InitMetrics creates all metric instruments from the given meter.
// If tracker is non-nil, registers observable gauges for per-window usage.
func InitMetrics(meter metric.Meter, tracker *ratelimit.Tracker) {
RequestCounter, _ = meter.Int64Counter("proxy.request.count",
metric.WithDescription("Total proxy requests"),
)
RequestDuration, _ = meter.Float64Histogram("proxy.request.duration_ms",
metric.WithDescription("Request latency in milliseconds"),
metric.WithUnit("ms"),
)
RequestBodySize, _ = meter.Int64Histogram("proxy.request.body_size_bytes",
metric.WithDescription("Request body size in bytes"),
metric.WithUnit("By"),
)
UpstreamErrors, _ = meter.Int64Counter("proxy.upstream.errors",
metric.WithDescription("Upstream error count"),
)
TokensInput, _ = meter.Int64Counter("proxy.tokens.input",
metric.WithDescription("Input tokens consumed"),
)
TokensOutput, _ = meter.Int64Counter("proxy.tokens.output",
metric.WithDescription("Output tokens consumed"),
)
CredentialCooldowns, _ = meter.Int64Counter("proxy.credential.cooldowns",
metric.WithDescription("Credential cooldown activations"),
)
ActiveCredentials, _ = meter.Int64UpDownCounter("proxy.credential.active",
metric.WithDescription("Currently active (non-cooldown) credentials"),
)
StreamRequests, _ = meter.Int64Counter("proxy.stream.requests",
metric.WithDescription("Streaming request count"),
)
if tracker == nil {
return
}
attr5h := attribute.String("window", "5h")
attr7d := attribute.String("window", "7d")
attrSonnet := attribute.String("window", "7d_sonnet")
meter.Float64ObservableGauge("proxy.usage.utilization",
metric.WithDescription("Current utilization % from API"),
metric.WithFloat64Callback(func(_ context.Context, o metric.Float64Observer) error {
o.Observe(tracker.FiveHour().Utilization, metric.WithAttributes(attr5h))
o.Observe(tracker.SevenDay().Utilization, metric.WithAttributes(attr7d))
o.Observe(tracker.Sonnet().Utilization, metric.WithAttributes(attrSonnet))
return nil
}),
)
meter.Int64ObservableGauge("proxy.usage.resets_at",
metric.WithDescription("Unix seconds when window resets"),
metric.WithInt64Callback(func(_ context.Context, o metric.Int64Observer) error {
o.Observe(tracker.FiveHour().ResetsAt.Unix(), metric.WithAttributes(attr5h))
o.Observe(tracker.SevenDay().ResetsAt.Unix(), metric.WithAttributes(attr7d))
o.Observe(tracker.Sonnet().ResetsAt.Unix(), metric.WithAttributes(attrSonnet))
return nil
}),
)
}
+112
View File
@@ -0,0 +1,112 @@
package telemetry
import (
"context"
"io"
"net/http"
"github.com/fujin/anthropic-proxy/internal/config"
"github.com/fujin/anthropic-proxy/internal/ratelimit"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc"
"go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
promexporter "go.opentelemetry.io/otel/exporters/prometheus"
otellog "go.opentelemetry.io/otel/log/global"
"go.opentelemetry.io/otel/sdk/log"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/resource"
"go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
)
func Setup(ctx context.Context, cfg config.TelemetryConfig, tracker *ratelimit.Tracker) (shutdown func(context.Context) error, logWriter io.Writer, metricsHandler http.Handler, err error) {
res, err := resource.New(ctx,
resource.WithAttributes(
semconv.ServiceName(cfg.ServiceName),
),
)
if err != nil {
return nil, nil, nil, err
}
var readers []sdkmetric.Option
readers = append(readers, sdkmetric.WithResource(res))
var promHandler http.Handler
if cfg.Embedded.Enabled {
exporter, pErr := promexporter.New()
if pErr != nil {
return nil, nil, nil, pErr
}
readers = append(readers, sdkmetric.WithReader(exporter))
promHandler = promhttp.Handler()
}
if !cfg.Export.Enabled() {
mp := sdkmetric.NewMeterProvider(readers...)
otel.SetMeterProvider(mp)
InitMetrics(mp.Meter(cfg.ServiceName), tracker)
return func(ctx context.Context) error { return mp.Shutdown(ctx) }, nil, promHandler, nil
}
traceOpts := []otlptracegrpc.Option{otlptracegrpc.WithEndpoint(cfg.Export.Endpoint)}
metricOpts := []otlpmetricgrpc.Option{
otlpmetricgrpc.WithEndpoint(cfg.Export.Endpoint),
otlpmetricgrpc.WithTemporalitySelector(sdkmetric.CumulativeTemporalitySelector),
}
logOpts := []otlploggrpc.Option{otlploggrpc.WithEndpoint(cfg.Export.Endpoint)}
if cfg.Export.Insecure {
traceOpts = append(traceOpts, otlptracegrpc.WithInsecure())
metricOpts = append(metricOpts, otlpmetricgrpc.WithInsecure())
logOpts = append(logOpts, otlploggrpc.WithInsecure())
}
traceExp, err := otlptracegrpc.New(ctx, traceOpts...)
if err != nil {
return nil, nil, nil, err
}
tp := trace.NewTracerProvider(
trace.WithBatcher(traceExp),
trace.WithResource(res),
)
otel.SetTracerProvider(tp)
metricExp, err := otlpmetricgrpc.New(ctx, metricOpts...)
if err != nil {
return nil, nil, nil, err
}
readers = append(readers, sdkmetric.WithReader(sdkmetric.NewPeriodicReader(metricExp)))
mp := sdkmetric.NewMeterProvider(readers...)
otel.SetMeterProvider(mp)
InitMetrics(mp.Meter(cfg.ServiceName), tracker)
logExp, err := otlploggrpc.New(ctx, logOpts...)
if err != nil {
return nil, nil, nil, err
}
lp := log.NewLoggerProvider(
log.WithProcessor(log.NewBatchProcessor(logExp)),
log.WithResource(res),
)
otellog.SetLoggerProvider(lp)
bridge := &LogBridge{provider: lp}
shutdownFn := func(ctx context.Context) error {
var firstErr error
if e := tp.Shutdown(ctx); e != nil && firstErr == nil {
firstErr = e
}
if e := mp.Shutdown(ctx); e != nil && firstErr == nil {
firstErr = e
}
if e := lp.Shutdown(ctx); e != nil && firstErr == nil {
firstErr = e
}
return firstErr
}
return shutdownFn, bridge, promHandler, nil
}
@@ -1,29 +1,47 @@
package proxy // Package transport provides a shared uTLS HTTP/2 round-tripper with Chrome
// TLS fingerprinting and per-host connection pooling. Used by both the upstream
// proxy client and the OAuth token refresh client.
package transport
import ( import (
"log"
"net" "net"
"net/http" "net/http"
"sync" "sync"
"time"
tls "github.com/refraction-networking/utls" tls "github.com/refraction-networking/utls"
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )
type utlsRoundTripper struct { // UTLS implements http.RoundTripper using uTLS (Chrome fingerprint) over HTTP/2.
// It maintains a per-host connection pool with coordination for concurrent
// requests to the same host.
type UTLS struct {
mu sync.Mutex mu sync.Mutex
connections map[string]*http2.ClientConn connections map[string]*http2.ClientConn
pending map[string]*sync.Cond pending map[string]*sync.Cond
dialTimeout time.Duration
} }
func newUtlsRoundTripper() *utlsRoundTripper { // NewUTLS creates a uTLS HTTP/2 round-tripper with a 10-second dial timeout.
return &utlsRoundTripper{ func NewUTLS() *UTLS {
return &UTLS{
connections: make(map[string]*http2.ClientConn), connections: make(map[string]*http2.ClientConn),
pending: make(map[string]*sync.Cond), pending: make(map[string]*sync.Cond),
dialTimeout: 10 * time.Second,
} }
} }
func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) { // NewHTTPClient returns an http.Client using uTLS transport with the given
// request timeout. Pass 0 for no timeout (streaming).
func NewHTTPClient(timeout time.Duration) *http.Client {
return &http.Client{
Timeout: timeout,
Transport: NewUTLS(),
}
}
func (t *UTLS) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) {
t.mu.Lock() t.mu.Lock()
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() { if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
@@ -59,8 +77,8 @@ func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.Clie
return h2Conn, nil return h2Conn, nil
} }
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) { func (t *UTLS) createConnection(host, addr string) (*http2.ClientConn, error) {
conn, err := net.Dial("tcp", addr) conn, err := net.DialTimeout("tcp", addr, t.dialTimeout)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -83,14 +101,14 @@ func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientCon
return h2Conn, nil return h2Conn, nil
} }
func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { // RoundTrip implements http.RoundTripper with uTLS Chrome fingerprinting.
func (t *UTLS) RoundTrip(req *http.Request) (*http.Response, error) {
hostname := req.URL.Hostname() hostname := req.URL.Hostname()
port := req.URL.Port() port := req.URL.Port()
if port == "" { if port == "" {
port = "443" port = "443"
} }
addr := net.JoinHostPort(hostname, port) addr := net.JoinHostPort(hostname, port)
log.Printf("utls: RoundTrip to %s (Chrome TLS fingerprint, HTTP/2)", addr)
h2Conn, err := t.getOrCreateConnection(hostname, addr) h2Conn, err := t.getOrCreateConnection(hostname, addr)
if err != nil { if err != nil {
+78
View File
@@ -0,0 +1,78 @@
package transport
import (
"net/http"
"testing"
"time"
)
func TestNewUTLS(t *testing.T) {
tr := NewUTLS()
if tr == nil {
t.Fatal("NewUTLS returned nil")
}
if tr.connections == nil {
t.Error("connections map is nil")
}
if tr.pending == nil {
t.Error("pending map is nil")
}
if tr.dialTimeout != 10*time.Second {
t.Errorf("dialTimeout = %v, want 10s", tr.dialTimeout)
}
}
func TestNewHTTPClient(t *testing.T) {
tests := []struct {
name string
timeout time.Duration
}{
{"zero timeout (streaming)", 0},
{"15s timeout", 15 * time.Second},
{"30s timeout", 30 * time.Second},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := NewHTTPClient(tt.timeout)
if c == nil {
t.Fatal("NewHTTPClient returned nil")
}
if c.Timeout != tt.timeout {
t.Errorf("Timeout = %v, want %v", c.Timeout, tt.timeout)
}
if c.Transport == nil {
t.Error("Transport is nil")
}
if _, ok := c.Transport.(*UTLS); !ok {
t.Errorf("Transport type = %T, want *UTLS", c.Transport)
}
})
}
}
func TestUTLS_ImplementsRoundTripper(t *testing.T) {
var _ http.RoundTripper = (*UTLS)(nil)
}
func TestUTLS_RoundTrip_InvalidHost(t *testing.T) {
tr := NewUTLS()
// Use a non-routable address to test dial timeout behavior
req, err := http.NewRequest("GET", "https://192.0.2.1:443/test", nil)
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
_, err = tr.RoundTrip(req)
if err == nil {
t.Error("expected error for non-routable address, got nil")
}
}
func TestUTLS_ConnectionEviction(t *testing.T) {
tr := NewUTLS()
// Verify connections map starts empty
tr.mu.Lock()
if len(tr.connections) != 0 {
t.Errorf("initial connections = %d, want 0", len(tr.connections))
}
tr.mu.Unlock()
}
+8
View File
@@ -0,0 +1,8 @@
// Package version provides the fallback Claude Code client version used when
// no sniffed profile is available. This constant is shared between the upstream
// proxy client and the rate limit usage poller.
package version
// ClaudeCodeFallback is the Claude Code CLI version string used as a fallback
// when no real version is obtained from sniffing.
const ClaudeCodeFallback = "2.1.92"
+113 -16
View File
@@ -3,7 +3,7 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"log" "io"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
@@ -12,62 +12,159 @@ import (
"github.com/fujin/anthropic-proxy/internal/auth" "github.com/fujin/anthropic-proxy/internal/auth"
"github.com/fujin/anthropic-proxy/internal/config" "github.com/fujin/anthropic-proxy/internal/config"
"github.com/fujin/anthropic-proxy/internal/embedded"
"github.com/fujin/anthropic-proxy/internal/logging"
"github.com/fujin/anthropic-proxy/internal/proxy" "github.com/fujin/anthropic-proxy/internal/proxy"
"github.com/fujin/anthropic-proxy/internal/ratelimit"
"github.com/fujin/anthropic-proxy/internal/server" "github.com/fujin/anthropic-proxy/internal/server"
"github.com/fujin/anthropic-proxy/internal/telemetry"
"github.com/rs/zerolog/log"
) )
func run() error { func initCredential() (*auth.Credential, error) {
log.SetFlags(log.LstdFlags) creds, err := auth.LoadDefaultCredentials()
if err != nil {
return nil, fmt.Errorf("load credentials: %w", err)
}
var cred *auth.Credential
if len(creds) > 0 {
cred = creds[0]
// If token is expired, try refresh first
if !cred.ExpiresAt.IsZero() && time.Now().After(cred.ExpiresAt) {
log.Info().Msg("token expired, attempting refresh")
refreshCtx, refreshCancel := context.WithTimeout(context.Background(), 15*time.Second)
refreshErr := auth.RefreshToken(refreshCtx, cred)
refreshCancel()
if refreshErr != nil {
log.Warn().Err(refreshErr).Msg("refresh failed, initiating login")
cred = nil // fall through to login
} else {
log.Info().Msg("token refreshed")
}
}
}
if cred == nil {
fi, statErr := os.Stdin.Stat()
if statErr == nil && (fi.Mode()&os.ModeCharDevice) == 0 {
return nil, fmt.Errorf("no valid credentials found; run the proxy interactively for initial login")
}
log.Info().Msg("no credentials found, starting OAuth login")
cred, err = auth.Login(context.Background())
if err != nil {
return nil, fmt.Errorf("login failed: %w", err)
}
}
log.Info().Str("credential", cred.Email).Msg("credential loaded")
return cred, nil
}
func initEmbedded(cfg *config.Config) (cleanup func(), err error) {
if !cfg.Telemetry.Embedded.Enabled {
return func() {}, nil
}
var cleanups []func()
vm := embedded.NewVM(cfg.Telemetry.Embedded, cfg.Port)
if err := vm.Start(); err != nil {
log.Error().Err(err).Msg("failed to start victoria-metrics")
} else {
cleanups = append(cleanups, vm.Stop)
}
perses := embedded.NewPerses(cfg.Telemetry.Embedded, cfg.Port)
if err := perses.Start(); err != nil {
log.Error().Err(err).Msg("failed to start perses")
} else {
cleanups = append(cleanups, perses.Stop)
}
return func() {
for i := len(cleanups) - 1; i >= 0; i-- {
cleanups[i]()
}
}, nil
}
func run() error {
cfg, err := config.Load("config.yaml") cfg, err := config.Load("config.yaml")
if err != nil { if err != nil {
return fmt.Errorf("load config: %w", err) return fmt.Errorf("load config: %w", err)
} }
creds, err := config.LoadCredentials(cfg) // Create usage tracker (started later once credential is loaded)
var credForTracker *auth.Credential
tracker := ratelimit.NewTracker(func() string {
if credForTracker == nil {
return ""
}
return credForTracker.Token()
})
// Initialize telemetry (metrics always active; OTLP export when endpoint set)
telemetryShutdown, logBridge, metricsHandler, err := telemetry.Setup(context.Background(), cfg.Telemetry, tracker)
if err != nil { if err != nil {
return fmt.Errorf("load credentials: %w", err) return fmt.Errorf("telemetry setup: %w", err)
}
defer telemetryShutdown(context.Background())
var extraWriters []io.Writer
if logBridge != nil {
extraWriters = append(extraWriters, logBridge)
} }
if len(creds) == 0 { logging.Setup(cfg.Logging, extraWriters...)
return fmt.Errorf("no credentials found")
cred, err := initCredential()
if err != nil {
return err
} }
log.Printf("loaded %d credentials", len(creds)) credForTracker = cred
pool := auth.NewPool(creds) pool := auth.NewPool([]*auth.Credential{cred})
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
pool.RefreshExpiring(context.Background()) pool.RefreshExpiring(context.Background())
auth.StartBackgroundRefresh(ctx, pool) auth.StartBackgroundRefresh(ctx, pool)
tracker.Start(ctx)
var profile *proxy.SniffedProfile var profile *proxy.SniffedProfile
if cfg.ClaudeBinary != "" { if cfg.ClaudeBinary != "" {
log.Printf("sniffing claude-code at %s...", cfg.ClaudeBinary) log.Info().Str("binary", cfg.ClaudeBinary).Msg("sniffing claude-code")
profile, err = proxy.SniffClaudeCode(cfg.ClaudeBinary) profile, err = proxy.SniffClaudeCode(cfg.ClaudeBinary)
if err != nil { if err != nil {
log.Printf("warning: sniff failed, using defaults: %v", err) log.Warn().Err(err).Msg("sniff failed, using defaults")
} }
} }
log.Printf("starting server on port %d", cfg.Port) embeddedCleanup, err := initEmbedded(cfg)
srv := server.New(cfg, pool, profile) if err != nil {
return err
}
defer embeddedCleanup()
log.Info().Int("port", cfg.Port).Msg("starting server")
srv := server.New(cfg, pool, profile, tracker, metricsHandler)
quit := make(chan os.Signal, 1) quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
go func() { go func() {
<-quit <-quit
log.Printf("shutting down...") log.Info().Msg("shutting down")
cancel() cancel()
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second) shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel() defer shutdownCancel()
if err := srv.Shutdown(shutdownCtx); err != nil { if err := srv.Shutdown(shutdownCtx); err != nil {
log.Printf("shutdown error: %v", err) log.Error().Err(err).Msg("shutdown error")
} }
}() }()
@@ -80,7 +177,7 @@ func run() error {
func main() { func main() {
if err := run(); err != nil { if err := run(); err != nil {
log.Printf("error: %v", err) log.Error().Err(err).Msg("fatal error")
os.Exit(1) os.Exit(1)
} }
} }
+20
View File
@@ -0,0 +1,20 @@
{
buildGoModule,
lib,
pkgs,
...
}:
buildGoModule rec {
pname = "anthropic-proxy";
version = "0.0.5";
src = ./.;
vendorHash = "sha256-yXINNC+NEw+HbOQ5aBgSE5dYTWp+zEZ230rzXfwOoDY=";
meta = with lib; {
description = "Reverse proxy that lets OpenCode (and similar tools) use a Claude subscription instead of an API key.";
homepage = "https://gitea.susano-homelab.duckdns.org/fujin/anthropic-proxy";
};
}