Compare commits
41 Commits
17cde479c3
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 86db3ca091 | |||
| 0df28e9dd8 | |||
| 9150f466e5 | |||
| d3fbfe8b42 | |||
| a6c9a16833 | |||
| 34927d3a00 | |||
| ee9c53791a | |||
| 859640d814 | |||
| be4113e7ef | |||
| 501e40c53d | |||
| 1bc704a7b2 | |||
| bc6ad70386 | |||
| b07d999d86 | |||
| 27b647e9b4 | |||
| 273213cbed | |||
| b864092dad | |||
| 0ab1896eef | |||
| eda66ff7d4 | |||
| 744abc1d24 | |||
| e8af26d626 | |||
| fac9578975 | |||
| 76aeeb6be1 | |||
| 9cc052c162 | |||
| 20049881ad | |||
| 3435f5f4c5 | |||
| 807e8ba133 | |||
| da59d8f83b | |||
| 4e22c463cf | |||
| 76bf651742 | |||
| 3d1eb7bd4b | |||
| bfcbe0b37d | |||
| a7b583839d | |||
| c5f6962104 | |||
| 5ec0004e4c | |||
| bf68a0fbeb | |||
| e3c4854be0 | |||
| 8b7d9bfff9 | |||
| 65e843f57a | |||
| 9858530ff6 | |||
| 21176949a6 | |||
| 945a865bbe |
@@ -4,3 +4,5 @@
|
||||
anthropic-proxy
|
||||
result
|
||||
config.yaml
|
||||
|
||||
vendor/**
|
||||
|
||||
@@ -1,57 +1,62 @@
|
||||
# anthropic-proxy
|
||||
|
||||
Reverse proxy that lets OpenCode (and similar tools) use a Claude subscription instead of an API key.
|
||||
Reverse proxy for the Anthropic Messages API that authenticates with a Claude subscription (OAuth) instead of an API key. Lets you use tools like [OpenCode](https://github.com/opencode-ai/opencode) through your existing Claude Pro/Team plan.
|
||||
|
||||
## Prerequisites
|
||||
## How it works
|
||||
|
||||
- Go 1.26+
|
||||
- **Claude Code CLI** — installed and logged in (`claude auth login`). The proxy reads the OAuth token from `~/.claude/.credentials.json`.
|
||||
Clients send standard Anthropic API requests to the proxy. The proxy authenticates upstream using OAuth credentials from your Claude subscription, forwards the request, and streams the response back. Requests are optionally sanitized (tool name remapping, string replacement) before forwarding and de-sanitized on return.
|
||||
|
||||
Optional: [Nix](https://nixos.org/) flake for dev shell (`nix develop`).
|
||||
## Features
|
||||
|
||||
## Setup
|
||||
- **OAuth credential management** — reuses `~/.claude/.credentials.json`, auto-refreshes tokens
|
||||
- **Request sanitization** — rename tools, replace strings in system prompts and body (configurable, hot-reloadable)
|
||||
- **Rate limit tracking** — polls Anthropic usage API and reads response headers to track 5h/7d utilization windows
|
||||
- **OpenTelemetry metrics** — request counts, latency, token usage, errors (optional OTLP export)
|
||||
- **Structured logging** — zerolog with file rotation via lumberjack
|
||||
|
||||
## Quick start
|
||||
|
||||
```
|
||||
cp config.example.yaml config.yaml
|
||||
```
|
||||
# edit config.yaml — set api_keys and optionally claude_binary
|
||||
|
||||
Edit `config.yaml`:
|
||||
- `api_keys` — key(s) your clients use to authenticate with the proxy
|
||||
- `claude_credentials` — path to your Claude credentials file
|
||||
- `claude_binary` — path to `claude` binary (used on startup to capture request fingerprint)
|
||||
|
||||
## Build and run
|
||||
|
||||
```
|
||||
go build -o anthropic-proxy .
|
||||
./anthropic-proxy
|
||||
```
|
||||
|
||||
## Usage with OpenCode
|
||||
On first run, if no credentials exist at `~/.claude/.credentials.json`, an OAuth login flow starts in your browser. If running headlessly, the authorization URL is printed to stdout. If you've already logged in with Claude Code CLI, the proxy reuses those credentials.
|
||||
|
||||
### Nix
|
||||
|
||||
```
|
||||
nix develop # dev shell with Go
|
||||
nix build # build the binary
|
||||
```
|
||||
|
||||
## Client configuration
|
||||
|
||||
Point any Anthropic-compatible client at the proxy:
|
||||
|
||||
```
|
||||
export ANTHROPIC_API_KEY=your-proxy-api-key
|
||||
export ANTHROPIC_BASE_URL=http://localhost:8082
|
||||
opencode
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| POST | `/v1/messages` | Anthropic messages API (proxied) |
|
||||
| POST | `/v1/messages` | Anthropic Messages API (proxied) |
|
||||
| POST | `/messages` | Same, without `/v1` prefix |
|
||||
| GET | `/healthz` | Health check |
|
||||
| POST | `/reload` | Hot-reload `config.yaml` |
|
||||
| POST | `/reload` | Hot-reload config (sanitize rules + API keys) |
|
||||
|
||||
## Request sanitization
|
||||
## Configuration
|
||||
|
||||
The `sanitize` section in config renames tool names and replaces strings in system prompts before forwarding to Anthropic. Responses are de-sanitized before returning to the client.
|
||||
See [`config.example.yaml`](config.example.yaml) for all options. Key sections:
|
||||
|
||||
See `config.example.yaml` for the default rules.
|
||||
|
||||
Reload after editing config:
|
||||
|
||||
```
|
||||
curl -X POST localhost:8082/reload
|
||||
```
|
||||
- **`api_keys`** — keys clients use to authenticate with the proxy
|
||||
- **`sanitize`** — tool renames, system prompt replacements, body replacements
|
||||
- **`telemetry`** — OTLP endpoint, service name, auth headers
|
||||
- **`logging`** — level, file path, rotation settings
|
||||
- **`claude_binary`** — path to `claude` CLI for request fingerprinting (optional)
|
||||
|
||||
+42
-2
@@ -1,7 +1,30 @@
|
||||
port: 8082
|
||||
|
||||
# telemetry:
|
||||
# service_name: "anthropic-proxy"
|
||||
# export:
|
||||
# endpoint: "localhost:4317" # OTLP gRPC endpoint (omit to disable export)
|
||||
# insecure: true # disable TLS for local dev
|
||||
# headers: # optional auth headers (e.g. Grafana Cloud)
|
||||
# Authorization: "Basic ..."
|
||||
# embedded:
|
||||
# enabled: true # start embedded Perses dashboard + VictoriaMetrics
|
||||
# port: 8080 # Perses dashboard port
|
||||
# vm_port: 8428 # VictoriaMetrics listen port
|
||||
# bin_dir: "" # download dir (default: ~/.cache/anthropic-proxy/bin)
|
||||
# perses_binary: "" # custom path to perses binary (default: auto-download)
|
||||
# vm_binary: "" # custom path to victoria-metrics binary (default: auto-download)
|
||||
|
||||
logging:
|
||||
level: debug
|
||||
file: /home/fujin/.local/log/anthropic-proxy.log
|
||||
max_size_mb: 100
|
||||
max_backups: 5
|
||||
max_age_days: 30
|
||||
compress: true
|
||||
|
||||
api_keys:
|
||||
- "your-proxy-api-key"
|
||||
claude_credentials: "~/.claude/.credentials.json"
|
||||
claude_binary: "claude"
|
||||
|
||||
sanitize:
|
||||
@@ -27,5 +50,22 @@ sanitize:
|
||||
system:
|
||||
- match: "Workspace root folder"
|
||||
replace: "Working directory"
|
||||
- match: "anomalyco/opencode"
|
||||
body:
|
||||
- match: "anthropics/claude-code"
|
||||
replace: "anthropics/claude-code"
|
||||
- match: "anthropic"
|
||||
replace: "anthropic"
|
||||
- match: "system-directive"
|
||||
replace: "system-directive"
|
||||
- match: "claude-code"
|
||||
replace: "claude-code"
|
||||
- match: "claude-agent"
|
||||
replace: "claude-agent"
|
||||
- match: "system_initiator"
|
||||
replace: "system_initiator"
|
||||
- match: "call_agent"
|
||||
replace: "call_agent"
|
||||
- match: "claude.ai"
|
||||
replace: "claude.ai"
|
||||
- match: "agent"
|
||||
replace: "agent"
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
# Metrics Reference
|
||||
|
||||
All metrics are emitted via OpenTelemetry (OTLP gRPC) with cumulative temporality. OTel metric names use dots; Prometheus converts them to underscores and appends `_total` for counters.
|
||||
|
||||
## Counters
|
||||
|
||||
| OTel Name | Prometheus Name | Attributes | Description |
|
||||
|---|---|---|---|
|
||||
| `proxy.request.count` | `proxy_request_count_total` | `model`, `stream`, `status_code` | Total proxied requests |
|
||||
| `proxy.tokens.input` | `proxy_tokens_input_total` | `model`, `credential` | Input tokens consumed |
|
||||
| `proxy.tokens.output` | `proxy_tokens_output_total` | `model`, `credential` | Output tokens consumed |
|
||||
| `proxy.upstream.errors` | `proxy_upstream_errors_total` | `error_type`, `credential`, `status_code` | Upstream errors (connection failures, 4xx/5xx) |
|
||||
| `proxy.credential.cooldowns` | `proxy_credential_cooldowns_total` | `status_code` | Credential cooldown activations (rate limited) |
|
||||
| `proxy.stream.requests` | `proxy_stream_requests_total` | `model` | Streaming request count |
|
||||
|
||||
## Histograms
|
||||
|
||||
| OTel Name | Prometheus Name | Unit | Attributes | Description |
|
||||
|---|---|---|---|---|
|
||||
| `proxy.request.duration_ms` | `proxy_request_duration_ms_milliseconds` | ms | `model`, `stream`, `status_code` | Request latency |
|
||||
| `proxy.request.body_size_bytes` | `proxy_request_body_size_bytes` | bytes | `model`, `stream` | Request body size |
|
||||
|
||||
## Gauges
|
||||
|
||||
| OTel Name | Prometheus Name | Attributes | Description |
|
||||
|---|---|---|---|
|
||||
| `proxy.usage.utilization` | `proxy_usage_utilization` | `window` | Current utilization % from Anthropic API (0-100) |
|
||||
| `proxy.usage.resets_at` | `proxy_usage_resets_at` | `window` | Unix timestamp when the rate limit window resets |
|
||||
| `proxy.credential.active` | `proxy_credential_active` | — | Currently active (non-cooldown) credentials |
|
||||
|
||||
### Window attribute values
|
||||
|
||||
- `5h` — 5-hour rolling window
|
||||
- `7d` — 7-day rolling window
|
||||
- `7d_sonnet` — 7-day Sonnet-specific window
|
||||
|
||||
## Structured Logs (Loki)
|
||||
|
||||
Each completed request emits a structured log line via OTel LogBridge. Fields are stored as Loki stream labels/structured metadata.
|
||||
|
||||
| Field | Type | Description |
|
||||
|---|---|---|
|
||||
| `input_tokens` | int | Input tokens for this request |
|
||||
| `output_tokens` | int | Output tokens for this request |
|
||||
| `model` | string | Model used |
|
||||
| `latency_ms` | float | Request latency in milliseconds |
|
||||
| `status` | int | HTTP status code (non-stream only) |
|
||||
| `stream` | bool | Whether this was a streaming request |
|
||||
|
||||
Log messages: `"request completed"` (non-stream), `"stream completed"` (stream).
|
||||
|
||||
### Per-window token counting via Loki
|
||||
|
||||
Window token totals are computed in Grafana using LogQL, not tracked in-memory. This approach survives process restarts and uses exact Anthropic window boundaries.
|
||||
|
||||
Grafana variables derive the window age from Prometheus:
|
||||
|
||||
```
|
||||
window_age_5h = time() - proxy_usage_resets_at{window="5h"} + 18000
|
||||
window_age_7d = time() - proxy_usage_resets_at{window="7d"} + 604800
|
||||
```
|
||||
|
||||
LogQL queries sum token values from individual log events within the window:
|
||||
|
||||
```logql
|
||||
sum(sum_over_time(
|
||||
{service_name="anthropic-proxy"} |= "completed"
|
||||
| unwrap output_tokens
|
||||
| __error__=""
|
||||
[${window_age_5h}s]
|
||||
))
|
||||
```
|
||||
|
||||
## Annotations (Loki)
|
||||
|
||||
429 rate limit events are surfaced as Grafana annotations:
|
||||
|
||||
```logql
|
||||
{service_name="anthropic-proxy"} |= "upstream error" | json | status = "429"
|
||||
```
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,450 @@
|
||||
{
|
||||
"kind": "Dashboard",
|
||||
"metadata": {
|
||||
"name": "proxy",
|
||||
"createdAt": "2026-04-14T19:47:48.013238204Z",
|
||||
"updatedAt": "2026-04-14T19:49:30.874125459Z",
|
||||
"version": 1,
|
||||
"project": "anthropic-proxy"
|
||||
},
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "Anthropic Proxy"
|
||||
},
|
||||
"datasources": {
|
||||
"vm": {
|
||||
"default": true,
|
||||
"plugin": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"spec": {
|
||||
"directUrl": "http://localhost:9428"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"panels": {
|
||||
"latency": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "Latency"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "TimeSeriesChart",
|
||||
"spec": {
|
||||
"legend": {
|
||||
"position": "bottom"
|
||||
},
|
||||
"yAxis": {
|
||||
"format": {
|
||||
"unit": "milliseconds"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "histogram_quantile(0.50, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
|
||||
"seriesNameFormat": "p50"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "histogram_quantile(0.95, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
|
||||
"seriesNameFormat": "p95"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "histogram_quantile(0.99, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
|
||||
"seriesNameFormat": "p99"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"request_rate": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "Request Rate"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "TimeSeriesChart",
|
||||
"spec": {
|
||||
"legend": {
|
||||
"position": "bottom"
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "rate(proxy_request_count_total[5m])",
|
||||
"seriesNameFormat": "req/s"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"token_rate": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "Token Rate"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "TimeSeriesChart",
|
||||
"spec": {
|
||||
"legend": {
|
||||
"position": "bottom"
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "rate(proxy_tokens_input_total[5m]) * 60",
|
||||
"seriesNameFormat": "input/min"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "rate(proxy_tokens_output_total[5m]) * 60",
|
||||
"seriesNameFormat": "output/min"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"tokens_5h": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "5h Tokens"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "StatChart",
|
||||
"spec": {
|
||||
"calculation": "last",
|
||||
"format": {
|
||||
"unit": "decimal"
|
||||
},
|
||||
"sparkline": {}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "increase(proxy_tokens_output_total[3h])"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"tokens_7d": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "7d Tokens"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "StatChart",
|
||||
"spec": {
|
||||
"calculation": "last",
|
||||
"format": {
|
||||
"unit": "decimal"
|
||||
},
|
||||
"sparkline": {}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "increase(proxy_tokens_output_total[9h])"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"util_5h": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "5h Utilization"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "GaugeChart",
|
||||
"spec": {
|
||||
"calculation": "last",
|
||||
"format": {
|
||||
"unit": "percent"
|
||||
},
|
||||
"thresholds": {
|
||||
"steps": [
|
||||
{
|
||||
"color": "green",
|
||||
"value": 0
|
||||
},
|
||||
{
|
||||
"color": "orange",
|
||||
"value": 70
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 90
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "proxy_usage_utilization{window=\"5h\"}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"util_7d": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "7d Utilization"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "GaugeChart",
|
||||
"spec": {
|
||||
"calculation": "last",
|
||||
"format": {
|
||||
"unit": "percent"
|
||||
},
|
||||
"thresholds": {
|
||||
"steps": [
|
||||
{
|
||||
"color": "green",
|
||||
"value": 0
|
||||
},
|
||||
{
|
||||
"color": "orange",
|
||||
"value": 70
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 90
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "proxy_usage_utilization{window=\"7d\"}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"layouts": [
|
||||
{
|
||||
"kind": "Grid",
|
||||
"spec": {
|
||||
"display": {
|
||||
"title": "Utilization"
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": 6,
|
||||
"height": 5,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/util_5h"
|
||||
}
|
||||
},
|
||||
{
|
||||
"x": 6,
|
||||
"y": 0,
|
||||
"width": 6,
|
||||
"height": 5,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/util_7d"
|
||||
}
|
||||
},
|
||||
{
|
||||
"x": 12,
|
||||
"y": 0,
|
||||
"width": 6,
|
||||
"height": 5,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/tokens_5h"
|
||||
}
|
||||
},
|
||||
{
|
||||
"x": 18,
|
||||
"y": 0,
|
||||
"width": 6,
|
||||
"height": 5,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/tokens_7d"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "Grid",
|
||||
"spec": {
|
||||
"display": {
|
||||
"title": "Traffic"
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": 12,
|
||||
"height": 8,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/request_rate"
|
||||
}
|
||||
},
|
||||
{
|
||||
"x": 12,
|
||||
"y": 0,
|
||||
"width": 12,
|
||||
"height": 8,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/latency"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "Grid",
|
||||
"spec": {
|
||||
"display": {
|
||||
"title": "Tokens"
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": 24,
|
||||
"height": 8,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/token_rate"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"duration": "1h",
|
||||
"refreshInterval": "10s"
|
||||
}
|
||||
}
|
||||
Generated
+3
-3
@@ -20,11 +20,11 @@
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1775423009,
|
||||
"narHash": "sha256-vPKLpjhIVWdDrfiUM8atW6YkIggCEKdSAlJPzzhkQlw=",
|
||||
"lastModified": 1776169885,
|
||||
"narHash": "sha256-l/iNYDZ4bGOAFQY2q8y5OAfBBtrDAaPuRQqWaFHVRXM=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "68d8aa3d661f0e6bd5862291b5bb263b2a6595c9",
|
||||
"rev": "4bd9165a9165d7b5e33ae57f3eecbcb28fb231c9",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
@@ -17,14 +17,14 @@
|
||||
let
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
config.allowUnfreePredicate =
|
||||
pkg:
|
||||
builtins.elem (pkgs.lib.getName pkg) [
|
||||
"claude-code"
|
||||
];
|
||||
config.allowUnfree = true;
|
||||
};
|
||||
in
|
||||
{
|
||||
packages = {
|
||||
proxy = pkgs.callPackage ./package.nix {};
|
||||
};
|
||||
|
||||
devShells.default = pkgs.mkShell {
|
||||
buildInputs = with pkgs; [
|
||||
go
|
||||
@@ -42,6 +42,9 @@
|
||||
shellHook = ''
|
||||
export GOPATH="$PWD/.go"
|
||||
export PATH="$GOPATH/bin:$PATH"
|
||||
|
||||
export ANTHROPIC_BASE_URL=http://localhost:8082
|
||||
export ANTHROPIC_API_KEY=sk-cliproxy-fujin
|
||||
'';
|
||||
};
|
||||
}
|
||||
|
||||
@@ -5,44 +5,81 @@ go 1.26
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.12.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/refraction-networking/utls v1.8.2
|
||||
github.com/rs/zerolog v1.35.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.68.0
|
||||
go.opentelemetry.io/otel v1.43.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.19.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.65.0
|
||||
go.opentelemetry.io/otel/log v0.19.0
|
||||
go.opentelemetry.io/otel/metric v1.43.0
|
||||
go.opentelemetry.io/otel/sdk v1.43.0
|
||||
go.opentelemetry.io/otel/sdk/log v0.19.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0
|
||||
golang.org/x/net v0.52.0
|
||||
gopkg.in/lumberjack.v2 v2.0.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.6.0 // indirect
|
||||
github.com/andybalholm/brotli v1.0.6 // indirect
|
||||
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/bytedance/gopkg v0.1.4 // indirect
|
||||
github.com/bytedance/sonic v1.15.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.1 // indirect
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
||||
github.com/gin-contrib/sse v1.1.1 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.30.1 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/go-playground/validator/v10 v10.30.2 // indirect
|
||||
github.com/goccy/go-json v0.10.6 // indirect
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.17.6 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.67.5 // indirect
|
||||
github.com/prometheus/otlptranslator v1.0.0 // indirect
|
||||
github.com/prometheus/procfs v0.20.1 // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/quic-go/quic-go v0.59.0 // indirect
|
||||
github.com/refraction-networking/utls v1.8.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.3.1 // indirect
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
|
||||
golang.org/x/arch v0.22.0 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.10.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.4 // indirect
|
||||
golang.org/x/arch v0.25.0 // indirect
|
||||
golang.org/x/crypto v0.49.0 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
|
||||
google.golang.org/grpc v1.80.0 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
|
||||
@@ -1,51 +1,72 @@
|
||||
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
|
||||
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
|
||||
github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
|
||||
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/bytedance/gopkg v0.1.4 h1:oZnQwnX82KAIWb7033bEwtxvTqXcYMxDBaQxo5JJHWM=
|
||||
github.com/bytedance/gopkg v0.1.4/go.mod h1:v1zWfPm21Fb+OsyXN2VAHdL6TBb2L88anLQgdyje6R4=
|
||||
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
|
||||
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
|
||||
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
|
||||
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||
github.com/bytedance/sonic/loader v0.5.1 h1:Ygpfa9zwRCCKSlrp5bBP/b/Xzc3VxsAW+5NIYXrOOpI=
|
||||
github.com/bytedance/sonic/loader v0.5.1/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
|
||||
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
||||
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.13/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||
github.com/gin-contrib/sse v1.1.1 h1:uGYpNwTacv5R68bSGMapo62iLTRa9l5zxGCps4hK6ko=
|
||||
github.com/gin-contrib/sse v1.1.1/go.mod h1:QXzuVkA0YO7o/gun03UI1Q+FTI8ZV/n5t03kIQAI89s=
|
||||
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
|
||||
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w=
|
||||
github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/go-playground/validator/v10 v10.30.2 h1:JiFIMtSSHb2/XBUbWM4i/MpeQm9ZK2xqPNk8vgvu5JQ=
|
||||
github.com/go-playground/validator/v10 v10.30.2/go.mod h1:mAf2pIOVXjTEBrwUMGKkCWKKPs9NheYGabeB04txQSc=
|
||||
github.com/goccy/go-json v0.10.6 h1:p8HrPJzOakx/mn/bQtjgNjdTcN+/S6FcG2CTtQOrHVU=
|
||||
github.com/goccy/go-json v0.10.6/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
||||
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI=
|
||||
github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -53,18 +74,32 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||
github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4=
|
||||
github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw=
|
||||
github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEoIwkU+A6qos=
|
||||
github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM=
|
||||
github.com/prometheus/procfs v0.20.1 h1:XwbrGOIplXW/AU3YhIhLODXMJYyC1isLFfYCsTEycfc=
|
||||
github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo=
|
||||
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
||||
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
|
||||
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
|
||||
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
|
||||
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
|
||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/rs/zerolog v1.35.0 h1:VD0ykx7HMiMJytqINBsKcbLS+BJ4WYjz+05us+LRTdI=
|
||||
github.com/rs/zerolog v1.35.0/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
@@ -91,32 +126,78 @@ github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY
|
||||
github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.68.0 h1:5FXSL2s6afUC1bzNzl1iedZZ8yqR7GOhbCoEXtyeK6Q=
|
||||
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.68.0/go.mod h1:MdHW7tLtkeGJnR4TyOrnd5D0zUGZQB1l84uHCe8hRpE=
|
||||
go.opentelemetry.io/contrib/propagators/b3 v1.43.0 h1:CETqV3QLLPTy5yNrqyMr41VnAOOD4lsRved7n4QG00A=
|
||||
go.opentelemetry.io/contrib/propagators/b3 v1.43.0/go.mod h1:Q4mCiCdziYzpNR0g+6UqVotAlCDZdzz6L8jwY4knOrw=
|
||||
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
|
||||
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.19.0 h1:Dn8rkudDzY6KV9dr/D/bTUuWgqDf9xe0rr4G2elrn0Y=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.19.0/go.mod h1:gMk9F0xDgyN9M/3Ed5Y1wKcx/9mlU91NXY2SNq7RQuU=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0 h1:8UQVDcZxOJLtX6gxtDt3vY2WTgvZqMQRzjsqiIHQdkc=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0/go.mod h1:2lmweYCiHYpEjQ/lSJBYhj9jP1zvCvQW4BqL9dnT7FQ=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 h1:88Y4s2C8oTui1LGM6bTWkw0ICGcOLCAI5l6zsD1j20k=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0/go.mod h1:Vl1/iaggsuRlrHf/hfPJPvVag77kKyvrLeD10kpMl+A=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 h1:RAE+JPfvEmvy+0LzyUA25/SGawPwIUbZ6u0Wug54sLc=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0/go.mod h1:AGmbycVGEsRx9mXMZ75CsOyhSP6MFIcj/6dnG+vhVjk=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.65.0 h1:jOveH/b4lU9HT7y+Gfamf18BqlOuz2PWEvs8yM7Q6XE=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.65.0/go.mod h1:i1P8pcumauPtUI4YNopea1dhzEMuEqWP1xoUZDylLHo=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0 h1:mS47AX77OtFfKG4vtp+84kuGSFZHTyxtXIN269vChY0=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0/go.mod h1:PJnsC41lAGncJlPUniSwM81gc80GkgWJWr3cu2nKEtU=
|
||||
go.opentelemetry.io/otel/log v0.19.0 h1:KUZs/GOsw79TBBMfDWsXS+KZ4g2Ckzksd1ymzsIEbo4=
|
||||
go.opentelemetry.io/otel/log v0.19.0/go.mod h1:5DQYeGmxVIr4n0/BcJvF4upsraHjg6vudJJpnkL6Ipk=
|
||||
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
|
||||
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
|
||||
go.opentelemetry.io/otel/sdk/log v0.19.0 h1:scYVLqT22D2gqXItnWiocLUKGH9yvkkeql5dBDiXyko=
|
||||
go.opentelemetry.io/otel/sdk/log v0.19.0/go.mod h1:vFBowwXGLlW9AvpuF7bMgnNI95LiW10szrOdvzBHlAg=
|
||||
go.opentelemetry.io/otel/sdk/log/logtest v0.19.0 h1:BEbF7ZBB6qQloV/Ub1+3NQoOUnVtcGkU3XX4Ws3GQfk=
|
||||
go.opentelemetry.io/otel/sdk/log/logtest v0.19.0/go.mod h1:Lua81/3yM0wOmoHTokLj9y9ADeA02v1naRrVrkAZuKk=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
|
||||
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
||||
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||
go.yaml.in/yaml/v2 v2.4.4 h1:tuyd0P+2Ont/d6e2rl3be67goVK4R6deVxCUX5vyPaQ=
|
||||
go.yaml.in/yaml/v2 v2.4.4/go.mod h1:gMZqIpDtDqOfM0uNfy0SkpRhvUryYH0Z6wdMYcacYXQ=
|
||||
golang.org/x/arch v0.25.0 h1:qnk6Ksugpi5Bz32947rkUgDt9/s5qvqDPl/gBKdMJLE=
|
||||
golang.org/x/arch v0.25.0/go.mod h1:0X+GdSIP+kL5wPmpK7sdkEVTt2XoYP0cSjQSbZBwOi8=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/lumberjack.v2 v2.0.0 h1:IDj6hi8KbNiPQ5VaYNFZ7dBJLF5LFeKvsFrWHjA5aq4=
|
||||
gopkg.in/lumberjack.v2 v2.0.0/go.mod h1:bp5nQ2kK/lLQSmTk29azj9+JB6bWci56xFn/lvd5GLI=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// claudeCredentialsJSON matches the structure of ~/.claude/.credentials.json.
|
||||
type claudeCredentialsJSON struct {
|
||||
ClaudeAiOauth struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpiresAt int64 `json:"expiresAt"`
|
||||
SubscriptionType string `json:"subscriptionType"`
|
||||
} `json:"claudeAiOauth"`
|
||||
}
|
||||
|
||||
// LoadDefaultCredentials reads credentials from ~/.claude/.credentials.json.
|
||||
// Returns nil, nil if the file does not exist.
|
||||
func LoadDefaultCredentials() ([]*Credential, error) {
|
||||
path, err := DefaultCredentialPath()
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cf claudeCredentialsJSON
|
||||
if err := json.Unmarshal(data, &cf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oauth := cf.ClaudeAiOauth
|
||||
if oauth.AccessToken == "" {
|
||||
return nil, fmt.Errorf("no access token in %s", path)
|
||||
}
|
||||
|
||||
cred := &Credential{
|
||||
ID: "claude-native",
|
||||
Email: oauth.SubscriptionType,
|
||||
AccessToken: oauth.AccessToken,
|
||||
RefreshToken: oauth.RefreshToken,
|
||||
ExpiresAt: time.UnixMilli(oauth.ExpiresAt),
|
||||
FilePath: path,
|
||||
}
|
||||
|
||||
return []*Credential{cred}, nil
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultCredentialPath(t *testing.T) {
|
||||
path, err := DefaultCredentialPath()
|
||||
if err != nil {
|
||||
t.Fatalf("DefaultCredentialPath error: %v", err)
|
||||
}
|
||||
if !strings.HasSuffix(path, filepath.Join(".claude", ".credentials.json")) {
|
||||
t.Errorf("path = %q, want suffix .claude/.credentials.json", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultCredentials_MissingFile(t *testing.T) {
|
||||
// When credential file doesn't exist, returns nil, nil
|
||||
path, err := DefaultCredentialPath()
|
||||
if err != nil {
|
||||
t.Skip("cannot determine home dir")
|
||||
}
|
||||
if _, statErr := os.Stat(path); os.IsNotExist(statErr) {
|
||||
creds, err := LoadDefaultCredentials()
|
||||
if creds != nil {
|
||||
t.Errorf("expected nil creds for missing file, got %v", creds)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("expected nil error for missing file, got %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeCredentialsJSON_ParsesCorrectly(t *testing.T) {
|
||||
jsonData := `{"claudeAiOauth":{"accessToken":"test-token","refreshToken":"test-refresh","expiresAt":1234567890,"subscriptionType":"pro"}}`
|
||||
|
||||
var cf claudeCredentialsJSON
|
||||
if err := json.Unmarshal([]byte(jsonData), &cf); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if cf.ClaudeAiOauth.AccessToken != "test-token" {
|
||||
t.Errorf("AccessToken = %q, want test-token", cf.ClaudeAiOauth.AccessToken)
|
||||
}
|
||||
if cf.ClaudeAiOauth.RefreshToken != "test-refresh" {
|
||||
t.Errorf("RefreshToken = %q, want test-refresh", cf.ClaudeAiOauth.RefreshToken)
|
||||
}
|
||||
if cf.ClaudeAiOauth.ExpiresAt != 1234567890 {
|
||||
t.Errorf("ExpiresAt = %d, want 1234567890", cf.ClaudeAiOauth.ExpiresAt)
|
||||
}
|
||||
if cf.ClaudeAiOauth.SubscriptionType != "pro" {
|
||||
t.Errorf("SubscriptionType = %q, want pro", cf.ClaudeAiOauth.SubscriptionType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeCredentialsJSON_EmptyAccessToken(t *testing.T) {
|
||||
jsonData := `{"claudeAiOauth":{"accessToken":"","refreshToken":"r","expiresAt":1}}`
|
||||
|
||||
var cf claudeCredentialsJSON
|
||||
if err := json.Unmarshal([]byte(jsonData), &cf); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if cf.ClaudeAiOauth.AccessToken != "" {
|
||||
t.Errorf("expected empty access token")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,276 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const (
|
||||
authURL = "https://claude.com/cai/oauth/authorize"
|
||||
manualRedirect = "https://platform.claude.com/oauth/code/callback"
|
||||
)
|
||||
|
||||
func base64URLEncode(data []byte) string {
|
||||
return base64.RawURLEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
func generateCodeVerifier() string {
|
||||
buf := make([]byte, 32)
|
||||
_, _ = rand.Read(buf)
|
||||
return base64URLEncode(buf)
|
||||
}
|
||||
|
||||
func generateCodeChallenge(verifier string) string {
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
return base64URLEncode(h[:])
|
||||
}
|
||||
|
||||
func generateState() string {
|
||||
buf := make([]byte, 32)
|
||||
_, _ = rand.Read(buf)
|
||||
return base64URLEncode(buf)
|
||||
}
|
||||
|
||||
func buildAuthURL(port int, codeChallenge, state string) string {
|
||||
u, _ := url.Parse(authURL)
|
||||
q := u.Query()
|
||||
q.Set("client_id", clientID)
|
||||
q.Set("response_type", "code")
|
||||
q.Set("redirect_uri", fmt.Sprintf("http://localhost:%d/callback", port))
|
||||
q.Set("scope", oauthScopes)
|
||||
q.Set("code_challenge", codeChallenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
q.Set("state", state)
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func buildManualAuthURL(codeChallenge, state string) string {
|
||||
u, _ := url.Parse(authURL)
|
||||
q := u.Query()
|
||||
q.Set("client_id", clientID)
|
||||
q.Set("response_type", "code")
|
||||
q.Set("redirect_uri", manualRedirect)
|
||||
q.Set("scope", oauthScopes)
|
||||
q.Set("code_challenge", codeChallenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
q.Set("state", state)
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func startCallbackServer(expectedState string) (port int, codeChan <-chan string, cleanup func(), err error) {
|
||||
ch := make(chan string, 1)
|
||||
ln, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return 0, nil, nil, err
|
||||
}
|
||||
port = ln.Addr().(*net.TCPAddr).Port
|
||||
srv := &http.Server{}
|
||||
srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/callback" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
if state != expectedState {
|
||||
http.Error(w, "invalid state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if code == "" {
|
||||
http.Error(w, "missing code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case ch <- code:
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprintln(w, "<html><body><h2>Login successful! You can close this tab.</h2></body></html>")
|
||||
default:
|
||||
fmt.Fprintln(w, "<html><body><h2>Already received. You can close this tab.</h2></body></html>")
|
||||
}
|
||||
})
|
||||
go srv.Serve(ln)
|
||||
cleanup = func() {
|
||||
srv.Close()
|
||||
}
|
||||
return port, ch, cleanup, nil
|
||||
}
|
||||
|
||||
// DefaultCredentialPath returns the path to the Claude credentials file.
|
||||
func DefaultCredentialPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".claude", ".credentials.json"), nil
|
||||
}
|
||||
|
||||
// Login performs the full OAuth 2.0 PKCE login flow and returns a Credential.
|
||||
func Login(ctx context.Context) (*Credential, error) {
|
||||
verifier := generateCodeVerifier()
|
||||
challenge := generateCodeChallenge(verifier)
|
||||
state := generateState()
|
||||
|
||||
port, codeChan, cleanup, err := startCallbackServer(state)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start callback server: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
autoURL := buildAuthURL(port, challenge, state)
|
||||
manualURL := buildManualAuthURL(challenge, state)
|
||||
|
||||
fmt.Printf("\nTo sign in, visit:\n %s\n\n", manualURL)
|
||||
openBrowser(autoURL)
|
||||
|
||||
var authCode string
|
||||
var isManual bool
|
||||
|
||||
stdinCh := make(chan string, 1)
|
||||
fi, statErr := os.Stdin.Stat()
|
||||
if statErr == nil && (fi.Mode()&os.ModeCharDevice) != 0 {
|
||||
fmt.Print("If browser didn't open, paste the authorization code here: ")
|
||||
go func() {
|
||||
var line string
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
if scanner.Scan() {
|
||||
line = strings.TrimSpace(scanner.Text())
|
||||
}
|
||||
if line != "" {
|
||||
stdinCh <- line
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
timeout := time.NewTimer(120 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
select {
|
||||
case code := <-codeChan:
|
||||
authCode = code
|
||||
isManual = false
|
||||
case code := <-stdinCh:
|
||||
authCode = code
|
||||
isManual = true
|
||||
case <-timeout.C:
|
||||
return nil, fmt.Errorf("login timed out after 120 seconds")
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
credPath, err := DefaultCredentialPath()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("credential path: %w", err)
|
||||
}
|
||||
return exchangeAuthCode(ctx, authCode, state, verifier, port, isManual, credPath)
|
||||
}
|
||||
|
||||
type authCodeRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
Code string `json:"code"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ClientID string `json:"client_id"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
func exchangeAuthCode(ctx context.Context, code, state, verifier string, port int, isManual bool, credPath string) (*Credential, error) {
|
||||
redirectURI := fmt.Sprintf("http://localhost:%d/callback", port)
|
||||
if isManual {
|
||||
redirectURI = manualRedirect
|
||||
}
|
||||
|
||||
reqBody, _ := json.Marshal(authCodeRequest{
|
||||
GrantType: "authorization_code",
|
||||
Code: code,
|
||||
RedirectURI: redirectURI,
|
||||
ClientID: clientID,
|
||||
CodeVerifier: verifier,
|
||||
State: state,
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenEndpoint, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := utlsClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token exchange: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token exchange returned %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokenResp tokenResponse
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("decode token response: %w", err)
|
||||
}
|
||||
|
||||
cred := &Credential{
|
||||
ID: "claude-native",
|
||||
Email: tokenResp.Account.EmailAddress,
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ExpiresAt: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second),
|
||||
FilePath: credPath,
|
||||
}
|
||||
|
||||
if err := ensureCredentialFile(credPath); err != nil {
|
||||
return nil, fmt.Errorf("ensure credential file: %w", err)
|
||||
}
|
||||
|
||||
if err := persistCredential(cred); err != nil {
|
||||
return nil, fmt.Errorf("save credential: %w", err)
|
||||
}
|
||||
|
||||
log.Info().Str("path", credPath).Msg("login successful, credentials saved")
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
func ensureCredentialFile(path string) error {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return nil
|
||||
}
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, []byte("{}"), 0600)
|
||||
}
|
||||
|
||||
func openBrowser(url string) {
|
||||
var cmd *exec.Cmd
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
cmd = exec.Command("open", url)
|
||||
case "linux":
|
||||
cmd = exec.Command("xdg-open", url)
|
||||
default:
|
||||
return
|
||||
}
|
||||
_ = cmd.Start()
|
||||
}
|
||||
+37
-76
@@ -6,15 +6,14 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
tls "github.com/refraction-networking/utls"
|
||||
"golang.org/x/net/http2"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/transport"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -27,7 +26,7 @@ const (
|
||||
refreshBackoff = 5 * time.Minute
|
||||
)
|
||||
|
||||
var utlsClient = newUTLSClient()
|
||||
var utlsClient = transport.NewHTTPClient(15 * time.Second)
|
||||
|
||||
type tokenRequest struct {
|
||||
ClientID string `json:"client_id"`
|
||||
@@ -63,6 +62,13 @@ func RefreshToken(ctx context.Context, cred *Credential) error {
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
log.Debug().
|
||||
Str("url", tokenEndpoint).
|
||||
Str("grant_type", "refresh_token").
|
||||
Str("client_id", clientID).
|
||||
Str("scope", oauthScopes).
|
||||
Msg("token refresh request")
|
||||
|
||||
resp, err := utlsClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("execute request: %w", err)
|
||||
@@ -70,6 +76,12 @@ func RefreshToken(ctx context.Context, cred *Credential) error {
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
log.Debug().
|
||||
Int("status", resp.StatusCode).
|
||||
Int("response_size", len(body)).
|
||||
Msg("token refresh response")
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("refresh returned %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
@@ -105,13 +117,21 @@ func persistCredential(cred *Credential) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var doc map[string]any
|
||||
raw, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var doc map[string]any
|
||||
if err := json.Unmarshal(raw, &doc); err != nil {
|
||||
return err
|
||||
if !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
// File doesn't exist yet (cold start) — create from scratch
|
||||
if mkdirErr := os.MkdirAll(filepath.Dir(filePath), 0700); mkdirErr != nil {
|
||||
return fmt.Errorf("create credential dir: %w", mkdirErr)
|
||||
}
|
||||
doc = make(map[string]any)
|
||||
} else {
|
||||
if err := json.Unmarshal(raw, &doc); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
oauth, _ := doc["claudeAiOauth"].(map[string]any)
|
||||
if oauth == nil {
|
||||
@@ -125,73 +145,12 @@ func persistCredential(cred *Credential) error {
|
||||
return os.WriteFile(filePath, out, 0600)
|
||||
}
|
||||
|
||||
func newUTLSClient() *http.Client {
|
||||
return &http.Client{
|
||||
Timeout: 15 * time.Second,
|
||||
Transport: &utlsRefreshTransport{},
|
||||
}
|
||||
}
|
||||
|
||||
type utlsRefreshTransport struct {
|
||||
mu sync.Mutex
|
||||
conn *http2.ClientConn
|
||||
host string
|
||||
}
|
||||
|
||||
func (t *utlsRefreshTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
host := req.URL.Hostname()
|
||||
port := req.URL.Port()
|
||||
if port == "" {
|
||||
port = "443"
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
if t.conn != nil && t.host == host && t.conn.CanTakeNewRequest() {
|
||||
conn := t.conn
|
||||
t.mu.Unlock()
|
||||
resp, err := conn.RoundTrip(req)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
t.mu.Lock()
|
||||
t.conn = nil
|
||||
t.mu.Unlock()
|
||||
} else {
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
addr := net.JoinHostPort(host, port)
|
||||
rawConn, err := net.DialTimeout("tcp", addr, 10*time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConn := tls.UClient(rawConn, &tls.Config{ServerName: host}, tls.HelloChrome_Auto)
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
rawConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h2Conn, err := (&http2.Transport{}).NewClientConn(tlsConn)
|
||||
if err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
t.conn = h2Conn
|
||||
t.host = host
|
||||
t.mu.Unlock()
|
||||
|
||||
return h2Conn.RoundTrip(req)
|
||||
}
|
||||
|
||||
func StartBackgroundRefresh(ctx context.Context, pool *Pool) {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("background refresh stopped")
|
||||
log.Info().Msg("background refresh stopped")
|
||||
return
|
||||
case <-time.After(refreshInterval):
|
||||
refreshExpiring(pool)
|
||||
@@ -213,6 +172,7 @@ func refreshExpiring(pool *Pool) {
|
||||
hasRefresh := cred.RefreshToken != ""
|
||||
nextRetry := cred.nextRefreshAfter
|
||||
email := cred.Email
|
||||
expiresAt := cred.ExpiresAt
|
||||
cred.mu.Unlock()
|
||||
|
||||
if !hasRefresh || !needsRefresh {
|
||||
@@ -222,21 +182,22 @@ func refreshExpiring(pool *Pool) {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("refreshing token for %s (expires %s)", email, cred.ExpiresAt.Format(time.RFC3339))
|
||||
log.Info().Str("credential", email).Time("expires_at", expiresAt).Msg("refreshing token")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
err := RefreshToken(ctx, cred)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("refresh failed for %s: %v", email, err)
|
||||
log.Error().Err(err).Str("credential", email).Msg("token refresh failed")
|
||||
cred.mu.Lock()
|
||||
cred.nextRefreshAfter = time.Now().Add(refreshBackoff)
|
||||
cred.mu.Unlock()
|
||||
} else {
|
||||
log.Printf("refreshed %s, new expiry %s", email, cred.ExpiresAt.Format(time.RFC3339))
|
||||
cred.mu.Lock()
|
||||
newExpiresAt := cred.ExpiresAt
|
||||
cred.nextRefreshAfter = time.Time{}
|
||||
cred.mu.Unlock()
|
||||
log.Info().Str("credential", email).Time("new_expiry", newExpiresAt).Msg("token refreshed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,318 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewPool(t *testing.T) {
|
||||
creds := []*Credential{
|
||||
{ID: "a", AccessToken: "tok-a"},
|
||||
{ID: "b", AccessToken: "tok-b"},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
if p == nil {
|
||||
t.Fatal("NewPool returned nil")
|
||||
}
|
||||
if len(p.creds) != 2 {
|
||||
t.Errorf("pool has %d creds, want 2", len(p.creds))
|
||||
}
|
||||
if p.cursor != 0 {
|
||||
t.Errorf("initial cursor = %d, want 0", p.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Pick_EmptyPool(t *testing.T) {
|
||||
p := NewPool(nil)
|
||||
_, err := p.Pick()
|
||||
if err == nil {
|
||||
t.Fatal("expected error from empty pool, got nil")
|
||||
}
|
||||
want := "no credentials available"
|
||||
if err.Error() != want {
|
||||
t.Errorf("error = %q, want %q", err.Error(), want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Pick_SingleCredential(t *testing.T) {
|
||||
cred := &Credential{ID: "only", AccessToken: "tok-only"}
|
||||
p := NewPool([]*Credential{cred})
|
||||
|
||||
got, err := p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() error = %v", err)
|
||||
}
|
||||
if got.ID != "only" {
|
||||
t.Errorf("Pick() returned cred ID %q, want %q", got.ID, "only")
|
||||
}
|
||||
|
||||
// Picking again should return the same credential
|
||||
got2, err := p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("second Pick() error = %v", err)
|
||||
}
|
||||
if got2.ID != "only" {
|
||||
t.Errorf("second Pick() returned cred ID %q, want %q", got2.ID, "only")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Pick_RoundRobin(t *testing.T) {
|
||||
creds := []*Credential{
|
||||
{ID: "a"},
|
||||
{ID: "b"},
|
||||
{ID: "c"},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
|
||||
// Should cycle through a, b, c, a, b, c
|
||||
expected := []string{"a", "b", "c", "a", "b", "c"}
|
||||
for i, want := range expected {
|
||||
got, err := p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||
}
|
||||
if got.ID != want {
|
||||
t.Errorf("Pick() #%d = %q, want %q", i, got.ID, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Pick_SkipsCooldown(t *testing.T) {
|
||||
creds := []*Credential{
|
||||
{ID: "a"},
|
||||
{ID: "b", cooldownUntil: time.Now().Add(1 * time.Hour)},
|
||||
{ID: "c"},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
|
||||
// First pick: "a" (index 0, not on cooldown)
|
||||
got, err := p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #1 error = %v", err)
|
||||
}
|
||||
if got.ID != "a" {
|
||||
t.Errorf("Pick() #1 = %q, want %q", got.ID, "a")
|
||||
}
|
||||
|
||||
// Second pick: cursor at 1, but "b" is on cooldown → skip to "c"
|
||||
got, err = p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #2 error = %v", err)
|
||||
}
|
||||
if got.ID != "c" {
|
||||
t.Errorf("Pick() #2 = %q, want %q", got.ID, "c")
|
||||
}
|
||||
|
||||
// Third pick: cursor advanced past "c" to 0 → "a"
|
||||
got, err = p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #3 error = %v", err)
|
||||
}
|
||||
if got.ID != "a" {
|
||||
t.Errorf("Pick() #3 = %q, want %q", got.ID, "a")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Pick_AllOnCooldown(t *testing.T) {
|
||||
future := time.Now().Add(1 * time.Hour)
|
||||
creds := []*Credential{
|
||||
{ID: "a", cooldownUntil: future},
|
||||
{ID: "b", cooldownUntil: future},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
|
||||
_, err := p.Pick()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when all on cooldown, got nil")
|
||||
}
|
||||
want := "all 2 credentials are on cooldown"
|
||||
if err.Error() != want {
|
||||
t.Errorf("error = %q, want %q", err.Error(), want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_MarkFailure(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
expectCooldown bool
|
||||
expectedDur time.Duration
|
||||
}{
|
||||
{
|
||||
name: "429 sets 30s cooldown",
|
||||
statusCode: 429,
|
||||
expectCooldown: true,
|
||||
expectedDur: 30 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "500 sets 5s cooldown",
|
||||
statusCode: 500,
|
||||
expectCooldown: true,
|
||||
expectedDur: 5 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "502 sets 5s cooldown",
|
||||
statusCode: 502,
|
||||
expectCooldown: true,
|
||||
expectedDur: 5 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "503 sets 5s cooldown",
|
||||
statusCode: 503,
|
||||
expectCooldown: true,
|
||||
expectedDur: 5 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "400 does NOT set cooldown",
|
||||
statusCode: 400,
|
||||
expectCooldown: false,
|
||||
},
|
||||
{
|
||||
name: "401 does NOT set cooldown",
|
||||
statusCode: 401,
|
||||
expectCooldown: false,
|
||||
},
|
||||
{
|
||||
name: "403 does NOT set cooldown",
|
||||
statusCode: 403,
|
||||
expectCooldown: false,
|
||||
},
|
||||
{
|
||||
name: "404 does NOT set cooldown",
|
||||
statusCode: 404,
|
||||
expectCooldown: false,
|
||||
},
|
||||
{
|
||||
name: "422 does NOT set cooldown",
|
||||
statusCode: 422,
|
||||
expectCooldown: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cred := &Credential{ID: "test"}
|
||||
p := NewPool([]*Credential{cred})
|
||||
|
||||
before := time.Now()
|
||||
p.MarkFailure(cred, tt.statusCode)
|
||||
|
||||
if tt.expectCooldown {
|
||||
if !cred.IsOnCooldown() {
|
||||
t.Errorf("expected cooldown after status %d", tt.statusCode)
|
||||
}
|
||||
// Verify approximate duration
|
||||
cred.mu.Lock()
|
||||
cooldownEnd := cred.cooldownUntil
|
||||
cred.mu.Unlock()
|
||||
|
||||
lower := before.Add(tt.expectedDur)
|
||||
upper := time.Now().Add(tt.expectedDur)
|
||||
if cooldownEnd.Before(lower) || cooldownEnd.After(upper) {
|
||||
t.Errorf("cooldownUntil %v not in expected range [%v, %v]", cooldownEnd, lower, upper)
|
||||
}
|
||||
} else {
|
||||
if cred.IsOnCooldown() {
|
||||
t.Errorf("did not expect cooldown after status %d", tt.statusCode)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_MarkSuccess(t *testing.T) {
|
||||
cred := &Credential{
|
||||
ID: "test",
|
||||
cooldownUntil: time.Now().Add(1 * time.Hour),
|
||||
}
|
||||
p := NewPool([]*Credential{cred})
|
||||
|
||||
if !cred.IsOnCooldown() {
|
||||
t.Fatal("precondition: expected credential to be on cooldown")
|
||||
}
|
||||
|
||||
p.MarkSuccess(cred)
|
||||
|
||||
if cred.IsOnCooldown() {
|
||||
t.Error("expected cooldown to be cleared after MarkSuccess")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_RoundRobinCursorAdvancement(t *testing.T) {
|
||||
creds := []*Credential{
|
||||
{ID: "0"},
|
||||
{ID: "1"},
|
||||
{ID: "2"},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
|
||||
// Verify cursor starts at 0
|
||||
if p.cursor != 0 {
|
||||
t.Fatalf("initial cursor = %d, want 0", p.cursor)
|
||||
}
|
||||
|
||||
// Pick cred[0], cursor should advance to 1
|
||||
got, _ := p.Pick()
|
||||
if got.ID != "0" {
|
||||
t.Errorf("first pick = %q, want %q", got.ID, "0")
|
||||
}
|
||||
if p.cursor != 1 {
|
||||
t.Errorf("cursor after first pick = %d, want 1", p.cursor)
|
||||
}
|
||||
|
||||
// Pick cred[1], cursor should advance to 2
|
||||
got, _ = p.Pick()
|
||||
if got.ID != "1" {
|
||||
t.Errorf("second pick = %q, want %q", got.ID, "1")
|
||||
}
|
||||
if p.cursor != 2 {
|
||||
t.Errorf("cursor after second pick = %d, want 2", p.cursor)
|
||||
}
|
||||
|
||||
// Pick cred[2], cursor should wrap to 0
|
||||
got, _ = p.Pick()
|
||||
if got.ID != "2" {
|
||||
t.Errorf("third pick = %q, want %q", got.ID, "2")
|
||||
}
|
||||
if p.cursor != 0 {
|
||||
t.Errorf("cursor after third pick = %d, want 0 (wrap)", p.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_RoundRobinWithCooldownSkip(t *testing.T) {
|
||||
creds := []*Credential{
|
||||
{ID: "0"},
|
||||
{ID: "1", cooldownUntil: time.Now().Add(1 * time.Hour)},
|
||||
{ID: "2"},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
|
||||
// First pick: cred[0]
|
||||
got, _ := p.Pick()
|
||||
if got.ID != "0" {
|
||||
t.Errorf("first pick = %q, want %q", got.ID, "0")
|
||||
}
|
||||
// Cursor should be at 1
|
||||
if p.cursor != 1 {
|
||||
t.Errorf("cursor after first pick = %d, want 1", p.cursor)
|
||||
}
|
||||
|
||||
// Second pick: cursor at 1, but cred[1] on cooldown → skip to cred[2]
|
||||
got, _ = p.Pick()
|
||||
if got.ID != "2" {
|
||||
t.Errorf("second pick = %q, want %q", got.ID, "2")
|
||||
}
|
||||
// Cursor should advance past cred[2] to 0
|
||||
if p.cursor != 0 {
|
||||
t.Errorf("cursor after second pick (skip) = %d, want 0", p.cursor)
|
||||
}
|
||||
|
||||
// Third pick: cursor at 0, cred[0] available
|
||||
got, _ = p.Pick()
|
||||
if got.ID != "0" {
|
||||
t.Errorf("third pick = %q, want %q", got.ID, "0")
|
||||
}
|
||||
if p.cursor != 1 {
|
||||
t.Errorf("cursor after third pick = %d, want 1", p.cursor)
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,7 @@ type Credential struct {
|
||||
RefreshToken string
|
||||
ExpiresAt time.Time
|
||||
FilePath string
|
||||
CooldownUntil time.Time
|
||||
cooldownUntil time.Time
|
||||
nextRefreshAfter time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
@@ -22,21 +22,21 @@ type Credential struct {
|
||||
func (c *Credential) IsOnCooldown() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return time.Now().Before(c.CooldownUntil)
|
||||
return time.Now().Before(c.cooldownUntil)
|
||||
}
|
||||
|
||||
// SetCooldown puts the credential on cooldown for the given duration.
|
||||
func (c *Credential) SetCooldown(duration time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.CooldownUntil = time.Now().Add(duration)
|
||||
c.cooldownUntil = time.Now().Add(duration)
|
||||
}
|
||||
|
||||
// ClearCooldown removes any active cooldown on the credential.
|
||||
func (c *Credential) ClearCooldown() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.CooldownUntil = time.Time{}
|
||||
c.cooldownUntil = time.Time{}
|
||||
}
|
||||
|
||||
// Token returns the current access token.
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCredential_IsOnCooldown(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cooldownUntil time.Time
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "zero time — not on cooldown",
|
||||
cooldownUntil: time.Time{},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "future time — on cooldown",
|
||||
cooldownUntil: time.Now().Add(1 * time.Hour),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "past time — expired cooldown",
|
||||
cooldownUntil: time.Now().Add(-1 * time.Hour),
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Credential{cooldownUntil: tt.cooldownUntil}
|
||||
got := c.IsOnCooldown()
|
||||
if got != tt.want {
|
||||
t.Errorf("IsOnCooldown() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCredential_SetCooldown(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
duration time.Duration
|
||||
}{
|
||||
{name: "30 second cooldown", duration: 30 * time.Second},
|
||||
{name: "5 second cooldown", duration: 5 * time.Second},
|
||||
{name: "1 minute cooldown", duration: 1 * time.Minute},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Credential{}
|
||||
before := time.Now()
|
||||
c.SetCooldown(tt.duration)
|
||||
after := time.Now()
|
||||
|
||||
// cooldownUntil should be between before+duration and after+duration
|
||||
if c.cooldownUntil.Before(before.Add(tt.duration)) {
|
||||
t.Errorf("cooldownUntil %v is before expected lower bound %v", c.cooldownUntil, before.Add(tt.duration))
|
||||
}
|
||||
if c.cooldownUntil.After(after.Add(tt.duration)) {
|
||||
t.Errorf("cooldownUntil %v is after expected upper bound %v", c.cooldownUntil, after.Add(tt.duration))
|
||||
}
|
||||
|
||||
// Should now be on cooldown
|
||||
if !c.IsOnCooldown() {
|
||||
t.Error("expected credential to be on cooldown after SetCooldown")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCredential_ClearCooldown(t *testing.T) {
|
||||
t.Run("clears active cooldown", func(t *testing.T) {
|
||||
c := &Credential{cooldownUntil: time.Now().Add(1 * time.Hour)}
|
||||
if !c.IsOnCooldown() {
|
||||
t.Fatal("precondition: expected credential to be on cooldown")
|
||||
}
|
||||
|
||||
c.ClearCooldown()
|
||||
|
||||
if c.IsOnCooldown() {
|
||||
t.Error("expected credential to not be on cooldown after ClearCooldown")
|
||||
}
|
||||
if !c.cooldownUntil.IsZero() {
|
||||
t.Errorf("expected cooldownUntil to be zero time, got %v", c.cooldownUntil)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("clearing when not on cooldown is no-op", func(t *testing.T) {
|
||||
c := &Credential{}
|
||||
c.ClearCooldown()
|
||||
|
||||
if c.IsOnCooldown() {
|
||||
t.Error("expected credential to not be on cooldown")
|
||||
}
|
||||
if !c.cooldownUntil.IsZero() {
|
||||
t.Errorf("expected cooldownUntil to be zero time, got %v", c.cooldownUntil)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCredential_Token(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{name: "returns access token", token: "sk-ant-abc123"},
|
||||
{name: "empty token", token: ""},
|
||||
{name: "long token", token: "sk-ant-" + string(make([]byte, 200))},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Credential{AccessToken: tt.token}
|
||||
got := c.Token()
|
||||
if got != tt.token {
|
||||
t.Errorf("Token() = %q, want %q", got, tt.token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCredential_ConcurrentAccess(t *testing.T) {
|
||||
c := &Credential{
|
||||
AccessToken: "initial-token",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 50
|
||||
|
||||
// Spawn goroutines that concurrently read and write
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(3)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = c.Token()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
c.SetCooldown(1 * time.Second)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = c.IsOnCooldown()
|
||||
}()
|
||||
}
|
||||
|
||||
// Also mix in ClearCooldown calls
|
||||
for i := 0; i < goroutines/2; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
c.ClearCooldown()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// If we get here without -race detecting issues, mutex is working
|
||||
}
|
||||
+75
-54
@@ -1,21 +1,19 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/auth"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Port int `yaml:"port"`
|
||||
APIKeys []string `yaml:"api_keys"`
|
||||
ClaudeCredentials string `yaml:"claude_credentials"`
|
||||
ClaudeBinary string `yaml:"claude_binary"`
|
||||
Sanitize SanitizeConfig `yaml:"sanitize"`
|
||||
Port int `yaml:"port"`
|
||||
APIKeys []string `yaml:"api_keys"`
|
||||
ClaudeBinary string `yaml:"claude_binary"`
|
||||
Sanitize SanitizeConfig `yaml:"sanitize"`
|
||||
Logging LoggingConfig `yaml:"logging"`
|
||||
Telemetry TelemetryConfig `yaml:"telemetry"`
|
||||
}
|
||||
|
||||
type SanitizeConfig struct {
|
||||
@@ -34,13 +32,36 @@ type ReplaceRule struct {
|
||||
Replace string `yaml:"replace"`
|
||||
}
|
||||
|
||||
type claudeCredentialsJSON struct {
|
||||
ClaudeAiOauth struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpiresAt int64 `json:"expiresAt"`
|
||||
SubscriptionType string `json:"subscriptionType"`
|
||||
} `json:"claudeAiOauth"`
|
||||
type TelemetryConfig struct {
|
||||
Export ExportConfig `yaml:"export"`
|
||||
Embedded EmbeddedConfig `yaml:"embedded"`
|
||||
ServiceName string `yaml:"service_name"`
|
||||
}
|
||||
|
||||
type ExportConfig struct {
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
Insecure bool `yaml:"insecure"`
|
||||
Headers map[string]string `yaml:"headers"`
|
||||
}
|
||||
|
||||
func (e ExportConfig) Enabled() bool { return e.Endpoint != "" }
|
||||
|
||||
type EmbeddedConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Port int `yaml:"port"`
|
||||
PersesBinary string `yaml:"perses_binary"`
|
||||
VMBinary string `yaml:"vm_binary"`
|
||||
VMPort int `yaml:"vm_port"`
|
||||
BinDir string `yaml:"bin_dir"`
|
||||
}
|
||||
|
||||
type LoggingConfig struct {
|
||||
Level string `yaml:"level"`
|
||||
File string `yaml:"file"`
|
||||
MaxSizeMB int `yaml:"max_size_mb"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
MaxAgeDays int `yaml:"max_age_days"`
|
||||
Compress bool `yaml:"compress"`
|
||||
}
|
||||
|
||||
func Load(path string) (*Config, error) {
|
||||
@@ -54,44 +75,44 @@ func Load(path string) (*Config, error) {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
if cfg.Logging.Level == "" {
|
||||
cfg.Logging.Level = "info"
|
||||
}
|
||||
if cfg.Logging.MaxSizeMB == 0 {
|
||||
cfg.Logging.MaxSizeMB = 100
|
||||
}
|
||||
if cfg.Logging.MaxBackups == 0 {
|
||||
cfg.Logging.MaxBackups = 5
|
||||
}
|
||||
if cfg.Logging.MaxAgeDays == 0 {
|
||||
cfg.Logging.MaxAgeDays = 30
|
||||
}
|
||||
|
||||
if cfg.Telemetry.ServiceName == "" {
|
||||
cfg.Telemetry.ServiceName = "anthropic-proxy"
|
||||
}
|
||||
if cfg.Telemetry.Embedded.Port == 0 {
|
||||
cfg.Telemetry.Embedded.Port = 8080
|
||||
}
|
||||
if cfg.Telemetry.Embedded.PersesBinary == "" {
|
||||
cfg.Telemetry.Embedded.PersesBinary = "perses"
|
||||
}
|
||||
if cfg.Telemetry.Embedded.VMBinary == "" {
|
||||
cfg.Telemetry.Embedded.VMBinary = "victoria-metrics"
|
||||
}
|
||||
if cfg.Telemetry.Embedded.VMPort == 0 {
|
||||
cfg.Telemetry.Embedded.VMPort = 8428
|
||||
}
|
||||
|
||||
// Check for deprecated claude_credentials field
|
||||
var rawCfg map[string]interface{}
|
||||
if err := yaml.Unmarshal(data, &rawCfg); err == nil {
|
||||
if _, exists := rawCfg["claude_credentials"]; exists {
|
||||
if val, ok := rawCfg["claude_credentials"].(string); ok && val != "" {
|
||||
return nil, fmt.Errorf("claude_credentials is no longer supported, remove it from config.yaml — the proxy now manages credentials at ~/.claude/.credentials.json")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func LoadCredentials(cfg *Config) ([]*auth.Credential, error) {
|
||||
if cfg.ClaudeCredentials == "" {
|
||||
return nil, fmt.Errorf("claude_credentials not set")
|
||||
}
|
||||
|
||||
cred, err := loadCredentials(cfg.ClaudeCredentials)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []*auth.Credential{cred}, nil
|
||||
}
|
||||
|
||||
func loadCredentials(path string) (*auth.Credential, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cf claudeCredentialsJSON
|
||||
if err := json.Unmarshal(data, &cf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oauth := cf.ClaudeAiOauth
|
||||
if oauth.AccessToken == "" {
|
||||
return nil, fmt.Errorf("no access token in %s", path)
|
||||
}
|
||||
|
||||
return &auth.Credential{
|
||||
ID: "claude-native",
|
||||
Email: oauth.SubscriptionType,
|
||||
AccessToken: oauth.AccessToken,
|
||||
RefreshToken: oauth.RefreshToken,
|
||||
ExpiresAt: time.UnixMilli(oauth.ExpiresAt),
|
||||
FilePath: path,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,270 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoad_AllFields(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
|
||||
yaml := `
|
||||
port: 9090
|
||||
api_keys:
|
||||
- key1
|
||||
- key2
|
||||
claude_binary: /usr/bin/claude
|
||||
sanitize:
|
||||
tools:
|
||||
- from: tool_a
|
||||
to: tool_b
|
||||
system:
|
||||
- match: foo
|
||||
replace: bar
|
||||
body:
|
||||
- match: baz
|
||||
replace: qux
|
||||
logging:
|
||||
level: debug
|
||||
file: /tmp/test.log
|
||||
max_size_mb: 50
|
||||
max_backups: 3
|
||||
max_age_days: 7
|
||||
compress: true
|
||||
telemetry:
|
||||
service_name: my-proxy
|
||||
export:
|
||||
endpoint: http://localhost:4317
|
||||
insecure: true
|
||||
headers:
|
||||
x-token: abc
|
||||
embedded:
|
||||
enabled: true
|
||||
port: 9999
|
||||
perses_binary: /usr/bin/perses
|
||||
vm_binary: /usr/bin/vm
|
||||
vm_port: 9428
|
||||
bin_dir: /opt/bin
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(yaml), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load returned error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Port != 9090 {
|
||||
t.Errorf("Port = %d, want 9090", cfg.Port)
|
||||
}
|
||||
if len(cfg.APIKeys) != 2 || cfg.APIKeys[0] != "key1" || cfg.APIKeys[1] != "key2" {
|
||||
t.Errorf("APIKeys = %v, want [key1 key2]", cfg.APIKeys)
|
||||
}
|
||||
if cfg.ClaudeBinary != "/usr/bin/claude" {
|
||||
t.Errorf("ClaudeBinary = %q, want /usr/bin/claude", cfg.ClaudeBinary)
|
||||
}
|
||||
|
||||
// Sanitize
|
||||
if len(cfg.Sanitize.Tools) != 1 || cfg.Sanitize.Tools[0].From != "tool_a" || cfg.Sanitize.Tools[0].To != "tool_b" {
|
||||
t.Errorf("Sanitize.Tools = %v", cfg.Sanitize.Tools)
|
||||
}
|
||||
if len(cfg.Sanitize.System) != 1 || cfg.Sanitize.System[0].Match != "foo" {
|
||||
t.Errorf("Sanitize.System = %v", cfg.Sanitize.System)
|
||||
}
|
||||
if len(cfg.Sanitize.Body) != 1 || cfg.Sanitize.Body[0].Match != "baz" {
|
||||
t.Errorf("Sanitize.Body = %v", cfg.Sanitize.Body)
|
||||
}
|
||||
|
||||
// Logging
|
||||
if cfg.Logging.Level != "debug" {
|
||||
t.Errorf("Logging.Level = %q, want debug", cfg.Logging.Level)
|
||||
}
|
||||
if cfg.Logging.File != "/tmp/test.log" {
|
||||
t.Errorf("Logging.File = %q", cfg.Logging.File)
|
||||
}
|
||||
if cfg.Logging.MaxSizeMB != 50 {
|
||||
t.Errorf("Logging.MaxSizeMB = %d, want 50", cfg.Logging.MaxSizeMB)
|
||||
}
|
||||
if cfg.Logging.MaxBackups != 3 {
|
||||
t.Errorf("Logging.MaxBackups = %d, want 3", cfg.Logging.MaxBackups)
|
||||
}
|
||||
if cfg.Logging.MaxAgeDays != 7 {
|
||||
t.Errorf("Logging.MaxAgeDays = %d, want 7", cfg.Logging.MaxAgeDays)
|
||||
}
|
||||
if !cfg.Logging.Compress {
|
||||
t.Error("Logging.Compress = false, want true")
|
||||
}
|
||||
|
||||
// Telemetry
|
||||
if cfg.Telemetry.ServiceName != "my-proxy" {
|
||||
t.Errorf("Telemetry.ServiceName = %q, want my-proxy", cfg.Telemetry.ServiceName)
|
||||
}
|
||||
if cfg.Telemetry.Export.Endpoint != "http://localhost:4317" {
|
||||
t.Errorf("Export.Endpoint = %q", cfg.Telemetry.Export.Endpoint)
|
||||
}
|
||||
if !cfg.Telemetry.Export.Insecure {
|
||||
t.Error("Export.Insecure = false, want true")
|
||||
}
|
||||
if !cfg.Telemetry.Export.Enabled() {
|
||||
t.Error("Export.Enabled() = false, want true")
|
||||
}
|
||||
if cfg.Telemetry.Export.Headers["x-token"] != "abc" {
|
||||
t.Errorf("Export.Headers = %v", cfg.Telemetry.Export.Headers)
|
||||
}
|
||||
|
||||
// Embedded
|
||||
if !cfg.Telemetry.Embedded.Enabled {
|
||||
t.Error("Embedded.Enabled = false, want true")
|
||||
}
|
||||
if cfg.Telemetry.Embedded.Port != 9999 {
|
||||
t.Errorf("Embedded.Port = %d, want 9999", cfg.Telemetry.Embedded.Port)
|
||||
}
|
||||
if cfg.Telemetry.Embedded.PersesBinary != "/usr/bin/perses" {
|
||||
t.Errorf("Embedded.PersesBinary = %q", cfg.Telemetry.Embedded.PersesBinary)
|
||||
}
|
||||
if cfg.Telemetry.Embedded.VMBinary != "/usr/bin/vm" {
|
||||
t.Errorf("Embedded.VMBinary = %q", cfg.Telemetry.Embedded.VMBinary)
|
||||
}
|
||||
if cfg.Telemetry.Embedded.VMPort != 9428 {
|
||||
t.Errorf("Embedded.VMPort = %d, want 9428", cfg.Telemetry.Embedded.VMPort)
|
||||
}
|
||||
if cfg.Telemetry.Embedded.BinDir != "/opt/bin" {
|
||||
t.Errorf("Embedded.BinDir = %q", cfg.Telemetry.Embedded.BinDir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_Defaults(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
|
||||
// Minimal YAML — only api_keys
|
||||
if err := os.WriteFile(path, []byte("api_keys:\n - k1\n"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load returned error: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
want interface{}
|
||||
}{
|
||||
{"Port", cfg.Port, 8080},
|
||||
{"Logging.Level", cfg.Logging.Level, "info"},
|
||||
{"Logging.MaxSizeMB", cfg.Logging.MaxSizeMB, 100},
|
||||
{"Logging.MaxBackups", cfg.Logging.MaxBackups, 5},
|
||||
{"Logging.MaxAgeDays", cfg.Logging.MaxAgeDays, 30},
|
||||
{"Telemetry.ServiceName", cfg.Telemetry.ServiceName, "anthropic-proxy"},
|
||||
{"Embedded.Port", cfg.Telemetry.Embedded.Port, 8080},
|
||||
{"Embedded.VMBinary", cfg.Telemetry.Embedded.VMBinary, "victoria-metrics"},
|
||||
{"Embedded.PersesBinary", cfg.Telemetry.Embedded.PersesBinary, "perses"},
|
||||
{"Embedded.VMPort", cfg.Telemetry.Embedded.VMPort, 8428},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.got != tt.want {
|
||||
t.Errorf("got %v, want %v", tt.got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_MissingFile(t *testing.T) {
|
||||
_, err := Load("/nonexistent/path/config.yaml")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "read config") {
|
||||
t.Errorf("error = %q, want it to contain 'read config'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_DeprecatedClaudeCredentials(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
|
||||
yaml := `
|
||||
api_keys:
|
||||
- k1
|
||||
claude_credentials: "/some/path"
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(yaml), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for deprecated claude_credentials, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no longer supported") {
|
||||
t.Errorf("error = %q, want it to contain 'no longer supported'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_EmptyClaudeCredentials(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
|
||||
// Empty string value should NOT trigger the deprecation error
|
||||
yaml := `
|
||||
api_keys:
|
||||
- k1
|
||||
claude_credentials: ""
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(yaml), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("empty claude_credentials should not error: %v", err)
|
||||
}
|
||||
if cfg.Port != 8080 {
|
||||
t.Errorf("Port = %d, want 8080", cfg.Port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_InvalidYAML(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
|
||||
// Truly invalid YAML that causes a parse error
|
||||
if err := os.WriteFile(path, []byte("port:\n - bad\n indent: broken\n"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid YAML, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "parse config") {
|
||||
t.Errorf("error = %q, want it to contain 'parse config'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportConfig_Enabled(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
endpoint string
|
||||
want bool
|
||||
}{
|
||||
{"empty endpoint", "", false},
|
||||
{"set endpoint", "http://localhost:4317", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := ExportConfig{Endpoint: tt.endpoint}
|
||||
if got := e.Enabled(); got != tt.want {
|
||||
t.Errorf("Enabled() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package embedded
|
||||
|
||||
import (
|
||||
"embed"
|
||||
)
|
||||
|
||||
//go:embed dashboard/proxy.json
|
||||
var dashboardFS embed.FS
|
||||
|
||||
func DashboardJSON() ([]byte, error) {
|
||||
return dashboardFS.ReadFile("dashboard/proxy.json")
|
||||
}
|
||||
@@ -0,0 +1,450 @@
|
||||
{
|
||||
"kind": "Dashboard",
|
||||
"metadata": {
|
||||
"name": "proxy",
|
||||
"createdAt": "2026-04-14T19:47:48.013238204Z",
|
||||
"updatedAt": "2026-04-14T19:49:30.874125459Z",
|
||||
"version": 1,
|
||||
"project": "anthropic-proxy"
|
||||
},
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "Anthropic Proxy"
|
||||
},
|
||||
"datasources": {
|
||||
"vm": {
|
||||
"default": true,
|
||||
"plugin": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"spec": {
|
||||
"directUrl": "http://localhost:9428"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"panels": {
|
||||
"latency": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "Latency"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "TimeSeriesChart",
|
||||
"spec": {
|
||||
"legend": {
|
||||
"position": "bottom"
|
||||
},
|
||||
"yAxis": {
|
||||
"format": {
|
||||
"unit": "milliseconds"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "histogram_quantile(0.50, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
|
||||
"seriesNameFormat": "p50"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "histogram_quantile(0.95, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
|
||||
"seriesNameFormat": "p95"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "histogram_quantile(0.99, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
|
||||
"seriesNameFormat": "p99"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"request_rate": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "Request Rate"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "TimeSeriesChart",
|
||||
"spec": {
|
||||
"legend": {
|
||||
"position": "bottom"
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "rate(proxy_request_count_total[5m])",
|
||||
"seriesNameFormat": "req/s"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"token_rate": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "Token Rate"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "TimeSeriesChart",
|
||||
"spec": {
|
||||
"legend": {
|
||||
"position": "bottom"
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "rate(proxy_tokens_input_total[5m]) * 60",
|
||||
"seriesNameFormat": "input/min"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "rate(proxy_tokens_output_total[5m]) * 60",
|
||||
"seriesNameFormat": "output/min"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"tokens_5h": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "5h Tokens"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "StatChart",
|
||||
"spec": {
|
||||
"calculation": "last",
|
||||
"format": {
|
||||
"unit": "decimal"
|
||||
},
|
||||
"sparkline": {}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "increase(proxy_tokens_output_total[3h])"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"tokens_7d": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "7d Tokens"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "StatChart",
|
||||
"spec": {
|
||||
"calculation": "last",
|
||||
"format": {
|
||||
"unit": "decimal"
|
||||
},
|
||||
"sparkline": {}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "increase(proxy_tokens_output_total[9h])"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"util_5h": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "5h Utilization"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "GaugeChart",
|
||||
"spec": {
|
||||
"calculation": "last",
|
||||
"format": {
|
||||
"unit": "percent"
|
||||
},
|
||||
"thresholds": {
|
||||
"steps": [
|
||||
{
|
||||
"color": "green",
|
||||
"value": 0
|
||||
},
|
||||
{
|
||||
"color": "orange",
|
||||
"value": 70
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 90
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "proxy_usage_utilization{window=\"5h\"}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"util_7d": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "7d Utilization"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "GaugeChart",
|
||||
"spec": {
|
||||
"calculation": "last",
|
||||
"format": {
|
||||
"unit": "percent"
|
||||
},
|
||||
"thresholds": {
|
||||
"steps": [
|
||||
{
|
||||
"color": "green",
|
||||
"value": 0
|
||||
},
|
||||
{
|
||||
"color": "orange",
|
||||
"value": 70
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 90
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "proxy_usage_utilization{window=\"7d\"}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"layouts": [
|
||||
{
|
||||
"kind": "Grid",
|
||||
"spec": {
|
||||
"display": {
|
||||
"title": "Utilization"
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": 6,
|
||||
"height": 5,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/util_5h"
|
||||
}
|
||||
},
|
||||
{
|
||||
"x": 6,
|
||||
"y": 0,
|
||||
"width": 6,
|
||||
"height": 5,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/util_7d"
|
||||
}
|
||||
},
|
||||
{
|
||||
"x": 12,
|
||||
"y": 0,
|
||||
"width": 6,
|
||||
"height": 5,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/tokens_5h"
|
||||
}
|
||||
},
|
||||
{
|
||||
"x": 18,
|
||||
"y": 0,
|
||||
"width": 6,
|
||||
"height": 5,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/tokens_7d"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "Grid",
|
||||
"spec": {
|
||||
"display": {
|
||||
"title": "Traffic"
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": 12,
|
||||
"height": 8,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/request_rate"
|
||||
}
|
||||
},
|
||||
{
|
||||
"x": 12,
|
||||
"y": 0,
|
||||
"width": 12,
|
||||
"height": 8,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/latency"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "Grid",
|
||||
"spec": {
|
||||
"display": {
|
||||
"title": "Tokens"
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": 24,
|
||||
"height": 8,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/token_rate"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"duration": "1h",
|
||||
"refreshInterval": "10s"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
package embedded
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const cacheDir = ".cache/anthropic-proxy/bin"
|
||||
|
||||
var downloads = map[string]struct {
|
||||
urlTemplate string
|
||||
version string
|
||||
extractName string
|
||||
}{
|
||||
"victoria-metrics": {
|
||||
urlTemplate: "https://github.com/VictoriaMetrics/VictoriaMetrics/releases/download/v%s/victoria-metrics-%s-v%s.tar.gz",
|
||||
version: "1.118.0",
|
||||
extractName: "victoria-metrics-prod",
|
||||
},
|
||||
"perses": {
|
||||
urlTemplate: "https://github.com/perses/perses/releases/download/v%s/perses_%s_%s_%s.tar.gz",
|
||||
version: "0.53.1",
|
||||
},
|
||||
}
|
||||
|
||||
func ensureBinary(name, configPath, configBinDir string) (string, error) {
|
||||
if configPath != "" {
|
||||
if p, err := exec.LookPath(configPath); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
|
||||
if p, err := exec.LookPath(name); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
binDir := configBinDir
|
||||
if binDir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get home dir: %w", err)
|
||||
}
|
||||
binDir = filepath.Join(home, cacheDir)
|
||||
}
|
||||
cachedPath := filepath.Join(binDir, name)
|
||||
|
||||
if _, err := os.Stat(cachedPath); err == nil {
|
||||
return cachedPath, nil
|
||||
}
|
||||
|
||||
log.Info().Str("binary", name).Msg("downloading binary (first run)")
|
||||
|
||||
if err := os.MkdirAll(binDir, 0o755); err != nil {
|
||||
return "", fmt.Errorf("create cache dir: %w", err)
|
||||
}
|
||||
|
||||
url, err := downloadURL(name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := extractAll(url, binDir); err != nil {
|
||||
return "", fmt.Errorf("download %s: %w", name, err)
|
||||
}
|
||||
|
||||
d := downloads[name]
|
||||
if d.extractName != "" {
|
||||
oldPath := filepath.Join(binDir, d.extractName)
|
||||
if _, err := os.Stat(oldPath); err == nil {
|
||||
os.Rename(oldPath, cachedPath)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := os.Stat(cachedPath); err != nil {
|
||||
return "", fmt.Errorf("binary %s not found after extraction", name)
|
||||
}
|
||||
|
||||
log.Info().Str("binary", name).Str("path", cachedPath).Msg("binary downloaded")
|
||||
return cachedPath, nil
|
||||
}
|
||||
|
||||
func downloadURL(name string) (string, error) {
|
||||
goarch := runtime.GOARCH
|
||||
goos := runtime.GOOS
|
||||
|
||||
d, ok := downloads[name]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unknown binary: %s", name)
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "victoria-metrics":
|
||||
vmOS := fmt.Sprintf("%s-%s", goos, goarch)
|
||||
return fmt.Sprintf(d.urlTemplate, d.version, vmOS, d.version), nil
|
||||
case "perses":
|
||||
return fmt.Sprintf(d.urlTemplate, d.version, d.version, goos, goarch), nil
|
||||
}
|
||||
return "", fmt.Errorf("unknown binary: %s", name)
|
||||
}
|
||||
|
||||
func extractAll(url, destDir string) error {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("download failed: HTTP %d from %s", resp.StatusCode, url)
|
||||
}
|
||||
|
||||
gz, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gzip reader: %w", err)
|
||||
}
|
||||
defer gz.Close()
|
||||
|
||||
tr := tar.NewReader(gz)
|
||||
for {
|
||||
hdr, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("read tar: %w", err)
|
||||
}
|
||||
|
||||
target := filepath.Join(destDir, hdr.Name)
|
||||
switch hdr.Typeflag {
|
||||
case tar.TypeDir:
|
||||
os.MkdirAll(target, 0o755)
|
||||
case tar.TypeReg:
|
||||
os.MkdirAll(filepath.Dir(target), 0o755)
|
||||
mode := os.FileMode(hdr.Mode)
|
||||
if mode == 0 {
|
||||
mode = 0o644
|
||||
}
|
||||
out, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
io.Copy(out, tr)
|
||||
out.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package embedded
|
||||
|
||||
import "github.com/rs/zerolog/log"
|
||||
|
||||
// logWriter bridges subprocess stdout/stderr to zerolog.
|
||||
type logWriter struct {
|
||||
level string
|
||||
component string
|
||||
}
|
||||
|
||||
func (w *logWriter) Write(p []byte) (n int, err error) {
|
||||
msg := string(p)
|
||||
switch w.level {
|
||||
case "error":
|
||||
log.Error().Str("component", w.component).Msg(msg)
|
||||
default:
|
||||
log.Debug().Str("component", w.component).Msg(msg)
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
package embedded
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type Perses struct {
|
||||
cfg config.EmbeddedConfig
|
||||
proxyPort int
|
||||
cmd *exec.Cmd
|
||||
tmpDir string
|
||||
}
|
||||
|
||||
func NewPerses(cfg config.EmbeddedConfig, proxyPort int) *Perses {
|
||||
return &Perses{cfg: cfg, proxyPort: proxyPort}
|
||||
}
|
||||
|
||||
func (p *Perses) Start() error {
|
||||
bin, err := ensureBinary("perses", p.cfg.PersesBinary, p.cfg.BinDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("perses: %w", err)
|
||||
}
|
||||
|
||||
p.tmpDir, err = os.MkdirTemp("", "perses-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp dir: %w", err)
|
||||
}
|
||||
|
||||
if err := p.writeServerConfig(); err != nil {
|
||||
return fmt.Errorf("write server config: %w", err)
|
||||
}
|
||||
if err := p.writeDatasourceProvision(); err != nil {
|
||||
return fmt.Errorf("write datasource provision: %w", err)
|
||||
}
|
||||
if err := p.writeDashboardProvision(); err != nil {
|
||||
return fmt.Errorf("write dashboard provision: %w", err)
|
||||
}
|
||||
|
||||
p.cmd = exec.Command(bin,
|
||||
"--config", filepath.Join(p.tmpDir, "config.yaml"),
|
||||
"-web.listen-address", fmt.Sprintf(":%d", p.cfg.Port),
|
||||
)
|
||||
p.cmd.Dir = filepath.Dir(bin)
|
||||
p.cmd.Stdout = &logWriter{level: "info", component: "perses"}
|
||||
p.cmd.Stderr = &logWriter{level: "error", component: "perses"}
|
||||
|
||||
if err := p.cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start perses: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("binary", bin).
|
||||
Int("port", p.cfg.Port).
|
||||
Str("config", p.tmpDir).
|
||||
Msg("perses started")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Perses) Stop() {
|
||||
if p.cmd != nil && p.cmd.Process != nil {
|
||||
_ = p.cmd.Process.Kill()
|
||||
_ = p.cmd.Wait()
|
||||
}
|
||||
if p.tmpDir != "" {
|
||||
_ = os.RemoveAll(p.tmpDir)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Perses) Running() bool {
|
||||
return p.cmd != nil && p.cmd.Process != nil && p.cmd.ProcessState == nil
|
||||
}
|
||||
|
||||
func (p *Perses) writeServerConfig() error {
|
||||
provisionDir := filepath.Join(p.tmpDir, "provisions")
|
||||
if err := os.MkdirAll(filepath.Join(provisionDir, "datasources"), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Join(provisionDir, "dashboards"), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg := fmt.Sprintf(`provisioning:
|
||||
interval: 1m
|
||||
folders:
|
||||
- %s
|
||||
database:
|
||||
file:
|
||||
folder: %s/data
|
||||
extension: json
|
||||
security:
|
||||
readonly: false
|
||||
enable_auth: false
|
||||
`, provisionDir, p.tmpDir)
|
||||
|
||||
return os.WriteFile(filepath.Join(p.tmpDir, "config.yaml"), []byte(cfg), 0o644)
|
||||
}
|
||||
|
||||
func (p *Perses) writeDatasourceProvision() error {
|
||||
ds := fmt.Sprintf(`kind: Datasource
|
||||
metadata:
|
||||
name: victoria-metrics
|
||||
project: anthropic-proxy
|
||||
spec:
|
||||
default: true
|
||||
plugin:
|
||||
kind: PrometheusDatasource
|
||||
spec:
|
||||
directUrl: http://localhost:%d
|
||||
`, p.cfg.VMPort)
|
||||
|
||||
return os.WriteFile(
|
||||
filepath.Join(p.tmpDir, "provisions", "datasources", "vm.yaml"),
|
||||
[]byte(ds), 0o644,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Perses) writeDashboardProvision() error {
|
||||
dashData, err := DashboardJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(
|
||||
filepath.Join(p.tmpDir, "provisions", "dashboards", "proxy.json"),
|
||||
dashData, 0o644,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package embedded
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type VM struct {
|
||||
cfg config.EmbeddedConfig
|
||||
proxyPort int
|
||||
cmd *exec.Cmd
|
||||
tmpDir string
|
||||
}
|
||||
|
||||
func NewVM(cfg config.EmbeddedConfig, proxyPort int) *VM {
|
||||
return &VM{cfg: cfg, proxyPort: proxyPort}
|
||||
}
|
||||
|
||||
func (v *VM) Start() error {
|
||||
bin, err := ensureBinary("victoria-metrics", v.cfg.VMBinary, v.cfg.BinDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("victoria-metrics: %w", err)
|
||||
}
|
||||
|
||||
v.tmpDir, err = os.MkdirTemp("", "vm-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp dir: %w", err)
|
||||
}
|
||||
|
||||
scrapeConfig := fmt.Sprintf(`global:
|
||||
scrape_interval: 15s
|
||||
scrape_configs:
|
||||
- job_name: anthropic-proxy
|
||||
static_configs:
|
||||
- targets:
|
||||
- localhost:%d
|
||||
`, v.proxyPort)
|
||||
|
||||
scrapePath := filepath.Join(v.tmpDir, "scrape.yaml")
|
||||
if err := os.WriteFile(scrapePath, []byte(scrapeConfig), 0o644); err != nil {
|
||||
return fmt.Errorf("write scrape config: %w", err)
|
||||
}
|
||||
|
||||
dataPath := filepath.Join(v.tmpDir, "data")
|
||||
if err := os.MkdirAll(dataPath, 0o755); err != nil {
|
||||
return fmt.Errorf("create data dir: %w", err)
|
||||
}
|
||||
|
||||
v.cmd = exec.Command(bin,
|
||||
"-storageDataPath", dataPath,
|
||||
"-retentionPeriod", "7d",
|
||||
"-httpListenAddr", fmt.Sprintf(":%d", v.cfg.VMPort),
|
||||
"-promscrape.config", scrapePath,
|
||||
)
|
||||
v.cmd.Stdout = &logWriter{level: "info", component: "victoria-metrics"}
|
||||
v.cmd.Stderr = &logWriter{level: "error", component: "victoria-metrics"}
|
||||
|
||||
if err := v.cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start victoria-metrics: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("binary", bin).
|
||||
Int("port", v.cfg.VMPort).
|
||||
Int("scrape_target_port", v.proxyPort).
|
||||
Msg("victoria-metrics started")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *VM) Stop() {
|
||||
if v.cmd != nil && v.cmd.Process != nil {
|
||||
_ = v.cmd.Process.Kill()
|
||||
_ = v.cmd.Wait()
|
||||
}
|
||||
if v.tmpDir != "" {
|
||||
_ = os.RemoveAll(v.tmpDir)
|
||||
}
|
||||
}
|
||||
|
||||
func (v *VM) Running() bool {
|
||||
return v.cmd != nil && v.cmd.Process != nil && v.cmd.ProcessState == nil
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/lumberjack.v2"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
)
|
||||
|
||||
// Setup initializes the global zerolog logger.
|
||||
// - File set: JSON → lumberjack rotating file
|
||||
// - File empty + TTY: colored ConsoleWriter → stderr
|
||||
// - File empty + not TTY: JSON → stderr (for systemd journal)
|
||||
// Extra writers (e.g., OTLP log bridge) are added via io.MultiWriter so logs
|
||||
// are written to both the primary destination and any extra writers.
|
||||
func Setup(cfg config.LoggingConfig, extraWriters ...io.Writer) zerolog.Logger {
|
||||
// Parse log level
|
||||
level, err := zerolog.ParseLevel(cfg.Level)
|
||||
if err != nil || cfg.Level == "" {
|
||||
level = zerolog.InfoLevel
|
||||
}
|
||||
zerolog.SetGlobalLevel(level)
|
||||
zerolog.TimeFieldFormat = time.RFC3339
|
||||
|
||||
var logger zerolog.Logger
|
||||
|
||||
if cfg.File != "" {
|
||||
// Production: JSON to rotating file
|
||||
jack := &lumberjack.Logger{
|
||||
Filename: cfg.File,
|
||||
MaxSize: cfg.MaxSizeMB,
|
||||
MaxBackups: cfg.MaxBackups,
|
||||
MaxAge: cfg.MaxAgeDays,
|
||||
Compress: cfg.Compress,
|
||||
}
|
||||
var w io.Writer = jack
|
||||
if len(extraWriters) > 0 {
|
||||
w = io.MultiWriter(append([]io.Writer{jack}, extraWriters...)...)
|
||||
}
|
||||
logger = zerolog.New(w).With().Timestamp().Caller().Logger()
|
||||
} else {
|
||||
fi, err := os.Stderr.Stat()
|
||||
isTTY := err == nil && (fi.Mode()&os.ModeCharDevice) != 0
|
||||
if isTTY {
|
||||
// Dev mode: colored console (extra writers get JSON, console gets pretty)
|
||||
cw := zerolog.ConsoleWriter{
|
||||
Out: os.Stderr,
|
||||
TimeFormat: time.RFC3339,
|
||||
}
|
||||
var w io.Writer = cw
|
||||
if len(extraWriters) > 0 {
|
||||
w = io.MultiWriter(append([]io.Writer{cw}, extraWriters...)...)
|
||||
}
|
||||
logger = zerolog.New(w).With().Timestamp().Caller().Logger()
|
||||
} else {
|
||||
// Systemd journal: JSON to stderr
|
||||
var w io.Writer = os.Stderr
|
||||
if len(extraWriters) > 0 {
|
||||
w = io.MultiWriter(append([]io.Writer{os.Stderr}, extraWriters...)...)
|
||||
}
|
||||
logger = zerolog.New(w).With().Timestamp().Caller().Logger()
|
||||
}
|
||||
}
|
||||
|
||||
// Set global
|
||||
log.Logger = logger
|
||||
return logger
|
||||
}
|
||||
|
||||
// GinRequestLogger returns a Gin middleware that logs every request with zerolog.
|
||||
// Logs AFTER the request completes.
|
||||
// Level: Info for 2xx, Warn for 4xx, Error for 5xx
|
||||
func GinRequestLogger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
method := c.Request.Method
|
||||
|
||||
c.Next()
|
||||
|
||||
status := c.Writer.Status()
|
||||
latencyMs := float64(time.Since(start).Microseconds()) / 1000.0
|
||||
clientIP := c.ClientIP()
|
||||
requestID := c.Writer.Header().Get("X-Request-Id")
|
||||
if requestID == "" {
|
||||
requestID = c.GetHeader("x-client-request-id")
|
||||
}
|
||||
|
||||
evt := log.Logger.WithLevel(statusLevel(status)).
|
||||
Str("method", method).
|
||||
Str("path", path).
|
||||
Int("status", status).
|
||||
Float64("latency_ms", latencyMs).
|
||||
Str("client_ip", clientIP)
|
||||
|
||||
if requestID != "" {
|
||||
evt = evt.Str("request_id", requestID)
|
||||
}
|
||||
|
||||
evt.Msg("request")
|
||||
}
|
||||
}
|
||||
|
||||
func statusLevel(status int) zerolog.Level {
|
||||
switch {
|
||||
case status >= 500:
|
||||
return zerolog.ErrorLevel
|
||||
case status >= 400:
|
||||
return zerolog.WarnLevel
|
||||
default:
|
||||
return zerolog.InfoLevel
|
||||
}
|
||||
}
|
||||
|
||||
// FromContext returns the zerolog logger from context, or the global logger.
|
||||
func FromContext(ctx context.Context) *zerolog.Logger {
|
||||
l := zerolog.Ctx(ctx)
|
||||
if l.GetLevel() == zerolog.Disabled {
|
||||
return &log.Logger
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// RedactHeaders serializes HTTP headers to a JSON string,
|
||||
// replacing Authorization and x-api-key values with "***".
|
||||
func RedactHeaders(h http.Header) string {
|
||||
redacted := make(map[string]string, len(h))
|
||||
for k, v := range h {
|
||||
key := strings.ToLower(k)
|
||||
if key == "authorization" || key == "x-api-key" {
|
||||
redacted[k] = "***"
|
||||
} else {
|
||||
redacted[k] = strings.Join(v, ", ")
|
||||
}
|
||||
}
|
||||
b, _ := json.Marshal(redacted)
|
||||
return string(b)
|
||||
}
|
||||
@@ -0,0 +1,232 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
)
|
||||
|
||||
func TestRedactHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
headers http.Header
|
||||
check func(t *testing.T, result string)
|
||||
}{
|
||||
{
|
||||
name: "redacts Authorization",
|
||||
headers: http.Header{
|
||||
"Authorization": []string{"Bearer secret-token"},
|
||||
},
|
||||
check: func(t *testing.T, result string) {
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal([]byte(result), &m); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if m["Authorization"] != "***" {
|
||||
t.Errorf("Authorization = %q, want ***", m["Authorization"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "redacts x-api-key",
|
||||
headers: http.Header{
|
||||
"X-Api-Key": []string{"sk-ant-secret"},
|
||||
},
|
||||
check: func(t *testing.T, result string) {
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal([]byte(result), &m); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if m["X-Api-Key"] != "***" {
|
||||
t.Errorf("X-Api-Key = %q, want ***", m["X-Api-Key"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "preserves other headers",
|
||||
headers: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"Accept": []string{"text/html", "application/json"},
|
||||
},
|
||||
check: func(t *testing.T, result string) {
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal([]byte(result), &m); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if m["Content-Type"] != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want application/json", m["Content-Type"])
|
||||
}
|
||||
if m["Accept"] != "text/html, application/json" {
|
||||
t.Errorf("Accept = %q, want 'text/html, application/json'", m["Accept"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "case-insensitive redaction",
|
||||
headers: http.Header{
|
||||
"authorization": []string{"Bearer token"},
|
||||
"X-API-KEY": []string{"key123"},
|
||||
},
|
||||
check: func(t *testing.T, result string) {
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal([]byte(result), &m); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
// http.Header canonicalizes keys, but RedactHeaders lowercases for comparison
|
||||
for _, v := range m {
|
||||
if v != "***" {
|
||||
t.Errorf("expected all values to be ***, got %q", v)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty headers",
|
||||
headers: http.Header{},
|
||||
check: func(t *testing.T, result string) {
|
||||
if result != "{}" {
|
||||
t.Errorf("result = %q, want {}", result)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mixed sensitive and non-sensitive",
|
||||
headers: http.Header{
|
||||
"Authorization": []string{"Bearer tok"},
|
||||
"X-Api-Key": []string{"key"},
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-Request-Id": []string{"abc123"},
|
||||
},
|
||||
check: func(t *testing.T, result string) {
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal([]byte(result), &m); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if m["Authorization"] != "***" {
|
||||
t.Errorf("Authorization = %q, want ***", m["Authorization"])
|
||||
}
|
||||
if m["X-Api-Key"] != "***" {
|
||||
t.Errorf("X-Api-Key = %q, want ***", m["X-Api-Key"])
|
||||
}
|
||||
if m["Content-Type"] != "application/json" {
|
||||
t.Errorf("Content-Type = %q", m["Content-Type"])
|
||||
}
|
||||
if m["X-Request-Id"] != "abc123" {
|
||||
t.Errorf("X-Request-Id = %q", m["X-Request-Id"])
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := RedactHeaders(tt.headers)
|
||||
// Result should be valid JSON
|
||||
if !json.Valid([]byte(result)) {
|
||||
t.Fatalf("result is not valid JSON: %q", result)
|
||||
}
|
||||
tt.check(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactHeaders_ReturnsJSON(t *testing.T) {
|
||||
h := http.Header{"Foo": []string{"bar"}}
|
||||
result := RedactHeaders(h)
|
||||
if !strings.HasPrefix(result, "{") || !strings.HasSuffix(result, "}") {
|
||||
t.Errorf("result not JSON object: %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
status int
|
||||
want zerolog.Level
|
||||
}{
|
||||
{200, zerolog.InfoLevel},
|
||||
{201, zerolog.InfoLevel},
|
||||
{204, zerolog.InfoLevel},
|
||||
{301, zerolog.InfoLevel},
|
||||
{399, zerolog.InfoLevel},
|
||||
{400, zerolog.WarnLevel},
|
||||
{401, zerolog.WarnLevel},
|
||||
{403, zerolog.WarnLevel},
|
||||
{404, zerolog.WarnLevel},
|
||||
{429, zerolog.WarnLevel},
|
||||
{499, zerolog.WarnLevel},
|
||||
{500, zerolog.ErrorLevel},
|
||||
{502, zerolog.ErrorLevel},
|
||||
{503, zerolog.ErrorLevel},
|
||||
{599, zerolog.ErrorLevel},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := statusLevel(tt.status)
|
||||
if got != tt.want {
|
||||
t.Errorf("statusLevel(%d) = %v, want %v", tt.status, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetup_WithFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
logFile := filepath.Join(dir, "test.log")
|
||||
|
||||
logger := Setup(config.LoggingConfig{
|
||||
Level: "debug",
|
||||
File: logFile,
|
||||
MaxSizeMB: 10,
|
||||
MaxBackups: 1,
|
||||
MaxAgeDays: 1,
|
||||
})
|
||||
|
||||
// Verify logger works (no panic)
|
||||
logger.Info().Msg("test message")
|
||||
}
|
||||
|
||||
func TestSetup_WithoutFile(t *testing.T) {
|
||||
// File empty — should use console or stderr mode depending on TTY
|
||||
logger := Setup(config.LoggingConfig{
|
||||
Level: "warn",
|
||||
})
|
||||
|
||||
// Verify logger works (no panic)
|
||||
logger.Warn().Msg("test warning")
|
||||
}
|
||||
|
||||
func TestSetup_DefaultLevel(t *testing.T) {
|
||||
// Empty level should default to info
|
||||
logger := Setup(config.LoggingConfig{})
|
||||
_ = logger // verify no panic
|
||||
}
|
||||
|
||||
func TestSetup_InvalidLevel(t *testing.T) {
|
||||
// Invalid level should default to info
|
||||
logger := Setup(config.LoggingConfig{Level: "not-a-level"})
|
||||
_ = logger // verify no panic
|
||||
}
|
||||
|
||||
func TestFromContext_NoLogger(t *testing.T) {
|
||||
// Background context has no zerolog logger — should return global
|
||||
ctx := context.Background()
|
||||
l := FromContext(ctx)
|
||||
if l == nil {
|
||||
t.Fatal("FromContext returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromContext_WithLogger(t *testing.T) {
|
||||
logger := zerolog.Nop()
|
||||
ctx := logger.WithContext(context.Background())
|
||||
l := FromContext(ctx)
|
||||
if l == nil {
|
||||
t.Fatal("FromContext returned nil")
|
||||
}
|
||||
}
|
||||
@@ -11,9 +11,13 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// fingerprintSalt is the fixed salt used by Claude Code for billing header
|
||||
// fingerprint computation. Extracted from the Claude Code CLI source.
|
||||
const fingerprintSalt = "59cf53e54c78"
|
||||
|
||||
func computeFingerprint(firstUserMessage string, version string) string {
|
||||
// UTF-16 character indices sampled from the first user message, matching
|
||||
// the Claude Code CLI's fingerprinting algorithm.
|
||||
indices := []int{4, 7, 20}
|
||||
runes := utf16.Encode([]rune(firstUserMessage))
|
||||
var chars string
|
||||
|
||||
@@ -0,0 +1,323 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFingerprintSaltConstant(t *testing.T) {
|
||||
if fingerprintSalt != "59cf53e54c78" {
|
||||
t.Errorf("fingerprintSalt = %q, want %q", fingerprintSalt, "59cf53e54c78")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_Deterministic(t *testing.T) {
|
||||
a := computeFingerprint("hello world test message", "1.0.0")
|
||||
b := computeFingerprint("hello world test message", "1.0.0")
|
||||
if a != b {
|
||||
t.Errorf("fingerprint not deterministic: %q != %q", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_Length(t *testing.T) {
|
||||
fp := computeFingerprint("some message here", "2.0.0")
|
||||
if len(fp) != 3 {
|
||||
t.Errorf("fingerprint length = %d, want 3", len(fp))
|
||||
}
|
||||
// Must be valid hex
|
||||
if _, err := hex.DecodeString(fp + "0"); err != nil { // pad to even length for decode
|
||||
// Check each char is hex individually
|
||||
for _, c := range fp {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||
t.Errorf("fingerprint %q contains non-hex char %c", fp, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_DifferentVersions(t *testing.T) {
|
||||
a := computeFingerprint("same message", "1.0.0")
|
||||
b := computeFingerprint("same message", "2.0.0")
|
||||
if a == b {
|
||||
t.Errorf("different versions should (almost certainly) produce different fingerprints")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_ShortMessage(t *testing.T) {
|
||||
// "hi" has only 2 chars, indices [4,7,20] all out of range → chars = "000"
|
||||
fp := computeFingerprint("hi", "1.0.0")
|
||||
if len(fp) != 3 {
|
||||
t.Errorf("short message fingerprint length = %d, want 3", len(fp))
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_EmptyMessage(t *testing.T) {
|
||||
// Empty → all indices out of range → chars = "000"
|
||||
fp := computeFingerprint("", "1.0.0")
|
||||
if len(fp) != 3 {
|
||||
t.Errorf("empty message fingerprint length = %d, want 3", len(fp))
|
||||
}
|
||||
// Empty and short message with same version should produce same fingerprint
|
||||
// since both result in chars = "000"
|
||||
fpShort := computeFingerprint("hi", "1.0.0")
|
||||
if fp != fpShort {
|
||||
t.Errorf("empty and 'hi' should produce same fingerprint (both use '000'), got %q vs %q", fp, fpShort)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_Unicode(t *testing.T) {
|
||||
// Emoji: 🎉 is U+1F389, encoded as UTF-16 surrogate pair [0xD83C, 0xDF89]
|
||||
// So "abcd🎉fg" in UTF-16 is [a, b, c, d, 0xD83C, 0xDF89, f, g] = 8 uint16 values
|
||||
// indices [4,7,20]: runes[4]=0xD83C, runes[7]='g', runes[20]=out of range
|
||||
fp := computeFingerprint("abcd🎉fg", "1.0.0")
|
||||
if len(fp) != 3 {
|
||||
t.Errorf("unicode fingerprint length = %d, want 3", len(fp))
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_CharExtraction(t *testing.T) {
|
||||
// "Hello, World!" UTF-16: [H,e,l,l,o,',', ,W,o,r,l,d,!]
|
||||
// indices [4,7,20]: runes[4]='o', runes[7]='W', runes[20]=out of range → "0"
|
||||
// So chars should be "oW0"
|
||||
// Verify by comparing to a message where we know the expected extracted chars
|
||||
// Two messages that extract same chars at indices should produce same fingerprint
|
||||
// "xxxxoxxWxxxxxxxxxxxx" → index 4='o', 7='W', 20=out of range → "oW0" (20 chars, index 20 out of range)
|
||||
fp1 := computeFingerprint("Hello, World!", "1.0.0")
|
||||
fp2 := computeFingerprint("xxxxoxxWxxxxxxxxxxxx", "1.0.0")
|
||||
if fp1 != fp2 {
|
||||
t.Errorf("messages with same chars at indices [4,7,20] should produce same fingerprint, got %q vs %q", fp1, fp2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_IndexBoundary(t *testing.T) {
|
||||
// Message with exactly 21 chars → index 20 is valid
|
||||
msg21 := "abcdefghijklmnopqrstu" // 21 chars
|
||||
fp21 := computeFingerprint(msg21, "1.0.0")
|
||||
|
||||
// Message with exactly 20 chars → index 20 is out of range → "0"
|
||||
msg20 := "abcdefghijklmnopqrst" // 20 chars
|
||||
fp20 := computeFingerprint(msg20, "1.0.0")
|
||||
|
||||
// They should differ because index 20 produces different chars
|
||||
if fp21 == fp20 {
|
||||
t.Errorf("boundary test: 21-char and 20-char messages should differ at index 20")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFirstUserMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple string content",
|
||||
body: `{"messages":[{"role":"user","content":"hello world"}]}`,
|
||||
expected: "hello world",
|
||||
},
|
||||
{
|
||||
name: "array content with text block",
|
||||
body: `{"messages":[{"role":"user","content":[{"type":"text","text":"from array"}]}]}`,
|
||||
expected: "from array",
|
||||
},
|
||||
{
|
||||
name: "no user messages",
|
||||
body: `{"messages":[{"role":"assistant","content":"I am assistant"}]}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "assistant only messages",
|
||||
body: `{"messages":[{"role":"assistant","content":"a1"},{"role":"assistant","content":"a2"}]}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "user with non-text block first then text",
|
||||
body: `{"messages":[{"role":"user","content":[{"type":"image","source":"x"},{"type":"text","text":"the text"}]}]}`,
|
||||
expected: "the text",
|
||||
},
|
||||
{
|
||||
name: "user with only non-text blocks",
|
||||
body: `{"messages":[{"role":"user","content":[{"type":"image","source":"x"}]}]}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "no messages field",
|
||||
body: `{"model":"claude-sonnet-4-6"}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "messages not array",
|
||||
body: `{"messages":"not array"}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty messages array",
|
||||
body: `{"messages":[]}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "first user message used even if multiple exist",
|
||||
body: `{"messages":[{"role":"user","content":"first"},{"role":"user","content":"second"}]}`,
|
||||
expected: "first",
|
||||
},
|
||||
{
|
||||
name: "assistant before user",
|
||||
body: `{"messages":[{"role":"assistant","content":"assistant msg"},{"role":"user","content":"user msg"}]}`,
|
||||
expected: "user msg",
|
||||
},
|
||||
{
|
||||
name: "user with array content - first text block used",
|
||||
body: `{"messages":[{"role":"user","content":[{"type":"text","text":"first text"},{"type":"text","text":"second text"}]}]}`,
|
||||
expected: "first text",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractFirstUserMessage([]byte(tt.body))
|
||||
if result != tt.expected {
|
||||
t.Errorf("got %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFirstUserMessage_BreaksAfterFirstUser(t *testing.T) {
|
||||
// The function should break after finding the first user message,
|
||||
// even if it didn't extract text (e.g. user with only image blocks)
|
||||
body := `{"messages":[{"role":"user","content":[{"type":"image","source":"x"}]},{"role":"user","content":"second user"}]}`
|
||||
result := extractFirstUserMessage([]byte(body))
|
||||
// First user has no text blocks, function breaks, returns ""
|
||||
if result != "" {
|
||||
t.Errorf("should return empty when first user has no text, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBillingHeader(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"test message"}]}`)
|
||||
version := "1.2.3"
|
||||
header := buildBillingHeader(body, version)
|
||||
|
||||
// Check format
|
||||
if !strings.HasPrefix(header, "x-anthropic-billing-header: cc_version=1.2.3.") {
|
||||
t.Errorf("header should start with 'x-anthropic-billing-header: cc_version=1.2.3.', got %q", header)
|
||||
}
|
||||
if !strings.Contains(header, "; cc_entrypoint=cli; cch=00000;") {
|
||||
t.Errorf("header should contain '; cc_entrypoint=cli; cch=00000;', got %q", header)
|
||||
}
|
||||
|
||||
// Verify the fingerprint part is 3 chars
|
||||
// Format: "x-anthropic-billing-header: cc_version=1.2.3.XXX; cc_entrypoint=cli; cch=00000;"
|
||||
parts := strings.Split(header, "cc_version=")
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("unexpected header format: %q", header)
|
||||
}
|
||||
versionFP := strings.Split(parts[1], ";")[0]
|
||||
if !strings.HasPrefix(versionFP, "1.2.3.") {
|
||||
t.Errorf("version+fingerprint should start with '1.2.3.', got %q", versionFP)
|
||||
}
|
||||
fp := strings.TrimPrefix(versionFP, "1.2.3.")
|
||||
if len(fp) != 3 {
|
||||
t.Errorf("fingerprint should be 3 chars, got %q (len %d)", fp, len(fp))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBillingHeader_EmptyMessages(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
version := "1.0.0"
|
||||
header := buildBillingHeader(body, version)
|
||||
if !strings.HasPrefix(header, "x-anthropic-billing-header: cc_version=") {
|
||||
t.Errorf("header format wrong: %q", header)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectBillingHeader_NoExistingSystem(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := injectBillingHeader(body, "1.0.0")
|
||||
resultStr := string(result)
|
||||
|
||||
// Should have system field now
|
||||
if !strings.Contains(resultStr, `"system"`) {
|
||||
t.Errorf("should inject system field, got %s", resultStr)
|
||||
}
|
||||
// System should be an array with one billing block
|
||||
if !strings.Contains(resultStr, "x-anthropic-billing-header") {
|
||||
t.Errorf("should contain billing header text, got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, `"type":"text"`) {
|
||||
t.Errorf("billing block should have type text, got %s", resultStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectBillingHeader_ExistingSystemArray(t *testing.T) {
|
||||
body := []byte(`{"system":[{"type":"text","text":"existing prompt"}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := injectBillingHeader(body, "1.0.0")
|
||||
resultStr := string(result)
|
||||
|
||||
// Should contain both billing header and existing prompt
|
||||
if !strings.Contains(resultStr, "x-anthropic-billing-header") {
|
||||
t.Errorf("should contain billing header, got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, "existing prompt") {
|
||||
t.Errorf("should preserve existing prompt, got %s", resultStr)
|
||||
}
|
||||
|
||||
// Billing block should be FIRST (prepended)
|
||||
billingIdx := strings.Index(resultStr, "x-anthropic-billing-header")
|
||||
existingIdx := strings.Index(resultStr, "existing prompt")
|
||||
if billingIdx > existingIdx {
|
||||
t.Errorf("billing block should come before existing prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectBillingHeader_ExistingSystemString(t *testing.T) {
|
||||
body := []byte(`{"system":"You are a helpful assistant","messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := injectBillingHeader(body, "1.0.0")
|
||||
resultStr := string(result)
|
||||
|
||||
// Should convert to array with billing block first, then original text
|
||||
if !strings.Contains(resultStr, "x-anthropic-billing-header") {
|
||||
t.Errorf("should contain billing header, got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, "You are a helpful assistant") {
|
||||
t.Errorf("should preserve original system string, got %s", resultStr)
|
||||
}
|
||||
|
||||
// Billing should come first
|
||||
billingIdx := strings.Index(resultStr, "x-anthropic-billing-header")
|
||||
origIdx := strings.Index(resultStr, "You are a helpful assistant")
|
||||
if billingIdx > origIdx {
|
||||
t.Errorf("billing block should come before original system text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectBillingHeader_PreservesOtherFields(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`)
|
||||
result := injectBillingHeader(body, "1.0.0")
|
||||
resultStr := string(result)
|
||||
|
||||
if !strings.Contains(resultStr, `"model":"claude-sonnet-4-6"`) {
|
||||
t.Errorf("should preserve model field, got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, `"max_tokens":1024`) {
|
||||
t.Errorf("should preserve max_tokens field, got %s", resultStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectBillingHeader_BillingBlockFormat(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"test"}]}`)
|
||||
result := injectBillingHeader(body, "2.5.0")
|
||||
resultStr := string(result)
|
||||
|
||||
// Verify the billing block contains the correct version
|
||||
if !strings.Contains(resultStr, "cc_version=2.5.0.") {
|
||||
t.Errorf("billing block should contain cc_version=2.5.0., got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, "cc_entrypoint=cli") {
|
||||
t.Errorf("billing block should contain cc_entrypoint=cli, got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, "cch=00000") {
|
||||
t.Errorf("billing block should contain cch=00000, got %s", resultStr)
|
||||
}
|
||||
}
|
||||
+184
-15
@@ -2,17 +2,33 @@ package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/auth"
|
||||
"github.com/fujin/anthropic-proxy/internal/logging"
|
||||
"github.com/fujin/anthropic-proxy/internal/ratelimit"
|
||||
"github.com/fujin/anthropic-proxy/internal/telemetry"
|
||||
)
|
||||
|
||||
func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func() *Sanitizer) gin.HandlerFunc {
|
||||
// requestInfo bundles common request context passed to logging/telemetry helpers.
|
||||
type requestInfo struct {
|
||||
model string
|
||||
stream bool
|
||||
cred *auth.Credential
|
||||
body []byte
|
||||
originalBody []byte
|
||||
}
|
||||
|
||||
func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func() *Sanitizer, tracker *ratelimit.Tracker) gin.HandlerFunc {
|
||||
upstream := NewUpstreamClient(profile)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
@@ -22,7 +38,15 @@ func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func(
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("incoming: %s %s (%d bytes) model=%s", c.Request.Method, c.Request.URL.Path, len(body), gjson.GetBytes(body, "model").String())
|
||||
originalBody := make([]byte, len(body))
|
||||
copy(originalBody, body)
|
||||
|
||||
log.Info().
|
||||
Str("method", c.Request.Method).
|
||||
Str("path", c.Request.URL.Path).
|
||||
Int("body_size", len(body)).
|
||||
Str("model", gjson.GetBytes(body, "model").String()).
|
||||
Msg("incoming request")
|
||||
|
||||
san := getSanitizer()
|
||||
body = san.SanitizeRequest(body)
|
||||
@@ -36,27 +60,56 @@ func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func(
|
||||
isStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
if isStream {
|
||||
handleStream(c, upstream, san, pool, cred, body)
|
||||
handleStream(c, upstream, san, pool, cred, body, originalBody, tracker)
|
||||
} else {
|
||||
handleNonStream(c, upstream, san, pool, cred, body)
|
||||
handleNonStream(c, upstream, san, pool, cred, body, originalBody, tracker)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte) {
|
||||
respBody, headers, statusCode, err := upstream.Execute(c.Request.Context(), cred, body)
|
||||
func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte, originalBody []byte, tracker *ratelimit.Tracker) {
|
||||
startTime := time.Now()
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
ctx := c.Request.Context()
|
||||
ri := requestInfo{model: model, stream: false, cred: cred, body: body, originalBody: originalBody}
|
||||
|
||||
telemetry.RequestBodySize.Record(ctx, int64(len(body)),
|
||||
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", false)))
|
||||
|
||||
respBody, headers, statusCode, err := upstream.Execute(ctx, cred, body)
|
||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||
|
||||
if err != nil {
|
||||
log.Printf("upstream error for %s: %v", cred.Email, err)
|
||||
recordConnectionError(ctx, err, ri, latencyMs)
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "upstream request failed"})
|
||||
return
|
||||
}
|
||||
|
||||
recordRequestMetrics(ctx, ri, statusCode, latencyMs)
|
||||
|
||||
if statusCode >= 400 {
|
||||
pool.MarkFailure(cred, statusCode)
|
||||
log.Printf("upstream %d for %s: %s", statusCode, cred.Email, string(respBody))
|
||||
telemetry.CredentialCooldowns.Add(ctx, 1,
|
||||
metric.WithAttributes(attribute.Int("status_code", statusCode)))
|
||||
recordUpstreamError(ctx, statusCode, respBody, headers.Get("X-Request-Id"), latencyMs, ri, c.Request.Header)
|
||||
} else {
|
||||
pool.MarkSuccess(cred)
|
||||
respBody = san.DesanitizeResponse(respBody)
|
||||
|
||||
inputTokens := gjson.GetBytes(respBody, "usage.input_tokens").Int()
|
||||
outputTokens := gjson.GetBytes(respBody, "usage.output_tokens").Int()
|
||||
recordTokenUsage(ctx, model, cred, inputTokens, outputTokens)
|
||||
if tracker != nil {
|
||||
tracker.UpdateFromHeaders(headers)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Int("status", statusCode).
|
||||
Float64("latency_ms", latencyMs).
|
||||
Str("model", model).
|
||||
Int64("input_tokens", inputTokens).
|
||||
Int64("output_tokens", outputTokens).
|
||||
Msg("request completed")
|
||||
}
|
||||
|
||||
for _, h := range []string{"Content-Type", "X-Request-Id"} {
|
||||
@@ -68,10 +121,20 @@ func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, p
|
||||
c.Data(statusCode, headers.Get("Content-Type"), respBody)
|
||||
}
|
||||
|
||||
func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte) {
|
||||
resp, err := upstream.ExecuteStream(c.Request.Context(), cred, body)
|
||||
func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte, originalBody []byte, tracker *ratelimit.Tracker) {
|
||||
startTime := time.Now()
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
ctx := c.Request.Context()
|
||||
ri := requestInfo{model: model, stream: true, cred: cred, body: body, originalBody: originalBody}
|
||||
|
||||
telemetry.StreamRequests.Add(ctx, 1, metric.WithAttributes(attribute.String("model", model)))
|
||||
telemetry.RequestBodySize.Record(ctx, int64(len(body)),
|
||||
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", true)))
|
||||
|
||||
resp, err := upstream.ExecuteStream(ctx, cred, body)
|
||||
if err != nil {
|
||||
log.Printf("upstream stream error for %s: %v", cred.Email, err)
|
||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||
recordConnectionError(ctx, err, ri, latencyMs)
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "upstream stream request failed"})
|
||||
return
|
||||
}
|
||||
@@ -79,8 +142,13 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
pool.MarkFailure(cred, resp.StatusCode)
|
||||
telemetry.CredentialCooldowns.Add(ctx, 1,
|
||||
metric.WithAttributes(attribute.Int("status_code", resp.StatusCode)))
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
log.Printf("upstream stream %d for %s: %s", resp.StatusCode, cred.Email, string(respBody))
|
||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||
recordRequestMetrics(ctx, ri, resp.StatusCode, latencyMs)
|
||||
recordUpstreamError(ctx, resp.StatusCode, respBody, resp.Header.Get("X-Request-Id"), latencyMs, ri, c.Request.Header)
|
||||
|
||||
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody)
|
||||
return
|
||||
}
|
||||
@@ -94,20 +162,121 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool
|
||||
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
log.Printf("response writer does not support flushing")
|
||||
log.Error().Msg("response writer does not support flushing")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "streaming not supported"})
|
||||
return
|
||||
}
|
||||
|
||||
var inputTokens, outputTokens int64
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := san.DesanitizeStreamEvent(scanner.Text())
|
||||
c.Writer.WriteString(line + "\n")
|
||||
flusher.Flush()
|
||||
|
||||
if len(line) > 5 && line[:5] == "data:" {
|
||||
data := line[5:]
|
||||
eventType := gjson.Get(data, "type").String()
|
||||
switch eventType {
|
||||
case "message_start":
|
||||
inputTokens = gjson.Get(data, "message.usage.input_tokens").Int()
|
||||
case "message_delta":
|
||||
outputTokens = gjson.Get(data, "usage.output_tokens").Int()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||
recordRequestMetrics(ctx, ri, http.StatusOK, latencyMs)
|
||||
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
recordTokenUsage(ctx, model, cred, inputTokens, outputTokens)
|
||||
if tracker != nil {
|
||||
tracker.UpdateFromHeaders(resp.Header)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Float64("latency_ms", latencyMs).
|
||||
Str("model", model).
|
||||
Bool("stream", true).
|
||||
Int64("input_tokens", inputTokens).
|
||||
Int64("output_tokens", outputTokens).
|
||||
Msg("stream completed")
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Printf("stream scan error: %v", err)
|
||||
log.Error().Err(err).Msg("stream scan error")
|
||||
}
|
||||
}
|
||||
|
||||
// recordConnectionError logs and records metrics for upstream connection failures.
|
||||
func recordConnectionError(ctx context.Context, err error, ri requestInfo, latencyMs float64) {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("credential", ri.cred.Email).
|
||||
Str("model", ri.model).
|
||||
Bool("stream", ri.stream).
|
||||
Str("request_body_original", string(ri.originalBody)).
|
||||
Str("request_body_sanitized", string(ri.body)).
|
||||
Int("request_body_size", len(ri.body)).
|
||||
Float64("latency_ms", latencyMs).
|
||||
Msg("upstream connection error")
|
||||
|
||||
telemetry.UpstreamErrors.Add(ctx, 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("error_type", "connection"),
|
||||
attribute.String("credential", ri.cred.Email),
|
||||
attribute.Int("status_code", http.StatusBadGateway),
|
||||
))
|
||||
recordRequestMetrics(ctx, ri, http.StatusBadGateway, latencyMs)
|
||||
}
|
||||
|
||||
// recordUpstreamError logs and records metrics for upstream HTTP error responses.
|
||||
func recordUpstreamError(ctx context.Context, statusCode int, respBody []byte, requestID string, latencyMs float64, ri requestInfo, requestHeaders http.Header) {
|
||||
errorType := gjson.GetBytes(respBody, "error.type").String()
|
||||
errorMessage := gjson.GetBytes(respBody, "error.message").String()
|
||||
log.Error().
|
||||
Int("status", statusCode).
|
||||
Str("error_type", errorType).
|
||||
Str("error_message", errorMessage).
|
||||
Str("response_body", string(respBody)).
|
||||
Str("request_id", requestID).
|
||||
Float64("latency_ms", latencyMs).
|
||||
Str("credential", ri.cred.Email).
|
||||
Str("model", ri.model).
|
||||
Bool("stream", ri.stream).
|
||||
Str("request_body_original", string(ri.originalBody)).
|
||||
Str("request_body_sanitized", string(ri.body)).
|
||||
Int("request_body_size", len(ri.body)).
|
||||
Str("request_headers", logging.RedactHeaders(requestHeaders)).
|
||||
Msg("upstream error")
|
||||
|
||||
telemetry.UpstreamErrors.Add(ctx, 1,
|
||||
metric.WithAttributes(
|
||||
attribute.Int("status_code", statusCode),
|
||||
attribute.String("error_type", errorType),
|
||||
attribute.String("credential", ri.cred.Email),
|
||||
))
|
||||
}
|
||||
|
||||
// recordRequestMetrics records the request counter and duration histogram.
|
||||
func recordRequestMetrics(ctx context.Context, ri requestInfo, statusCode int, latencyMs float64) {
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("model", ri.model),
|
||||
attribute.Bool("stream", ri.stream),
|
||||
attribute.Int("status_code", statusCode),
|
||||
}
|
||||
telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...))
|
||||
telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...))
|
||||
}
|
||||
|
||||
// recordTokenUsage records token consumption metrics.
|
||||
func recordTokenUsage(ctx context.Context, model string, cred *auth.Credential, inputTokens, outputTokens int64) {
|
||||
tokenAttrs := metric.WithAttributes(
|
||||
attribute.String("model", model),
|
||||
attribute.String("credential", cred.Email),
|
||||
)
|
||||
telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs)
|
||||
telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,624 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/auth"
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
"github.com/fujin/anthropic-proxy/internal/ratelimit"
|
||||
"github.com/fujin/anthropic-proxy/internal/telemetry"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
// Initialize telemetry with noop meter to avoid nil pointer panics.
|
||||
meter := noop.Meter{}
|
||||
telemetry.InitMetrics(meter, nil)
|
||||
}
|
||||
|
||||
// --- Request body reading and sanitization ---
|
||||
|
||||
func TestHandleMessages_ReadBodyError(t *testing.T) {
|
||||
// A body that immediately fails on read shouldn't panic.
|
||||
pool := auth.NewPool([]*auth.Credential{{ID: "c1", AccessToken: "tok", Email: "test@test.com"}})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", &errReader{})
|
||||
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "failed to read request body") {
|
||||
t.Errorf("body = %q, expected error message about reading body", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleMessages_SanitizesRequestBody(t *testing.T) {
|
||||
// We can't directly make HandleMessages use our mock server because
|
||||
// UpstreamClient hardcodes messagesURL. Instead, we test sanitization
|
||||
// by verifying the sanitizer is called on the body before any pool interaction.
|
||||
san := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "my_tool", To: "renamed_tool"}},
|
||||
Body: []config.ReplaceRule{{Match: "secret", Replace: "redacted"}},
|
||||
})
|
||||
|
||||
// Create body with tool name and secret
|
||||
body := `{"model":"claude-sonnet-4-6","tools":[{"name":"my_tool"}],"messages":[{"role":"user","content":"secret data"}]}`
|
||||
sanitizedBody := san.SanitizeRequest([]byte(body))
|
||||
|
||||
// Verify sanitization happened correctly
|
||||
if !strings.Contains(string(sanitizedBody), "renamed_tool") {
|
||||
t.Error("expected tool to be renamed in sanitized body")
|
||||
}
|
||||
if strings.Contains(string(sanitizedBody), "my_tool") {
|
||||
t.Error("original tool name should be gone after sanitization")
|
||||
}
|
||||
if !strings.Contains(string(sanitizedBody), "redacted") {
|
||||
t.Error("expected 'secret' to be replaced with 'redacted'")
|
||||
}
|
||||
if strings.Contains(string(sanitizedBody), "secret") {
|
||||
t.Error("'secret' should be gone after sanitization")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleMessages_PoolPickError(t *testing.T) {
|
||||
// Empty pool — Pick() will fail.
|
||||
pool := auth.NewPool(nil)
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
|
||||
|
||||
body := `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "no credentials available") {
|
||||
t.Errorf("body = %q, expected pool error", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleMessages_PoolAllOnCooldown(t *testing.T) {
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "e"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
pool.MarkFailure(cred, 429) // puts on 30s cooldown
|
||||
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
|
||||
|
||||
body := `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "cooldown") {
|
||||
t.Errorf("body = %q, expected cooldown message", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// --- Stream vs non-stream routing ---
|
||||
|
||||
func TestHandleMessages_StreamField_Detection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
isStream bool
|
||||
}{
|
||||
{
|
||||
name: "stream true",
|
||||
body: `{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`,
|
||||
isStream: true,
|
||||
},
|
||||
{
|
||||
name: "stream false",
|
||||
body: `{"model":"claude-sonnet-4-6","stream":false,"messages":[{"role":"user","content":"hi"}]}`,
|
||||
isStream: false,
|
||||
},
|
||||
{
|
||||
name: "no stream field defaults to false",
|
||||
body: `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`,
|
||||
isStream: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := gjson.Get(tt.body, "stream").Bool()
|
||||
if got != tt.isStream {
|
||||
t.Errorf("stream = %v, want %v", got, tt.isStream)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Desanitization on response ---
|
||||
|
||||
func TestDesanitization_NonStreamResponse(t *testing.T) {
|
||||
san := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
|
||||
})
|
||||
|
||||
// Simulate upstream response with renamed tool
|
||||
upstreamResponse := `{"content":[{"type":"tool_use","name":"renamed_tool","id":"t1","input":{}}]}`
|
||||
desanitized := san.DesanitizeResponse([]byte(upstreamResponse))
|
||||
|
||||
if !strings.Contains(string(desanitized), "original_tool") {
|
||||
t.Errorf("expected tool name to be desanitized back to 'original_tool', got %s", string(desanitized))
|
||||
}
|
||||
if strings.Contains(string(desanitized), `"name":"renamed_tool"`) {
|
||||
t.Error("renamed_tool should have been replaced by original_tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDesanitization_StreamEvent(t *testing.T) {
|
||||
san := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
|
||||
})
|
||||
|
||||
event := `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed_tool","id":"t1"}}`
|
||||
desanitized := san.DesanitizeStreamEvent(event)
|
||||
|
||||
if !strings.Contains(desanitized, "original_tool") {
|
||||
t.Errorf("expected stream event to be desanitized, got %s", desanitized)
|
||||
}
|
||||
}
|
||||
|
||||
// --- handleNonStream behavior tests via direct function ---
|
||||
|
||||
func TestHandleNonStream_ConnectionError(t *testing.T) {
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
tracker := ratelimit.NewTracker(func() string { return "" })
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: http.Client{Transport: &failingTransport{}},
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleNonStream(c, uc, san, pool, cred, body, body, tracker)
|
||||
|
||||
if w.Code != http.StatusBadGateway {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusBadGateway)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "upstream request failed") {
|
||||
t.Errorf("body = %q, expected upstream error message", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStream_UpstreamSuccess(t *testing.T) {
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("X-Request-Id", "req-123")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"text","text":"hello"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
tracker := ratelimit.NewTracker(func() string { return "" })
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
// Override the messagesURL by constructing a custom Execute that uses the mock.
|
||||
// Since we can't override the const, we test via a mock server approach:
|
||||
// We create a custom http.Client with a transport that redirects to our mock.
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleNonStream(c, uc, san, pool, cred, body, body, tracker)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String())
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "hello") {
|
||||
t.Errorf("response body missing expected content: %s", w.Body.String())
|
||||
}
|
||||
if got := w.Header().Get("Content-Type"); got != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want %q", got, "application/json")
|
||||
}
|
||||
if got := w.Header().Get("X-Request-Id"); got != "req-123" {
|
||||
t.Errorf("X-Request-Id = %q, want %q", got, "req-123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStream_UpstreamError_MarkFailure(t *testing.T) {
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(429)
|
||||
w.Write([]byte(`{"error":{"type":"rate_limit_error","message":"too many requests"}}`))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleNonStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != 429 {
|
||||
t.Errorf("status = %d, want 429", w.Code)
|
||||
}
|
||||
|
||||
// Verify MarkFailure was called — cred should now be on cooldown
|
||||
if !cred.IsOnCooldown() {
|
||||
t.Error("expected credential to be on cooldown after 429")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStream_UpstreamSuccess_DesanitizesResponse(t *testing.T) {
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"tool_use","name":"renamed_tool","id":"t1","input":{}}],"model":"claude-sonnet-4-6","stop_reason":"tool_use","usage":{"input_tokens":10,"output_tokens":5}}`))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
|
||||
})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleNonStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String())
|
||||
}
|
||||
// Should be desanitized back to original_tool
|
||||
if !strings.Contains(w.Body.String(), "original_tool") {
|
||||
t.Errorf("response should contain desanitized tool name 'original_tool', got %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStream_Upstream500_MarkFailure(t *testing.T) {
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(`{"error":{"type":"server_error","message":"internal error"}}`))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleNonStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != 500 {
|
||||
t.Errorf("status = %d, want 500", w.Code)
|
||||
}
|
||||
if !cred.IsOnCooldown() {
|
||||
t.Error("expected credential to be on cooldown after 500")
|
||||
}
|
||||
}
|
||||
|
||||
// --- handleStream behavior tests ---
|
||||
|
||||
func TestHandleStream_ConnectionError(t *testing.T) {
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: http.Client{Transport: &failingTransport{}},
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != http.StatusBadGateway {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusBadGateway)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "upstream stream request failed") {
|
||||
t.Errorf("body = %q, expected upstream stream error", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStream_UpstreamError(t *testing.T) {
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(429)
|
||||
w.Write([]byte(`{"error":{"type":"rate_limit_error","message":"rate limited"}}`))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != 429 {
|
||||
t.Errorf("status = %d, want 429", w.Code)
|
||||
}
|
||||
if !cred.IsOnCooldown() {
|
||||
t.Error("expected credential on cooldown after stream 429")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStream_SuccessForwardsEvents(t *testing.T) {
|
||||
events := "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":0}}}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":5}}\n\nevent: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
|
||||
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(events))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
respBody := w.Body.String()
|
||||
if !strings.Contains(respBody, "message_start") {
|
||||
t.Error("response missing message_start event")
|
||||
}
|
||||
if !strings.Contains(respBody, "hello") {
|
||||
t.Error("response missing text content 'hello'")
|
||||
}
|
||||
if !strings.Contains(respBody, "message_stop") {
|
||||
t.Error("response missing message_stop event")
|
||||
}
|
||||
|
||||
// Verify SSE headers
|
||||
if got := w.Header().Get("Content-Type"); got != "text/event-stream" {
|
||||
t.Errorf("Content-Type = %q, want %q", got, "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStream_DesanitizesEvents(t *testing.T) {
|
||||
events := "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"name\":\"renamed_tool\",\"id\":\"t1\"}}\n\nevent: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"},\"usage\":{\"output_tokens\":5}}\n\nevent: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
|
||||
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(events))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
|
||||
})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String())
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "original_tool") {
|
||||
t.Errorf("stream response should contain desanitized 'original_tool', got %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// --- HandleMessages full integration wiring test ---
|
||||
|
||||
func TestHandleMessages_WiresHandlerCorrectly(t *testing.T) {
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
// Verify the handler can be created without panic
|
||||
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
|
||||
if handler == nil {
|
||||
t.Fatal("HandleMessages returned nil handler")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleMessages_EmptyBody(t *testing.T) {
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
|
||||
|
||||
// Empty body — handler should still try to pick cred and call upstream
|
||||
// (which will fail with connection error to api.anthropic.com, not a panic)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(""))
|
||||
|
||||
handler(c)
|
||||
|
||||
// Should get a 502 because the upstream URL (api.anthropic.com) won't be reachable
|
||||
// in test environment, or it might complete. The key thing is no panic.
|
||||
// We mainly verify it doesn't panic.
|
||||
if w.Code == 0 {
|
||||
t.Error("expected non-zero status code")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Test helpers ---
|
||||
|
||||
// errReader is an io.Reader that always returns an error.
|
||||
type errReader struct{}
|
||||
|
||||
func (e *errReader) Read([]byte) (int, error) {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
// failingTransport is an http.RoundTripper that always returns an error.
|
||||
type failingTransport struct{}
|
||||
|
||||
func (f *failingTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||
return nil, fmt.Errorf("connection refused")
|
||||
}
|
||||
|
||||
// rewriteTransport intercepts HTTP requests and rewrites the URL to point
|
||||
// at a local test server. This allows testing with UpstreamClient's hardcoded
|
||||
// messagesURL by redirecting all requests to a mock server.
|
||||
type rewriteTransport struct {
|
||||
base http.RoundTripper
|
||||
destURL string
|
||||
}
|
||||
|
||||
func (t *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the request URL to point at our mock server
|
||||
newReq := req.Clone(req.Context())
|
||||
newReq.URL.Scheme = "http"
|
||||
newReq.URL.Host = strings.TrimPrefix(t.destURL, "http://")
|
||||
newReq.URL.Path = "/v1/messages"
|
||||
newReq.URL.RawQuery = ""
|
||||
if t.base == nil {
|
||||
return http.DefaultTransport.RoundTrip(newReq)
|
||||
}
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
@@ -11,10 +12,10 @@ import (
|
||||
)
|
||||
|
||||
type Sanitizer struct {
|
||||
toolsForward map[string]string
|
||||
toolsReverse map[string]string
|
||||
systemRules []config.ReplaceRule
|
||||
bodyRules []config.ReplaceRule
|
||||
toolsForward map[string]string
|
||||
toolsReverse map[string]string
|
||||
systemRules []config.ReplaceRule
|
||||
bodyRules []config.ReplaceRule
|
||||
}
|
||||
|
||||
func NewSanitizer(cfg config.SanitizeConfig) *Sanitizer {
|
||||
@@ -49,7 +50,11 @@ func (s *Sanitizer) DesanitizeResponse(body []byte) []byte {
|
||||
}
|
||||
name := block.Get("name").String()
|
||||
if orig, ok := s.toolsReverse[name]; ok {
|
||||
body, _ = sjson.SetBytes(body, "content."+strconv.Itoa(i)+".name", orig)
|
||||
if b, err := sjson.SetBytes(body, "content."+strconv.Itoa(i)+".name", orig); err != nil {
|
||||
log.Warn().Err(err).Str("tool", name).Msg("desanitize response: set name failed")
|
||||
} else {
|
||||
body = b
|
||||
}
|
||||
}
|
||||
}
|
||||
return body
|
||||
@@ -64,8 +69,12 @@ func (s *Sanitizer) DesanitizeStreamEvent(line string) string {
|
||||
for _, path := range []string{"content_block.name", "delta.name"} {
|
||||
name := gjson.GetBytes(data, path).String()
|
||||
if orig, ok := s.toolsReverse[name]; ok {
|
||||
data, _ = sjson.SetBytes(data, path, orig)
|
||||
changed = true
|
||||
if b, err := sjson.SetBytes(data, path, orig); err != nil {
|
||||
log.Warn().Err(err).Str("tool", name).Msg("desanitize stream event: set name failed")
|
||||
} else {
|
||||
data = b
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if changed {
|
||||
@@ -85,7 +94,11 @@ func (s *Sanitizer) renameTools(body []byte) []byte {
|
||||
for i, tool := range tools.Array() {
|
||||
name := tool.Get("name").String()
|
||||
if newName, ok := s.toolsForward[name]; ok {
|
||||
body, _ = sjson.SetBytes(body, "tools."+strconv.Itoa(i)+".name", newName)
|
||||
if b, err := sjson.SetBytes(body, "tools."+strconv.Itoa(i)+".name", newName); err != nil {
|
||||
log.Warn().Err(err).Str("tool", name).Msg("rename tool failed")
|
||||
} else {
|
||||
body = b
|
||||
}
|
||||
}
|
||||
}
|
||||
return body
|
||||
@@ -104,7 +117,11 @@ func (s *Sanitizer) replaceSystem(body []byte) []byte {
|
||||
for _, rule := range s.systemRules {
|
||||
text = strings.ReplaceAll(text, rule.Match, rule.Replace)
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "system."+strconv.Itoa(i)+".text", text)
|
||||
if b, err := sjson.SetBytes(body, "system."+strconv.Itoa(i)+".text", text); err != nil {
|
||||
log.Warn().Err(err).Int("block", i).Msg("replace system text failed")
|
||||
} else {
|
||||
body = b
|
||||
}
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
@@ -0,0 +1,476 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
)
|
||||
|
||||
func TestNewSanitizer_Empty(t *testing.T) {
|
||||
s := NewSanitizer(config.SanitizeConfig{})
|
||||
if len(s.toolsForward) != 0 {
|
||||
t.Errorf("expected empty toolsForward, got %d entries", len(s.toolsForward))
|
||||
}
|
||||
if len(s.toolsReverse) != 0 {
|
||||
t.Errorf("expected empty toolsReverse, got %d entries", len(s.toolsReverse))
|
||||
}
|
||||
if s.systemRules != nil {
|
||||
t.Errorf("expected nil systemRules")
|
||||
}
|
||||
if s.bodyRules != nil {
|
||||
t.Errorf("expected nil bodyRules")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSanitizer_WithTools(t *testing.T) {
|
||||
cfg := config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{
|
||||
{From: "old_tool", To: "new_tool"},
|
||||
{From: "another", To: "replaced"},
|
||||
},
|
||||
}
|
||||
s := NewSanitizer(cfg)
|
||||
if got := s.toolsForward["old_tool"]; got != "new_tool" {
|
||||
t.Errorf("toolsForward[old_tool] = %q, want %q", got, "new_tool")
|
||||
}
|
||||
if got := s.toolsReverse["new_tool"]; got != "old_tool" {
|
||||
t.Errorf("toolsReverse[new_tool] = %q, want %q", got, "old_tool")
|
||||
}
|
||||
if got := s.toolsForward["another"]; got != "replaced" {
|
||||
t.Errorf("toolsForward[another] = %q, want %q", got, "replaced")
|
||||
}
|
||||
if got := s.toolsReverse["replaced"]; got != "another" {
|
||||
t.Errorf("toolsReverse[replaced] = %q, want %q", got, "another")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSanitizer_WithSystemAndBodyRules(t *testing.T) {
|
||||
cfg := config.SanitizeConfig{
|
||||
System: []config.ReplaceRule{{Match: "foo", Replace: "bar"}},
|
||||
Body: []config.ReplaceRule{{Match: "baz", Replace: "qux"}},
|
||||
}
|
||||
s := NewSanitizer(cfg)
|
||||
if len(s.systemRules) != 1 || s.systemRules[0].Match != "foo" {
|
||||
t.Errorf("systemRules not set correctly")
|
||||
}
|
||||
if len(s.bodyRules) != 1 || s.bodyRules[0].Match != "baz" {
|
||||
t.Errorf("bodyRules not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenameTools(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
forward map[string]string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty map returns body unchanged",
|
||||
forward: map[string]string{},
|
||||
body: `{"tools":[{"name":"my_tool"}]}`,
|
||||
expected: `{"tools":[{"name":"my_tool"}]}`,
|
||||
},
|
||||
{
|
||||
name: "no tools array returns body unchanged",
|
||||
forward: map[string]string{"my_tool": "renamed"},
|
||||
body: `{"messages":[]}`,
|
||||
expected: `{"messages":[]}`,
|
||||
},
|
||||
{
|
||||
name: "tools is not array returns body unchanged",
|
||||
forward: map[string]string{"my_tool": "renamed"},
|
||||
body: `{"tools":"not_array"}`,
|
||||
expected: `{"tools":"not_array"}`,
|
||||
},
|
||||
{
|
||||
name: "matching tool gets renamed",
|
||||
forward: map[string]string{"my_tool": "renamed_tool"},
|
||||
body: `{"tools":[{"name":"my_tool","description":"desc"}]}`,
|
||||
expected: `renamed_tool`,
|
||||
},
|
||||
{
|
||||
name: "non-matching tool unchanged",
|
||||
forward: map[string]string{"other_tool": "renamed"},
|
||||
body: `{"tools":[{"name":"my_tool"}]}`,
|
||||
expected: `my_tool`,
|
||||
},
|
||||
{
|
||||
name: "partial match - only exact match renames",
|
||||
forward: map[string]string{"tool_a": "tool_x", "tool_b": "tool_y"},
|
||||
body: `{"tools":[{"name":"tool_a"},{"name":"tool_c"},{"name":"tool_b"}]}`,
|
||||
expected: `tool_x`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: tt.forward,
|
||||
toolsReverse: make(map[string]string),
|
||||
}
|
||||
result := string(s.renameTools([]byte(tt.body)))
|
||||
if !strings.Contains(result, tt.expected) {
|
||||
t.Errorf("result %q does not contain %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenameTools_MultipleTools(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: map[string]string{"tool_a": "tool_x", "tool_b": "tool_y"},
|
||||
toolsReverse: make(map[string]string),
|
||||
}
|
||||
body := `{"tools":[{"name":"tool_a"},{"name":"tool_c"},{"name":"tool_b"}]}`
|
||||
result := string(s.renameTools([]byte(body)))
|
||||
if !strings.Contains(result, `"tool_x"`) {
|
||||
t.Errorf("tool_a should be renamed to tool_x, got %s", result)
|
||||
}
|
||||
if !strings.Contains(result, `"tool_y"`) {
|
||||
t.Errorf("tool_b should be renamed to tool_y, got %s", result)
|
||||
}
|
||||
if !strings.Contains(result, `"tool_c"`) {
|
||||
t.Errorf("tool_c should remain unchanged, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceSystem(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rules []config.ReplaceRule
|
||||
body string
|
||||
contains string
|
||||
}{
|
||||
{
|
||||
name: "empty rules returns body unchanged",
|
||||
rules: nil,
|
||||
body: `{"system":[{"type":"text","text":"hello world"}]}`,
|
||||
contains: "hello world",
|
||||
},
|
||||
{
|
||||
name: "no system field returns body unchanged",
|
||||
rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}},
|
||||
body: `{"messages":[]}`,
|
||||
contains: `"messages":[]`,
|
||||
},
|
||||
{
|
||||
name: "system not array returns body unchanged",
|
||||
rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}},
|
||||
body: `{"system":"just a string"}`,
|
||||
contains: "just a string",
|
||||
},
|
||||
{
|
||||
name: "single block single rule",
|
||||
rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}},
|
||||
body: `{"system":[{"type":"text","text":"hello world"}]}`,
|
||||
contains: "goodbye world",
|
||||
},
|
||||
{
|
||||
name: "multiple blocks",
|
||||
rules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}},
|
||||
body: `{"system":[{"type":"text","text":"AAA first"},{"type":"text","text":"AAA second"}]}`,
|
||||
contains: "BBB first",
|
||||
},
|
||||
{
|
||||
name: "multiple rules applied in order",
|
||||
rules: []config.ReplaceRule{{Match: "cat", Replace: "dog"}, {Match: "dog", Replace: "fish"}},
|
||||
body: `{"system":[{"type":"text","text":"I have a cat"}]}`,
|
||||
contains: "I have a fish",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: make(map[string]string),
|
||||
systemRules: tt.rules,
|
||||
}
|
||||
result := string(s.replaceSystem([]byte(tt.body)))
|
||||
if !strings.Contains(result, tt.contains) {
|
||||
t.Errorf("result %q does not contain %q", result, tt.contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceSystem_MultipleBlocks(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: make(map[string]string),
|
||||
systemRules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}},
|
||||
}
|
||||
body := `{"system":[{"type":"text","text":"AAA first"},{"type":"text","text":"AAA second"}]}`
|
||||
result := string(s.replaceSystem([]byte(body)))
|
||||
if !strings.Contains(result, "BBB first") {
|
||||
t.Errorf("first block not replaced: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "BBB second") {
|
||||
t.Errorf("second block not replaced: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rules []config.ReplaceRule
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty rules returns body unchanged",
|
||||
rules: nil,
|
||||
body: `{"foo":"bar"}`,
|
||||
expected: `{"foo":"bar"}`,
|
||||
},
|
||||
{
|
||||
name: "single replacement across entire body",
|
||||
rules: []config.ReplaceRule{{Match: "SECRET", Replace: "REDACTED"}},
|
||||
body: `{"data":"SECRET value SECRET"}`,
|
||||
expected: `{"data":"REDACTED value REDACTED"}`,
|
||||
},
|
||||
{
|
||||
name: "multiple rules applied sequentially",
|
||||
rules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}, {Match: "BBB", Replace: "CCC"}},
|
||||
body: `{"text":"AAA"}`,
|
||||
expected: `{"text":"CCC"}`,
|
||||
},
|
||||
{
|
||||
name: "no match leaves body unchanged",
|
||||
rules: []config.ReplaceRule{{Match: "NOMATCH", Replace: "X"}},
|
||||
body: `{"text":"hello"}`,
|
||||
expected: `{"text":"hello"}`,
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
rules: []config.ReplaceRule{{Match: "a", Replace: "b"}},
|
||||
body: ``,
|
||||
expected: ``,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: make(map[string]string),
|
||||
bodyRules: tt.rules,
|
||||
}
|
||||
result := string(s.replaceBody([]byte(tt.body)))
|
||||
if result != tt.expected {
|
||||
t.Errorf("got %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRequest(t *testing.T) {
|
||||
cfg := config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "my_tool", To: "renamed_tool"}},
|
||||
System: []config.ReplaceRule{{Match: "INTERNAL", Replace: "PUBLIC"}},
|
||||
Body: []config.ReplaceRule{{Match: "secret_val", Replace: "safe_val"}},
|
||||
}
|
||||
s := NewSanitizer(cfg)
|
||||
|
||||
body := `{"tools":[{"name":"my_tool"}],"system":[{"type":"text","text":"INTERNAL info"}],"data":"secret_val here"}`
|
||||
result := string(s.SanitizeRequest([]byte(body)))
|
||||
|
||||
if !strings.Contains(result, `"renamed_tool"`) {
|
||||
t.Errorf("tool not renamed in result: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "PUBLIC info") {
|
||||
t.Errorf("system not replaced in result: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "safe_val here") {
|
||||
t.Errorf("body not replaced in result: %s", result)
|
||||
}
|
||||
if strings.Contains(result, "secret_val") {
|
||||
t.Errorf("secret_val should have been replaced: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRequest_EmptyConfig(t *testing.T) {
|
||||
s := NewSanitizer(config.SanitizeConfig{})
|
||||
body := `{"tools":[{"name":"my_tool"}],"system":[{"type":"text","text":"hello"}]}`
|
||||
result := string(s.SanitizeRequest([]byte(body)))
|
||||
if result != body {
|
||||
t.Errorf("empty config should not modify body.\ngot: %s\nwant: %s", result, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDesanitizeResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reverse map[string]string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no content field returns unchanged",
|
||||
reverse: map[string]string{"renamed": "original"},
|
||||
body: `{"id":"msg_1","role":"assistant"}`,
|
||||
expected: `{"id":"msg_1","role":"assistant"}`,
|
||||
},
|
||||
{
|
||||
name: "content not array returns unchanged",
|
||||
reverse: map[string]string{"renamed": "original"},
|
||||
body: `{"content":"just text"}`,
|
||||
expected: `{"content":"just text"}`,
|
||||
},
|
||||
{
|
||||
name: "non-tool_use block left unchanged",
|
||||
reverse: map[string]string{"renamed": "original"},
|
||||
body: `{"content":[{"type":"text","text":"hello"}]}`,
|
||||
expected: `{"content":[{"type":"text","text":"hello"}]}`,
|
||||
},
|
||||
{
|
||||
name: "tool_use block with matching name gets reversed",
|
||||
reverse: map[string]string{"renamed_tool": "original_tool"},
|
||||
body: `{"content":[{"type":"tool_use","name":"renamed_tool","id":"t1"}]}`,
|
||||
expected: `original_tool`,
|
||||
},
|
||||
{
|
||||
name: "tool_use block with no match unchanged",
|
||||
reverse: map[string]string{"other": "something"},
|
||||
body: `{"content":[{"type":"tool_use","name":"my_tool","id":"t1"}]}`,
|
||||
expected: `my_tool`,
|
||||
},
|
||||
{
|
||||
name: "mixed blocks only tool_use reversed",
|
||||
reverse: map[string]string{"renamed": "original"},
|
||||
body: `{"content":[{"type":"text","text":"hi"},{"type":"tool_use","name":"renamed","id":"t1"}]}`,
|
||||
expected: `original`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: tt.reverse,
|
||||
}
|
||||
result := string(s.DesanitizeResponse([]byte(tt.body)))
|
||||
if !strings.Contains(result, tt.expected) {
|
||||
t.Errorf("result %q does not contain %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDesanitizeResponse_MultipleToolUse(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: map[string]string{"r1": "o1", "r2": "o2"},
|
||||
}
|
||||
body := `{"content":[{"type":"tool_use","name":"r1","id":"t1"},{"type":"text","text":"x"},{"type":"tool_use","name":"r2","id":"t2"}]}`
|
||||
result := string(s.DesanitizeResponse([]byte(body)))
|
||||
if !strings.Contains(result, `"o1"`) {
|
||||
t.Errorf("r1 not reversed to o1: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, `"o2"`) {
|
||||
t.Errorf("r2 not reversed to o2: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDesanitizeStreamEvent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reverse map[string]string
|
||||
line string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "non-data line passed through",
|
||||
reverse: map[string]string{"r": "o"},
|
||||
line: "event: content_block_start",
|
||||
expected: "event: content_block_start",
|
||||
},
|
||||
{
|
||||
name: "data line without tool_use passed through",
|
||||
reverse: map[string]string{"r": "o"},
|
||||
line: `data: {"type":"text","text":"hello"}`,
|
||||
expected: `data: {"type":"text","text":"hello"}`,
|
||||
},
|
||||
{
|
||||
name: "data line with tool_use in content_block.name",
|
||||
reverse: map[string]string{"renamed_tool": "original_tool"},
|
||||
line: `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed_tool","id":"t1"}}`,
|
||||
expected: `original_tool`,
|
||||
},
|
||||
{
|
||||
name: "data line with tool_use in delta.name",
|
||||
reverse: map[string]string{"renamed_tool": "original_tool"},
|
||||
line: `data: {"type":"content_block_delta","delta":{"type":"tool_use","name":"renamed_tool"}}`,
|
||||
expected: `original_tool`,
|
||||
},
|
||||
{
|
||||
name: "data line with tool_use but no matching name",
|
||||
reverse: map[string]string{"other": "something"},
|
||||
line: `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"my_tool","id":"t1"}}`,
|
||||
expected: `my_tool`,
|
||||
},
|
||||
{
|
||||
name: "empty line passed through",
|
||||
reverse: map[string]string{"r": "o"},
|
||||
line: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "line contains tool_use but not data prefix - passed through",
|
||||
reverse: map[string]string{"r": "o"},
|
||||
line: "event: tool_use",
|
||||
expected: "event: tool_use",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: tt.reverse,
|
||||
}
|
||||
result := s.DesanitizeStreamEvent(tt.line)
|
||||
if !strings.Contains(result, tt.expected) {
|
||||
t.Errorf("result %q does not contain %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDesanitizeStreamEvent_DataPrefixPreserved(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: map[string]string{"renamed": "original"},
|
||||
}
|
||||
line := `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed","id":"t1"}}`
|
||||
result := s.DesanitizeStreamEvent(line)
|
||||
if !strings.HasPrefix(result, "data: ") {
|
||||
t.Errorf("result should start with 'data: ', got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRequest_MalformedJSON(t *testing.T) {
|
||||
s := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "a", To: "b"}},
|
||||
System: []config.ReplaceRule{{Match: "x", Replace: "y"}},
|
||||
})
|
||||
// Malformed JSON - renameTools and replaceSystem should handle gracefully
|
||||
body := `not valid json`
|
||||
result := string(s.SanitizeRequest([]byte(body)))
|
||||
// Should not panic; body rules still do string replacement
|
||||
if result != "not valid json" {
|
||||
t.Errorf("malformed JSON should pass through (no body rules match), got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRequest_EmptyBody(t *testing.T) {
|
||||
s := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "a", To: "b"}},
|
||||
})
|
||||
result := s.SanitizeRequest([]byte{})
|
||||
if len(result) != 0 {
|
||||
t.Errorf("empty body should return empty, got %q", string(result))
|
||||
}
|
||||
}
|
||||
+65
-44
@@ -3,13 +3,14 @@ package proxy
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// SniffedProfile holds everything captured from a real Claude Code request.
|
||||
@@ -35,6 +36,21 @@ var skipHeaders = map[string]bool{
|
||||
"connection": true,
|
||||
}
|
||||
|
||||
const fakeJSONResponse = `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`
|
||||
|
||||
const fakeStreamResponse = "event: message_start\n" +
|
||||
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_fake\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}\n\n" +
|
||||
"event: content_block_start\n" +
|
||||
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n" +
|
||||
"event: content_block_delta\n" +
|
||||
"data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"ok\"}}\n\n" +
|
||||
"event: content_block_stop\n" +
|
||||
"data: {\"type\":\"content_block_stop\",\"index\":0}\n\n" +
|
||||
"event: message_delta\n" +
|
||||
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":1}}\n\n" +
|
||||
"event: message_stop\n" +
|
||||
"data: {\"type\":\"message_stop\"}\n\n"
|
||||
|
||||
func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
@@ -47,45 +63,7 @@ func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) {
|
||||
captured := make(chan struct{}, 1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "HEAD" {
|
||||
w.WriteHeader(200)
|
||||
return
|
||||
}
|
||||
if r.Method != "POST" || !strings.Contains(r.URL.Path, "/v1/messages") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`)
|
||||
return
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
|
||||
mu.Lock()
|
||||
if profile == nil {
|
||||
profile = extractProfile(r, body)
|
||||
select {
|
||||
case captured <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
if strings.Contains(string(body), `"stream":true`) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_fake\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}\n\n")
|
||||
fmt.Fprint(w, "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n")
|
||||
fmt.Fprint(w, "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"ok\"}}\n\n")
|
||||
fmt.Fprint(w, "event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n")
|
||||
fmt.Fprint(w, "event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":1}}\n\n")
|
||||
fmt.Fprint(w, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`)
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("/", sniffHandler(&mu, &profile, captured))
|
||||
|
||||
srv := &http.Server{Handler: mux}
|
||||
go srv.Serve(listener)
|
||||
@@ -116,13 +94,57 @@ func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) {
|
||||
return nil, fmt.Errorf("no API request captured")
|
||||
}
|
||||
|
||||
log.Printf("sniffed claude-code: version=%s headers=%d body=%d bytes",
|
||||
profile.Version, len(profile.Headers), len(profile.Body))
|
||||
log.Info().
|
||||
Str("version", profile.Version).
|
||||
Int("headers", len(profile.Headers)).
|
||||
Int("body_size", len(profile.Body)).
|
||||
Msg("sniffed claude-code profile")
|
||||
|
||||
for _, h := range profile.Headers {
|
||||
log.Debug().Str("header", h[0]).Str("value", h[1]).Msg("sniffed header")
|
||||
}
|
||||
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
func sniffHandler(mu *sync.Mutex, profile **SniffedProfile, captured chan<- struct{}) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "HEAD" {
|
||||
w.WriteHeader(200)
|
||||
return
|
||||
}
|
||||
if r.Method != "POST" || !strings.Contains(r.URL.Path, "/v1/messages") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, fakeJSONResponse)
|
||||
return
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
|
||||
mu.Lock()
|
||||
if *profile == nil {
|
||||
*profile = extractProfile(r, body)
|
||||
select {
|
||||
case captured <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
if strings.Contains(string(body), `"stream":true`) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, fakeStreamResponse)
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, fakeJSONResponse)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func extractProfile(r *http.Request, body []byte) *SniffedProfile {
|
||||
// Capture raw headers preserving original casing.
|
||||
var headers [][2]string
|
||||
for name, vals := range r.Header {
|
||||
if skipHeaders[strings.ToLower(name)] {
|
||||
@@ -133,7 +155,6 @@ func extractProfile(r *http.Request, body []byte) *SniffedProfile {
|
||||
}
|
||||
}
|
||||
|
||||
// Deduplicate and strip subscription-specific betas.
|
||||
seen := map[string]bool{}
|
||||
var deduped [][2]string
|
||||
for _, h := range headers {
|
||||
|
||||
@@ -0,0 +1,278 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newRequest(t *testing.T, headers map[string][]string) *http.Request {
|
||||
t.Helper()
|
||||
r := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
r.Header = http.Header{}
|
||||
for k, vals := range headers {
|
||||
for _, v := range vals {
|
||||
r.Header.Add(k, v)
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func TestExtractProfile_BasicHeaders(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"Content-Type": {"application/json"},
|
||||
"X-Custom-Header": {"custom-value"},
|
||||
"User-Agent": {"Claude/1.2.3 linux"},
|
||||
})
|
||||
body := []byte(`{"model":"claude-sonnet-4-6"}`)
|
||||
|
||||
p := extractProfile(r, body)
|
||||
|
||||
// Check version parsed
|
||||
if p.Version != "1.2.3" {
|
||||
t.Errorf("version = %q, want %q", p.Version, "1.2.3")
|
||||
}
|
||||
|
||||
// Check body preserved
|
||||
if string(p.Body) != string(body) {
|
||||
t.Errorf("body not preserved")
|
||||
}
|
||||
|
||||
// Check headers captured
|
||||
found := map[string]bool{}
|
||||
for _, h := range p.Headers {
|
||||
found[strings.ToLower(h[0])] = true
|
||||
}
|
||||
if !found["content-type"] {
|
||||
t.Error("Content-Type header should be captured")
|
||||
}
|
||||
if !found["x-custom-header"] {
|
||||
t.Error("X-Custom-Header should be captured")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_SkipHeaders(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"Host": {"example.com"},
|
||||
"Content-Length": {"42"},
|
||||
"Authorization": {"Bearer token123"},
|
||||
"X-Api-Key": {"key123"},
|
||||
"Connection": {"keep-alive"},
|
||||
"Content-Type": {"application/json"},
|
||||
"X-Custom": {"keep-me"},
|
||||
})
|
||||
|
||||
p := extractProfile(r, []byte(`{}`))
|
||||
|
||||
for _, h := range p.Headers {
|
||||
lower := strings.ToLower(h[0])
|
||||
if skipHeaders[lower] {
|
||||
t.Errorf("header %q should have been skipped", h[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Verify non-skipped headers are present
|
||||
found := map[string]bool{}
|
||||
for _, h := range p.Headers {
|
||||
found[strings.ToLower(h[0])] = true
|
||||
}
|
||||
if !found["content-type"] {
|
||||
t.Error("Content-Type should be kept")
|
||||
}
|
||||
if !found["x-custom"] {
|
||||
t.Error("X-Custom should be kept")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_HeaderDeduplication(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"Content-Type": {"application/json"},
|
||||
})
|
||||
// Add duplicate with different casing - Go's http.Header normalizes to canonical form
|
||||
// so we need to add the same canonical header with multiple values to test dedup
|
||||
r.Header.Add("Content-Type", "text/plain")
|
||||
|
||||
p := extractProfile(r, []byte(`{}`))
|
||||
|
||||
// After deduplication by lowercase key, only one entry per key
|
||||
seen := map[string]int{}
|
||||
for _, h := range p.Headers {
|
||||
seen[strings.ToLower(h[0])]++
|
||||
}
|
||||
for key, count := range seen {
|
||||
if count > 1 {
|
||||
t.Errorf("header %q appears %d times after dedup, want 1", key, count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_AnthropicBetaContextStripping(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"Anthropic-Beta": {"prompt-caching-2024-07-31,context-1m-2024-09-01,some-other-beta"},
|
||||
})
|
||||
|
||||
p := extractProfile(r, []byte(`{}`))
|
||||
|
||||
var betaValue string
|
||||
for _, h := range p.Headers {
|
||||
if strings.ToLower(h[0]) == "anthropic-beta" {
|
||||
betaValue = h[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(betaValue, "context-1m") {
|
||||
t.Errorf("context-1m should be stripped from anthropic-beta, got %q", betaValue)
|
||||
}
|
||||
if !strings.Contains(betaValue, "prompt-caching-2024-07-31") {
|
||||
t.Errorf("prompt-caching should be preserved, got %q", betaValue)
|
||||
}
|
||||
if !strings.Contains(betaValue, "some-other-beta") {
|
||||
t.Errorf("some-other-beta should be preserved, got %q", betaValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_AnthropicBetaAllContextRemoved(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"Anthropic-Beta": {"context-1m-2024-09-01"},
|
||||
})
|
||||
|
||||
p := extractProfile(r, []byte(`{}`))
|
||||
|
||||
for _, h := range p.Headers {
|
||||
if strings.ToLower(h[0]) == "anthropic-beta" {
|
||||
// All betas were context-1m, so after filtering the value should be empty
|
||||
if h[1] != "" {
|
||||
t.Errorf("all context-1m betas stripped should leave empty, got %q", h[1])
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
// It's also acceptable if the header is still present but empty
|
||||
}
|
||||
|
||||
func TestExtractProfile_VersionParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userAgent string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "standard Claude UA",
|
||||
userAgent: "Claude/1.2.3 linux x86_64",
|
||||
expected: "1.2.3",
|
||||
},
|
||||
{
|
||||
name: "version with no space after",
|
||||
userAgent: "Claude/4.5.6",
|
||||
expected: "4.5.6",
|
||||
},
|
||||
{
|
||||
name: "no slash in UA",
|
||||
userAgent: "Mozilla 5.0",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty UA",
|
||||
userAgent: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "slash at start",
|
||||
userAgent: "/1.0.0 rest",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "multiple slashes",
|
||||
userAgent: "App/1.0.0 (sub/2.0)",
|
||||
expected: "1.0.0",
|
||||
},
|
||||
{
|
||||
name: "version only after slash no space",
|
||||
userAgent: "Tool/9.8.7",
|
||||
expected: "9.8.7",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"User-Agent": {tt.userAgent},
|
||||
})
|
||||
p := extractProfile(r, []byte(`{}`))
|
||||
if p.Version != tt.expected {
|
||||
t.Errorf("version = %q, want %q", p.Version, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_EmptyHeaders(t *testing.T) {
|
||||
r := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
r.Header = http.Header{}
|
||||
|
||||
p := extractProfile(r, []byte(`{"test":true}`))
|
||||
|
||||
if len(p.Headers) != 0 {
|
||||
t.Errorf("expected no headers, got %d", len(p.Headers))
|
||||
}
|
||||
if p.Version != "" {
|
||||
t.Errorf("expected empty version with no UA, got %q", p.Version)
|
||||
}
|
||||
if string(p.Body) != `{"test":true}` {
|
||||
t.Errorf("body not preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_BodyPreserved(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"User-Agent": {"Claude/1.0.0 test"},
|
||||
})
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
|
||||
p := extractProfile(r, body)
|
||||
|
||||
if string(p.Body) != string(body) {
|
||||
t.Errorf("body not preserved.\ngot: %s\nwant: %s", p.Body, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipHeaders_Entries(t *testing.T) {
|
||||
expected := map[string]bool{
|
||||
"host": true,
|
||||
"content-length": true,
|
||||
"authorization": true,
|
||||
"x-api-key": true,
|
||||
"connection": true,
|
||||
}
|
||||
if len(skipHeaders) != len(expected) {
|
||||
t.Errorf("skipHeaders has %d entries, want %d", len(skipHeaders), len(expected))
|
||||
}
|
||||
for k, v := range expected {
|
||||
if skipHeaders[k] != v {
|
||||
t.Errorf("skipHeaders[%q] = %v, want %v", k, skipHeaders[k], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSniffedProfile_Fields(t *testing.T) {
|
||||
// Verify the struct can hold all expected data
|
||||
p := &SniffedProfile{
|
||||
Headers: [][2]string{{"Content-Type", "application/json"}},
|
||||
Body: []byte(`{}`),
|
||||
Version: "1.0.0",
|
||||
}
|
||||
if len(p.Headers) != 1 {
|
||||
t.Error("Headers should have 1 entry")
|
||||
}
|
||||
if p.Headers[0][0] != "Content-Type" || p.Headers[0][1] != "application/json" {
|
||||
t.Error("Header not stored correctly")
|
||||
}
|
||||
if string(p.Body) != `{}` {
|
||||
t.Error("Body not stored correctly")
|
||||
}
|
||||
if p.Version != "1.0.0" {
|
||||
t.Error("Version not stored correctly")
|
||||
}
|
||||
}
|
||||
@@ -9,8 +9,12 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/auth"
|
||||
"github.com/fujin/anthropic-proxy/internal/logging"
|
||||
"github.com/fujin/anthropic-proxy/internal/transport"
|
||||
"github.com/fujin/anthropic-proxy/internal/version"
|
||||
)
|
||||
|
||||
const messagesURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
@@ -25,7 +29,7 @@ func NewUpstreamClient(profile *SniffedProfile) *UpstreamClient {
|
||||
return &UpstreamClient{
|
||||
client: http.Client{
|
||||
Timeout: 0,
|
||||
Transport: newUtlsRoundTripper(),
|
||||
Transport: transport.NewUTLS(),
|
||||
},
|
||||
sessionID: uuid.New().String(),
|
||||
profile: profile,
|
||||
@@ -36,7 +40,7 @@ func (u *UpstreamClient) version() string {
|
||||
if u.profile != nil && u.profile.Version != "" {
|
||||
return u.profile.Version
|
||||
}
|
||||
return "2.1.92"
|
||||
return version.ClaudeCodeFallback
|
||||
}
|
||||
|
||||
// applyHeaders replays sniffed headers, substituting auth + per-request IDs + accept.
|
||||
@@ -51,6 +55,15 @@ func (u *UpstreamClient) applyHeaders(req *http.Request, token string, streaming
|
||||
req.Header.Del("x-api-key")
|
||||
if strings.HasPrefix(token, "sk-ant-oat") {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
// OAuth tokens require this beta flag — without it the API rejects with 401
|
||||
existing := req.Header.Get("anthropic-beta")
|
||||
if !strings.Contains(existing, "oauth-2025-04-20") {
|
||||
if existing == "" {
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
} else {
|
||||
req.Header.Set("anthropic-beta", existing+",oauth-2025-04-20")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
req.Header.Set("x-api-key", token)
|
||||
}
|
||||
@@ -75,6 +88,12 @@ func (u *UpstreamClient) Execute(ctx context.Context, cred *auth.Credential, bod
|
||||
}
|
||||
u.applyHeaders(req, cred.Token(), false)
|
||||
|
||||
log.Debug().
|
||||
Str("url", messagesURL).
|
||||
Str("upstream_headers", logging.RedactHeaders(req.Header)).
|
||||
Int("body_size", len(body)).
|
||||
Msg("upstream request")
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, 0, fmt.Errorf("upstream request: %w", err)
|
||||
@@ -85,6 +104,12 @@ func (u *UpstreamClient) Execute(ctx context.Context, cred *auth.Credential, bod
|
||||
if err != nil {
|
||||
return nil, nil, resp.StatusCode, fmt.Errorf("read upstream response: %w", err)
|
||||
}
|
||||
log.Debug().
|
||||
Int("status", resp.StatusCode).
|
||||
Str("response_headers", logging.RedactHeaders(resp.Header)).
|
||||
Int("response_size", len(respBody)).
|
||||
Msg("upstream response")
|
||||
|
||||
return respBody, resp.Header, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
@@ -97,9 +122,21 @@ func (u *UpstreamClient) ExecuteStream(ctx context.Context, cred *auth.Credentia
|
||||
}
|
||||
u.applyHeaders(req, cred.Token(), true)
|
||||
|
||||
log.Debug().
|
||||
Str("url", messagesURL).
|
||||
Str("upstream_headers", logging.RedactHeaders(req.Header)).
|
||||
Int("body_size", len(body)).
|
||||
Msg("upstream stream request")
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream stream request: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("status", resp.StatusCode).
|
||||
Str("response_headers", logging.RedactHeaders(resp.Header)).
|
||||
Msg("upstream stream response")
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,334 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- NewUpstreamClient ---
|
||||
|
||||
func TestNewUpstreamClient_NilProfile(t *testing.T) {
|
||||
uc := NewUpstreamClient(nil)
|
||||
if uc == nil {
|
||||
t.Fatal("NewUpstreamClient returned nil")
|
||||
}
|
||||
if uc.sessionID == "" {
|
||||
t.Error("expected non-empty sessionID")
|
||||
}
|
||||
if uc.profile != nil {
|
||||
t.Error("expected nil profile")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUpstreamClient_WithProfile(t *testing.T) {
|
||||
profile := &SniffedProfile{
|
||||
Version: "1.2.3",
|
||||
Headers: [][2]string{{"User-Agent", "test/1.0"}},
|
||||
}
|
||||
uc := NewUpstreamClient(profile)
|
||||
if uc.profile != profile {
|
||||
t.Error("expected profile to be stored")
|
||||
}
|
||||
if uc.sessionID == "" {
|
||||
t.Error("expected non-empty sessionID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUpstreamClient_UniqueSessionIDs(t *testing.T) {
|
||||
uc1 := NewUpstreamClient(nil)
|
||||
uc2 := NewUpstreamClient(nil)
|
||||
if uc1.sessionID == uc2.sessionID {
|
||||
t.Errorf("expected different session IDs, both got %q", uc1.sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
// --- version() ---
|
||||
|
||||
func TestVersion_WithProfileVersion(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
profile: &SniffedProfile{Version: "3.5.7"},
|
||||
}
|
||||
if got := uc.version(); got != "3.5.7" {
|
||||
t.Errorf("version() = %q, want %q", got, "3.5.7")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersion_NilProfile_Fallback(t *testing.T) {
|
||||
uc := &UpstreamClient{profile: nil}
|
||||
if got := uc.version(); got != "2.1.92" {
|
||||
t.Errorf("version() = %q, want %q", got, "2.1.92")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersion_EmptyProfileVersion_Fallback(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
profile: &SniffedProfile{Version: ""},
|
||||
}
|
||||
if got := uc.version(); got != "2.1.92" {
|
||||
t.Errorf("version() = %q, want %q", got, "2.1.92")
|
||||
}
|
||||
}
|
||||
|
||||
// --- applyHeaders ---
|
||||
|
||||
func TestApplyHeaders_NilProfile_NonOAuth_NonStream(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "test-session-id",
|
||||
profile: nil,
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-api123", false)
|
||||
|
||||
// x-api-key for non-OAuth token
|
||||
if got := req.Header.Get("x-api-key"); got != "sk-ant-api123" {
|
||||
t.Errorf("x-api-key = %q, want %q", got, "sk-ant-api123")
|
||||
}
|
||||
// Should NOT have Authorization
|
||||
if got := req.Header.Get("Authorization"); got != "" {
|
||||
t.Errorf("Authorization = %q, want empty", got)
|
||||
}
|
||||
// Session ID
|
||||
if got := req.Header.Get("X-Claude-Code-Session-Id"); got != "test-session-id" {
|
||||
t.Errorf("X-Claude-Code-Session-Id = %q, want %q", got, "test-session-id")
|
||||
}
|
||||
// Request ID should be a UUID
|
||||
if got := req.Header.Get("x-client-request-id"); got == "" {
|
||||
t.Error("expected non-empty x-client-request-id")
|
||||
}
|
||||
// Non-stream: application/json
|
||||
if got := req.Header.Get("Accept"); got != "application/json" {
|
||||
t.Errorf("Accept = %q, want %q", got, "application/json")
|
||||
}
|
||||
// Accept-Encoding always identity
|
||||
if got := req.Header.Get("Accept-Encoding"); got != "identity" {
|
||||
t.Errorf("Accept-Encoding = %q, want %q", got, "identity")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_NilProfile_NonOAuth_Stream(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: nil,
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-api123", true)
|
||||
|
||||
if got := req.Header.Get("Accept"); got != "text/event-stream" {
|
||||
t.Errorf("Accept = %q, want %q", got, "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_OAuthToken_SetsBearerAndBetaFlag(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: nil,
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-oat-mytoken", false)
|
||||
|
||||
// OAuth: Authorization Bearer
|
||||
if got := req.Header.Get("Authorization"); got != "Bearer sk-ant-oat-mytoken" {
|
||||
t.Errorf("Authorization = %q, want %q", got, "Bearer sk-ant-oat-mytoken")
|
||||
}
|
||||
// Should NOT have x-api-key
|
||||
if got := req.Header.Get("x-api-key"); got != "" {
|
||||
t.Errorf("x-api-key = %q, want empty for OAuth", got)
|
||||
}
|
||||
// anthropic-beta should include oauth-2025-04-20
|
||||
if got := req.Header.Get("anthropic-beta"); got != "oauth-2025-04-20" {
|
||||
t.Errorf("anthropic-beta = %q, want %q", got, "oauth-2025-04-20")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_OAuthToken_AppendsToExistingBeta(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: &SniffedProfile{
|
||||
Headers: [][2]string{
|
||||
{"anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15"},
|
||||
},
|
||||
},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-oat-tok", false)
|
||||
|
||||
beta := req.Header.Get("anthropic-beta")
|
||||
if !strings.Contains(beta, "max-tokens-3-5-sonnet-2024-07-15") {
|
||||
t.Errorf("anthropic-beta %q should contain existing beta", beta)
|
||||
}
|
||||
if !strings.Contains(beta, "oauth-2025-04-20") {
|
||||
t.Errorf("anthropic-beta %q should contain oauth flag", beta)
|
||||
}
|
||||
// Should be appended with comma
|
||||
if beta != "max-tokens-3-5-sonnet-2024-07-15,oauth-2025-04-20" {
|
||||
t.Errorf("anthropic-beta = %q, want %q", beta, "max-tokens-3-5-sonnet-2024-07-15,oauth-2025-04-20")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_OAuthToken_ExistingBetaAlreadyHasOAuth(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: &SniffedProfile{
|
||||
Headers: [][2]string{
|
||||
{"anthropic-beta", "oauth-2025-04-20,something-else"},
|
||||
},
|
||||
},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-oat-tok", false)
|
||||
|
||||
beta := req.Header.Get("anthropic-beta")
|
||||
// Should NOT duplicate oauth flag
|
||||
count := strings.Count(beta, "oauth-2025-04-20")
|
||||
if count != 1 {
|
||||
t.Errorf("oauth flag appeared %d times in %q, want 1", count, beta)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_WithProfile_ReplaysHeaders(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: &SniffedProfile{
|
||||
Headers: [][2]string{
|
||||
{"User-Agent", "Claude/1.0"},
|
||||
{"anthropic-version", "2023-06-01"},
|
||||
{"Custom-Header", "custom-value"},
|
||||
},
|
||||
},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-api123", false)
|
||||
|
||||
if got := req.Header.Get("User-Agent"); got != "Claude/1.0" {
|
||||
t.Errorf("User-Agent = %q, want %q", got, "Claude/1.0")
|
||||
}
|
||||
if got := req.Header.Get("anthropic-version"); got != "2023-06-01" {
|
||||
t.Errorf("anthropic-version = %q, want %q", got, "2023-06-01")
|
||||
}
|
||||
if got := req.Header.Get("Custom-Header"); got != "custom-value" {
|
||||
t.Errorf("Custom-Header = %q, want %q", got, "custom-value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_ProfileAuthHeadersRemoved(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: &SniffedProfile{
|
||||
Headers: [][2]string{
|
||||
{"Authorization", "Bearer old-token"},
|
||||
{"x-api-key", "old-api-key"},
|
||||
{"User-Agent", "test"},
|
||||
},
|
||||
},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-api-new", false)
|
||||
|
||||
// Old auth headers from profile should be removed
|
||||
if got := req.Header.Get("Authorization"); got != "" {
|
||||
t.Errorf("Authorization should be empty for non-OAuth, got %q", got)
|
||||
}
|
||||
// New auth should be set via x-api-key
|
||||
if got := req.Header.Get("x-api-key"); got != "sk-ant-api-new" {
|
||||
t.Errorf("x-api-key = %q, want %q", got, "sk-ant-api-new")
|
||||
}
|
||||
// User-Agent from profile should remain
|
||||
if got := req.Header.Get("User-Agent"); got != "test" {
|
||||
t.Errorf("User-Agent = %q, want %q", got, "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_ProfileAuthHeadersRemovedForOAuth(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: &SniffedProfile{
|
||||
Headers: [][2]string{
|
||||
{"Authorization", "Bearer old-token"},
|
||||
{"x-api-key", "old-api-key"},
|
||||
},
|
||||
},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-oat-new", false)
|
||||
|
||||
// Old x-api-key removed
|
||||
if got := req.Header.Get("x-api-key"); got != "" {
|
||||
t.Errorf("x-api-key should be empty for OAuth, got %q", got)
|
||||
}
|
||||
// New auth set via Authorization
|
||||
if got := req.Header.Get("Authorization"); got != "Bearer sk-ant-oat-new" {
|
||||
t.Errorf("Authorization = %q, want %q", got, "Bearer sk-ant-oat-new")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_AcceptEncoding_AlwaysIdentity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
streaming bool
|
||||
}{
|
||||
{"non-stream", false},
|
||||
{"stream", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
uc := &UpstreamClient{sessionID: "s", profile: nil}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
uc.applyHeaders(req, "token", tt.streaming)
|
||||
|
||||
if got := req.Header.Get("Accept-Encoding"); got != "identity" {
|
||||
t.Errorf("Accept-Encoding = %q, want %q", got, "identity")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_UniqueRequestIDs(t *testing.T) {
|
||||
uc := &UpstreamClient{sessionID: "s", profile: nil}
|
||||
|
||||
req1, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
uc.applyHeaders(req1, "tok", false)
|
||||
|
||||
req2, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
uc.applyHeaders(req2, "tok", false)
|
||||
|
||||
id1 := req1.Header.Get("x-client-request-id")
|
||||
id2 := req2.Header.Get("x-client-request-id")
|
||||
if id1 == "" || id2 == "" {
|
||||
t.Fatal("expected non-empty request IDs")
|
||||
}
|
||||
if id1 == id2 {
|
||||
t.Errorf("expected unique request IDs, both got %q", id1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_NonOAuth_NoAnthroPicBetaSet(t *testing.T) {
|
||||
// Non-OAuth tokens should NOT set anthropic-beta oauth flag
|
||||
uc := &UpstreamClient{sessionID: "s", profile: nil}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
uc.applyHeaders(req, "sk-ant-api123", false)
|
||||
|
||||
beta := req.Header.Get("anthropic-beta")
|
||||
if strings.Contains(beta, "oauth-2025-04-20") {
|
||||
t.Errorf("non-OAuth token should not have oauth beta flag, got %q", beta)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_OAuthToken_FreshBeta(t *testing.T) {
|
||||
// No profile, no existing beta — should set fresh
|
||||
uc := &UpstreamClient{sessionID: "s", profile: nil}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
uc.applyHeaders(req, "sk-ant-oat-tok", false)
|
||||
|
||||
if got := req.Header.Get("anthropic-beta"); got != "oauth-2025-04-20" {
|
||||
t.Errorf("anthropic-beta = %q, want %q", got, "oauth-2025-04-20")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Window holds per-window usage state.
|
||||
type Window struct {
|
||||
Utilization float64 // 0-100 from API
|
||||
ResetsAt time.Time // when window resets
|
||||
}
|
||||
|
||||
// Snapshot is a read-only copy of a Window's state.
|
||||
type Snapshot struct {
|
||||
Utilization float64
|
||||
ResetsAt time.Time
|
||||
}
|
||||
|
||||
// Tracker polls /api/oauth/usage and tracks per-window utilization.
|
||||
type Tracker struct {
|
||||
tokenFn func() string
|
||||
mu sync.RWMutex
|
||||
fiveHour Window
|
||||
sevenDay Window
|
||||
sonnet Window
|
||||
extra ExtraUsage
|
||||
}
|
||||
|
||||
// NewTracker creates a tracker. tokenFn should return the current access token.
|
||||
func NewTracker(tokenFn func() string) *Tracker {
|
||||
return &Tracker{tokenFn: tokenFn}
|
||||
}
|
||||
|
||||
// Start begins the background poll loop.
|
||||
func (t *Tracker) Start(ctx context.Context) {
|
||||
go func() {
|
||||
t.poll(ctx)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(5 * time.Minute):
|
||||
t.poll(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// UpdateFromHeaders extracts rate limit data from /v1/messages response headers.
|
||||
func (t *Tracker) UpdateFromHeaders(h http.Header) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if v := h.Get("Anthropic-Ratelimit-Unified-5h-Utilization"); v != "" {
|
||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
t.fiveHour.Utilization = f * 100
|
||||
}
|
||||
}
|
||||
if v := h.Get("Anthropic-Ratelimit-Unified-5h-Reset"); v != "" {
|
||||
if ts, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
t.fiveHour.ResetsAt = time.Unix(ts, 0).UTC().Truncate(time.Minute)
|
||||
}
|
||||
}
|
||||
if v := h.Get("Anthropic-Ratelimit-Unified-7d-Utilization"); v != "" {
|
||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
t.sevenDay.Utilization = f * 100
|
||||
}
|
||||
}
|
||||
if v := h.Get("Anthropic-Ratelimit-Unified-7d-Reset"); v != "" {
|
||||
if ts, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
t.sevenDay.ResetsAt = time.Unix(ts, 0).UTC().Truncate(time.Minute)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FiveHour returns a snapshot of the 5-hour window.
|
||||
func (t *Tracker) FiveHour() Snapshot {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return Snapshot{Utilization: t.fiveHour.Utilization, ResetsAt: t.fiveHour.ResetsAt}
|
||||
}
|
||||
|
||||
// SevenDay returns a snapshot of the 7-day window.
|
||||
func (t *Tracker) SevenDay() Snapshot {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return Snapshot{Utilization: t.sevenDay.Utilization, ResetsAt: t.sevenDay.ResetsAt}
|
||||
}
|
||||
|
||||
// Sonnet returns a snapshot of the 7-day sonnet window.
|
||||
func (t *Tracker) Sonnet() Snapshot {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return Snapshot{Utilization: t.sonnet.Utilization, ResetsAt: t.sonnet.ResetsAt}
|
||||
}
|
||||
|
||||
// Extra returns the current extra usage state.
|
||||
func (t *Tracker) Extra() ExtraUsage {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return t.extra
|
||||
}
|
||||
|
||||
func (t *Tracker) poll(ctx context.Context) {
|
||||
token := t.tokenFn()
|
||||
if token == "" {
|
||||
return
|
||||
}
|
||||
|
||||
usage, err := fetchUsage(ctx, token)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("usage poll failed")
|
||||
return
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if usage.FiveHour != nil {
|
||||
t.updateWindow(&t.fiveHour, usage.FiveHour)
|
||||
}
|
||||
if usage.SevenDay != nil {
|
||||
t.updateWindow(&t.sevenDay, usage.SevenDay)
|
||||
}
|
||||
if usage.SevenDaySonnet != nil {
|
||||
t.updateWindow(&t.sonnet, usage.SevenDaySonnet)
|
||||
}
|
||||
if usage.ExtraUsage != nil {
|
||||
t.extra = *usage.ExtraUsage
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Float64("5h_util", t.fiveHour.Utilization).
|
||||
Time("5h_resets", t.fiveHour.ResetsAt).
|
||||
Float64("7d_util", t.sevenDay.Utilization).
|
||||
Time("7d_resets", t.sevenDay.ResetsAt).
|
||||
Msg("usage poll")
|
||||
|
||||
if t.fiveHour.Utilization > 80 {
|
||||
log.Warn().Float64("utilization", t.fiveHour.Utilization).Time("resets_at", t.fiveHour.ResetsAt).Msg("5h window utilization high")
|
||||
}
|
||||
if t.sevenDay.Utilization > 80 {
|
||||
log.Warn().Float64("utilization", t.sevenDay.Utilization).Time("resets_at", t.sevenDay.ResetsAt).Msg("7d window utilization high")
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tracker) updateWindow(w *Window, rl *RateLimit) {
|
||||
if rl.Utilization != nil {
|
||||
w.Utilization = *rl.Utilization
|
||||
}
|
||||
if rl.ResetsAt != nil {
|
||||
parsed, err := time.Parse(time.RFC3339Nano, *rl.ResetsAt)
|
||||
if err != nil {
|
||||
parsed, err = time.Parse(time.RFC3339, *rl.ResetsAt)
|
||||
}
|
||||
if err == nil {
|
||||
w.ResetsAt = parsed.UTC().Truncate(time.Minute)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,278 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewTracker(t *testing.T) {
|
||||
called := false
|
||||
tr := NewTracker(func() string {
|
||||
called = true
|
||||
return "tok"
|
||||
})
|
||||
if tr == nil {
|
||||
t.Fatal("NewTracker returned nil")
|
||||
}
|
||||
// tokenFn stored but not called during construction
|
||||
if called {
|
||||
t.Error("tokenFn should not be called by NewTracker")
|
||||
}
|
||||
// Invoke to verify it's wired
|
||||
if got := tr.tokenFn(); got != "tok" {
|
||||
t.Errorf("tokenFn() = %q, want tok", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFromHeaders_Full(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
h := http.Header{}
|
||||
h.Set("Anthropic-Ratelimit-Unified-5h-Utilization", "0.42")
|
||||
h.Set("Anthropic-Ratelimit-Unified-5h-Reset", "1700000000")
|
||||
h.Set("Anthropic-Ratelimit-Unified-7d-Utilization", "0.75")
|
||||
h.Set("Anthropic-Ratelimit-Unified-7d-Reset", "1700100000")
|
||||
|
||||
tr.UpdateFromHeaders(h)
|
||||
|
||||
fh := tr.FiveHour()
|
||||
if fh.Utilization != 42.0 {
|
||||
t.Errorf("FiveHour.Utilization = %f, want 42.0", fh.Utilization)
|
||||
}
|
||||
wantReset5h := time.Unix(1700000000, 0).UTC().Truncate(time.Minute)
|
||||
if !fh.ResetsAt.Equal(wantReset5h) {
|
||||
t.Errorf("FiveHour.ResetsAt = %v, want %v", fh.ResetsAt, wantReset5h)
|
||||
}
|
||||
|
||||
sd := tr.SevenDay()
|
||||
if sd.Utilization != 75.0 {
|
||||
t.Errorf("SevenDay.Utilization = %f, want 75.0", sd.Utilization)
|
||||
}
|
||||
wantReset7d := time.Unix(1700100000, 0).UTC().Truncate(time.Minute)
|
||||
if !sd.ResetsAt.Equal(wantReset7d) {
|
||||
t.Errorf("SevenDay.ResetsAt = %v, want %v", sd.ResetsAt, wantReset7d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFromHeaders_Partial(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
// Only set 5h utilization, no reset, no 7d
|
||||
h := http.Header{}
|
||||
h.Set("Anthropic-Ratelimit-Unified-5h-Utilization", "0.33")
|
||||
tr.UpdateFromHeaders(h)
|
||||
|
||||
fh := tr.FiveHour()
|
||||
if fh.Utilization != 33.0 {
|
||||
t.Errorf("FiveHour.Utilization = %f, want 33.0", fh.Utilization)
|
||||
}
|
||||
if !fh.ResetsAt.IsZero() {
|
||||
t.Errorf("FiveHour.ResetsAt should be zero, got %v", fh.ResetsAt)
|
||||
}
|
||||
|
||||
sd := tr.SevenDay()
|
||||
if sd.Utilization != 0 {
|
||||
t.Errorf("SevenDay.Utilization = %f, want 0", sd.Utilization)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFromHeaders_Missing(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
// Pre-set some state
|
||||
tr.mu.Lock()
|
||||
tr.fiveHour.Utilization = 50.0
|
||||
tr.mu.Unlock()
|
||||
|
||||
// Update with empty headers — should not change state
|
||||
tr.UpdateFromHeaders(http.Header{})
|
||||
|
||||
fh := tr.FiveHour()
|
||||
if fh.Utilization != 50.0 {
|
||||
t.Errorf("FiveHour.Utilization = %f, want 50.0 (unchanged)", fh.Utilization)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFromHeaders_InvalidValues(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
h := http.Header{}
|
||||
h.Set("Anthropic-Ratelimit-Unified-5h-Utilization", "not-a-number")
|
||||
h.Set("Anthropic-Ratelimit-Unified-5h-Reset", "not-a-timestamp")
|
||||
|
||||
tr.UpdateFromHeaders(h)
|
||||
|
||||
fh := tr.FiveHour()
|
||||
if fh.Utilization != 0 {
|
||||
t.Errorf("Utilization should stay 0 for invalid input, got %f", fh.Utilization)
|
||||
}
|
||||
if !fh.ResetsAt.IsZero() {
|
||||
t.Errorf("ResetsAt should stay zero for invalid input, got %v", fh.ResetsAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSonnet_Snapshot(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
// Sonnet is only set via poll/updateWindow, not UpdateFromHeaders
|
||||
// Verify it starts at zero
|
||||
s := tr.Sonnet()
|
||||
if s.Utilization != 0 {
|
||||
t.Errorf("Sonnet.Utilization = %f, want 0", s.Utilization)
|
||||
}
|
||||
if !s.ResetsAt.IsZero() {
|
||||
t.Errorf("Sonnet.ResetsAt should be zero, got %v", s.ResetsAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtra_Default(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
extra := tr.Extra()
|
||||
if extra.IsEnabled {
|
||||
t.Error("Extra.IsEnabled should be false by default")
|
||||
}
|
||||
if extra.MonthlyLimit != nil {
|
||||
t.Error("Extra.MonthlyLimit should be nil by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateWindow(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
util *float64
|
||||
resetsAt *string
|
||||
wantUtil float64
|
||||
wantResetOK bool
|
||||
}{
|
||||
{
|
||||
name: "both fields",
|
||||
util: float64Ptr(65.5),
|
||||
resetsAt: stringPtr("2024-01-15T10:30:45Z"),
|
||||
wantUtil: 65.5,
|
||||
wantResetOK: true,
|
||||
},
|
||||
{
|
||||
name: "utilization only",
|
||||
util: float64Ptr(30.0),
|
||||
resetsAt: nil,
|
||||
wantUtil: 30.0,
|
||||
wantResetOK: false,
|
||||
},
|
||||
{
|
||||
name: "reset only (RFC3339Nano)",
|
||||
util: nil,
|
||||
resetsAt: stringPtr("2024-06-01T12:00:00.123456789Z"),
|
||||
wantUtil: 0,
|
||||
wantResetOK: true,
|
||||
},
|
||||
{
|
||||
name: "nil both",
|
||||
util: nil,
|
||||
resetsAt: nil,
|
||||
wantUtil: 0,
|
||||
wantResetOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := &Window{}
|
||||
rl := &RateLimit{
|
||||
Utilization: tt.util,
|
||||
ResetsAt: tt.resetsAt,
|
||||
}
|
||||
tr.updateWindow(w, rl)
|
||||
|
||||
if w.Utilization != tt.wantUtil {
|
||||
t.Errorf("Utilization = %f, want %f", w.Utilization, tt.wantUtil)
|
||||
}
|
||||
if tt.wantResetOK {
|
||||
if w.ResetsAt.IsZero() {
|
||||
t.Error("ResetsAt should be set")
|
||||
}
|
||||
// Verify truncation to minute
|
||||
if w.ResetsAt.Second() != 0 || w.ResetsAt.Nanosecond() != 0 {
|
||||
t.Errorf("ResetsAt not truncated to minute: %v", w.ResetsAt)
|
||||
}
|
||||
if w.ResetsAt.Location() != time.UTC {
|
||||
t.Errorf("ResetsAt not in UTC: %v", w.ResetsAt.Location())
|
||||
}
|
||||
} else if tt.resetsAt == nil {
|
||||
if !w.ResetsAt.IsZero() {
|
||||
t.Errorf("ResetsAt should be zero when input is nil, got %v", w.ResetsAt)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateWindow_InvalidTime(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
w := &Window{}
|
||||
bad := "not-a-time"
|
||||
rl := &RateLimit{ResetsAt: &bad}
|
||||
tr.updateWindow(w, rl)
|
||||
if !w.ResetsAt.IsZero() {
|
||||
t.Errorf("ResetsAt should stay zero for invalid time, got %v", w.ResetsAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoll_SetsStateFromUsageResponse(t *testing.T) {
|
||||
// White-box: directly set fields that poll would set after fetchUsage
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
// Simulate what poll does after fetching usage
|
||||
tr.mu.Lock()
|
||||
usage := &UsageResponse{
|
||||
FiveHour: &RateLimit{Utilization: float64Ptr(55.5), ResetsAt: stringPtr("2024-03-01T08:00:00Z")},
|
||||
SevenDay: &RateLimit{Utilization: float64Ptr(22.3), ResetsAt: stringPtr("2024-03-07T00:00:00Z")},
|
||||
SevenDaySonnet: &RateLimit{Utilization: float64Ptr(10.0), ResetsAt: stringPtr("2024-03-07T00:00:00Z")},
|
||||
ExtraUsage: &ExtraUsage{IsEnabled: true, MonthlyLimit: float64Ptr(100.0), UsedCredits: float64Ptr(42.5)},
|
||||
}
|
||||
if usage.FiveHour != nil {
|
||||
tr.updateWindow(&tr.fiveHour, usage.FiveHour)
|
||||
}
|
||||
if usage.SevenDay != nil {
|
||||
tr.updateWindow(&tr.sevenDay, usage.SevenDay)
|
||||
}
|
||||
if usage.SevenDaySonnet != nil {
|
||||
tr.updateWindow(&tr.sonnet, usage.SevenDaySonnet)
|
||||
}
|
||||
if usage.ExtraUsage != nil {
|
||||
tr.extra = *usage.ExtraUsage
|
||||
}
|
||||
tr.mu.Unlock()
|
||||
|
||||
fh := tr.FiveHour()
|
||||
if fh.Utilization != 55.5 {
|
||||
t.Errorf("FiveHour.Utilization = %f, want 55.5", fh.Utilization)
|
||||
}
|
||||
|
||||
sd := tr.SevenDay()
|
||||
if sd.Utilization != 22.3 {
|
||||
t.Errorf("SevenDay.Utilization = %f, want 22.3", sd.Utilization)
|
||||
}
|
||||
|
||||
sn := tr.Sonnet()
|
||||
if sn.Utilization != 10.0 {
|
||||
t.Errorf("Sonnet.Utilization = %f, want 10.0", sn.Utilization)
|
||||
}
|
||||
|
||||
extra := tr.Extra()
|
||||
if !extra.IsEnabled {
|
||||
t.Error("Extra.IsEnabled = false, want true")
|
||||
}
|
||||
if extra.MonthlyLimit == nil || *extra.MonthlyLimit != 100.0 {
|
||||
t.Errorf("Extra.MonthlyLimit = %v, want 100.0", extra.MonthlyLimit)
|
||||
}
|
||||
if extra.UsedCredits == nil || *extra.UsedCredits != 42.5 {
|
||||
t.Errorf("Extra.UsedCredits = %v, want 42.5", extra.UsedCredits)
|
||||
}
|
||||
}
|
||||
|
||||
func float64Ptr(f float64) *float64 { return &f }
|
||||
func stringPtr(s string) *string { return &s }
|
||||
@@ -0,0 +1,67 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/transport"
|
||||
"github.com/fujin/anthropic-proxy/internal/version"
|
||||
)
|
||||
|
||||
var usageClient = transport.NewHTTPClient(10 * time.Second)
|
||||
|
||||
const usageURL = "https://api.anthropic.com/api/oauth/usage"
|
||||
|
||||
type RateLimit struct {
|
||||
Utilization *float64 `json:"utilization"` // 0-100
|
||||
ResetsAt *string `json:"resets_at"` // ISO 8601
|
||||
}
|
||||
|
||||
type ExtraUsage struct {
|
||||
IsEnabled bool `json:"is_enabled"`
|
||||
MonthlyLimit *float64 `json:"monthly_limit"`
|
||||
UsedCredits *float64 `json:"used_credits"`
|
||||
Utilization *float64 `json:"utilization"`
|
||||
}
|
||||
|
||||
type UsageResponse struct {
|
||||
FiveHour *RateLimit `json:"five_hour"`
|
||||
SevenDay *RateLimit `json:"seven_day"`
|
||||
SevenDaySonnet *RateLimit `json:"seven_day_sonnet"`
|
||||
ExtraUsage *ExtraUsage `json:"extra_usage"`
|
||||
}
|
||||
|
||||
func fetchUsage(ctx context.Context, token string) (*UsageResponse, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, usageURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
req.Header.Set("User-Agent", "claude-cli/"+version.ClaudeCodeFallback)
|
||||
|
||||
resp, err := usageClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("usage returned %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var usage UsageResponse
|
||||
if err := json.Unmarshal(body, &usage); err != nil {
|
||||
return nil, fmt.Errorf("decode: %w", err)
|
||||
}
|
||||
return &usage, nil
|
||||
}
|
||||
@@ -0,0 +1,241 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUsageResponse_FullJSON(t *testing.T) {
|
||||
raw := `{
|
||||
"five_hour": {"utilization": 42.5, "resets_at": "2024-01-15T10:30:00Z"},
|
||||
"seven_day": {"utilization": 75.0, "resets_at": "2024-01-20T00:00:00Z"},
|
||||
"seven_day_sonnet": {"utilization": 10.0, "resets_at": "2024-01-20T00:00:00Z"},
|
||||
"extra_usage": {
|
||||
"is_enabled": true,
|
||||
"monthly_limit": 100.0,
|
||||
"used_credits": 42.5,
|
||||
"utilization": 42.5
|
||||
}
|
||||
}`
|
||||
|
||||
var resp UsageResponse
|
||||
if err := json.Unmarshal([]byte(raw), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if resp.FiveHour == nil {
|
||||
t.Fatal("FiveHour is nil")
|
||||
}
|
||||
if resp.FiveHour.Utilization == nil || *resp.FiveHour.Utilization != 42.5 {
|
||||
t.Errorf("FiveHour.Utilization = %v, want 42.5", resp.FiveHour.Utilization)
|
||||
}
|
||||
if resp.FiveHour.ResetsAt == nil || *resp.FiveHour.ResetsAt != "2024-01-15T10:30:00Z" {
|
||||
t.Errorf("FiveHour.ResetsAt = %v", resp.FiveHour.ResetsAt)
|
||||
}
|
||||
|
||||
if resp.SevenDay == nil {
|
||||
t.Fatal("SevenDay is nil")
|
||||
}
|
||||
if resp.SevenDay.Utilization == nil || *resp.SevenDay.Utilization != 75.0 {
|
||||
t.Errorf("SevenDay.Utilization = %v, want 75.0", resp.SevenDay.Utilization)
|
||||
}
|
||||
|
||||
if resp.SevenDaySonnet == nil {
|
||||
t.Fatal("SevenDaySonnet is nil")
|
||||
}
|
||||
if resp.SevenDaySonnet.Utilization == nil || *resp.SevenDaySonnet.Utilization != 10.0 {
|
||||
t.Errorf("SevenDaySonnet.Utilization = %v", resp.SevenDaySonnet.Utilization)
|
||||
}
|
||||
|
||||
if resp.ExtraUsage == nil {
|
||||
t.Fatal("ExtraUsage is nil")
|
||||
}
|
||||
if !resp.ExtraUsage.IsEnabled {
|
||||
t.Error("ExtraUsage.IsEnabled = false, want true")
|
||||
}
|
||||
if resp.ExtraUsage.MonthlyLimit == nil || *resp.ExtraUsage.MonthlyLimit != 100.0 {
|
||||
t.Errorf("ExtraUsage.MonthlyLimit = %v, want 100.0", resp.ExtraUsage.MonthlyLimit)
|
||||
}
|
||||
if resp.ExtraUsage.UsedCredits == nil || *resp.ExtraUsage.UsedCredits != 42.5 {
|
||||
t.Errorf("ExtraUsage.UsedCredits = %v, want 42.5", resp.ExtraUsage.UsedCredits)
|
||||
}
|
||||
if resp.ExtraUsage.Utilization == nil || *resp.ExtraUsage.Utilization != 42.5 {
|
||||
t.Errorf("ExtraUsage.Utilization = %v, want 42.5", resp.ExtraUsage.Utilization)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageResponse_PartialJSON(t *testing.T) {
|
||||
raw := `{"five_hour": {"utilization": 10.0}}`
|
||||
|
||||
var resp UsageResponse
|
||||
if err := json.Unmarshal([]byte(raw), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if resp.FiveHour == nil {
|
||||
t.Fatal("FiveHour is nil")
|
||||
}
|
||||
if resp.FiveHour.Utilization == nil || *resp.FiveHour.Utilization != 10.0 {
|
||||
t.Errorf("FiveHour.Utilization = %v, want 10.0", resp.FiveHour.Utilization)
|
||||
}
|
||||
if resp.FiveHour.ResetsAt != nil {
|
||||
t.Errorf("FiveHour.ResetsAt should be nil, got %v", resp.FiveHour.ResetsAt)
|
||||
}
|
||||
if resp.SevenDay != nil {
|
||||
t.Errorf("SevenDay should be nil, got %v", resp.SevenDay)
|
||||
}
|
||||
if resp.SevenDaySonnet != nil {
|
||||
t.Errorf("SevenDaySonnet should be nil, got %v", resp.SevenDaySonnet)
|
||||
}
|
||||
if resp.ExtraUsage != nil {
|
||||
t.Errorf("ExtraUsage should be nil, got %v", resp.ExtraUsage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageResponse_EmptyJSON(t *testing.T) {
|
||||
var resp UsageResponse
|
||||
if err := json.Unmarshal([]byte(`{}`), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if resp.FiveHour != nil || resp.SevenDay != nil || resp.SevenDaySonnet != nil || resp.ExtraUsage != nil {
|
||||
t.Error("all fields should be nil for empty JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchUsage_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify request headers
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer test-token" {
|
||||
t.Errorf("Authorization = %q, want 'Bearer test-token'", got)
|
||||
}
|
||||
if got := r.Header.Get("Content-Type"); got != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want application/json", got)
|
||||
}
|
||||
if got := r.Header.Get("anthropic-beta"); got != "oauth-2025-04-20" {
|
||||
t.Errorf("anthropic-beta = %q, want oauth-2025-04-20", got)
|
||||
}
|
||||
if got := r.Header.Get("User-Agent"); got != "claude-cli/2.1.92" {
|
||||
t.Errorf("User-Agent = %q, want claude-cli/2.1.92", got)
|
||||
}
|
||||
if r.Method != http.MethodGet {
|
||||
t.Errorf("Method = %q, want GET", r.Method)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{
|
||||
"five_hour": {"utilization": 50.0, "resets_at": "2024-01-15T10:00:00Z"},
|
||||
"seven_day": {"utilization": 25.0, "resets_at": "2024-01-20T00:00:00Z"}
|
||||
}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// fetchUsage hardcodes usageURL, but we can test via the mock by temporarily
|
||||
// using http.DefaultClient's transport. Instead, we test the handler directly.
|
||||
// The httptest server validates our request expectations above.
|
||||
|
||||
// Make a real request to the test server to verify handler behavior
|
||||
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
req.Header.Set("User-Agent", "claude-cli/2.1.92")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var usage UsageResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&usage); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
|
||||
if usage.FiveHour == nil || *usage.FiveHour.Utilization != 50.0 {
|
||||
t.Errorf("FiveHour.Utilization = %v, want 50.0", usage.FiveHour)
|
||||
}
|
||||
if usage.SevenDay == nil || *usage.SevenDay.Utilization != 25.0 {
|
||||
t.Errorf("SevenDay.Utilization = %v, want 25.0", usage.SevenDay)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchUsage_Non200(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = w.Write([]byte(`{"error":"forbidden"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Simulate the error path: non-200 returns error with status and body
|
||||
resp, err := http.Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
t.Fatal("expected non-200 status")
|
||||
}
|
||||
|
||||
// This matches the fetchUsage error format
|
||||
body := make([]byte, 1024)
|
||||
n, _ := resp.Body.Read(body)
|
||||
bodyStr := string(body[:n])
|
||||
if !strings.Contains(bodyStr, "forbidden") {
|
||||
t.Errorf("body = %q, want it to contain 'forbidden'", bodyStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchUsage_MalformedJSON(t *testing.T) {
|
||||
raw := `{not valid json`
|
||||
var resp UsageResponse
|
||||
err := json.Unmarshal([]byte(raw), &resp)
|
||||
if err == nil {
|
||||
t.Fatal("expected decode error for malformed JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimit_NilFields(t *testing.T) {
|
||||
raw := `{}`
|
||||
var rl RateLimit
|
||||
if err := json.Unmarshal([]byte(raw), &rl); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if rl.Utilization != nil {
|
||||
t.Errorf("Utilization should be nil, got %v", rl.Utilization)
|
||||
}
|
||||
if rl.ResetsAt != nil {
|
||||
t.Errorf("ResetsAt should be nil, got %v", rl.ResetsAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtraUsage_JSON(t *testing.T) {
|
||||
raw := `{"is_enabled":false,"monthly_limit":null,"used_credits":null,"utilization":null}`
|
||||
var eu ExtraUsage
|
||||
if err := json.Unmarshal([]byte(raw), &eu); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if eu.IsEnabled {
|
||||
t.Error("IsEnabled should be false")
|
||||
}
|
||||
if eu.MonthlyLimit != nil {
|
||||
t.Error("MonthlyLimit should be nil")
|
||||
}
|
||||
if eu.UsedCredits != nil {
|
||||
t.Error("UsedCredits should be nil")
|
||||
}
|
||||
if eu.Utilization != nil {
|
||||
t.Error("Utilization should be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageURL_Constant(t *testing.T) {
|
||||
if usageURL != "https://api.anthropic.com/api/oauth/usage" {
|
||||
t.Errorf("usageURL = %q, want https://api.anthropic.com/api/oauth/usage", usageURL)
|
||||
}
|
||||
}
|
||||
@@ -3,16 +3,19 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/auth"
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
"github.com/fujin/anthropic-proxy/internal/logging"
|
||||
"github.com/fujin/anthropic-proxy/internal/proxy"
|
||||
"github.com/fujin/anthropic-proxy/internal/ratelimit"
|
||||
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
@@ -23,7 +26,7 @@ type Server struct {
|
||||
apiKeys atomic.Pointer[map[string]struct{}]
|
||||
}
|
||||
|
||||
func New(cfg *config.Config, pool *auth.Pool, profile *proxy.SniffedProfile) *Server {
|
||||
func New(cfg *config.Config, pool *auth.Pool, profile *proxy.SniffedProfile, tracker *ratelimit.Tracker, metricsHandler http.Handler) *Server {
|
||||
s := &Server{configPath: "config.yaml"}
|
||||
|
||||
san := proxy.NewSanitizer(cfg.Sanitize)
|
||||
@@ -36,21 +39,29 @@ func New(cfg *config.Config, pool *auth.Pool, profile *proxy.SniffedProfile) *Se
|
||||
engine := gin.New()
|
||||
engine.Use(gin.Recovery())
|
||||
engine.Use(corsMiddleware())
|
||||
if cfg.Telemetry.Export.Enabled() {
|
||||
engine.Use(otelgin.Middleware(cfg.Telemetry.ServiceName))
|
||||
}
|
||||
engine.Use(s.authMiddleware())
|
||||
engine.Use(logging.GinRequestLogger())
|
||||
|
||||
handler := proxy.HandleMessages(pool, profile, func() *proxy.Sanitizer {
|
||||
return s.sanitizer.Load()
|
||||
})
|
||||
}, tracker)
|
||||
engine.POST("/v1/messages", handler)
|
||||
engine.POST("/messages", handler)
|
||||
|
||||
if metricsHandler != nil {
|
||||
engine.GET("/metrics", gin.WrapH(metricsHandler))
|
||||
}
|
||||
|
||||
engine.POST("/reload", s.handleReload())
|
||||
engine.POST("/debug/refresh", handleDebugRefresh(pool))
|
||||
engine.GET("/healthz", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
engine.NoRoute(func(c *gin.Context) {
|
||||
log.Printf("unmatched route: %s %s", c.Request.Method, c.Request.URL.Path)
|
||||
log.Warn().Str("method", c.Request.Method).Str("path", c.Request.URL.Path).Msg("unmatched route")
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
|
||||
})
|
||||
|
||||
@@ -85,8 +96,7 @@ func (s *Server) handleReload() gin.HandlerFunc {
|
||||
keys := makeKeySet(cfg.APIKeys)
|
||||
s.apiKeys.Store(&keys)
|
||||
|
||||
log.Printf("config reloaded: %d tool renames, %d system rules, %d body rules, %d api keys",
|
||||
len(cfg.Sanitize.Tools), len(cfg.Sanitize.System), len(cfg.Sanitize.Body), len(cfg.APIKeys))
|
||||
log.Info().Int("tool_renames", len(cfg.Sanitize.Tools)).Int("system_rules", len(cfg.Sanitize.System)).Int("body_rules", len(cfg.Sanitize.Body)).Int("api_keys", len(cfg.APIKeys)).Msg("config reloaded")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "reloaded",
|
||||
@@ -128,10 +138,16 @@ func corsMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// authBypassPaths lists endpoints that do not require API key authentication.
|
||||
var authBypassPaths = map[string]bool{
|
||||
"/healthz": true,
|
||||
"/reload": true,
|
||||
"/metrics": true,
|
||||
}
|
||||
|
||||
func (s *Server) authMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
if path == "/healthz" || path == "/reload" {
|
||||
if authBypassPaths[c.Request.URL.Path] {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -0,0 +1,529 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// --- makeKeySet ---
|
||||
|
||||
func TestMakeKeySet(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keys []string
|
||||
wantN int
|
||||
lookup string
|
||||
found bool
|
||||
}{
|
||||
{
|
||||
name: "nil slice returns empty map",
|
||||
keys: nil,
|
||||
wantN: 0,
|
||||
},
|
||||
{
|
||||
name: "empty slice returns empty map",
|
||||
keys: []string{},
|
||||
wantN: 0,
|
||||
},
|
||||
{
|
||||
name: "single key",
|
||||
keys: []string{"key1"},
|
||||
wantN: 1,
|
||||
lookup: "key1",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "multiple keys",
|
||||
keys: []string{"a", "b", "c"},
|
||||
wantN: 3,
|
||||
lookup: "b",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "missing key not found",
|
||||
keys: []string{"a", "b"},
|
||||
wantN: 2,
|
||||
lookup: "c",
|
||||
found: false,
|
||||
},
|
||||
{
|
||||
name: "duplicate keys deduped",
|
||||
keys: []string{"x", "x", "x"},
|
||||
wantN: 1,
|
||||
lookup: "x",
|
||||
found: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := makeKeySet(tt.keys)
|
||||
if len(got) != tt.wantN {
|
||||
t.Errorf("len(makeKeySet) = %d, want %d", len(got), tt.wantN)
|
||||
}
|
||||
if tt.lookup != "" {
|
||||
_, ok := got[tt.lookup]
|
||||
if ok != tt.found {
|
||||
t.Errorf("keySet[%q] found=%v, want %v", tt.lookup, ok, tt.found)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- corsMiddleware ---
|
||||
|
||||
func TestCorsMiddleware_SetsHeaders(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
handler := corsMiddleware()
|
||||
handler(c)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
|
||||
t.Errorf("Access-Control-Allow-Origin = %q, want %q", got, "*")
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, PUT, DELETE, OPTIONS" {
|
||||
t.Errorf("Access-Control-Allow-Methods = %q, want %q", got, "GET, POST, PUT, DELETE, OPTIONS")
|
||||
}
|
||||
allowHeaders := w.Header().Get("Access-Control-Allow-Headers")
|
||||
for _, h := range []string{"x-api-key", "anthropic-version", "anthropic-beta", "Authorization", "Content-Type", "Origin"} {
|
||||
if !containsSubstring(allowHeaders, h) {
|
||||
t.Errorf("Access-Control-Allow-Headers %q missing %q", allowHeaders, h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsMiddleware_OptionsReturns204(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodOptions, "/v1/messages", nil)
|
||||
|
||||
handler := corsMiddleware()
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Errorf("OPTIONS status = %d, want %d", w.Code, http.StatusNoContent)
|
||||
}
|
||||
if !c.IsAborted() {
|
||||
t.Error("expected context to be aborted on OPTIONS")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsMiddleware_NonOptionsDoesNotAbort(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
handler := corsMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("POST request should not be aborted")
|
||||
}
|
||||
}
|
||||
|
||||
// --- authMiddleware ---
|
||||
|
||||
func newServerWithKeys(keys []string) *Server {
|
||||
s := &Server{}
|
||||
keySet := makeKeySet(keys)
|
||||
s.apiKeys.Store(&keySet)
|
||||
return s
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_BypassPaths(t *testing.T) {
|
||||
paths := []string{"/healthz", "/reload", "/metrics"}
|
||||
s := newServerWithKeys(nil) // no keys — would reject if auth checked
|
||||
|
||||
for _, path := range paths {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, path, nil)
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Errorf("path %q should bypass auth but was aborted", path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_MissingToken_401(t *testing.T) {
|
||||
s := newServerWithKeys([]string{"valid-key"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusUnauthorized)
|
||||
}
|
||||
if !c.IsAborted() {
|
||||
t.Error("expected aborted on missing token")
|
||||
}
|
||||
|
||||
var body map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if body["error"] != "missing authentication" {
|
||||
t.Errorf("error = %q, want %q", body["error"], "missing authentication")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_InvalidKey_403(t *testing.T) {
|
||||
s := newServerWithKeys([]string{"valid-key"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("x-api-key", "wrong-key")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
if !c.IsAborted() {
|
||||
t.Error("expected aborted on invalid key")
|
||||
}
|
||||
|
||||
var body map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if body["error"] != "invalid api key" {
|
||||
t.Errorf("error = %q, want %q", body["error"], "invalid api key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_ValidKey_XApiKey(t *testing.T) {
|
||||
s := newServerWithKeys([]string{"valid-key"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("x-api-key", "valid-key")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("valid key should not abort")
|
||||
}
|
||||
if w.Code == http.StatusUnauthorized || w.Code == http.StatusForbidden {
|
||||
t.Errorf("unexpected status %d for valid key", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_ValidKey_BearerAuth(t *testing.T) {
|
||||
s := newServerWithKeys([]string{"my-token"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("Authorization", "Bearer my-token")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("valid Bearer token should not abort")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_BearerPrefix_Stripped(t *testing.T) {
|
||||
// The token is "my-token", sent as "Bearer my-token". The middleware should
|
||||
// strip "Bearer " and compare "my-token" against the key set.
|
||||
s := newServerWithKeys([]string{"my-token"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("Authorization", "Bearer my-token")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("expected auth to pass with Bearer-prefixed valid key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_AuthorizationWithoutBearer(t *testing.T) {
|
||||
// If Authorization header doesn't have Bearer prefix, TrimPrefix is a no-op,
|
||||
// so the full header value is used as the token.
|
||||
s := newServerWithKeys([]string{"raw-token-value"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("Authorization", "raw-token-value")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("raw Authorization value matching a key should pass")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_XApiKey_FallbackWhenNoAuthHeader(t *testing.T) {
|
||||
// If Authorization is empty, x-api-key is checked.
|
||||
s := newServerWithKeys([]string{"fallback-key"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("x-api-key", "fallback-key")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("x-api-key fallback should pass")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_AuthorizationPreferredOverXApiKey(t *testing.T) {
|
||||
// Both headers set; Authorization takes precedence.
|
||||
s := newServerWithKeys([]string{"auth-key"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("Authorization", "Bearer auth-key")
|
||||
c.Request.Header.Set("x-api-key", "wrong-key")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("Authorization should take precedence over x-api-key")
|
||||
}
|
||||
}
|
||||
|
||||
// --- handleReload ---
|
||||
|
||||
func TestHandleReload_Success(t *testing.T) {
|
||||
// Create a temp config file
|
||||
tmpFile, err := os.CreateTemp("", "config-*.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp file: %v", err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
configContent := `
|
||||
port: 9999
|
||||
api_keys:
|
||||
- reloaded-key-1
|
||||
- reloaded-key-2
|
||||
sanitize:
|
||||
tools:
|
||||
- from: old_tool
|
||||
to: new_tool
|
||||
system:
|
||||
- match: foo
|
||||
replace: bar
|
||||
body:
|
||||
- match: baz
|
||||
replace: qux
|
||||
`
|
||||
if _, err := tmpFile.WriteString(configContent); err != nil {
|
||||
t.Fatalf("failed to write config: %v", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
s := &Server{configPath: tmpFile.Name()}
|
||||
// Initialize with empty values
|
||||
emptyKeys := makeKeySet(nil)
|
||||
s.apiKeys.Store(&emptyKeys)
|
||||
|
||||
emptySan := &atomic.Pointer[interface{}]{}
|
||||
_ = emptySan // just to show we're aware
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/reload", nil)
|
||||
|
||||
handler := s.handleReload()
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if resp["status"] != "reloaded" {
|
||||
t.Errorf("status = %v, want %q", resp["status"], "reloaded")
|
||||
}
|
||||
|
||||
// Verify api keys were updated
|
||||
keys := s.apiKeys.Load()
|
||||
if _, ok := (*keys)["reloaded-key-1"]; !ok {
|
||||
t.Error("expected reloaded-key-1 in api keys after reload")
|
||||
}
|
||||
if _, ok := (*keys)["reloaded-key-2"]; !ok {
|
||||
t.Error("expected reloaded-key-2 in api keys after reload")
|
||||
}
|
||||
if len(*keys) != 2 {
|
||||
t.Errorf("expected 2 api keys, got %d", len(*keys))
|
||||
}
|
||||
|
||||
// Verify sanitizer was updated
|
||||
san := s.sanitizer.Load()
|
||||
if san == nil {
|
||||
t.Fatal("sanitizer is nil after reload")
|
||||
}
|
||||
|
||||
// Check tool_renames in response
|
||||
if toolRenames, ok := resp["tool_renames"].(float64); !ok || int(toolRenames) != 1 {
|
||||
t.Errorf("tool_renames = %v, want 1", resp["tool_renames"])
|
||||
}
|
||||
if apiKeys, ok := resp["api_keys"].(float64); !ok || int(apiKeys) != 2 {
|
||||
t.Errorf("api_keys = %v, want 2", resp["api_keys"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleReload_InvalidConfig(t *testing.T) {
|
||||
s := &Server{configPath: "/nonexistent/path/config.yaml"}
|
||||
emptyKeys := makeKeySet(nil)
|
||||
s.apiKeys.Store(&emptyKeys)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/reload", nil)
|
||||
|
||||
handler := s.handleReload()
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
if resp["error"] == "" {
|
||||
t.Error("expected non-empty error message")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Full route tests using httptest ---
|
||||
|
||||
func TestHealthzEndpoint(t *testing.T) {
|
||||
engine := gin.New()
|
||||
engine.Use(corsMiddleware())
|
||||
|
||||
s := newServerWithKeys(nil)
|
||||
engine.Use(s.authMiddleware())
|
||||
|
||||
engine.GET("/healthz", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if body["status"] != "ok" {
|
||||
t.Errorf("status = %q, want %q", body["status"], "ok")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_FullRoute_Rejected(t *testing.T) {
|
||||
engine := gin.New()
|
||||
s := newServerWithKeys([]string{"correct-key"})
|
||||
engine.Use(s.authMiddleware())
|
||||
engine.POST("/v1/messages", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
// No auth header
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_FullRoute_Accepted(t *testing.T) {
|
||||
engine := gin.New()
|
||||
s := newServerWithKeys([]string{"correct-key"})
|
||||
engine.Use(s.authMiddleware())
|
||||
engine.POST("/v1/messages", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
req.Header.Set("x-api-key", "correct-key")
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsMiddleware_FullRoute_OptionsRequest(t *testing.T) {
|
||||
engine := gin.New()
|
||||
engine.Use(corsMiddleware())
|
||||
|
||||
s := newServerWithKeys([]string{"key"})
|
||||
engine.Use(s.authMiddleware())
|
||||
|
||||
engine.POST("/v1/messages", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/v1/messages", nil)
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusNoContent)
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
|
||||
t.Errorf("ACAO = %q, want %q", got, "*")
|
||||
}
|
||||
}
|
||||
|
||||
// helper
|
||||
func containsSubstring(s, sub string) bool {
|
||||
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub))
|
||||
}
|
||||
|
||||
func containsStr(s, sub string) bool {
|
||||
for i := 0; i <= len(s)-len(sub); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
otellog "go.opentelemetry.io/otel/log"
|
||||
sdklog "go.opentelemetry.io/otel/sdk/log"
|
||||
)
|
||||
|
||||
// LogBridge implements io.Writer and forwards zerolog JSON lines to the
|
||||
// OTel LoggerProvider. It is used as an extra writer in zerolog's MultiWriter
|
||||
// so that logs go to both file and OTLP.
|
||||
type LogBridge struct {
|
||||
provider *sdklog.LoggerProvider
|
||||
}
|
||||
|
||||
func (b *LogBridge) Write(p []byte) (n int, err error) {
|
||||
var entry map[string]interface{}
|
||||
if err := json.Unmarshal(p, &entry); err != nil {
|
||||
return len(p), nil // skip malformed lines
|
||||
}
|
||||
|
||||
logger := b.provider.Logger("zerolog")
|
||||
|
||||
var rec otellog.Record
|
||||
rec.SetTimestamp(time.Now())
|
||||
|
||||
if msg, ok := entry["message"].(string); ok {
|
||||
rec.SetBody(otellog.StringValue(msg))
|
||||
}
|
||||
|
||||
if lvl, ok := entry["level"].(string); ok {
|
||||
rec.SetSeverity(mapSeverity(lvl))
|
||||
}
|
||||
|
||||
// Forward all fields as attributes
|
||||
attrs := make([]otellog.KeyValue, 0, len(entry))
|
||||
for k, v := range entry {
|
||||
if k == "message" || k == "level" || k == "time" {
|
||||
continue
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
attrs = append(attrs, otellog.String(k, val))
|
||||
case float64:
|
||||
attrs = append(attrs, otellog.Float64(k, val))
|
||||
case bool:
|
||||
attrs = append(attrs, otellog.Bool(k, val))
|
||||
default:
|
||||
b, _ := json.Marshal(val)
|
||||
attrs = append(attrs, otellog.String(k, string(b)))
|
||||
}
|
||||
}
|
||||
rec.AddAttributes(attrs...)
|
||||
|
||||
logger.Emit(context.Background(), rec)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func mapSeverity(level string) otellog.Severity {
|
||||
switch level {
|
||||
case "trace":
|
||||
return otellog.SeverityTrace
|
||||
case "debug":
|
||||
return otellog.SeverityDebug
|
||||
case "info":
|
||||
return otellog.SeverityInfo
|
||||
case "warn", "warning":
|
||||
return otellog.SeverityWarn
|
||||
case "error":
|
||||
return otellog.SeverityError
|
||||
case "fatal":
|
||||
return otellog.SeverityFatal
|
||||
case "panic":
|
||||
return otellog.SeverityFatal2
|
||||
default:
|
||||
return otellog.SeverityInfo
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,178 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
otellog "go.opentelemetry.io/otel/log"
|
||||
sdklog "go.opentelemetry.io/otel/sdk/log"
|
||||
)
|
||||
|
||||
func TestMapSeverity(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want otellog.Severity
|
||||
}{
|
||||
{"trace", otellog.SeverityTrace},
|
||||
{"debug", otellog.SeverityDebug},
|
||||
{"info", otellog.SeverityInfo},
|
||||
{"warn", otellog.SeverityWarn},
|
||||
{"warning", otellog.SeverityWarn},
|
||||
{"error", otellog.SeverityError},
|
||||
{"fatal", otellog.SeverityFatal},
|
||||
{"panic", otellog.SeverityFatal2},
|
||||
{"unknown", otellog.SeverityInfo},
|
||||
{"", otellog.SeverityInfo},
|
||||
{"INFO", otellog.SeverityInfo}, // uppercase falls to default
|
||||
{"PANIC", otellog.SeverityInfo}, // uppercase falls to default
|
||||
{"gibberish", otellog.SeverityInfo},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run("level_"+tc.input, func(t *testing.T) {
|
||||
got := mapSeverity(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("mapSeverity(%q) = %v, want %v", tc.input, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestBridge(t *testing.T) *LogBridge {
|
||||
t.Helper()
|
||||
provider := sdklog.NewLoggerProvider()
|
||||
t.Cleanup(func() {
|
||||
_ = provider.Shutdown(t.Context())
|
||||
})
|
||||
return &LogBridge{provider: provider}
|
||||
}
|
||||
|
||||
func TestLogBridgeWrite(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{} // will be marshaled to JSON; use string for raw input
|
||||
raw string // if non-empty, use this directly instead of marshaling input
|
||||
}{
|
||||
{
|
||||
name: "valid_json_with_message_level_and_extras",
|
||||
input: map[string]interface{}{
|
||||
"message": "request handled",
|
||||
"level": "info",
|
||||
"method": "GET",
|
||||
"status": float64(200),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "message_only_no_level",
|
||||
input: map[string]interface{}{
|
||||
"message": "hello world",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "level_only_no_message",
|
||||
input: map[string]interface{}{
|
||||
"level": "error",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty_json_object",
|
||||
input: map[string]interface{}{},
|
||||
},
|
||||
{
|
||||
name: "string_float64_bool_attributes",
|
||||
input: map[string]interface{}{
|
||||
"message": "test",
|
||||
"level": "debug",
|
||||
"str_val": "hello",
|
||||
"num_val": float64(3.14),
|
||||
"bool_val": true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complex_nested_object_attribute",
|
||||
input: map[string]interface{}{
|
||||
"message": "nested",
|
||||
"level": "warn",
|
||||
"nested": map[string]interface{}{"foo": "bar", "n": float64(1)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "time_field_skipped_in_attributes",
|
||||
input: map[string]interface{}{
|
||||
"message": "with time",
|
||||
"level": "info",
|
||||
"time": "2025-01-01T00:00:00Z",
|
||||
"extra": "kept",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "malformed_json",
|
||||
raw: "this is not json at all",
|
||||
},
|
||||
{
|
||||
name: "malformed_json_partial",
|
||||
raw: `{"broken":`,
|
||||
},
|
||||
{
|
||||
name: "array_attribute_marshaled_as_string",
|
||||
input: map[string]interface{}{
|
||||
"message": "arrays",
|
||||
"tags": []interface{}{"a", "b"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "null_value_attribute",
|
||||
input: map[string]interface{}{
|
||||
"message": "nulls",
|
||||
"val": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
bridge := newTestBridge(t)
|
||||
|
||||
var p []byte
|
||||
if tc.raw != "" {
|
||||
p = []byte(tc.raw)
|
||||
} else {
|
||||
var err error
|
||||
p, err = json.Marshal(tc.input)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal test input: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
n, err := bridge.Write(p)
|
||||
if n != len(p) {
|
||||
t.Errorf("Write() returned n=%d, want %d", n, len(p))
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Write() returned err=%v, want nil", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogBridgeWriteAlwaysReturnsLenAndNil(t *testing.T) {
|
||||
bridge := newTestBridge(t)
|
||||
|
||||
inputs := [][]byte{
|
||||
[]byte(`{"message":"ok","level":"info"}`),
|
||||
[]byte(`not json`),
|
||||
[]byte(`{}`),
|
||||
[]byte(``),
|
||||
[]byte(`[]`),
|
||||
}
|
||||
|
||||
for _, p := range inputs {
|
||||
n, err := bridge.Write(p)
|
||||
if n != len(p) {
|
||||
t.Errorf("Write(%q) n=%d, want %d", string(p), n, len(p))
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Write(%q) err=%v, want nil", string(p), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/ratelimit"
|
||||
)
|
||||
|
||||
var (
|
||||
RequestCounter metric.Int64Counter
|
||||
RequestDuration metric.Float64Histogram
|
||||
RequestBodySize metric.Int64Histogram
|
||||
UpstreamErrors metric.Int64Counter
|
||||
TokensInput metric.Int64Counter
|
||||
TokensOutput metric.Int64Counter
|
||||
CredentialCooldowns metric.Int64Counter
|
||||
ActiveCredentials metric.Int64UpDownCounter
|
||||
StreamRequests metric.Int64Counter
|
||||
)
|
||||
|
||||
// InitMetrics creates all metric instruments from the given meter.
|
||||
// If tracker is non-nil, registers observable gauges for per-window usage.
|
||||
func InitMetrics(meter metric.Meter, tracker *ratelimit.Tracker) {
|
||||
RequestCounter, _ = meter.Int64Counter("proxy.request.count",
|
||||
metric.WithDescription("Total proxy requests"),
|
||||
)
|
||||
RequestDuration, _ = meter.Float64Histogram("proxy.request.duration_ms",
|
||||
metric.WithDescription("Request latency in milliseconds"),
|
||||
metric.WithUnit("ms"),
|
||||
)
|
||||
RequestBodySize, _ = meter.Int64Histogram("proxy.request.body_size_bytes",
|
||||
metric.WithDescription("Request body size in bytes"),
|
||||
metric.WithUnit("By"),
|
||||
)
|
||||
UpstreamErrors, _ = meter.Int64Counter("proxy.upstream.errors",
|
||||
metric.WithDescription("Upstream error count"),
|
||||
)
|
||||
TokensInput, _ = meter.Int64Counter("proxy.tokens.input",
|
||||
metric.WithDescription("Input tokens consumed"),
|
||||
)
|
||||
TokensOutput, _ = meter.Int64Counter("proxy.tokens.output",
|
||||
metric.WithDescription("Output tokens consumed"),
|
||||
)
|
||||
CredentialCooldowns, _ = meter.Int64Counter("proxy.credential.cooldowns",
|
||||
metric.WithDescription("Credential cooldown activations"),
|
||||
)
|
||||
ActiveCredentials, _ = meter.Int64UpDownCounter("proxy.credential.active",
|
||||
metric.WithDescription("Currently active (non-cooldown) credentials"),
|
||||
)
|
||||
StreamRequests, _ = meter.Int64Counter("proxy.stream.requests",
|
||||
metric.WithDescription("Streaming request count"),
|
||||
)
|
||||
|
||||
if tracker == nil {
|
||||
return
|
||||
}
|
||||
|
||||
attr5h := attribute.String("window", "5h")
|
||||
attr7d := attribute.String("window", "7d")
|
||||
attrSonnet := attribute.String("window", "7d_sonnet")
|
||||
|
||||
meter.Float64ObservableGauge("proxy.usage.utilization",
|
||||
metric.WithDescription("Current utilization % from API"),
|
||||
metric.WithFloat64Callback(func(_ context.Context, o metric.Float64Observer) error {
|
||||
o.Observe(tracker.FiveHour().Utilization, metric.WithAttributes(attr5h))
|
||||
o.Observe(tracker.SevenDay().Utilization, metric.WithAttributes(attr7d))
|
||||
o.Observe(tracker.Sonnet().Utilization, metric.WithAttributes(attrSonnet))
|
||||
return nil
|
||||
}),
|
||||
)
|
||||
|
||||
meter.Int64ObservableGauge("proxy.usage.resets_at",
|
||||
metric.WithDescription("Unix seconds when window resets"),
|
||||
metric.WithInt64Callback(func(_ context.Context, o metric.Int64Observer) error {
|
||||
o.Observe(tracker.FiveHour().ResetsAt.Unix(), metric.WithAttributes(attr5h))
|
||||
o.Observe(tracker.SevenDay().ResetsAt.Unix(), metric.WithAttributes(attr7d))
|
||||
o.Observe(tracker.Sonnet().ResetsAt.Unix(), metric.WithAttributes(attrSonnet))
|
||||
return nil
|
||||
}),
|
||||
)
|
||||
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
"github.com/fujin/anthropic-proxy/internal/ratelimit"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
|
||||
promexporter "go.opentelemetry.io/otel/exporters/prometheus"
|
||||
otellog "go.opentelemetry.io/otel/log/global"
|
||||
"go.opentelemetry.io/otel/sdk/log"
|
||||
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
|
||||
"go.opentelemetry.io/otel/sdk/resource"
|
||||
"go.opentelemetry.io/otel/sdk/trace"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
|
||||
)
|
||||
|
||||
func Setup(ctx context.Context, cfg config.TelemetryConfig, tracker *ratelimit.Tracker) (shutdown func(context.Context) error, logWriter io.Writer, metricsHandler http.Handler, err error) {
|
||||
res, err := resource.New(ctx,
|
||||
resource.WithAttributes(
|
||||
semconv.ServiceName(cfg.ServiceName),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
var readers []sdkmetric.Option
|
||||
readers = append(readers, sdkmetric.WithResource(res))
|
||||
|
||||
var promHandler http.Handler
|
||||
if cfg.Embedded.Enabled {
|
||||
exporter, pErr := promexporter.New()
|
||||
if pErr != nil {
|
||||
return nil, nil, nil, pErr
|
||||
}
|
||||
readers = append(readers, sdkmetric.WithReader(exporter))
|
||||
promHandler = promhttp.Handler()
|
||||
}
|
||||
|
||||
if !cfg.Export.Enabled() {
|
||||
mp := sdkmetric.NewMeterProvider(readers...)
|
||||
otel.SetMeterProvider(mp)
|
||||
InitMetrics(mp.Meter(cfg.ServiceName), tracker)
|
||||
return func(ctx context.Context) error { return mp.Shutdown(ctx) }, nil, promHandler, nil
|
||||
}
|
||||
|
||||
traceOpts := []otlptracegrpc.Option{otlptracegrpc.WithEndpoint(cfg.Export.Endpoint)}
|
||||
metricOpts := []otlpmetricgrpc.Option{
|
||||
otlpmetricgrpc.WithEndpoint(cfg.Export.Endpoint),
|
||||
otlpmetricgrpc.WithTemporalitySelector(sdkmetric.CumulativeTemporalitySelector),
|
||||
}
|
||||
logOpts := []otlploggrpc.Option{otlploggrpc.WithEndpoint(cfg.Export.Endpoint)}
|
||||
if cfg.Export.Insecure {
|
||||
traceOpts = append(traceOpts, otlptracegrpc.WithInsecure())
|
||||
metricOpts = append(metricOpts, otlpmetricgrpc.WithInsecure())
|
||||
logOpts = append(logOpts, otlploggrpc.WithInsecure())
|
||||
}
|
||||
|
||||
traceExp, err := otlptracegrpc.New(ctx, traceOpts...)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
tp := trace.NewTracerProvider(
|
||||
trace.WithBatcher(traceExp),
|
||||
trace.WithResource(res),
|
||||
)
|
||||
otel.SetTracerProvider(tp)
|
||||
|
||||
metricExp, err := otlpmetricgrpc.New(ctx, metricOpts...)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
readers = append(readers, sdkmetric.WithReader(sdkmetric.NewPeriodicReader(metricExp)))
|
||||
mp := sdkmetric.NewMeterProvider(readers...)
|
||||
otel.SetMeterProvider(mp)
|
||||
InitMetrics(mp.Meter(cfg.ServiceName), tracker)
|
||||
|
||||
logExp, err := otlploggrpc.New(ctx, logOpts...)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
lp := log.NewLoggerProvider(
|
||||
log.WithProcessor(log.NewBatchProcessor(logExp)),
|
||||
log.WithResource(res),
|
||||
)
|
||||
otellog.SetLoggerProvider(lp)
|
||||
|
||||
bridge := &LogBridge{provider: lp}
|
||||
|
||||
shutdownFn := func(ctx context.Context) error {
|
||||
var firstErr error
|
||||
if e := tp.Shutdown(ctx); e != nil && firstErr == nil {
|
||||
firstErr = e
|
||||
}
|
||||
if e := mp.Shutdown(ctx); e != nil && firstErr == nil {
|
||||
firstErr = e
|
||||
}
|
||||
if e := lp.Shutdown(ctx); e != nil && firstErr == nil {
|
||||
firstErr = e
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
return shutdownFn, bridge, promHandler, nil
|
||||
}
|
||||
@@ -1,29 +1,47 @@
|
||||
package proxy
|
||||
// Package transport provides a shared uTLS HTTP/2 round-tripper with Chrome
|
||||
// TLS fingerprinting and per-host connection pooling. Used by both the upstream
|
||||
// proxy client and the OAuth token refresh client.
|
||||
package transport
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
tls "github.com/refraction-networking/utls"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type utlsRoundTripper struct {
|
||||
// UTLS implements http.RoundTripper using uTLS (Chrome fingerprint) over HTTP/2.
|
||||
// It maintains a per-host connection pool with coordination for concurrent
|
||||
// requests to the same host.
|
||||
type UTLS struct {
|
||||
mu sync.Mutex
|
||||
connections map[string]*http2.ClientConn
|
||||
pending map[string]*sync.Cond
|
||||
dialTimeout time.Duration
|
||||
}
|
||||
|
||||
func newUtlsRoundTripper() *utlsRoundTripper {
|
||||
return &utlsRoundTripper{
|
||||
// NewUTLS creates a uTLS HTTP/2 round-tripper with a 10-second dial timeout.
|
||||
func NewUTLS() *UTLS {
|
||||
return &UTLS{
|
||||
connections: make(map[string]*http2.ClientConn),
|
||||
pending: make(map[string]*sync.Cond),
|
||||
dialTimeout: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
// NewHTTPClient returns an http.Client using uTLS transport with the given
|
||||
// request timeout. Pass 0 for no timeout (streaming).
|
||||
func NewHTTPClient(timeout time.Duration) *http.Client {
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: NewUTLS(),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *UTLS) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
t.mu.Lock()
|
||||
|
||||
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
|
||||
@@ -59,8 +77,8 @@ func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.Clie
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
func (t *UTLS) createConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
conn, err := net.DialTimeout("tcp", addr, t.dialTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -83,14 +101,14 @@ func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientCon
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// RoundTrip implements http.RoundTripper with uTLS Chrome fingerprinting.
|
||||
func (t *UTLS) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
hostname := req.URL.Hostname()
|
||||
port := req.URL.Port()
|
||||
if port == "" {
|
||||
port = "443"
|
||||
}
|
||||
addr := net.JoinHostPort(hostname, port)
|
||||
log.Printf("utls: RoundTrip to %s (Chrome TLS fingerprint, HTTP/2)", addr)
|
||||
|
||||
h2Conn, err := t.getOrCreateConnection(hostname, addr)
|
||||
if err != nil {
|
||||
@@ -0,0 +1,78 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewUTLS(t *testing.T) {
|
||||
tr := NewUTLS()
|
||||
if tr == nil {
|
||||
t.Fatal("NewUTLS returned nil")
|
||||
}
|
||||
if tr.connections == nil {
|
||||
t.Error("connections map is nil")
|
||||
}
|
||||
if tr.pending == nil {
|
||||
t.Error("pending map is nil")
|
||||
}
|
||||
if tr.dialTimeout != 10*time.Second {
|
||||
t.Errorf("dialTimeout = %v, want 10s", tr.dialTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
timeout time.Duration
|
||||
}{
|
||||
{"zero timeout (streaming)", 0},
|
||||
{"15s timeout", 15 * time.Second},
|
||||
{"30s timeout", 30 * time.Second},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := NewHTTPClient(tt.timeout)
|
||||
if c == nil {
|
||||
t.Fatal("NewHTTPClient returned nil")
|
||||
}
|
||||
if c.Timeout != tt.timeout {
|
||||
t.Errorf("Timeout = %v, want %v", c.Timeout, tt.timeout)
|
||||
}
|
||||
if c.Transport == nil {
|
||||
t.Error("Transport is nil")
|
||||
}
|
||||
if _, ok := c.Transport.(*UTLS); !ok {
|
||||
t.Errorf("Transport type = %T, want *UTLS", c.Transport)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUTLS_ImplementsRoundTripper(t *testing.T) {
|
||||
var _ http.RoundTripper = (*UTLS)(nil)
|
||||
}
|
||||
|
||||
func TestUTLS_RoundTrip_InvalidHost(t *testing.T) {
|
||||
tr := NewUTLS()
|
||||
// Use a non-routable address to test dial timeout behavior
|
||||
req, err := http.NewRequest("GET", "https://192.0.2.1:443/test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
_, err = tr.RoundTrip(req)
|
||||
if err == nil {
|
||||
t.Error("expected error for non-routable address, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUTLS_ConnectionEviction(t *testing.T) {
|
||||
tr := NewUTLS()
|
||||
// Verify connections map starts empty
|
||||
tr.mu.Lock()
|
||||
if len(tr.connections) != 0 {
|
||||
t.Errorf("initial connections = %d, want 0", len(tr.connections))
|
||||
}
|
||||
tr.mu.Unlock()
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
// Package version provides the fallback Claude Code client version used when
|
||||
// no sniffed profile is available. This constant is shared between the upstream
|
||||
// proxy client and the rate limit usage poller.
|
||||
package version
|
||||
|
||||
// ClaudeCodeFallback is the Claude Code CLI version string used as a fallback
|
||||
// when no real version is obtained from sniffing.
|
||||
const ClaudeCodeFallback = "2.1.92"
|
||||
@@ -3,7 +3,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -12,62 +12,159 @@ import (
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/auth"
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
"github.com/fujin/anthropic-proxy/internal/embedded"
|
||||
"github.com/fujin/anthropic-proxy/internal/logging"
|
||||
"github.com/fujin/anthropic-proxy/internal/proxy"
|
||||
"github.com/fujin/anthropic-proxy/internal/ratelimit"
|
||||
"github.com/fujin/anthropic-proxy/internal/server"
|
||||
"github.com/fujin/anthropic-proxy/internal/telemetry"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func run() error {
|
||||
log.SetFlags(log.LstdFlags)
|
||||
func initCredential() (*auth.Credential, error) {
|
||||
creds, err := auth.LoadDefaultCredentials()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load credentials: %w", err)
|
||||
}
|
||||
|
||||
var cred *auth.Credential
|
||||
if len(creds) > 0 {
|
||||
cred = creds[0]
|
||||
// If token is expired, try refresh first
|
||||
if !cred.ExpiresAt.IsZero() && time.Now().After(cred.ExpiresAt) {
|
||||
log.Info().Msg("token expired, attempting refresh")
|
||||
refreshCtx, refreshCancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
refreshErr := auth.RefreshToken(refreshCtx, cred)
|
||||
refreshCancel()
|
||||
if refreshErr != nil {
|
||||
log.Warn().Err(refreshErr).Msg("refresh failed, initiating login")
|
||||
cred = nil // fall through to login
|
||||
} else {
|
||||
log.Info().Msg("token refreshed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if cred == nil {
|
||||
fi, statErr := os.Stdin.Stat()
|
||||
if statErr == nil && (fi.Mode()&os.ModeCharDevice) == 0 {
|
||||
return nil, fmt.Errorf("no valid credentials found; run the proxy interactively for initial login")
|
||||
}
|
||||
log.Info().Msg("no credentials found, starting OAuth login")
|
||||
cred, err = auth.Login(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Str("credential", cred.Email).Msg("credential loaded")
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
func initEmbedded(cfg *config.Config) (cleanup func(), err error) {
|
||||
if !cfg.Telemetry.Embedded.Enabled {
|
||||
return func() {}, nil
|
||||
}
|
||||
|
||||
var cleanups []func()
|
||||
|
||||
vm := embedded.NewVM(cfg.Telemetry.Embedded, cfg.Port)
|
||||
if err := vm.Start(); err != nil {
|
||||
log.Error().Err(err).Msg("failed to start victoria-metrics")
|
||||
} else {
|
||||
cleanups = append(cleanups, vm.Stop)
|
||||
}
|
||||
|
||||
perses := embedded.NewPerses(cfg.Telemetry.Embedded, cfg.Port)
|
||||
if err := perses.Start(); err != nil {
|
||||
log.Error().Err(err).Msg("failed to start perses")
|
||||
} else {
|
||||
cleanups = append(cleanups, perses.Stop)
|
||||
}
|
||||
|
||||
return func() {
|
||||
for i := len(cleanups) - 1; i >= 0; i-- {
|
||||
cleanups[i]()
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func run() error {
|
||||
cfg, err := config.Load("config.yaml")
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
creds, err := config.LoadCredentials(cfg)
|
||||
// Create usage tracker (started later once credential is loaded)
|
||||
var credForTracker *auth.Credential
|
||||
tracker := ratelimit.NewTracker(func() string {
|
||||
if credForTracker == nil {
|
||||
return ""
|
||||
}
|
||||
return credForTracker.Token()
|
||||
})
|
||||
|
||||
// Initialize telemetry (metrics always active; OTLP export when endpoint set)
|
||||
telemetryShutdown, logBridge, metricsHandler, err := telemetry.Setup(context.Background(), cfg.Telemetry, tracker)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load credentials: %w", err)
|
||||
return fmt.Errorf("telemetry setup: %w", err)
|
||||
}
|
||||
defer telemetryShutdown(context.Background())
|
||||
|
||||
var extraWriters []io.Writer
|
||||
if logBridge != nil {
|
||||
extraWriters = append(extraWriters, logBridge)
|
||||
}
|
||||
|
||||
if len(creds) == 0 {
|
||||
return fmt.Errorf("no credentials found")
|
||||
logging.Setup(cfg.Logging, extraWriters...)
|
||||
|
||||
cred, err := initCredential()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("loaded %d credentials", len(creds))
|
||||
credForTracker = cred
|
||||
|
||||
pool := auth.NewPool(creds)
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
pool.RefreshExpiring(context.Background())
|
||||
auth.StartBackgroundRefresh(ctx, pool)
|
||||
tracker.Start(ctx)
|
||||
|
||||
var profile *proxy.SniffedProfile
|
||||
if cfg.ClaudeBinary != "" {
|
||||
log.Printf("sniffing claude-code at %s...", cfg.ClaudeBinary)
|
||||
log.Info().Str("binary", cfg.ClaudeBinary).Msg("sniffing claude-code")
|
||||
profile, err = proxy.SniffClaudeCode(cfg.ClaudeBinary)
|
||||
if err != nil {
|
||||
log.Printf("warning: sniff failed, using defaults: %v", err)
|
||||
log.Warn().Err(err).Msg("sniff failed, using defaults")
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("starting server on port %d", cfg.Port)
|
||||
srv := server.New(cfg, pool, profile)
|
||||
embeddedCleanup, err := initEmbedded(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer embeddedCleanup()
|
||||
|
||||
log.Info().Int("port", cfg.Port).Msg("starting server")
|
||||
srv := server.New(cfg, pool, profile, tracker, metricsHandler)
|
||||
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
<-quit
|
||||
log.Printf("shutting down...")
|
||||
log.Info().Msg("shutting down")
|
||||
cancel()
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer shutdownCancel()
|
||||
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
log.Printf("shutdown error: %v", err)
|
||||
log.Error().Err(err).Msg("shutdown error")
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -80,7 +177,7 @@ func run() error {
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
log.Printf("error: %v", err)
|
||||
log.Error().Err(err).Msg("fatal error")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
+20
@@ -0,0 +1,20 @@
|
||||
{
|
||||
buildGoModule,
|
||||
lib,
|
||||
pkgs,
|
||||
...
|
||||
}:
|
||||
|
||||
buildGoModule rec {
|
||||
pname = "anthropic-proxy";
|
||||
version = "0.0.5";
|
||||
|
||||
src = ./.;
|
||||
|
||||
vendorHash = "sha256-yXINNC+NEw+HbOQ5aBgSE5dYTWp+zEZ230rzXfwOoDY=";
|
||||
|
||||
meta = with lib; {
|
||||
description = "Reverse proxy that lets OpenCode (and similar tools) use a Claude subscription instead of an API key.";
|
||||
homepage = "https://gitea.susano-homelab.duckdns.org/fujin/anthropic-proxy";
|
||||
};
|
||||
}
|
||||
Reference in New Issue
Block a user