Compare commits
2 Commits
master
..
b5ee53b225
| Author | SHA1 | Date | |
|---|---|---|---|
| b5ee53b225 | |||
| 9dc664a3ba |
@@ -4,5 +4,3 @@
|
||||
anthropic-proxy
|
||||
result
|
||||
config.yaml
|
||||
|
||||
vendor/**
|
||||
|
||||
@@ -1,62 +1,63 @@
|
||||
# anthropic-proxy
|
||||
|
||||
Reverse proxy for the Anthropic Messages API that authenticates with a Claude subscription (OAuth) instead of an API key. Lets you use tools like [OpenCode](https://github.com/opencode-ai/opencode) through your existing Claude Pro/Team plan.
|
||||
Reverse proxy that lets OpenCode (and similar tools) use a Claude subscription instead of an API key.
|
||||
|
||||
## How it works
|
||||
## Prerequisites
|
||||
|
||||
Clients send standard Anthropic API requests to the proxy. The proxy authenticates upstream using OAuth credentials from your Claude subscription, forwards the request, and streams the response back. Requests are optionally sanitized (tool name remapping, string replacement) before forwarding and de-sanitized on return.
|
||||
- Go 1.26+
|
||||
|
||||
## Features
|
||||
Optional: [Nix](https://nixos.org/) flake for dev shell (`nix develop`).
|
||||
|
||||
- **OAuth credential management** — reuses `~/.claude/.credentials.json`, auto-refreshes tokens
|
||||
- **Request sanitization** — rename tools, replace strings in system prompts and body (configurable, hot-reloadable)
|
||||
- **Rate limit tracking** — polls Anthropic usage API and reads response headers to track 5h/7d utilization windows
|
||||
- **OpenTelemetry metrics** — request counts, latency, token usage, errors (optional OTLP export)
|
||||
- **Structured logging** — zerolog with file rotation via lumberjack
|
||||
|
||||
## Quick start
|
||||
## Setup
|
||||
|
||||
```
|
||||
cp config.example.yaml config.yaml
|
||||
# edit config.yaml — set api_keys and optionally claude_binary
|
||||
```
|
||||
|
||||
Edit `config.yaml`:
|
||||
- `api_keys` — key(s) your clients use to authenticate with the proxy
|
||||
- `claude_binary` — optional path to `claude` binary (used for request fingerprinting via sniff only)
|
||||
|
||||
## Authentication
|
||||
|
||||
On first run, if no credentials are found at `~/.claude/.credentials.json`, the proxy starts an OAuth login flow in your browser. Credentials are stored at `~/.claude/.credentials.json` (the same file Claude Code CLI uses). On subsequent runs, existing credentials are reused and refreshed automatically.
|
||||
|
||||
If running headlessly (SSH/server), the authorization URL is printed to stdout and you can paste the authorization code manually.
|
||||
|
||||
If you've already logged in with Claude Code CLI, the proxy will use the same credentials.
|
||||
|
||||
## Build and run
|
||||
|
||||
```
|
||||
go build -o anthropic-proxy .
|
||||
./anthropic-proxy
|
||||
```
|
||||
|
||||
On first run, if no credentials exist at `~/.claude/.credentials.json`, an OAuth login flow starts in your browser. If running headlessly, the authorization URL is printed to stdout. If you've already logged in with Claude Code CLI, the proxy reuses those credentials.
|
||||
|
||||
### Nix
|
||||
|
||||
```
|
||||
nix develop # dev shell with Go
|
||||
nix build # build the binary
|
||||
```
|
||||
|
||||
## Client configuration
|
||||
|
||||
Point any Anthropic-compatible client at the proxy:
|
||||
## Usage with OpenCode
|
||||
|
||||
```
|
||||
export ANTHROPIC_API_KEY=your-proxy-api-key
|
||||
export ANTHROPIC_BASE_URL=http://localhost:8082
|
||||
opencode
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| POST | `/v1/messages` | Anthropic Messages API (proxied) |
|
||||
| POST | `/v1/messages` | Anthropic messages API (proxied) |
|
||||
| POST | `/messages` | Same, without `/v1` prefix |
|
||||
| GET | `/healthz` | Health check |
|
||||
| POST | `/reload` | Hot-reload config (sanitize rules + API keys) |
|
||||
| POST | `/reload` | Hot-reload `config.yaml` |
|
||||
|
||||
## Configuration
|
||||
## Request sanitization
|
||||
|
||||
See [`config.example.yaml`](config.example.yaml) for all options. Key sections:
|
||||
The `sanitize` section in config renames tool names and replaces strings in system prompts before forwarding to Anthropic. Responses are de-sanitized before returning to the client.
|
||||
|
||||
- **`api_keys`** — keys clients use to authenticate with the proxy
|
||||
- **`sanitize`** — tool renames, system prompt replacements, body replacements
|
||||
- **`telemetry`** — OTLP endpoint, service name, auth headers
|
||||
- **`logging`** — level, file path, rotation settings
|
||||
- **`claude_binary`** — path to `claude` CLI for request fingerprinting (optional)
|
||||
See `config.example.yaml` for the default rules.
|
||||
|
||||
Reload after editing config:
|
||||
|
||||
```
|
||||
curl -X POST localhost:8082/reload
|
||||
```
|
||||
|
||||
+17
-24
@@ -1,20 +1,5 @@
|
||||
port: 8082
|
||||
|
||||
# telemetry:
|
||||
# service_name: "anthropic-proxy"
|
||||
# export:
|
||||
# endpoint: "localhost:4317" # OTLP gRPC endpoint (omit to disable export)
|
||||
# insecure: true # disable TLS for local dev
|
||||
# headers: # optional auth headers (e.g. Grafana Cloud)
|
||||
# Authorization: "Basic ..."
|
||||
# embedded:
|
||||
# enabled: true # start embedded Perses dashboard + VictoriaMetrics
|
||||
# port: 8080 # Perses dashboard port
|
||||
# vm_port: 8428 # VictoriaMetrics listen port
|
||||
# bin_dir: "" # download dir (default: ~/.cache/anthropic-proxy/bin)
|
||||
# perses_binary: "" # custom path to perses binary (default: auto-download)
|
||||
# vm_binary: "" # custom path to victoria-metrics binary (default: auto-download)
|
||||
|
||||
logging:
|
||||
level: debug
|
||||
file: /home/fujin/.local/log/anthropic-proxy.log
|
||||
@@ -51,21 +36,29 @@ sanitize:
|
||||
- match: "Workspace root folder"
|
||||
replace: "Working directory"
|
||||
body:
|
||||
- match: "anthropics/claude-code"
|
||||
- match: "anomalyco/opencode"
|
||||
replace: "anthropics/claude-code"
|
||||
- match: "anthropic"
|
||||
- match: "anomalyco"
|
||||
replace: "anthropic"
|
||||
- match: "system-directive"
|
||||
- match: "oh-my-opencode"
|
||||
replace: "system-directive"
|
||||
- match: "claude-code"
|
||||
- match: "ohmyopencode"
|
||||
replace: "claude-code"
|
||||
- match: "claude-agent"
|
||||
- match: "oh-my-openagent"
|
||||
replace: "claude-agent"
|
||||
- match: "system_initiator"
|
||||
- match: "omo_internal_initiator"
|
||||
replace: "system_initiator"
|
||||
- match: "call_agent"
|
||||
- match: "call_omo_agent"
|
||||
replace: "call_agent"
|
||||
- match: "claude.ai"
|
||||
- match: "opencode.ai"
|
||||
replace: "claude.ai"
|
||||
- match: "agent"
|
||||
- match: "opencode"
|
||||
replace: "agent"
|
||||
|
||||
logging:
|
||||
level: info
|
||||
# file: /var/log/anthropic-proxy.log # omit to log to stderr
|
||||
# max_size_mb: 100
|
||||
# max_backups: 5
|
||||
# max_age_days: 30
|
||||
# compress: true
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
# Metrics Reference
|
||||
|
||||
All metrics are emitted via OpenTelemetry (OTLP gRPC) with cumulative temporality. OTel metric names use dots; Prometheus converts them to underscores and appends `_total` for counters.
|
||||
|
||||
## Counters
|
||||
|
||||
| OTel Name | Prometheus Name | Attributes | Description |
|
||||
|---|---|---|---|
|
||||
| `proxy.request.count` | `proxy_request_count_total` | `model`, `stream`, `status_code` | Total proxied requests |
|
||||
| `proxy.tokens.input` | `proxy_tokens_input_total` | `model`, `credential` | Input tokens consumed |
|
||||
| `proxy.tokens.output` | `proxy_tokens_output_total` | `model`, `credential` | Output tokens consumed |
|
||||
| `proxy.upstream.errors` | `proxy_upstream_errors_total` | `error_type`, `credential`, `status_code` | Upstream errors (connection failures, 4xx/5xx) |
|
||||
| `proxy.credential.cooldowns` | `proxy_credential_cooldowns_total` | `status_code` | Credential cooldown activations (rate limited) |
|
||||
| `proxy.stream.requests` | `proxy_stream_requests_total` | `model` | Streaming request count |
|
||||
|
||||
## Histograms
|
||||
|
||||
| OTel Name | Prometheus Name | Unit | Attributes | Description |
|
||||
|---|---|---|---|---|
|
||||
| `proxy.request.duration_ms` | `proxy_request_duration_ms_milliseconds` | ms | `model`, `stream`, `status_code` | Request latency |
|
||||
| `proxy.request.body_size_bytes` | `proxy_request_body_size_bytes` | bytes | `model`, `stream` | Request body size |
|
||||
|
||||
## Gauges
|
||||
|
||||
| OTel Name | Prometheus Name | Attributes | Description |
|
||||
|---|---|---|---|
|
||||
| `proxy.usage.utilization` | `proxy_usage_utilization` | `window` | Current utilization % from Anthropic API (0-100) |
|
||||
| `proxy.usage.resets_at` | `proxy_usage_resets_at` | `window` | Unix timestamp when the rate limit window resets |
|
||||
| `proxy.credential.active` | `proxy_credential_active` | — | Currently active (non-cooldown) credentials |
|
||||
|
||||
### Window attribute values
|
||||
|
||||
- `5h` — 5-hour rolling window
|
||||
- `7d` — 7-day rolling window
|
||||
- `7d_sonnet` — 7-day Sonnet-specific window
|
||||
|
||||
## Structured Logs (Loki)
|
||||
|
||||
Each completed request emits a structured log line via OTel LogBridge. Fields are stored as Loki stream labels/structured metadata.
|
||||
|
||||
| Field | Type | Description |
|
||||
|---|---|---|
|
||||
| `input_tokens` | int | Input tokens for this request |
|
||||
| `output_tokens` | int | Output tokens for this request |
|
||||
| `model` | string | Model used |
|
||||
| `latency_ms` | float | Request latency in milliseconds |
|
||||
| `status` | int | HTTP status code (non-stream only) |
|
||||
| `stream` | bool | Whether this was a streaming request |
|
||||
|
||||
Log messages: `"request completed"` (non-stream), `"stream completed"` (stream).
|
||||
|
||||
### Per-window token counting via Loki
|
||||
|
||||
Window token totals are computed in Grafana using LogQL, not tracked in-memory. This approach survives process restarts and uses exact Anthropic window boundaries.
|
||||
|
||||
Grafana variables derive the window age from Prometheus:
|
||||
|
||||
```
|
||||
window_age_5h = time() - proxy_usage_resets_at{window="5h"} + 18000
|
||||
window_age_7d = time() - proxy_usage_resets_at{window="7d"} + 604800
|
||||
```
|
||||
|
||||
LogQL queries sum token values from individual log events within the window:
|
||||
|
||||
```logql
|
||||
sum(sum_over_time(
|
||||
{service_name="anthropic-proxy"} |= "completed"
|
||||
| unwrap output_tokens
|
||||
| __error__=""
|
||||
[${window_age_5h}s]
|
||||
))
|
||||
```
|
||||
|
||||
## Annotations (Loki)
|
||||
|
||||
429 rate limit events are surfaced as Grafana annotations:
|
||||
|
||||
```logql
|
||||
{service_name="anthropic-proxy"} |= "upstream error" | json | status = "429"
|
||||
```
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,450 +0,0 @@
|
||||
{
|
||||
"kind": "Dashboard",
|
||||
"metadata": {
|
||||
"name": "proxy",
|
||||
"createdAt": "2026-04-14T19:47:48.013238204Z",
|
||||
"updatedAt": "2026-04-14T19:49:30.874125459Z",
|
||||
"version": 1,
|
||||
"project": "anthropic-proxy"
|
||||
},
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "Anthropic Proxy"
|
||||
},
|
||||
"datasources": {
|
||||
"vm": {
|
||||
"default": true,
|
||||
"plugin": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"spec": {
|
||||
"directUrl": "http://localhost:9428"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"panels": {
|
||||
"latency": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "Latency"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "TimeSeriesChart",
|
||||
"spec": {
|
||||
"legend": {
|
||||
"position": "bottom"
|
||||
},
|
||||
"yAxis": {
|
||||
"format": {
|
||||
"unit": "milliseconds"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "histogram_quantile(0.50, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
|
||||
"seriesNameFormat": "p50"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "histogram_quantile(0.95, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
|
||||
"seriesNameFormat": "p95"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "histogram_quantile(0.99, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
|
||||
"seriesNameFormat": "p99"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"request_rate": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "Request Rate"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "TimeSeriesChart",
|
||||
"spec": {
|
||||
"legend": {
|
||||
"position": "bottom"
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "rate(proxy_request_count_total[5m])",
|
||||
"seriesNameFormat": "req/s"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"token_rate": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "Token Rate"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "TimeSeriesChart",
|
||||
"spec": {
|
||||
"legend": {
|
||||
"position": "bottom"
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "rate(proxy_tokens_input_total[5m]) * 60",
|
||||
"seriesNameFormat": "input/min"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "rate(proxy_tokens_output_total[5m]) * 60",
|
||||
"seriesNameFormat": "output/min"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"tokens_5h": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "5h Tokens"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "StatChart",
|
||||
"spec": {
|
||||
"calculation": "last",
|
||||
"format": {
|
||||
"unit": "decimal"
|
||||
},
|
||||
"sparkline": {}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "increase(proxy_tokens_output_total[3h])"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"tokens_7d": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "7d Tokens"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "StatChart",
|
||||
"spec": {
|
||||
"calculation": "last",
|
||||
"format": {
|
||||
"unit": "decimal"
|
||||
},
|
||||
"sparkline": {}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "increase(proxy_tokens_output_total[9h])"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"util_5h": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "5h Utilization"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "GaugeChart",
|
||||
"spec": {
|
||||
"calculation": "last",
|
||||
"format": {
|
||||
"unit": "percent"
|
||||
},
|
||||
"thresholds": {
|
||||
"steps": [
|
||||
{
|
||||
"color": "green",
|
||||
"value": 0
|
||||
},
|
||||
{
|
||||
"color": "orange",
|
||||
"value": 70
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 90
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "proxy_usage_utilization{window=\"5h\"}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"util_7d": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "7d Utilization"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "GaugeChart",
|
||||
"spec": {
|
||||
"calculation": "last",
|
||||
"format": {
|
||||
"unit": "percent"
|
||||
},
|
||||
"thresholds": {
|
||||
"steps": [
|
||||
{
|
||||
"color": "green",
|
||||
"value": 0
|
||||
},
|
||||
{
|
||||
"color": "orange",
|
||||
"value": 70
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 90
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "proxy_usage_utilization{window=\"7d\"}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"layouts": [
|
||||
{
|
||||
"kind": "Grid",
|
||||
"spec": {
|
||||
"display": {
|
||||
"title": "Utilization"
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": 6,
|
||||
"height": 5,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/util_5h"
|
||||
}
|
||||
},
|
||||
{
|
||||
"x": 6,
|
||||
"y": 0,
|
||||
"width": 6,
|
||||
"height": 5,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/util_7d"
|
||||
}
|
||||
},
|
||||
{
|
||||
"x": 12,
|
||||
"y": 0,
|
||||
"width": 6,
|
||||
"height": 5,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/tokens_5h"
|
||||
}
|
||||
},
|
||||
{
|
||||
"x": 18,
|
||||
"y": 0,
|
||||
"width": 6,
|
||||
"height": 5,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/tokens_7d"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "Grid",
|
||||
"spec": {
|
||||
"display": {
|
||||
"title": "Traffic"
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": 12,
|
||||
"height": 8,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/request_rate"
|
||||
}
|
||||
},
|
||||
{
|
||||
"x": 12,
|
||||
"y": 0,
|
||||
"width": 12,
|
||||
"height": 8,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/latency"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "Grid",
|
||||
"spec": {
|
||||
"display": {
|
||||
"title": "Tokens"
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": 24,
|
||||
"height": 8,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/token_rate"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"duration": "1h",
|
||||
"refreshInterval": "10s"
|
||||
}
|
||||
}
|
||||
Generated
+3
-3
@@ -20,11 +20,11 @@
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1776169885,
|
||||
"narHash": "sha256-l/iNYDZ4bGOAFQY2q8y5OAfBBtrDAaPuRQqWaFHVRXM=",
|
||||
"lastModified": 1775710090,
|
||||
"narHash": "sha256-ar3rofg+awPB8QXDaFJhJ2jJhu+KqN/PRCXeyuXR76E=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "4bd9165a9165d7b5e33ae57f3eecbcb28fb231c9",
|
||||
"rev": "4c1018dae018162ec878d42fec712642d214fdfa",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
@@ -42,9 +42,6 @@
|
||||
shellHook = ''
|
||||
export GOPATH="$PWD/.go"
|
||||
export PATH="$GOPATH/bin:$PATH"
|
||||
|
||||
export ANTHROPIC_BASE_URL=http://localhost:8082
|
||||
export ANTHROPIC_API_KEY=sk-cliproxy-fujin
|
||||
'';
|
||||
};
|
||||
}
|
||||
|
||||
@@ -5,81 +5,47 @@ go 1.26
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.12.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/refraction-networking/utls v1.8.2
|
||||
github.com/rs/zerolog v1.35.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.68.0
|
||||
go.opentelemetry.io/otel v1.43.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.19.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.65.0
|
||||
go.opentelemetry.io/otel/log v0.19.0
|
||||
go.opentelemetry.io/otel/metric v1.43.0
|
||||
go.opentelemetry.io/otel/sdk v1.43.0
|
||||
go.opentelemetry.io/otel/sdk/log v0.19.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0
|
||||
golang.org/x/net v0.52.0
|
||||
gopkg.in/lumberjack.v2 v2.0.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.6.0 // indirect
|
||||
github.com/andybalholm/brotli v1.0.6 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/bytedance/gopkg v0.1.4 // indirect
|
||||
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||
github.com/bytedance/sonic v1.15.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.1 // indirect
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
||||
github.com/gin-contrib/sse v1.1.1 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.30.2 // indirect
|
||||
github.com/goccy/go-json v0.10.6 // indirect
|
||||
github.com/go-playground/validator/v10 v10.30.1 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/klauspost/compress v1.17.6 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.67.5 // indirect
|
||||
github.com/prometheus/otlptranslator v1.0.0 // indirect
|
||||
github.com/prometheus/procfs v0.20.1 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/quic-go/quic-go v0.59.0 // indirect
|
||||
github.com/refraction-networking/utls v1.8.2 // indirect
|
||||
github.com/rs/zerolog v1.35.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.3.1 // indirect
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.10.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.4 // indirect
|
||||
golang.org/x/arch v0.25.0 // indirect
|
||||
golang.org/x/arch v0.22.0 // indirect
|
||||
golang.org/x/crypto v0.49.0 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
|
||||
google.golang.org/grpc v1.80.0 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/lumberjack.v2 v2.0.0 // indirect
|
||||
)
|
||||
|
||||
@@ -1,68 +1,49 @@
|
||||
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
|
||||
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
|
||||
github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/bytedance/gopkg v0.1.4 h1:oZnQwnX82KAIWb7033bEwtxvTqXcYMxDBaQxo5JJHWM=
|
||||
github.com/bytedance/gopkg v0.1.4/go.mod h1:v1zWfPm21Fb+OsyXN2VAHdL6TBb2L88anLQgdyje6R4=
|
||||
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
|
||||
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
|
||||
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
|
||||
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
|
||||
github.com/bytedance/sonic/loader v0.5.1 h1:Ygpfa9zwRCCKSlrp5bBP/b/Xzc3VxsAW+5NIYXrOOpI=
|
||||
github.com/bytedance/sonic/loader v0.5.1/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
|
||||
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.13/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||
github.com/gin-contrib/sse v1.1.1 h1:uGYpNwTacv5R68bSGMapo62iLTRa9l5zxGCps4hK6ko=
|
||||
github.com/gin-contrib/sse v1.1.1/go.mod h1:QXzuVkA0YO7o/gun03UI1Q+FTI8ZV/n5t03kIQAI89s=
|
||||
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
|
||||
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
||||
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
||||
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
|
||||
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.30.2 h1:JiFIMtSSHb2/XBUbWM4i/MpeQm9ZK2xqPNk8vgvu5JQ=
|
||||
github.com/go-playground/validator/v10 v10.30.2/go.mod h1:mAf2pIOVXjTEBrwUMGKkCWKKPs9NheYGabeB04txQSc=
|
||||
github.com/goccy/go-json v0.10.6 h1:p8HrPJzOakx/mn/bQtjgNjdTcN+/S6FcG2CTtQOrHVU=
|
||||
github.com/goccy/go-json v0.10.6/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w=
|
||||
github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
||||
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI=
|
||||
github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
@@ -74,30 +55,18 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||
github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4=
|
||||
github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw=
|
||||
github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEoIwkU+A6qos=
|
||||
github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM=
|
||||
github.com/prometheus/procfs v0.20.1 h1:XwbrGOIplXW/AU3YhIhLODXMJYyC1isLFfYCsTEycfc=
|
||||
github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo=
|
||||
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
||||
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
|
||||
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
|
||||
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
|
||||
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
|
||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/rs/zerolog v1.35.0 h1:VD0ykx7HMiMJytqINBsKcbLS+BJ4WYjz+05us+LRTdI=
|
||||
github.com/rs/zerolog v1.35.0/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
@@ -126,78 +95,34 @@ github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY
|
||||
github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.68.0 h1:5FXSL2s6afUC1bzNzl1iedZZ8yqR7GOhbCoEXtyeK6Q=
|
||||
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.68.0/go.mod h1:MdHW7tLtkeGJnR4TyOrnd5D0zUGZQB1l84uHCe8hRpE=
|
||||
go.opentelemetry.io/contrib/propagators/b3 v1.43.0 h1:CETqV3QLLPTy5yNrqyMr41VnAOOD4lsRved7n4QG00A=
|
||||
go.opentelemetry.io/contrib/propagators/b3 v1.43.0/go.mod h1:Q4mCiCdziYzpNR0g+6UqVotAlCDZdzz6L8jwY4knOrw=
|
||||
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
|
||||
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.19.0 h1:Dn8rkudDzY6KV9dr/D/bTUuWgqDf9xe0rr4G2elrn0Y=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.19.0/go.mod h1:gMk9F0xDgyN9M/3Ed5Y1wKcx/9mlU91NXY2SNq7RQuU=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0 h1:8UQVDcZxOJLtX6gxtDt3vY2WTgvZqMQRzjsqiIHQdkc=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0/go.mod h1:2lmweYCiHYpEjQ/lSJBYhj9jP1zvCvQW4BqL9dnT7FQ=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 h1:88Y4s2C8oTui1LGM6bTWkw0ICGcOLCAI5l6zsD1j20k=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0/go.mod h1:Vl1/iaggsuRlrHf/hfPJPvVag77kKyvrLeD10kpMl+A=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 h1:RAE+JPfvEmvy+0LzyUA25/SGawPwIUbZ6u0Wug54sLc=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0/go.mod h1:AGmbycVGEsRx9mXMZ75CsOyhSP6MFIcj/6dnG+vhVjk=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.65.0 h1:jOveH/b4lU9HT7y+Gfamf18BqlOuz2PWEvs8yM7Q6XE=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.65.0/go.mod h1:i1P8pcumauPtUI4YNopea1dhzEMuEqWP1xoUZDylLHo=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0 h1:mS47AX77OtFfKG4vtp+84kuGSFZHTyxtXIN269vChY0=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0/go.mod h1:PJnsC41lAGncJlPUniSwM81gc80GkgWJWr3cu2nKEtU=
|
||||
go.opentelemetry.io/otel/log v0.19.0 h1:KUZs/GOsw79TBBMfDWsXS+KZ4g2Ckzksd1ymzsIEbo4=
|
||||
go.opentelemetry.io/otel/log v0.19.0/go.mod h1:5DQYeGmxVIr4n0/BcJvF4upsraHjg6vudJJpnkL6Ipk=
|
||||
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
|
||||
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
|
||||
go.opentelemetry.io/otel/sdk/log v0.19.0 h1:scYVLqT22D2gqXItnWiocLUKGH9yvkkeql5dBDiXyko=
|
||||
go.opentelemetry.io/otel/sdk/log v0.19.0/go.mod h1:vFBowwXGLlW9AvpuF7bMgnNI95LiW10szrOdvzBHlAg=
|
||||
go.opentelemetry.io/otel/sdk/log/logtest v0.19.0 h1:BEbF7ZBB6qQloV/Ub1+3NQoOUnVtcGkU3XX4Ws3GQfk=
|
||||
go.opentelemetry.io/otel/sdk/log/logtest v0.19.0/go.mod h1:Lua81/3yM0wOmoHTokLj9y9ADeA02v1naRrVrkAZuKk=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
|
||||
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||
go.yaml.in/yaml/v2 v2.4.4 h1:tuyd0P+2Ont/d6e2rl3be67goVK4R6deVxCUX5vyPaQ=
|
||||
go.yaml.in/yaml/v2 v2.4.4/go.mod h1:gMZqIpDtDqOfM0uNfy0SkpRhvUryYH0Z6wdMYcacYXQ=
|
||||
golang.org/x/arch v0.25.0 h1:qnk6Ksugpi5Bz32947rkUgDt9/s5qvqDPl/gBKdMJLE=
|
||||
golang.org/x/arch v0.25.0/go.mod h1:0X+GdSIP+kL5wPmpK7sdkEVTt2XoYP0cSjQSbZBwOi8=
|
||||
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
||||
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/lumberjack.v2 v2.0.0 h1:IDj6hi8KbNiPQ5VaYNFZ7dBJLF5LFeKvsFrWHjA5aq4=
|
||||
gopkg.in/lumberjack.v2 v2.0.0/go.mod h1:bp5nQ2kK/lLQSmTk29azj9+JB6bWci56xFn/lvd5GLI=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// claudeCredentialsJSON matches the structure of ~/.claude/.credentials.json.
|
||||
type claudeCredentialsJSON struct {
|
||||
ClaudeAiOauth struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpiresAt int64 `json:"expiresAt"`
|
||||
SubscriptionType string `json:"subscriptionType"`
|
||||
} `json:"claudeAiOauth"`
|
||||
}
|
||||
|
||||
// LoadDefaultCredentials reads credentials from ~/.claude/.credentials.json.
|
||||
// Returns nil, nil if the file does not exist.
|
||||
func LoadDefaultCredentials() ([]*Credential, error) {
|
||||
path, err := DefaultCredentialPath()
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cf claudeCredentialsJSON
|
||||
if err := json.Unmarshal(data, &cf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oauth := cf.ClaudeAiOauth
|
||||
if oauth.AccessToken == "" {
|
||||
return nil, fmt.Errorf("no access token in %s", path)
|
||||
}
|
||||
|
||||
cred := &Credential{
|
||||
ID: "claude-native",
|
||||
Email: oauth.SubscriptionType,
|
||||
AccessToken: oauth.AccessToken,
|
||||
RefreshToken: oauth.RefreshToken,
|
||||
ExpiresAt: time.UnixMilli(oauth.ExpiresAt),
|
||||
FilePath: path,
|
||||
}
|
||||
|
||||
return []*Credential{cred}, nil
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultCredentialPath(t *testing.T) {
|
||||
path, err := DefaultCredentialPath()
|
||||
if err != nil {
|
||||
t.Fatalf("DefaultCredentialPath error: %v", err)
|
||||
}
|
||||
if !strings.HasSuffix(path, filepath.Join(".claude", ".credentials.json")) {
|
||||
t.Errorf("path = %q, want suffix .claude/.credentials.json", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultCredentials_MissingFile(t *testing.T) {
|
||||
// When credential file doesn't exist, returns nil, nil
|
||||
path, err := DefaultCredentialPath()
|
||||
if err != nil {
|
||||
t.Skip("cannot determine home dir")
|
||||
}
|
||||
if _, statErr := os.Stat(path); os.IsNotExist(statErr) {
|
||||
creds, err := LoadDefaultCredentials()
|
||||
if creds != nil {
|
||||
t.Errorf("expected nil creds for missing file, got %v", creds)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("expected nil error for missing file, got %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeCredentialsJSON_ParsesCorrectly(t *testing.T) {
|
||||
jsonData := `{"claudeAiOauth":{"accessToken":"test-token","refreshToken":"test-refresh","expiresAt":1234567890,"subscriptionType":"pro"}}`
|
||||
|
||||
var cf claudeCredentialsJSON
|
||||
if err := json.Unmarshal([]byte(jsonData), &cf); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if cf.ClaudeAiOauth.AccessToken != "test-token" {
|
||||
t.Errorf("AccessToken = %q, want test-token", cf.ClaudeAiOauth.AccessToken)
|
||||
}
|
||||
if cf.ClaudeAiOauth.RefreshToken != "test-refresh" {
|
||||
t.Errorf("RefreshToken = %q, want test-refresh", cf.ClaudeAiOauth.RefreshToken)
|
||||
}
|
||||
if cf.ClaudeAiOauth.ExpiresAt != 1234567890 {
|
||||
t.Errorf("ExpiresAt = %d, want 1234567890", cf.ClaudeAiOauth.ExpiresAt)
|
||||
}
|
||||
if cf.ClaudeAiOauth.SubscriptionType != "pro" {
|
||||
t.Errorf("SubscriptionType = %q, want pro", cf.ClaudeAiOauth.SubscriptionType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeCredentialsJSON_EmptyAccessToken(t *testing.T) {
|
||||
jsonData := `{"claudeAiOauth":{"accessToken":"","refreshToken":"r","expiresAt":1}}`
|
||||
|
||||
var cf claudeCredentialsJSON
|
||||
if err := json.Unmarshal([]byte(jsonData), &cf); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if cf.ClaudeAiOauth.AccessToken != "" {
|
||||
t.Errorf("expected empty access token")
|
||||
}
|
||||
}
|
||||
+66
-16
@@ -6,14 +6,16 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
tls "github.com/refraction-networking/utls"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/transport"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -26,7 +28,7 @@ const (
|
||||
refreshBackoff = 5 * time.Minute
|
||||
)
|
||||
|
||||
var utlsClient = transport.NewHTTPClient(15 * time.Second)
|
||||
var utlsClient = newUTLSClient()
|
||||
|
||||
type tokenRequest struct {
|
||||
ClientID string `json:"client_id"`
|
||||
@@ -62,13 +64,6 @@ func RefreshToken(ctx context.Context, cred *Credential) error {
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
log.Debug().
|
||||
Str("url", tokenEndpoint).
|
||||
Str("grant_type", "refresh_token").
|
||||
Str("client_id", clientID).
|
||||
Str("scope", oauthScopes).
|
||||
Msg("token refresh request")
|
||||
|
||||
resp, err := utlsClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("execute request: %w", err)
|
||||
@@ -76,12 +71,6 @@ func RefreshToken(ctx context.Context, cred *Credential) error {
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
log.Debug().
|
||||
Int("status", resp.StatusCode).
|
||||
Int("response_size", len(body)).
|
||||
Msg("token refresh response")
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("refresh returned %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
@@ -145,6 +134,67 @@ func persistCredential(cred *Credential) error {
|
||||
return os.WriteFile(filePath, out, 0600)
|
||||
}
|
||||
|
||||
func newUTLSClient() *http.Client {
|
||||
return &http.Client{
|
||||
Timeout: 15 * time.Second,
|
||||
Transport: &utlsRefreshTransport{},
|
||||
}
|
||||
}
|
||||
|
||||
type utlsRefreshTransport struct {
|
||||
mu sync.Mutex
|
||||
conn *http2.ClientConn
|
||||
host string
|
||||
}
|
||||
|
||||
func (t *utlsRefreshTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
host := req.URL.Hostname()
|
||||
port := req.URL.Port()
|
||||
if port == "" {
|
||||
port = "443"
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
if t.conn != nil && t.host == host && t.conn.CanTakeNewRequest() {
|
||||
conn := t.conn
|
||||
t.mu.Unlock()
|
||||
resp, err := conn.RoundTrip(req)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
t.mu.Lock()
|
||||
t.conn = nil
|
||||
t.mu.Unlock()
|
||||
} else {
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
addr := net.JoinHostPort(host, port)
|
||||
rawConn, err := net.DialTimeout("tcp", addr, 10*time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConn := tls.UClient(rawConn, &tls.Config{ServerName: host}, tls.HelloChrome_Auto)
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
rawConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h2Conn, err := (&http2.Transport{}).NewClientConn(tlsConn)
|
||||
if err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
t.conn = h2Conn
|
||||
t.host = host
|
||||
t.mu.Unlock()
|
||||
|
||||
return h2Conn.RoundTrip(req)
|
||||
}
|
||||
|
||||
func StartBackgroundRefresh(ctx context.Context, pool *Pool) {
|
||||
go func() {
|
||||
for {
|
||||
|
||||
@@ -1,318 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewPool(t *testing.T) {
|
||||
creds := []*Credential{
|
||||
{ID: "a", AccessToken: "tok-a"},
|
||||
{ID: "b", AccessToken: "tok-b"},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
if p == nil {
|
||||
t.Fatal("NewPool returned nil")
|
||||
}
|
||||
if len(p.creds) != 2 {
|
||||
t.Errorf("pool has %d creds, want 2", len(p.creds))
|
||||
}
|
||||
if p.cursor != 0 {
|
||||
t.Errorf("initial cursor = %d, want 0", p.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Pick_EmptyPool(t *testing.T) {
|
||||
p := NewPool(nil)
|
||||
_, err := p.Pick()
|
||||
if err == nil {
|
||||
t.Fatal("expected error from empty pool, got nil")
|
||||
}
|
||||
want := "no credentials available"
|
||||
if err.Error() != want {
|
||||
t.Errorf("error = %q, want %q", err.Error(), want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Pick_SingleCredential(t *testing.T) {
|
||||
cred := &Credential{ID: "only", AccessToken: "tok-only"}
|
||||
p := NewPool([]*Credential{cred})
|
||||
|
||||
got, err := p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() error = %v", err)
|
||||
}
|
||||
if got.ID != "only" {
|
||||
t.Errorf("Pick() returned cred ID %q, want %q", got.ID, "only")
|
||||
}
|
||||
|
||||
// Picking again should return the same credential
|
||||
got2, err := p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("second Pick() error = %v", err)
|
||||
}
|
||||
if got2.ID != "only" {
|
||||
t.Errorf("second Pick() returned cred ID %q, want %q", got2.ID, "only")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Pick_RoundRobin(t *testing.T) {
|
||||
creds := []*Credential{
|
||||
{ID: "a"},
|
||||
{ID: "b"},
|
||||
{ID: "c"},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
|
||||
// Should cycle through a, b, c, a, b, c
|
||||
expected := []string{"a", "b", "c", "a", "b", "c"}
|
||||
for i, want := range expected {
|
||||
got, err := p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||
}
|
||||
if got.ID != want {
|
||||
t.Errorf("Pick() #%d = %q, want %q", i, got.ID, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Pick_SkipsCooldown(t *testing.T) {
|
||||
creds := []*Credential{
|
||||
{ID: "a"},
|
||||
{ID: "b", cooldownUntil: time.Now().Add(1 * time.Hour)},
|
||||
{ID: "c"},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
|
||||
// First pick: "a" (index 0, not on cooldown)
|
||||
got, err := p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #1 error = %v", err)
|
||||
}
|
||||
if got.ID != "a" {
|
||||
t.Errorf("Pick() #1 = %q, want %q", got.ID, "a")
|
||||
}
|
||||
|
||||
// Second pick: cursor at 1, but "b" is on cooldown → skip to "c"
|
||||
got, err = p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #2 error = %v", err)
|
||||
}
|
||||
if got.ID != "c" {
|
||||
t.Errorf("Pick() #2 = %q, want %q", got.ID, "c")
|
||||
}
|
||||
|
||||
// Third pick: cursor advanced past "c" to 0 → "a"
|
||||
got, err = p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #3 error = %v", err)
|
||||
}
|
||||
if got.ID != "a" {
|
||||
t.Errorf("Pick() #3 = %q, want %q", got.ID, "a")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Pick_AllOnCooldown(t *testing.T) {
|
||||
future := time.Now().Add(1 * time.Hour)
|
||||
creds := []*Credential{
|
||||
{ID: "a", cooldownUntil: future},
|
||||
{ID: "b", cooldownUntil: future},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
|
||||
_, err := p.Pick()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when all on cooldown, got nil")
|
||||
}
|
||||
want := "all 2 credentials are on cooldown"
|
||||
if err.Error() != want {
|
||||
t.Errorf("error = %q, want %q", err.Error(), want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_MarkFailure(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
expectCooldown bool
|
||||
expectedDur time.Duration
|
||||
}{
|
||||
{
|
||||
name: "429 sets 30s cooldown",
|
||||
statusCode: 429,
|
||||
expectCooldown: true,
|
||||
expectedDur: 30 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "500 sets 5s cooldown",
|
||||
statusCode: 500,
|
||||
expectCooldown: true,
|
||||
expectedDur: 5 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "502 sets 5s cooldown",
|
||||
statusCode: 502,
|
||||
expectCooldown: true,
|
||||
expectedDur: 5 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "503 sets 5s cooldown",
|
||||
statusCode: 503,
|
||||
expectCooldown: true,
|
||||
expectedDur: 5 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "400 does NOT set cooldown",
|
||||
statusCode: 400,
|
||||
expectCooldown: false,
|
||||
},
|
||||
{
|
||||
name: "401 does NOT set cooldown",
|
||||
statusCode: 401,
|
||||
expectCooldown: false,
|
||||
},
|
||||
{
|
||||
name: "403 does NOT set cooldown",
|
||||
statusCode: 403,
|
||||
expectCooldown: false,
|
||||
},
|
||||
{
|
||||
name: "404 does NOT set cooldown",
|
||||
statusCode: 404,
|
||||
expectCooldown: false,
|
||||
},
|
||||
{
|
||||
name: "422 does NOT set cooldown",
|
||||
statusCode: 422,
|
||||
expectCooldown: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cred := &Credential{ID: "test"}
|
||||
p := NewPool([]*Credential{cred})
|
||||
|
||||
before := time.Now()
|
||||
p.MarkFailure(cred, tt.statusCode)
|
||||
|
||||
if tt.expectCooldown {
|
||||
if !cred.IsOnCooldown() {
|
||||
t.Errorf("expected cooldown after status %d", tt.statusCode)
|
||||
}
|
||||
// Verify approximate duration
|
||||
cred.mu.Lock()
|
||||
cooldownEnd := cred.cooldownUntil
|
||||
cred.mu.Unlock()
|
||||
|
||||
lower := before.Add(tt.expectedDur)
|
||||
upper := time.Now().Add(tt.expectedDur)
|
||||
if cooldownEnd.Before(lower) || cooldownEnd.After(upper) {
|
||||
t.Errorf("cooldownUntil %v not in expected range [%v, %v]", cooldownEnd, lower, upper)
|
||||
}
|
||||
} else {
|
||||
if cred.IsOnCooldown() {
|
||||
t.Errorf("did not expect cooldown after status %d", tt.statusCode)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_MarkSuccess(t *testing.T) {
|
||||
cred := &Credential{
|
||||
ID: "test",
|
||||
cooldownUntil: time.Now().Add(1 * time.Hour),
|
||||
}
|
||||
p := NewPool([]*Credential{cred})
|
||||
|
||||
if !cred.IsOnCooldown() {
|
||||
t.Fatal("precondition: expected credential to be on cooldown")
|
||||
}
|
||||
|
||||
p.MarkSuccess(cred)
|
||||
|
||||
if cred.IsOnCooldown() {
|
||||
t.Error("expected cooldown to be cleared after MarkSuccess")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_RoundRobinCursorAdvancement(t *testing.T) {
|
||||
creds := []*Credential{
|
||||
{ID: "0"},
|
||||
{ID: "1"},
|
||||
{ID: "2"},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
|
||||
// Verify cursor starts at 0
|
||||
if p.cursor != 0 {
|
||||
t.Fatalf("initial cursor = %d, want 0", p.cursor)
|
||||
}
|
||||
|
||||
// Pick cred[0], cursor should advance to 1
|
||||
got, _ := p.Pick()
|
||||
if got.ID != "0" {
|
||||
t.Errorf("first pick = %q, want %q", got.ID, "0")
|
||||
}
|
||||
if p.cursor != 1 {
|
||||
t.Errorf("cursor after first pick = %d, want 1", p.cursor)
|
||||
}
|
||||
|
||||
// Pick cred[1], cursor should advance to 2
|
||||
got, _ = p.Pick()
|
||||
if got.ID != "1" {
|
||||
t.Errorf("second pick = %q, want %q", got.ID, "1")
|
||||
}
|
||||
if p.cursor != 2 {
|
||||
t.Errorf("cursor after second pick = %d, want 2", p.cursor)
|
||||
}
|
||||
|
||||
// Pick cred[2], cursor should wrap to 0
|
||||
got, _ = p.Pick()
|
||||
if got.ID != "2" {
|
||||
t.Errorf("third pick = %q, want %q", got.ID, "2")
|
||||
}
|
||||
if p.cursor != 0 {
|
||||
t.Errorf("cursor after third pick = %d, want 0 (wrap)", p.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_RoundRobinWithCooldownSkip(t *testing.T) {
|
||||
creds := []*Credential{
|
||||
{ID: "0"},
|
||||
{ID: "1", cooldownUntil: time.Now().Add(1 * time.Hour)},
|
||||
{ID: "2"},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
|
||||
// First pick: cred[0]
|
||||
got, _ := p.Pick()
|
||||
if got.ID != "0" {
|
||||
t.Errorf("first pick = %q, want %q", got.ID, "0")
|
||||
}
|
||||
// Cursor should be at 1
|
||||
if p.cursor != 1 {
|
||||
t.Errorf("cursor after first pick = %d, want 1", p.cursor)
|
||||
}
|
||||
|
||||
// Second pick: cursor at 1, but cred[1] on cooldown → skip to cred[2]
|
||||
got, _ = p.Pick()
|
||||
if got.ID != "2" {
|
||||
t.Errorf("second pick = %q, want %q", got.ID, "2")
|
||||
}
|
||||
// Cursor should advance past cred[2] to 0
|
||||
if p.cursor != 0 {
|
||||
t.Errorf("cursor after second pick (skip) = %d, want 0", p.cursor)
|
||||
}
|
||||
|
||||
// Third pick: cursor at 0, cred[0] available
|
||||
got, _ = p.Pick()
|
||||
if got.ID != "0" {
|
||||
t.Errorf("third pick = %q, want %q", got.ID, "0")
|
||||
}
|
||||
if p.cursor != 1 {
|
||||
t.Errorf("cursor after third pick = %d, want 1", p.cursor)
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,7 @@ type Credential struct {
|
||||
RefreshToken string
|
||||
ExpiresAt time.Time
|
||||
FilePath string
|
||||
cooldownUntil time.Time
|
||||
CooldownUntil time.Time
|
||||
nextRefreshAfter time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
@@ -22,21 +22,21 @@ type Credential struct {
|
||||
func (c *Credential) IsOnCooldown() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return time.Now().Before(c.cooldownUntil)
|
||||
return time.Now().Before(c.CooldownUntil)
|
||||
}
|
||||
|
||||
// SetCooldown puts the credential on cooldown for the given duration.
|
||||
func (c *Credential) SetCooldown(duration time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.cooldownUntil = time.Now().Add(duration)
|
||||
c.CooldownUntil = time.Now().Add(duration)
|
||||
}
|
||||
|
||||
// ClearCooldown removes any active cooldown on the credential.
|
||||
func (c *Credential) ClearCooldown() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.cooldownUntil = time.Time{}
|
||||
c.CooldownUntil = time.Time{}
|
||||
}
|
||||
|
||||
// Token returns the current access token.
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCredential_IsOnCooldown(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cooldownUntil time.Time
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "zero time — not on cooldown",
|
||||
cooldownUntil: time.Time{},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "future time — on cooldown",
|
||||
cooldownUntil: time.Now().Add(1 * time.Hour),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "past time — expired cooldown",
|
||||
cooldownUntil: time.Now().Add(-1 * time.Hour),
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Credential{cooldownUntil: tt.cooldownUntil}
|
||||
got := c.IsOnCooldown()
|
||||
if got != tt.want {
|
||||
t.Errorf("IsOnCooldown() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCredential_SetCooldown(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
duration time.Duration
|
||||
}{
|
||||
{name: "30 second cooldown", duration: 30 * time.Second},
|
||||
{name: "5 second cooldown", duration: 5 * time.Second},
|
||||
{name: "1 minute cooldown", duration: 1 * time.Minute},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Credential{}
|
||||
before := time.Now()
|
||||
c.SetCooldown(tt.duration)
|
||||
after := time.Now()
|
||||
|
||||
// cooldownUntil should be between before+duration and after+duration
|
||||
if c.cooldownUntil.Before(before.Add(tt.duration)) {
|
||||
t.Errorf("cooldownUntil %v is before expected lower bound %v", c.cooldownUntil, before.Add(tt.duration))
|
||||
}
|
||||
if c.cooldownUntil.After(after.Add(tt.duration)) {
|
||||
t.Errorf("cooldownUntil %v is after expected upper bound %v", c.cooldownUntil, after.Add(tt.duration))
|
||||
}
|
||||
|
||||
// Should now be on cooldown
|
||||
if !c.IsOnCooldown() {
|
||||
t.Error("expected credential to be on cooldown after SetCooldown")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCredential_ClearCooldown(t *testing.T) {
|
||||
t.Run("clears active cooldown", func(t *testing.T) {
|
||||
c := &Credential{cooldownUntil: time.Now().Add(1 * time.Hour)}
|
||||
if !c.IsOnCooldown() {
|
||||
t.Fatal("precondition: expected credential to be on cooldown")
|
||||
}
|
||||
|
||||
c.ClearCooldown()
|
||||
|
||||
if c.IsOnCooldown() {
|
||||
t.Error("expected credential to not be on cooldown after ClearCooldown")
|
||||
}
|
||||
if !c.cooldownUntil.IsZero() {
|
||||
t.Errorf("expected cooldownUntil to be zero time, got %v", c.cooldownUntil)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("clearing when not on cooldown is no-op", func(t *testing.T) {
|
||||
c := &Credential{}
|
||||
c.ClearCooldown()
|
||||
|
||||
if c.IsOnCooldown() {
|
||||
t.Error("expected credential to not be on cooldown")
|
||||
}
|
||||
if !c.cooldownUntil.IsZero() {
|
||||
t.Errorf("expected cooldownUntil to be zero time, got %v", c.cooldownUntil)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCredential_Token(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{name: "returns access token", token: "sk-ant-abc123"},
|
||||
{name: "empty token", token: ""},
|
||||
{name: "long token", token: "sk-ant-" + string(make([]byte, 200))},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Credential{AccessToken: tt.token}
|
||||
got := c.Token()
|
||||
if got != tt.token {
|
||||
t.Errorf("Token() = %q, want %q", got, tt.token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCredential_ConcurrentAccess(t *testing.T) {
|
||||
c := &Credential{
|
||||
AccessToken: "initial-token",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 50
|
||||
|
||||
// Spawn goroutines that concurrently read and write
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(3)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = c.Token()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
c.SetCooldown(1 * time.Second)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = c.IsOnCooldown()
|
||||
}()
|
||||
}
|
||||
|
||||
// Also mix in ClearCooldown calls
|
||||
for i := 0; i < goroutines/2; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
c.ClearCooldown()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// If we get here without -race detecting issues, mutex is working
|
||||
}
|
||||
+61
-45
@@ -1,19 +1,21 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/auth"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Port int `yaml:"port"`
|
||||
APIKeys []string `yaml:"api_keys"`
|
||||
ClaudeBinary string `yaml:"claude_binary"`
|
||||
Sanitize SanitizeConfig `yaml:"sanitize"`
|
||||
Logging LoggingConfig `yaml:"logging"`
|
||||
Telemetry TelemetryConfig `yaml:"telemetry"`
|
||||
Port int `yaml:"port"`
|
||||
APIKeys []string `yaml:"api_keys"`
|
||||
ClaudeBinary string `yaml:"claude_binary"`
|
||||
Sanitize SanitizeConfig `yaml:"sanitize"`
|
||||
Logging LoggingConfig `yaml:"logging"`
|
||||
}
|
||||
|
||||
type SanitizeConfig struct {
|
||||
@@ -32,29 +34,6 @@ type ReplaceRule struct {
|
||||
Replace string `yaml:"replace"`
|
||||
}
|
||||
|
||||
type TelemetryConfig struct {
|
||||
Export ExportConfig `yaml:"export"`
|
||||
Embedded EmbeddedConfig `yaml:"embedded"`
|
||||
ServiceName string `yaml:"service_name"`
|
||||
}
|
||||
|
||||
type ExportConfig struct {
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
Insecure bool `yaml:"insecure"`
|
||||
Headers map[string]string `yaml:"headers"`
|
||||
}
|
||||
|
||||
func (e ExportConfig) Enabled() bool { return e.Endpoint != "" }
|
||||
|
||||
type EmbeddedConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Port int `yaml:"port"`
|
||||
PersesBinary string `yaml:"perses_binary"`
|
||||
VMBinary string `yaml:"vm_binary"`
|
||||
VMPort int `yaml:"vm_port"`
|
||||
BinDir string `yaml:"bin_dir"`
|
||||
}
|
||||
|
||||
type LoggingConfig struct {
|
||||
Level string `yaml:"level"`
|
||||
File string `yaml:"file"`
|
||||
@@ -64,6 +43,15 @@ type LoggingConfig struct {
|
||||
Compress bool `yaml:"compress"`
|
||||
}
|
||||
|
||||
type claudeCredentialsJSON struct {
|
||||
ClaudeAiOauth struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpiresAt int64 `json:"expiresAt"`
|
||||
SubscriptionType string `json:"subscriptionType"`
|
||||
} `json:"claudeAiOauth"`
|
||||
}
|
||||
|
||||
func Load(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
@@ -88,22 +76,6 @@ func Load(path string) (*Config, error) {
|
||||
cfg.Logging.MaxAgeDays = 30
|
||||
}
|
||||
|
||||
if cfg.Telemetry.ServiceName == "" {
|
||||
cfg.Telemetry.ServiceName = "anthropic-proxy"
|
||||
}
|
||||
if cfg.Telemetry.Embedded.Port == 0 {
|
||||
cfg.Telemetry.Embedded.Port = 8080
|
||||
}
|
||||
if cfg.Telemetry.Embedded.PersesBinary == "" {
|
||||
cfg.Telemetry.Embedded.PersesBinary = "perses"
|
||||
}
|
||||
if cfg.Telemetry.Embedded.VMBinary == "" {
|
||||
cfg.Telemetry.Embedded.VMBinary = "victoria-metrics"
|
||||
}
|
||||
if cfg.Telemetry.Embedded.VMPort == 0 {
|
||||
cfg.Telemetry.Embedded.VMPort = 8428
|
||||
}
|
||||
|
||||
// Check for deprecated claude_credentials field
|
||||
var rawCfg map[string]interface{}
|
||||
if err := yaml.Unmarshal(data, &rawCfg); err == nil {
|
||||
@@ -116,3 +88,47 @@ func Load(path string) (*Config, error) {
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func DefaultCredentialPath() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return home + "/.claude/.credentials.json"
|
||||
}
|
||||
|
||||
func LoadDefaultCredentials() ([]*auth.Credential, error) {
|
||||
path := DefaultCredentialPath()
|
||||
if path == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cf claudeCredentialsJSON
|
||||
if err := json.Unmarshal(data, &cf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oauth := cf.ClaudeAiOauth
|
||||
if oauth.AccessToken == "" {
|
||||
return nil, fmt.Errorf("no access token in %s", path)
|
||||
}
|
||||
|
||||
cred := &auth.Credential{
|
||||
ID: "claude-native",
|
||||
Email: oauth.SubscriptionType,
|
||||
AccessToken: oauth.AccessToken,
|
||||
RefreshToken: oauth.RefreshToken,
|
||||
ExpiresAt: time.UnixMilli(oauth.ExpiresAt),
|
||||
FilePath: path,
|
||||
}
|
||||
|
||||
return []*auth.Credential{cred}, nil
|
||||
}
|
||||
|
||||
@@ -1,270 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoad_AllFields(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
|
||||
yaml := `
|
||||
port: 9090
|
||||
api_keys:
|
||||
- key1
|
||||
- key2
|
||||
claude_binary: /usr/bin/claude
|
||||
sanitize:
|
||||
tools:
|
||||
- from: tool_a
|
||||
to: tool_b
|
||||
system:
|
||||
- match: foo
|
||||
replace: bar
|
||||
body:
|
||||
- match: baz
|
||||
replace: qux
|
||||
logging:
|
||||
level: debug
|
||||
file: /tmp/test.log
|
||||
max_size_mb: 50
|
||||
max_backups: 3
|
||||
max_age_days: 7
|
||||
compress: true
|
||||
telemetry:
|
||||
service_name: my-proxy
|
||||
export:
|
||||
endpoint: http://localhost:4317
|
||||
insecure: true
|
||||
headers:
|
||||
x-token: abc
|
||||
embedded:
|
||||
enabled: true
|
||||
port: 9999
|
||||
perses_binary: /usr/bin/perses
|
||||
vm_binary: /usr/bin/vm
|
||||
vm_port: 9428
|
||||
bin_dir: /opt/bin
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(yaml), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load returned error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Port != 9090 {
|
||||
t.Errorf("Port = %d, want 9090", cfg.Port)
|
||||
}
|
||||
if len(cfg.APIKeys) != 2 || cfg.APIKeys[0] != "key1" || cfg.APIKeys[1] != "key2" {
|
||||
t.Errorf("APIKeys = %v, want [key1 key2]", cfg.APIKeys)
|
||||
}
|
||||
if cfg.ClaudeBinary != "/usr/bin/claude" {
|
||||
t.Errorf("ClaudeBinary = %q, want /usr/bin/claude", cfg.ClaudeBinary)
|
||||
}
|
||||
|
||||
// Sanitize
|
||||
if len(cfg.Sanitize.Tools) != 1 || cfg.Sanitize.Tools[0].From != "tool_a" || cfg.Sanitize.Tools[0].To != "tool_b" {
|
||||
t.Errorf("Sanitize.Tools = %v", cfg.Sanitize.Tools)
|
||||
}
|
||||
if len(cfg.Sanitize.System) != 1 || cfg.Sanitize.System[0].Match != "foo" {
|
||||
t.Errorf("Sanitize.System = %v", cfg.Sanitize.System)
|
||||
}
|
||||
if len(cfg.Sanitize.Body) != 1 || cfg.Sanitize.Body[0].Match != "baz" {
|
||||
t.Errorf("Sanitize.Body = %v", cfg.Sanitize.Body)
|
||||
}
|
||||
|
||||
// Logging
|
||||
if cfg.Logging.Level != "debug" {
|
||||
t.Errorf("Logging.Level = %q, want debug", cfg.Logging.Level)
|
||||
}
|
||||
if cfg.Logging.File != "/tmp/test.log" {
|
||||
t.Errorf("Logging.File = %q", cfg.Logging.File)
|
||||
}
|
||||
if cfg.Logging.MaxSizeMB != 50 {
|
||||
t.Errorf("Logging.MaxSizeMB = %d, want 50", cfg.Logging.MaxSizeMB)
|
||||
}
|
||||
if cfg.Logging.MaxBackups != 3 {
|
||||
t.Errorf("Logging.MaxBackups = %d, want 3", cfg.Logging.MaxBackups)
|
||||
}
|
||||
if cfg.Logging.MaxAgeDays != 7 {
|
||||
t.Errorf("Logging.MaxAgeDays = %d, want 7", cfg.Logging.MaxAgeDays)
|
||||
}
|
||||
if !cfg.Logging.Compress {
|
||||
t.Error("Logging.Compress = false, want true")
|
||||
}
|
||||
|
||||
// Telemetry
|
||||
if cfg.Telemetry.ServiceName != "my-proxy" {
|
||||
t.Errorf("Telemetry.ServiceName = %q, want my-proxy", cfg.Telemetry.ServiceName)
|
||||
}
|
||||
if cfg.Telemetry.Export.Endpoint != "http://localhost:4317" {
|
||||
t.Errorf("Export.Endpoint = %q", cfg.Telemetry.Export.Endpoint)
|
||||
}
|
||||
if !cfg.Telemetry.Export.Insecure {
|
||||
t.Error("Export.Insecure = false, want true")
|
||||
}
|
||||
if !cfg.Telemetry.Export.Enabled() {
|
||||
t.Error("Export.Enabled() = false, want true")
|
||||
}
|
||||
if cfg.Telemetry.Export.Headers["x-token"] != "abc" {
|
||||
t.Errorf("Export.Headers = %v", cfg.Telemetry.Export.Headers)
|
||||
}
|
||||
|
||||
// Embedded
|
||||
if !cfg.Telemetry.Embedded.Enabled {
|
||||
t.Error("Embedded.Enabled = false, want true")
|
||||
}
|
||||
if cfg.Telemetry.Embedded.Port != 9999 {
|
||||
t.Errorf("Embedded.Port = %d, want 9999", cfg.Telemetry.Embedded.Port)
|
||||
}
|
||||
if cfg.Telemetry.Embedded.PersesBinary != "/usr/bin/perses" {
|
||||
t.Errorf("Embedded.PersesBinary = %q", cfg.Telemetry.Embedded.PersesBinary)
|
||||
}
|
||||
if cfg.Telemetry.Embedded.VMBinary != "/usr/bin/vm" {
|
||||
t.Errorf("Embedded.VMBinary = %q", cfg.Telemetry.Embedded.VMBinary)
|
||||
}
|
||||
if cfg.Telemetry.Embedded.VMPort != 9428 {
|
||||
t.Errorf("Embedded.VMPort = %d, want 9428", cfg.Telemetry.Embedded.VMPort)
|
||||
}
|
||||
if cfg.Telemetry.Embedded.BinDir != "/opt/bin" {
|
||||
t.Errorf("Embedded.BinDir = %q", cfg.Telemetry.Embedded.BinDir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_Defaults(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
|
||||
// Minimal YAML — only api_keys
|
||||
if err := os.WriteFile(path, []byte("api_keys:\n - k1\n"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load returned error: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
want interface{}
|
||||
}{
|
||||
{"Port", cfg.Port, 8080},
|
||||
{"Logging.Level", cfg.Logging.Level, "info"},
|
||||
{"Logging.MaxSizeMB", cfg.Logging.MaxSizeMB, 100},
|
||||
{"Logging.MaxBackups", cfg.Logging.MaxBackups, 5},
|
||||
{"Logging.MaxAgeDays", cfg.Logging.MaxAgeDays, 30},
|
||||
{"Telemetry.ServiceName", cfg.Telemetry.ServiceName, "anthropic-proxy"},
|
||||
{"Embedded.Port", cfg.Telemetry.Embedded.Port, 8080},
|
||||
{"Embedded.VMBinary", cfg.Telemetry.Embedded.VMBinary, "victoria-metrics"},
|
||||
{"Embedded.PersesBinary", cfg.Telemetry.Embedded.PersesBinary, "perses"},
|
||||
{"Embedded.VMPort", cfg.Telemetry.Embedded.VMPort, 8428},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.got != tt.want {
|
||||
t.Errorf("got %v, want %v", tt.got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_MissingFile(t *testing.T) {
|
||||
_, err := Load("/nonexistent/path/config.yaml")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "read config") {
|
||||
t.Errorf("error = %q, want it to contain 'read config'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_DeprecatedClaudeCredentials(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
|
||||
yaml := `
|
||||
api_keys:
|
||||
- k1
|
||||
claude_credentials: "/some/path"
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(yaml), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for deprecated claude_credentials, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no longer supported") {
|
||||
t.Errorf("error = %q, want it to contain 'no longer supported'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_EmptyClaudeCredentials(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
|
||||
// Empty string value should NOT trigger the deprecation error
|
||||
yaml := `
|
||||
api_keys:
|
||||
- k1
|
||||
claude_credentials: ""
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(yaml), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("empty claude_credentials should not error: %v", err)
|
||||
}
|
||||
if cfg.Port != 8080 {
|
||||
t.Errorf("Port = %d, want 8080", cfg.Port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_InvalidYAML(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
|
||||
// Truly invalid YAML that causes a parse error
|
||||
if err := os.WriteFile(path, []byte("port:\n - bad\n indent: broken\n"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid YAML, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "parse config") {
|
||||
t.Errorf("error = %q, want it to contain 'parse config'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportConfig_Enabled(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
endpoint string
|
||||
want bool
|
||||
}{
|
||||
{"empty endpoint", "", false},
|
||||
{"set endpoint", "http://localhost:4317", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := ExportConfig{Endpoint: tt.endpoint}
|
||||
if got := e.Enabled(); got != tt.want {
|
||||
t.Errorf("Enabled() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package embedded
|
||||
|
||||
import (
|
||||
"embed"
|
||||
)
|
||||
|
||||
//go:embed dashboard/proxy.json
|
||||
var dashboardFS embed.FS
|
||||
|
||||
func DashboardJSON() ([]byte, error) {
|
||||
return dashboardFS.ReadFile("dashboard/proxy.json")
|
||||
}
|
||||
@@ -1,450 +0,0 @@
|
||||
{
|
||||
"kind": "Dashboard",
|
||||
"metadata": {
|
||||
"name": "proxy",
|
||||
"createdAt": "2026-04-14T19:47:48.013238204Z",
|
||||
"updatedAt": "2026-04-14T19:49:30.874125459Z",
|
||||
"version": 1,
|
||||
"project": "anthropic-proxy"
|
||||
},
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "Anthropic Proxy"
|
||||
},
|
||||
"datasources": {
|
||||
"vm": {
|
||||
"default": true,
|
||||
"plugin": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"spec": {
|
||||
"directUrl": "http://localhost:9428"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"panels": {
|
||||
"latency": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "Latency"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "TimeSeriesChart",
|
||||
"spec": {
|
||||
"legend": {
|
||||
"position": "bottom"
|
||||
},
|
||||
"yAxis": {
|
||||
"format": {
|
||||
"unit": "milliseconds"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "histogram_quantile(0.50, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
|
||||
"seriesNameFormat": "p50"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "histogram_quantile(0.95, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
|
||||
"seriesNameFormat": "p95"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "histogram_quantile(0.99, rate(proxy_request_duration_ms_milliseconds_bucket[5m]))",
|
||||
"seriesNameFormat": "p99"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"request_rate": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "Request Rate"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "TimeSeriesChart",
|
||||
"spec": {
|
||||
"legend": {
|
||||
"position": "bottom"
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "rate(proxy_request_count_total[5m])",
|
||||
"seriesNameFormat": "req/s"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"token_rate": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "Token Rate"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "TimeSeriesChart",
|
||||
"spec": {
|
||||
"legend": {
|
||||
"position": "bottom"
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "rate(proxy_tokens_input_total[5m]) * 60",
|
||||
"seriesNameFormat": "input/min"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "rate(proxy_tokens_output_total[5m]) * 60",
|
||||
"seriesNameFormat": "output/min"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"tokens_5h": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "5h Tokens"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "StatChart",
|
||||
"spec": {
|
||||
"calculation": "last",
|
||||
"format": {
|
||||
"unit": "decimal"
|
||||
},
|
||||
"sparkline": {}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "increase(proxy_tokens_output_total[3h])"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"tokens_7d": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "7d Tokens"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "StatChart",
|
||||
"spec": {
|
||||
"calculation": "last",
|
||||
"format": {
|
||||
"unit": "decimal"
|
||||
},
|
||||
"sparkline": {}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "increase(proxy_tokens_output_total[9h])"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"util_5h": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "5h Utilization"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "GaugeChart",
|
||||
"spec": {
|
||||
"calculation": "last",
|
||||
"format": {
|
||||
"unit": "percent"
|
||||
},
|
||||
"thresholds": {
|
||||
"steps": [
|
||||
{
|
||||
"color": "green",
|
||||
"value": 0
|
||||
},
|
||||
{
|
||||
"color": "orange",
|
||||
"value": 70
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 90
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "proxy_usage_utilization{window=\"5h\"}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"util_7d": {
|
||||
"kind": "Panel",
|
||||
"spec": {
|
||||
"display": {
|
||||
"name": "7d Utilization"
|
||||
},
|
||||
"plugin": {
|
||||
"kind": "GaugeChart",
|
||||
"spec": {
|
||||
"calculation": "last",
|
||||
"format": {
|
||||
"unit": "percent"
|
||||
},
|
||||
"thresholds": {
|
||||
"steps": [
|
||||
{
|
||||
"color": "green",
|
||||
"value": 0
|
||||
},
|
||||
{
|
||||
"color": "orange",
|
||||
"value": 70
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 90
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"queries": [
|
||||
{
|
||||
"kind": "TimeSeriesQuery",
|
||||
"spec": {
|
||||
"plugin": {
|
||||
"kind": "PrometheusTimeSeriesQuery",
|
||||
"spec": {
|
||||
"datasource": {
|
||||
"kind": "PrometheusDatasource",
|
||||
"name": "vm"
|
||||
},
|
||||
"query": "proxy_usage_utilization{window=\"7d\"}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"layouts": [
|
||||
{
|
||||
"kind": "Grid",
|
||||
"spec": {
|
||||
"display": {
|
||||
"title": "Utilization"
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": 6,
|
||||
"height": 5,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/util_5h"
|
||||
}
|
||||
},
|
||||
{
|
||||
"x": 6,
|
||||
"y": 0,
|
||||
"width": 6,
|
||||
"height": 5,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/util_7d"
|
||||
}
|
||||
},
|
||||
{
|
||||
"x": 12,
|
||||
"y": 0,
|
||||
"width": 6,
|
||||
"height": 5,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/tokens_5h"
|
||||
}
|
||||
},
|
||||
{
|
||||
"x": 18,
|
||||
"y": 0,
|
||||
"width": 6,
|
||||
"height": 5,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/tokens_7d"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "Grid",
|
||||
"spec": {
|
||||
"display": {
|
||||
"title": "Traffic"
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": 12,
|
||||
"height": 8,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/request_rate"
|
||||
}
|
||||
},
|
||||
{
|
||||
"x": 12,
|
||||
"y": 0,
|
||||
"width": 12,
|
||||
"height": 8,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/latency"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"kind": "Grid",
|
||||
"spec": {
|
||||
"display": {
|
||||
"title": "Tokens"
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": 24,
|
||||
"height": 8,
|
||||
"content": {
|
||||
"$ref": "#/spec/panels/token_rate"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"duration": "1h",
|
||||
"refreshInterval": "10s"
|
||||
}
|
||||
}
|
||||
@@ -1,155 +0,0 @@
|
||||
package embedded
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const cacheDir = ".cache/anthropic-proxy/bin"
|
||||
|
||||
var downloads = map[string]struct {
|
||||
urlTemplate string
|
||||
version string
|
||||
extractName string
|
||||
}{
|
||||
"victoria-metrics": {
|
||||
urlTemplate: "https://github.com/VictoriaMetrics/VictoriaMetrics/releases/download/v%s/victoria-metrics-%s-v%s.tar.gz",
|
||||
version: "1.118.0",
|
||||
extractName: "victoria-metrics-prod",
|
||||
},
|
||||
"perses": {
|
||||
urlTemplate: "https://github.com/perses/perses/releases/download/v%s/perses_%s_%s_%s.tar.gz",
|
||||
version: "0.53.1",
|
||||
},
|
||||
}
|
||||
|
||||
func ensureBinary(name, configPath, configBinDir string) (string, error) {
|
||||
if configPath != "" {
|
||||
if p, err := exec.LookPath(configPath); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
|
||||
if p, err := exec.LookPath(name); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
binDir := configBinDir
|
||||
if binDir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get home dir: %w", err)
|
||||
}
|
||||
binDir = filepath.Join(home, cacheDir)
|
||||
}
|
||||
cachedPath := filepath.Join(binDir, name)
|
||||
|
||||
if _, err := os.Stat(cachedPath); err == nil {
|
||||
return cachedPath, nil
|
||||
}
|
||||
|
||||
log.Info().Str("binary", name).Msg("downloading binary (first run)")
|
||||
|
||||
if err := os.MkdirAll(binDir, 0o755); err != nil {
|
||||
return "", fmt.Errorf("create cache dir: %w", err)
|
||||
}
|
||||
|
||||
url, err := downloadURL(name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := extractAll(url, binDir); err != nil {
|
||||
return "", fmt.Errorf("download %s: %w", name, err)
|
||||
}
|
||||
|
||||
d := downloads[name]
|
||||
if d.extractName != "" {
|
||||
oldPath := filepath.Join(binDir, d.extractName)
|
||||
if _, err := os.Stat(oldPath); err == nil {
|
||||
os.Rename(oldPath, cachedPath)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := os.Stat(cachedPath); err != nil {
|
||||
return "", fmt.Errorf("binary %s not found after extraction", name)
|
||||
}
|
||||
|
||||
log.Info().Str("binary", name).Str("path", cachedPath).Msg("binary downloaded")
|
||||
return cachedPath, nil
|
||||
}
|
||||
|
||||
func downloadURL(name string) (string, error) {
|
||||
goarch := runtime.GOARCH
|
||||
goos := runtime.GOOS
|
||||
|
||||
d, ok := downloads[name]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unknown binary: %s", name)
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "victoria-metrics":
|
||||
vmOS := fmt.Sprintf("%s-%s", goos, goarch)
|
||||
return fmt.Sprintf(d.urlTemplate, d.version, vmOS, d.version), nil
|
||||
case "perses":
|
||||
return fmt.Sprintf(d.urlTemplate, d.version, d.version, goos, goarch), nil
|
||||
}
|
||||
return "", fmt.Errorf("unknown binary: %s", name)
|
||||
}
|
||||
|
||||
func extractAll(url, destDir string) error {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("download failed: HTTP %d from %s", resp.StatusCode, url)
|
||||
}
|
||||
|
||||
gz, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gzip reader: %w", err)
|
||||
}
|
||||
defer gz.Close()
|
||||
|
||||
tr := tar.NewReader(gz)
|
||||
for {
|
||||
hdr, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("read tar: %w", err)
|
||||
}
|
||||
|
||||
target := filepath.Join(destDir, hdr.Name)
|
||||
switch hdr.Typeflag {
|
||||
case tar.TypeDir:
|
||||
os.MkdirAll(target, 0o755)
|
||||
case tar.TypeReg:
|
||||
os.MkdirAll(filepath.Dir(target), 0o755)
|
||||
mode := os.FileMode(hdr.Mode)
|
||||
if mode == 0 {
|
||||
mode = 0o644
|
||||
}
|
||||
out, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
io.Copy(out, tr)
|
||||
out.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package embedded
|
||||
|
||||
import "github.com/rs/zerolog/log"
|
||||
|
||||
// logWriter bridges subprocess stdout/stderr to zerolog.
|
||||
type logWriter struct {
|
||||
level string
|
||||
component string
|
||||
}
|
||||
|
||||
func (w *logWriter) Write(p []byte) (n int, err error) {
|
||||
msg := string(p)
|
||||
switch w.level {
|
||||
case "error":
|
||||
log.Error().Str("component", w.component).Msg(msg)
|
||||
default:
|
||||
log.Debug().Str("component", w.component).Msg(msg)
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
package embedded
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type Perses struct {
|
||||
cfg config.EmbeddedConfig
|
||||
proxyPort int
|
||||
cmd *exec.Cmd
|
||||
tmpDir string
|
||||
}
|
||||
|
||||
func NewPerses(cfg config.EmbeddedConfig, proxyPort int) *Perses {
|
||||
return &Perses{cfg: cfg, proxyPort: proxyPort}
|
||||
}
|
||||
|
||||
func (p *Perses) Start() error {
|
||||
bin, err := ensureBinary("perses", p.cfg.PersesBinary, p.cfg.BinDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("perses: %w", err)
|
||||
}
|
||||
|
||||
p.tmpDir, err = os.MkdirTemp("", "perses-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp dir: %w", err)
|
||||
}
|
||||
|
||||
if err := p.writeServerConfig(); err != nil {
|
||||
return fmt.Errorf("write server config: %w", err)
|
||||
}
|
||||
if err := p.writeDatasourceProvision(); err != nil {
|
||||
return fmt.Errorf("write datasource provision: %w", err)
|
||||
}
|
||||
if err := p.writeDashboardProvision(); err != nil {
|
||||
return fmt.Errorf("write dashboard provision: %w", err)
|
||||
}
|
||||
|
||||
p.cmd = exec.Command(bin,
|
||||
"--config", filepath.Join(p.tmpDir, "config.yaml"),
|
||||
"-web.listen-address", fmt.Sprintf(":%d", p.cfg.Port),
|
||||
)
|
||||
p.cmd.Dir = filepath.Dir(bin)
|
||||
p.cmd.Stdout = &logWriter{level: "info", component: "perses"}
|
||||
p.cmd.Stderr = &logWriter{level: "error", component: "perses"}
|
||||
|
||||
if err := p.cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start perses: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("binary", bin).
|
||||
Int("port", p.cfg.Port).
|
||||
Str("config", p.tmpDir).
|
||||
Msg("perses started")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Perses) Stop() {
|
||||
if p.cmd != nil && p.cmd.Process != nil {
|
||||
_ = p.cmd.Process.Kill()
|
||||
_ = p.cmd.Wait()
|
||||
}
|
||||
if p.tmpDir != "" {
|
||||
_ = os.RemoveAll(p.tmpDir)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Perses) Running() bool {
|
||||
return p.cmd != nil && p.cmd.Process != nil && p.cmd.ProcessState == nil
|
||||
}
|
||||
|
||||
func (p *Perses) writeServerConfig() error {
|
||||
provisionDir := filepath.Join(p.tmpDir, "provisions")
|
||||
if err := os.MkdirAll(filepath.Join(provisionDir, "datasources"), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Join(provisionDir, "dashboards"), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg := fmt.Sprintf(`provisioning:
|
||||
interval: 1m
|
||||
folders:
|
||||
- %s
|
||||
database:
|
||||
file:
|
||||
folder: %s/data
|
||||
extension: json
|
||||
security:
|
||||
readonly: false
|
||||
enable_auth: false
|
||||
`, provisionDir, p.tmpDir)
|
||||
|
||||
return os.WriteFile(filepath.Join(p.tmpDir, "config.yaml"), []byte(cfg), 0o644)
|
||||
}
|
||||
|
||||
func (p *Perses) writeDatasourceProvision() error {
|
||||
ds := fmt.Sprintf(`kind: Datasource
|
||||
metadata:
|
||||
name: victoria-metrics
|
||||
project: anthropic-proxy
|
||||
spec:
|
||||
default: true
|
||||
plugin:
|
||||
kind: PrometheusDatasource
|
||||
spec:
|
||||
directUrl: http://localhost:%d
|
||||
`, p.cfg.VMPort)
|
||||
|
||||
return os.WriteFile(
|
||||
filepath.Join(p.tmpDir, "provisions", "datasources", "vm.yaml"),
|
||||
[]byte(ds), 0o644,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Perses) writeDashboardProvision() error {
|
||||
dashData, err := DashboardJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(
|
||||
filepath.Join(p.tmpDir, "provisions", "dashboards", "proxy.json"),
|
||||
dashData, 0o644,
|
||||
)
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
package embedded
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type VM struct {
|
||||
cfg config.EmbeddedConfig
|
||||
proxyPort int
|
||||
cmd *exec.Cmd
|
||||
tmpDir string
|
||||
}
|
||||
|
||||
func NewVM(cfg config.EmbeddedConfig, proxyPort int) *VM {
|
||||
return &VM{cfg: cfg, proxyPort: proxyPort}
|
||||
}
|
||||
|
||||
func (v *VM) Start() error {
|
||||
bin, err := ensureBinary("victoria-metrics", v.cfg.VMBinary, v.cfg.BinDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("victoria-metrics: %w", err)
|
||||
}
|
||||
|
||||
v.tmpDir, err = os.MkdirTemp("", "vm-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp dir: %w", err)
|
||||
}
|
||||
|
||||
scrapeConfig := fmt.Sprintf(`global:
|
||||
scrape_interval: 15s
|
||||
scrape_configs:
|
||||
- job_name: anthropic-proxy
|
||||
static_configs:
|
||||
- targets:
|
||||
- localhost:%d
|
||||
`, v.proxyPort)
|
||||
|
||||
scrapePath := filepath.Join(v.tmpDir, "scrape.yaml")
|
||||
if err := os.WriteFile(scrapePath, []byte(scrapeConfig), 0o644); err != nil {
|
||||
return fmt.Errorf("write scrape config: %w", err)
|
||||
}
|
||||
|
||||
dataPath := filepath.Join(v.tmpDir, "data")
|
||||
if err := os.MkdirAll(dataPath, 0o755); err != nil {
|
||||
return fmt.Errorf("create data dir: %w", err)
|
||||
}
|
||||
|
||||
v.cmd = exec.Command(bin,
|
||||
"-storageDataPath", dataPath,
|
||||
"-retentionPeriod", "7d",
|
||||
"-httpListenAddr", fmt.Sprintf(":%d", v.cfg.VMPort),
|
||||
"-promscrape.config", scrapePath,
|
||||
)
|
||||
v.cmd.Stdout = &logWriter{level: "info", component: "victoria-metrics"}
|
||||
v.cmd.Stderr = &logWriter{level: "error", component: "victoria-metrics"}
|
||||
|
||||
if err := v.cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start victoria-metrics: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("binary", bin).
|
||||
Int("port", v.cfg.VMPort).
|
||||
Int("scrape_target_port", v.proxyPort).
|
||||
Msg("victoria-metrics started")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *VM) Stop() {
|
||||
if v.cmd != nil && v.cmd.Process != nil {
|
||||
_ = v.cmd.Process.Kill()
|
||||
_ = v.cmd.Wait()
|
||||
}
|
||||
if v.tmpDir != "" {
|
||||
_ = os.RemoveAll(v.tmpDir)
|
||||
}
|
||||
}
|
||||
|
||||
func (v *VM) Running() bool {
|
||||
return v.cmd != nil && v.cmd.Process != nil && v.cmd.ProcessState == nil
|
||||
}
|
||||
+15
-22
@@ -3,7 +3,6 @@ package logging
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -13,17 +12,23 @@ import (
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/lumberjack.v2"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
)
|
||||
|
||||
// Config holds logging configuration, mirrors config.LoggingConfig.
|
||||
type Config struct {
|
||||
Level string
|
||||
File string
|
||||
MaxSizeMB int
|
||||
MaxBackups int
|
||||
MaxAgeDays int
|
||||
Compress bool
|
||||
}
|
||||
|
||||
// Setup initializes the global zerolog logger.
|
||||
// - File set: JSON → lumberjack rotating file
|
||||
// - File empty + TTY: colored ConsoleWriter → stderr
|
||||
// - File empty + not TTY: JSON → stderr (for systemd journal)
|
||||
// Extra writers (e.g., OTLP log bridge) are added via io.MultiWriter so logs
|
||||
// are written to both the primary destination and any extra writers.
|
||||
func Setup(cfg config.LoggingConfig, extraWriters ...io.Writer) zerolog.Logger {
|
||||
func Setup(cfg Config) zerolog.Logger {
|
||||
// Parse log level
|
||||
level, err := zerolog.ParseLevel(cfg.Level)
|
||||
if err != nil || cfg.Level == "" {
|
||||
@@ -43,32 +48,20 @@ func Setup(cfg config.LoggingConfig, extraWriters ...io.Writer) zerolog.Logger {
|
||||
MaxAge: cfg.MaxAgeDays,
|
||||
Compress: cfg.Compress,
|
||||
}
|
||||
var w io.Writer = jack
|
||||
if len(extraWriters) > 0 {
|
||||
w = io.MultiWriter(append([]io.Writer{jack}, extraWriters...)...)
|
||||
}
|
||||
logger = zerolog.New(w).With().Timestamp().Caller().Logger()
|
||||
logger = zerolog.New(jack).With().Timestamp().Caller().Logger()
|
||||
} else {
|
||||
fi, err := os.Stderr.Stat()
|
||||
isTTY := err == nil && (fi.Mode()&os.ModeCharDevice) != 0
|
||||
if isTTY {
|
||||
// Dev mode: colored console (extra writers get JSON, console gets pretty)
|
||||
// Dev mode: colored console
|
||||
cw := zerolog.ConsoleWriter{
|
||||
Out: os.Stderr,
|
||||
TimeFormat: time.RFC3339,
|
||||
}
|
||||
var w io.Writer = cw
|
||||
if len(extraWriters) > 0 {
|
||||
w = io.MultiWriter(append([]io.Writer{cw}, extraWriters...)...)
|
||||
}
|
||||
logger = zerolog.New(w).With().Timestamp().Caller().Logger()
|
||||
logger = zerolog.New(cw).With().Timestamp().Caller().Logger()
|
||||
} else {
|
||||
// Systemd journal: JSON to stderr
|
||||
var w io.Writer = os.Stderr
|
||||
if len(extraWriters) > 0 {
|
||||
w = io.MultiWriter(append([]io.Writer{os.Stderr}, extraWriters...)...)
|
||||
}
|
||||
logger = zerolog.New(w).With().Timestamp().Caller().Logger()
|
||||
logger = zerolog.New(os.Stderr).With().Timestamp().Caller().Logger()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,232 +0,0 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
)
|
||||
|
||||
func TestRedactHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
headers http.Header
|
||||
check func(t *testing.T, result string)
|
||||
}{
|
||||
{
|
||||
name: "redacts Authorization",
|
||||
headers: http.Header{
|
||||
"Authorization": []string{"Bearer secret-token"},
|
||||
},
|
||||
check: func(t *testing.T, result string) {
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal([]byte(result), &m); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if m["Authorization"] != "***" {
|
||||
t.Errorf("Authorization = %q, want ***", m["Authorization"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "redacts x-api-key",
|
||||
headers: http.Header{
|
||||
"X-Api-Key": []string{"sk-ant-secret"},
|
||||
},
|
||||
check: func(t *testing.T, result string) {
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal([]byte(result), &m); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if m["X-Api-Key"] != "***" {
|
||||
t.Errorf("X-Api-Key = %q, want ***", m["X-Api-Key"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "preserves other headers",
|
||||
headers: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"Accept": []string{"text/html", "application/json"},
|
||||
},
|
||||
check: func(t *testing.T, result string) {
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal([]byte(result), &m); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if m["Content-Type"] != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want application/json", m["Content-Type"])
|
||||
}
|
||||
if m["Accept"] != "text/html, application/json" {
|
||||
t.Errorf("Accept = %q, want 'text/html, application/json'", m["Accept"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "case-insensitive redaction",
|
||||
headers: http.Header{
|
||||
"authorization": []string{"Bearer token"},
|
||||
"X-API-KEY": []string{"key123"},
|
||||
},
|
||||
check: func(t *testing.T, result string) {
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal([]byte(result), &m); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
// http.Header canonicalizes keys, but RedactHeaders lowercases for comparison
|
||||
for _, v := range m {
|
||||
if v != "***" {
|
||||
t.Errorf("expected all values to be ***, got %q", v)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty headers",
|
||||
headers: http.Header{},
|
||||
check: func(t *testing.T, result string) {
|
||||
if result != "{}" {
|
||||
t.Errorf("result = %q, want {}", result)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mixed sensitive and non-sensitive",
|
||||
headers: http.Header{
|
||||
"Authorization": []string{"Bearer tok"},
|
||||
"X-Api-Key": []string{"key"},
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-Request-Id": []string{"abc123"},
|
||||
},
|
||||
check: func(t *testing.T, result string) {
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal([]byte(result), &m); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if m["Authorization"] != "***" {
|
||||
t.Errorf("Authorization = %q, want ***", m["Authorization"])
|
||||
}
|
||||
if m["X-Api-Key"] != "***" {
|
||||
t.Errorf("X-Api-Key = %q, want ***", m["X-Api-Key"])
|
||||
}
|
||||
if m["Content-Type"] != "application/json" {
|
||||
t.Errorf("Content-Type = %q", m["Content-Type"])
|
||||
}
|
||||
if m["X-Request-Id"] != "abc123" {
|
||||
t.Errorf("X-Request-Id = %q", m["X-Request-Id"])
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := RedactHeaders(tt.headers)
|
||||
// Result should be valid JSON
|
||||
if !json.Valid([]byte(result)) {
|
||||
t.Fatalf("result is not valid JSON: %q", result)
|
||||
}
|
||||
tt.check(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactHeaders_ReturnsJSON(t *testing.T) {
|
||||
h := http.Header{"Foo": []string{"bar"}}
|
||||
result := RedactHeaders(h)
|
||||
if !strings.HasPrefix(result, "{") || !strings.HasSuffix(result, "}") {
|
||||
t.Errorf("result not JSON object: %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
status int
|
||||
want zerolog.Level
|
||||
}{
|
||||
{200, zerolog.InfoLevel},
|
||||
{201, zerolog.InfoLevel},
|
||||
{204, zerolog.InfoLevel},
|
||||
{301, zerolog.InfoLevel},
|
||||
{399, zerolog.InfoLevel},
|
||||
{400, zerolog.WarnLevel},
|
||||
{401, zerolog.WarnLevel},
|
||||
{403, zerolog.WarnLevel},
|
||||
{404, zerolog.WarnLevel},
|
||||
{429, zerolog.WarnLevel},
|
||||
{499, zerolog.WarnLevel},
|
||||
{500, zerolog.ErrorLevel},
|
||||
{502, zerolog.ErrorLevel},
|
||||
{503, zerolog.ErrorLevel},
|
||||
{599, zerolog.ErrorLevel},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := statusLevel(tt.status)
|
||||
if got != tt.want {
|
||||
t.Errorf("statusLevel(%d) = %v, want %v", tt.status, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetup_WithFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
logFile := filepath.Join(dir, "test.log")
|
||||
|
||||
logger := Setup(config.LoggingConfig{
|
||||
Level: "debug",
|
||||
File: logFile,
|
||||
MaxSizeMB: 10,
|
||||
MaxBackups: 1,
|
||||
MaxAgeDays: 1,
|
||||
})
|
||||
|
||||
// Verify logger works (no panic)
|
||||
logger.Info().Msg("test message")
|
||||
}
|
||||
|
||||
func TestSetup_WithoutFile(t *testing.T) {
|
||||
// File empty — should use console or stderr mode depending on TTY
|
||||
logger := Setup(config.LoggingConfig{
|
||||
Level: "warn",
|
||||
})
|
||||
|
||||
// Verify logger works (no panic)
|
||||
logger.Warn().Msg("test warning")
|
||||
}
|
||||
|
||||
func TestSetup_DefaultLevel(t *testing.T) {
|
||||
// Empty level should default to info
|
||||
logger := Setup(config.LoggingConfig{})
|
||||
_ = logger // verify no panic
|
||||
}
|
||||
|
||||
func TestSetup_InvalidLevel(t *testing.T) {
|
||||
// Invalid level should default to info
|
||||
logger := Setup(config.LoggingConfig{Level: "not-a-level"})
|
||||
_ = logger // verify no panic
|
||||
}
|
||||
|
||||
func TestFromContext_NoLogger(t *testing.T) {
|
||||
// Background context has no zerolog logger — should return global
|
||||
ctx := context.Background()
|
||||
l := FromContext(ctx)
|
||||
if l == nil {
|
||||
t.Fatal("FromContext returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromContext_WithLogger(t *testing.T) {
|
||||
logger := zerolog.Nop()
|
||||
ctx := logger.WithContext(context.Background())
|
||||
l := FromContext(ctx)
|
||||
if l == nil {
|
||||
t.Fatal("FromContext returned nil")
|
||||
}
|
||||
}
|
||||
@@ -11,13 +11,9 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// fingerprintSalt is the fixed salt used by Claude Code for billing header
|
||||
// fingerprint computation. Extracted from the Claude Code CLI source.
|
||||
const fingerprintSalt = "59cf53e54c78"
|
||||
|
||||
func computeFingerprint(firstUserMessage string, version string) string {
|
||||
// UTF-16 character indices sampled from the first user message, matching
|
||||
// the Claude Code CLI's fingerprinting algorithm.
|
||||
indices := []int{4, 7, 20}
|
||||
runes := utf16.Encode([]rune(firstUserMessage))
|
||||
var chars string
|
||||
|
||||
@@ -1,323 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFingerprintSaltConstant(t *testing.T) {
|
||||
if fingerprintSalt != "59cf53e54c78" {
|
||||
t.Errorf("fingerprintSalt = %q, want %q", fingerprintSalt, "59cf53e54c78")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_Deterministic(t *testing.T) {
|
||||
a := computeFingerprint("hello world test message", "1.0.0")
|
||||
b := computeFingerprint("hello world test message", "1.0.0")
|
||||
if a != b {
|
||||
t.Errorf("fingerprint not deterministic: %q != %q", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_Length(t *testing.T) {
|
||||
fp := computeFingerprint("some message here", "2.0.0")
|
||||
if len(fp) != 3 {
|
||||
t.Errorf("fingerprint length = %d, want 3", len(fp))
|
||||
}
|
||||
// Must be valid hex
|
||||
if _, err := hex.DecodeString(fp + "0"); err != nil { // pad to even length for decode
|
||||
// Check each char is hex individually
|
||||
for _, c := range fp {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||
t.Errorf("fingerprint %q contains non-hex char %c", fp, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_DifferentVersions(t *testing.T) {
|
||||
a := computeFingerprint("same message", "1.0.0")
|
||||
b := computeFingerprint("same message", "2.0.0")
|
||||
if a == b {
|
||||
t.Errorf("different versions should (almost certainly) produce different fingerprints")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_ShortMessage(t *testing.T) {
|
||||
// "hi" has only 2 chars, indices [4,7,20] all out of range → chars = "000"
|
||||
fp := computeFingerprint("hi", "1.0.0")
|
||||
if len(fp) != 3 {
|
||||
t.Errorf("short message fingerprint length = %d, want 3", len(fp))
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_EmptyMessage(t *testing.T) {
|
||||
// Empty → all indices out of range → chars = "000"
|
||||
fp := computeFingerprint("", "1.0.0")
|
||||
if len(fp) != 3 {
|
||||
t.Errorf("empty message fingerprint length = %d, want 3", len(fp))
|
||||
}
|
||||
// Empty and short message with same version should produce same fingerprint
|
||||
// since both result in chars = "000"
|
||||
fpShort := computeFingerprint("hi", "1.0.0")
|
||||
if fp != fpShort {
|
||||
t.Errorf("empty and 'hi' should produce same fingerprint (both use '000'), got %q vs %q", fp, fpShort)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_Unicode(t *testing.T) {
|
||||
// Emoji: 🎉 is U+1F389, encoded as UTF-16 surrogate pair [0xD83C, 0xDF89]
|
||||
// So "abcd🎉fg" in UTF-16 is [a, b, c, d, 0xD83C, 0xDF89, f, g] = 8 uint16 values
|
||||
// indices [4,7,20]: runes[4]=0xD83C, runes[7]='g', runes[20]=out of range
|
||||
fp := computeFingerprint("abcd🎉fg", "1.0.0")
|
||||
if len(fp) != 3 {
|
||||
t.Errorf("unicode fingerprint length = %d, want 3", len(fp))
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_CharExtraction(t *testing.T) {
|
||||
// "Hello, World!" UTF-16: [H,e,l,l,o,',', ,W,o,r,l,d,!]
|
||||
// indices [4,7,20]: runes[4]='o', runes[7]='W', runes[20]=out of range → "0"
|
||||
// So chars should be "oW0"
|
||||
// Verify by comparing to a message where we know the expected extracted chars
|
||||
// Two messages that extract same chars at indices should produce same fingerprint
|
||||
// "xxxxoxxWxxxxxxxxxxxx" → index 4='o', 7='W', 20=out of range → "oW0" (20 chars, index 20 out of range)
|
||||
fp1 := computeFingerprint("Hello, World!", "1.0.0")
|
||||
fp2 := computeFingerprint("xxxxoxxWxxxxxxxxxxxx", "1.0.0")
|
||||
if fp1 != fp2 {
|
||||
t.Errorf("messages with same chars at indices [4,7,20] should produce same fingerprint, got %q vs %q", fp1, fp2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_IndexBoundary(t *testing.T) {
|
||||
// Message with exactly 21 chars → index 20 is valid
|
||||
msg21 := "abcdefghijklmnopqrstu" // 21 chars
|
||||
fp21 := computeFingerprint(msg21, "1.0.0")
|
||||
|
||||
// Message with exactly 20 chars → index 20 is out of range → "0"
|
||||
msg20 := "abcdefghijklmnopqrst" // 20 chars
|
||||
fp20 := computeFingerprint(msg20, "1.0.0")
|
||||
|
||||
// They should differ because index 20 produces different chars
|
||||
if fp21 == fp20 {
|
||||
t.Errorf("boundary test: 21-char and 20-char messages should differ at index 20")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFirstUserMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple string content",
|
||||
body: `{"messages":[{"role":"user","content":"hello world"}]}`,
|
||||
expected: "hello world",
|
||||
},
|
||||
{
|
||||
name: "array content with text block",
|
||||
body: `{"messages":[{"role":"user","content":[{"type":"text","text":"from array"}]}]}`,
|
||||
expected: "from array",
|
||||
},
|
||||
{
|
||||
name: "no user messages",
|
||||
body: `{"messages":[{"role":"assistant","content":"I am assistant"}]}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "assistant only messages",
|
||||
body: `{"messages":[{"role":"assistant","content":"a1"},{"role":"assistant","content":"a2"}]}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "user with non-text block first then text",
|
||||
body: `{"messages":[{"role":"user","content":[{"type":"image","source":"x"},{"type":"text","text":"the text"}]}]}`,
|
||||
expected: "the text",
|
||||
},
|
||||
{
|
||||
name: "user with only non-text blocks",
|
||||
body: `{"messages":[{"role":"user","content":[{"type":"image","source":"x"}]}]}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "no messages field",
|
||||
body: `{"model":"claude-sonnet-4-6"}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "messages not array",
|
||||
body: `{"messages":"not array"}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty messages array",
|
||||
body: `{"messages":[]}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "first user message used even if multiple exist",
|
||||
body: `{"messages":[{"role":"user","content":"first"},{"role":"user","content":"second"}]}`,
|
||||
expected: "first",
|
||||
},
|
||||
{
|
||||
name: "assistant before user",
|
||||
body: `{"messages":[{"role":"assistant","content":"assistant msg"},{"role":"user","content":"user msg"}]}`,
|
||||
expected: "user msg",
|
||||
},
|
||||
{
|
||||
name: "user with array content - first text block used",
|
||||
body: `{"messages":[{"role":"user","content":[{"type":"text","text":"first text"},{"type":"text","text":"second text"}]}]}`,
|
||||
expected: "first text",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractFirstUserMessage([]byte(tt.body))
|
||||
if result != tt.expected {
|
||||
t.Errorf("got %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFirstUserMessage_BreaksAfterFirstUser(t *testing.T) {
|
||||
// The function should break after finding the first user message,
|
||||
// even if it didn't extract text (e.g. user with only image blocks)
|
||||
body := `{"messages":[{"role":"user","content":[{"type":"image","source":"x"}]},{"role":"user","content":"second user"}]}`
|
||||
result := extractFirstUserMessage([]byte(body))
|
||||
// First user has no text blocks, function breaks, returns ""
|
||||
if result != "" {
|
||||
t.Errorf("should return empty when first user has no text, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBillingHeader(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"test message"}]}`)
|
||||
version := "1.2.3"
|
||||
header := buildBillingHeader(body, version)
|
||||
|
||||
// Check format
|
||||
if !strings.HasPrefix(header, "x-anthropic-billing-header: cc_version=1.2.3.") {
|
||||
t.Errorf("header should start with 'x-anthropic-billing-header: cc_version=1.2.3.', got %q", header)
|
||||
}
|
||||
if !strings.Contains(header, "; cc_entrypoint=cli; cch=00000;") {
|
||||
t.Errorf("header should contain '; cc_entrypoint=cli; cch=00000;', got %q", header)
|
||||
}
|
||||
|
||||
// Verify the fingerprint part is 3 chars
|
||||
// Format: "x-anthropic-billing-header: cc_version=1.2.3.XXX; cc_entrypoint=cli; cch=00000;"
|
||||
parts := strings.Split(header, "cc_version=")
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("unexpected header format: %q", header)
|
||||
}
|
||||
versionFP := strings.Split(parts[1], ";")[0]
|
||||
if !strings.HasPrefix(versionFP, "1.2.3.") {
|
||||
t.Errorf("version+fingerprint should start with '1.2.3.', got %q", versionFP)
|
||||
}
|
||||
fp := strings.TrimPrefix(versionFP, "1.2.3.")
|
||||
if len(fp) != 3 {
|
||||
t.Errorf("fingerprint should be 3 chars, got %q (len %d)", fp, len(fp))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBillingHeader_EmptyMessages(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
version := "1.0.0"
|
||||
header := buildBillingHeader(body, version)
|
||||
if !strings.HasPrefix(header, "x-anthropic-billing-header: cc_version=") {
|
||||
t.Errorf("header format wrong: %q", header)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectBillingHeader_NoExistingSystem(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := injectBillingHeader(body, "1.0.0")
|
||||
resultStr := string(result)
|
||||
|
||||
// Should have system field now
|
||||
if !strings.Contains(resultStr, `"system"`) {
|
||||
t.Errorf("should inject system field, got %s", resultStr)
|
||||
}
|
||||
// System should be an array with one billing block
|
||||
if !strings.Contains(resultStr, "x-anthropic-billing-header") {
|
||||
t.Errorf("should contain billing header text, got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, `"type":"text"`) {
|
||||
t.Errorf("billing block should have type text, got %s", resultStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectBillingHeader_ExistingSystemArray(t *testing.T) {
|
||||
body := []byte(`{"system":[{"type":"text","text":"existing prompt"}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := injectBillingHeader(body, "1.0.0")
|
||||
resultStr := string(result)
|
||||
|
||||
// Should contain both billing header and existing prompt
|
||||
if !strings.Contains(resultStr, "x-anthropic-billing-header") {
|
||||
t.Errorf("should contain billing header, got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, "existing prompt") {
|
||||
t.Errorf("should preserve existing prompt, got %s", resultStr)
|
||||
}
|
||||
|
||||
// Billing block should be FIRST (prepended)
|
||||
billingIdx := strings.Index(resultStr, "x-anthropic-billing-header")
|
||||
existingIdx := strings.Index(resultStr, "existing prompt")
|
||||
if billingIdx > existingIdx {
|
||||
t.Errorf("billing block should come before existing prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectBillingHeader_ExistingSystemString(t *testing.T) {
|
||||
body := []byte(`{"system":"You are a helpful assistant","messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := injectBillingHeader(body, "1.0.0")
|
||||
resultStr := string(result)
|
||||
|
||||
// Should convert to array with billing block first, then original text
|
||||
if !strings.Contains(resultStr, "x-anthropic-billing-header") {
|
||||
t.Errorf("should contain billing header, got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, "You are a helpful assistant") {
|
||||
t.Errorf("should preserve original system string, got %s", resultStr)
|
||||
}
|
||||
|
||||
// Billing should come first
|
||||
billingIdx := strings.Index(resultStr, "x-anthropic-billing-header")
|
||||
origIdx := strings.Index(resultStr, "You are a helpful assistant")
|
||||
if billingIdx > origIdx {
|
||||
t.Errorf("billing block should come before original system text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectBillingHeader_PreservesOtherFields(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`)
|
||||
result := injectBillingHeader(body, "1.0.0")
|
||||
resultStr := string(result)
|
||||
|
||||
if !strings.Contains(resultStr, `"model":"claude-sonnet-4-6"`) {
|
||||
t.Errorf("should preserve model field, got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, `"max_tokens":1024`) {
|
||||
t.Errorf("should preserve max_tokens field, got %s", resultStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectBillingHeader_BillingBlockFormat(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"test"}]}`)
|
||||
result := injectBillingHeader(body, "2.5.0")
|
||||
resultStr := string(result)
|
||||
|
||||
// Verify the billing block contains the correct version
|
||||
if !strings.Contains(resultStr, "cc_version=2.5.0.") {
|
||||
t.Errorf("billing block should contain cc_version=2.5.0., got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, "cc_entrypoint=cli") {
|
||||
t.Errorf("billing block should contain cc_entrypoint=cli, got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, "cch=00000") {
|
||||
t.Errorf("billing block should contain cch=00000, got %s", resultStr)
|
||||
}
|
||||
}
|
||||
+67
-166
@@ -2,7 +2,6 @@ package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
@@ -10,25 +9,12 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/auth"
|
||||
"github.com/fujin/anthropic-proxy/internal/logging"
|
||||
"github.com/fujin/anthropic-proxy/internal/ratelimit"
|
||||
"github.com/fujin/anthropic-proxy/internal/telemetry"
|
||||
)
|
||||
|
||||
// requestInfo bundles common request context passed to logging/telemetry helpers.
|
||||
type requestInfo struct {
|
||||
model string
|
||||
stream bool
|
||||
cred *auth.Credential
|
||||
body []byte
|
||||
originalBody []byte
|
||||
}
|
||||
|
||||
func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func() *Sanitizer, tracker *ratelimit.Tracker) gin.HandlerFunc {
|
||||
func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func() *Sanitizer) gin.HandlerFunc {
|
||||
upstream := NewUpstreamClient(profile)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
@@ -60,56 +46,57 @@ func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func(
|
||||
isStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
if isStream {
|
||||
handleStream(c, upstream, san, pool, cred, body, originalBody, tracker)
|
||||
handleStream(c, upstream, san, pool, cred, body, originalBody)
|
||||
} else {
|
||||
handleNonStream(c, upstream, san, pool, cred, body, originalBody, tracker)
|
||||
handleNonStream(c, upstream, san, pool, cred, body, originalBody)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte, originalBody []byte, tracker *ratelimit.Tracker) {
|
||||
func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte, originalBody []byte) {
|
||||
startTime := time.Now()
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
ctx := c.Request.Context()
|
||||
ri := requestInfo{model: model, stream: false, cred: cred, body: body, originalBody: originalBody}
|
||||
|
||||
telemetry.RequestBodySize.Record(ctx, int64(len(body)),
|
||||
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", false)))
|
||||
|
||||
respBody, headers, statusCode, err := upstream.Execute(ctx, cred, body)
|
||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||
|
||||
respBody, headers, statusCode, err := upstream.Execute(c.Request.Context(), cred, body)
|
||||
if err != nil {
|
||||
recordConnectionError(ctx, err, ri, latencyMs)
|
||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("credential", cred.Email).
|
||||
Str("model", model).
|
||||
Bool("stream", false).
|
||||
Str("request_body_original", string(originalBody)).
|
||||
Str("request_body_sanitized", string(body)).
|
||||
Int("request_body_size", len(body)).
|
||||
Float64("latency_ms", latencyMs).
|
||||
Msg("upstream connection error")
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "upstream request failed"})
|
||||
return
|
||||
}
|
||||
|
||||
recordRequestMetrics(ctx, ri, statusCode, latencyMs)
|
||||
|
||||
if statusCode >= 400 {
|
||||
pool.MarkFailure(cred, statusCode)
|
||||
telemetry.CredentialCooldowns.Add(ctx, 1,
|
||||
metric.WithAttributes(attribute.Int("status_code", statusCode)))
|
||||
recordUpstreamError(ctx, statusCode, respBody, headers.Get("X-Request-Id"), latencyMs, ri, c.Request.Header)
|
||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
errorType := gjson.GetBytes(respBody, "error.type").String()
|
||||
errorMessage := gjson.GetBytes(respBody, "error.message").String()
|
||||
log.Error().
|
||||
Int("status", statusCode).
|
||||
Str("error_type", errorType).
|
||||
Str("error_message", errorMessage).
|
||||
Str("response_body", string(respBody)).
|
||||
Str("request_id", headers.Get("X-Request-Id")).
|
||||
Float64("latency_ms", latencyMs).
|
||||
Str("credential", cred.Email).
|
||||
Str("model", model).
|
||||
Bool("stream", false).
|
||||
Str("request_body_original", string(originalBody)).
|
||||
Str("request_body_sanitized", string(body)).
|
||||
Int("request_body_size", len(body)).
|
||||
Str("request_headers", logging.RedactHeaders(c.Request.Header)).
|
||||
Msg("upstream error")
|
||||
} else {
|
||||
pool.MarkSuccess(cred)
|
||||
respBody = san.DesanitizeResponse(respBody)
|
||||
|
||||
inputTokens := gjson.GetBytes(respBody, "usage.input_tokens").Int()
|
||||
outputTokens := gjson.GetBytes(respBody, "usage.output_tokens").Int()
|
||||
recordTokenUsage(ctx, model, cred, inputTokens, outputTokens)
|
||||
if tracker != nil {
|
||||
tracker.UpdateFromHeaders(headers)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Int("status", statusCode).
|
||||
Float64("latency_ms", latencyMs).
|
||||
Str("model", model).
|
||||
Int64("input_tokens", inputTokens).
|
||||
Int64("output_tokens", outputTokens).
|
||||
Msg("request completed")
|
||||
}
|
||||
|
||||
for _, h := range []string{"Content-Type", "X-Request-Id"} {
|
||||
@@ -121,20 +108,22 @@ func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, p
|
||||
c.Data(statusCode, headers.Get("Content-Type"), respBody)
|
||||
}
|
||||
|
||||
func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte, originalBody []byte, tracker *ratelimit.Tracker) {
|
||||
func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool *auth.Pool, cred *auth.Credential, body []byte, originalBody []byte) {
|
||||
startTime := time.Now()
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
ctx := c.Request.Context()
|
||||
ri := requestInfo{model: model, stream: true, cred: cred, body: body, originalBody: originalBody}
|
||||
|
||||
telemetry.StreamRequests.Add(ctx, 1, metric.WithAttributes(attribute.String("model", model)))
|
||||
telemetry.RequestBodySize.Record(ctx, int64(len(body)),
|
||||
metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", true)))
|
||||
|
||||
resp, err := upstream.ExecuteStream(ctx, cred, body)
|
||||
resp, err := upstream.ExecuteStream(c.Request.Context(), cred, body)
|
||||
if err != nil {
|
||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||
recordConnectionError(ctx, err, ri, latencyMs)
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("credential", cred.Email).
|
||||
Str("model", model).
|
||||
Bool("stream", true).
|
||||
Str("request_body_original", string(originalBody)).
|
||||
Str("request_body_sanitized", string(body)).
|
||||
Int("request_body_size", len(body)).
|
||||
Float64("latency_ms", latencyMs).
|
||||
Msg("upstream connection error")
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "upstream stream request failed"})
|
||||
return
|
||||
}
|
||||
@@ -142,13 +131,26 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
pool.MarkFailure(cred, resp.StatusCode)
|
||||
telemetry.CredentialCooldowns.Add(ctx, 1,
|
||||
metric.WithAttributes(attribute.Int("status_code", resp.StatusCode)))
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||
recordRequestMetrics(ctx, ri, resp.StatusCode, latencyMs)
|
||||
recordUpstreamError(ctx, resp.StatusCode, respBody, resp.Header.Get("X-Request-Id"), latencyMs, ri, c.Request.Header)
|
||||
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
errorType := gjson.GetBytes(respBody, "error.type").String()
|
||||
errorMessage := gjson.GetBytes(respBody, "error.message").String()
|
||||
log.Error().
|
||||
Int("status", resp.StatusCode).
|
||||
Str("error_type", errorType).
|
||||
Str("error_message", errorMessage).
|
||||
Str("response_body", string(respBody)).
|
||||
Str("request_id", resp.Header.Get("X-Request-Id")).
|
||||
Float64("latency_ms", latencyMs).
|
||||
Str("credential", cred.Email).
|
||||
Str("model", model).
|
||||
Bool("stream", true).
|
||||
Str("request_body_original", string(originalBody)).
|
||||
Str("request_body_sanitized", string(body)).
|
||||
Int("request_body_size", len(body)).
|
||||
Str("request_headers", logging.RedactHeaders(c.Request.Header)).
|
||||
Msg("upstream error")
|
||||
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody)
|
||||
return
|
||||
}
|
||||
@@ -167,116 +169,15 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool
|
||||
return
|
||||
}
|
||||
|
||||
var inputTokens, outputTokens int64
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := san.DesanitizeStreamEvent(scanner.Text())
|
||||
c.Writer.WriteString(line + "\n")
|
||||
flusher.Flush()
|
||||
|
||||
if len(line) > 5 && line[:5] == "data:" {
|
||||
data := line[5:]
|
||||
eventType := gjson.Get(data, "type").String()
|
||||
switch eventType {
|
||||
case "message_start":
|
||||
inputTokens = gjson.Get(data, "message.usage.input_tokens").Int()
|
||||
case "message_delta":
|
||||
outputTokens = gjson.Get(data, "usage.output_tokens").Int()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
latencyMs := float64(time.Since(startTime).Milliseconds())
|
||||
recordRequestMetrics(ctx, ri, http.StatusOK, latencyMs)
|
||||
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
recordTokenUsage(ctx, model, cred, inputTokens, outputTokens)
|
||||
if tracker != nil {
|
||||
tracker.UpdateFromHeaders(resp.Header)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Float64("latency_ms", latencyMs).
|
||||
Str("model", model).
|
||||
Bool("stream", true).
|
||||
Int64("input_tokens", inputTokens).
|
||||
Int64("output_tokens", outputTokens).
|
||||
Msg("stream completed")
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Error().Err(err).Msg("stream scan error")
|
||||
}
|
||||
}
|
||||
|
||||
// recordConnectionError logs and records metrics for upstream connection failures.
|
||||
func recordConnectionError(ctx context.Context, err error, ri requestInfo, latencyMs float64) {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("credential", ri.cred.Email).
|
||||
Str("model", ri.model).
|
||||
Bool("stream", ri.stream).
|
||||
Str("request_body_original", string(ri.originalBody)).
|
||||
Str("request_body_sanitized", string(ri.body)).
|
||||
Int("request_body_size", len(ri.body)).
|
||||
Float64("latency_ms", latencyMs).
|
||||
Msg("upstream connection error")
|
||||
|
||||
telemetry.UpstreamErrors.Add(ctx, 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("error_type", "connection"),
|
||||
attribute.String("credential", ri.cred.Email),
|
||||
attribute.Int("status_code", http.StatusBadGateway),
|
||||
))
|
||||
recordRequestMetrics(ctx, ri, http.StatusBadGateway, latencyMs)
|
||||
}
|
||||
|
||||
// recordUpstreamError logs and records metrics for upstream HTTP error responses.
|
||||
func recordUpstreamError(ctx context.Context, statusCode int, respBody []byte, requestID string, latencyMs float64, ri requestInfo, requestHeaders http.Header) {
|
||||
errorType := gjson.GetBytes(respBody, "error.type").String()
|
||||
errorMessage := gjson.GetBytes(respBody, "error.message").String()
|
||||
log.Error().
|
||||
Int("status", statusCode).
|
||||
Str("error_type", errorType).
|
||||
Str("error_message", errorMessage).
|
||||
Str("response_body", string(respBody)).
|
||||
Str("request_id", requestID).
|
||||
Float64("latency_ms", latencyMs).
|
||||
Str("credential", ri.cred.Email).
|
||||
Str("model", ri.model).
|
||||
Bool("stream", ri.stream).
|
||||
Str("request_body_original", string(ri.originalBody)).
|
||||
Str("request_body_sanitized", string(ri.body)).
|
||||
Int("request_body_size", len(ri.body)).
|
||||
Str("request_headers", logging.RedactHeaders(requestHeaders)).
|
||||
Msg("upstream error")
|
||||
|
||||
telemetry.UpstreamErrors.Add(ctx, 1,
|
||||
metric.WithAttributes(
|
||||
attribute.Int("status_code", statusCode),
|
||||
attribute.String("error_type", errorType),
|
||||
attribute.String("credential", ri.cred.Email),
|
||||
))
|
||||
}
|
||||
|
||||
// recordRequestMetrics records the request counter and duration histogram.
|
||||
func recordRequestMetrics(ctx context.Context, ri requestInfo, statusCode int, latencyMs float64) {
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("model", ri.model),
|
||||
attribute.Bool("stream", ri.stream),
|
||||
attribute.Int("status_code", statusCode),
|
||||
}
|
||||
telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...))
|
||||
telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...))
|
||||
}
|
||||
|
||||
// recordTokenUsage records token consumption metrics.
|
||||
func recordTokenUsage(ctx context.Context, model string, cred *auth.Credential, inputTokens, outputTokens int64) {
|
||||
tokenAttrs := metric.WithAttributes(
|
||||
attribute.String("model", model),
|
||||
attribute.String("credential", cred.Email),
|
||||
)
|
||||
telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs)
|
||||
telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs)
|
||||
}
|
||||
|
||||
@@ -1,624 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/auth"
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
"github.com/fujin/anthropic-proxy/internal/ratelimit"
|
||||
"github.com/fujin/anthropic-proxy/internal/telemetry"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
// Initialize telemetry with noop meter to avoid nil pointer panics.
|
||||
meter := noop.Meter{}
|
||||
telemetry.InitMetrics(meter, nil)
|
||||
}
|
||||
|
||||
// --- Request body reading and sanitization ---
|
||||
|
||||
func TestHandleMessages_ReadBodyError(t *testing.T) {
|
||||
// A body that immediately fails on read shouldn't panic.
|
||||
pool := auth.NewPool([]*auth.Credential{{ID: "c1", AccessToken: "tok", Email: "test@test.com"}})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", &errReader{})
|
||||
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "failed to read request body") {
|
||||
t.Errorf("body = %q, expected error message about reading body", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleMessages_SanitizesRequestBody(t *testing.T) {
|
||||
// We can't directly make HandleMessages use our mock server because
|
||||
// UpstreamClient hardcodes messagesURL. Instead, we test sanitization
|
||||
// by verifying the sanitizer is called on the body before any pool interaction.
|
||||
san := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "my_tool", To: "renamed_tool"}},
|
||||
Body: []config.ReplaceRule{{Match: "secret", Replace: "redacted"}},
|
||||
})
|
||||
|
||||
// Create body with tool name and secret
|
||||
body := `{"model":"claude-sonnet-4-6","tools":[{"name":"my_tool"}],"messages":[{"role":"user","content":"secret data"}]}`
|
||||
sanitizedBody := san.SanitizeRequest([]byte(body))
|
||||
|
||||
// Verify sanitization happened correctly
|
||||
if !strings.Contains(string(sanitizedBody), "renamed_tool") {
|
||||
t.Error("expected tool to be renamed in sanitized body")
|
||||
}
|
||||
if strings.Contains(string(sanitizedBody), "my_tool") {
|
||||
t.Error("original tool name should be gone after sanitization")
|
||||
}
|
||||
if !strings.Contains(string(sanitizedBody), "redacted") {
|
||||
t.Error("expected 'secret' to be replaced with 'redacted'")
|
||||
}
|
||||
if strings.Contains(string(sanitizedBody), "secret") {
|
||||
t.Error("'secret' should be gone after sanitization")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleMessages_PoolPickError(t *testing.T) {
|
||||
// Empty pool — Pick() will fail.
|
||||
pool := auth.NewPool(nil)
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
|
||||
|
||||
body := `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "no credentials available") {
|
||||
t.Errorf("body = %q, expected pool error", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleMessages_PoolAllOnCooldown(t *testing.T) {
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "e"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
pool.MarkFailure(cred, 429) // puts on 30s cooldown
|
||||
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
|
||||
|
||||
body := `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "cooldown") {
|
||||
t.Errorf("body = %q, expected cooldown message", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// --- Stream vs non-stream routing ---
|
||||
|
||||
func TestHandleMessages_StreamField_Detection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
isStream bool
|
||||
}{
|
||||
{
|
||||
name: "stream true",
|
||||
body: `{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`,
|
||||
isStream: true,
|
||||
},
|
||||
{
|
||||
name: "stream false",
|
||||
body: `{"model":"claude-sonnet-4-6","stream":false,"messages":[{"role":"user","content":"hi"}]}`,
|
||||
isStream: false,
|
||||
},
|
||||
{
|
||||
name: "no stream field defaults to false",
|
||||
body: `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`,
|
||||
isStream: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := gjson.Get(tt.body, "stream").Bool()
|
||||
if got != tt.isStream {
|
||||
t.Errorf("stream = %v, want %v", got, tt.isStream)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Desanitization on response ---
|
||||
|
||||
func TestDesanitization_NonStreamResponse(t *testing.T) {
|
||||
san := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
|
||||
})
|
||||
|
||||
// Simulate upstream response with renamed tool
|
||||
upstreamResponse := `{"content":[{"type":"tool_use","name":"renamed_tool","id":"t1","input":{}}]}`
|
||||
desanitized := san.DesanitizeResponse([]byte(upstreamResponse))
|
||||
|
||||
if !strings.Contains(string(desanitized), "original_tool") {
|
||||
t.Errorf("expected tool name to be desanitized back to 'original_tool', got %s", string(desanitized))
|
||||
}
|
||||
if strings.Contains(string(desanitized), `"name":"renamed_tool"`) {
|
||||
t.Error("renamed_tool should have been replaced by original_tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDesanitization_StreamEvent(t *testing.T) {
|
||||
san := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
|
||||
})
|
||||
|
||||
event := `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed_tool","id":"t1"}}`
|
||||
desanitized := san.DesanitizeStreamEvent(event)
|
||||
|
||||
if !strings.Contains(desanitized, "original_tool") {
|
||||
t.Errorf("expected stream event to be desanitized, got %s", desanitized)
|
||||
}
|
||||
}
|
||||
|
||||
// --- handleNonStream behavior tests via direct function ---
|
||||
|
||||
func TestHandleNonStream_ConnectionError(t *testing.T) {
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
tracker := ratelimit.NewTracker(func() string { return "" })
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: http.Client{Transport: &failingTransport{}},
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleNonStream(c, uc, san, pool, cred, body, body, tracker)
|
||||
|
||||
if w.Code != http.StatusBadGateway {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusBadGateway)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "upstream request failed") {
|
||||
t.Errorf("body = %q, expected upstream error message", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStream_UpstreamSuccess(t *testing.T) {
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("X-Request-Id", "req-123")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"text","text":"hello"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
tracker := ratelimit.NewTracker(func() string { return "" })
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
// Override the messagesURL by constructing a custom Execute that uses the mock.
|
||||
// Since we can't override the const, we test via a mock server approach:
|
||||
// We create a custom http.Client with a transport that redirects to our mock.
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleNonStream(c, uc, san, pool, cred, body, body, tracker)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String())
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "hello") {
|
||||
t.Errorf("response body missing expected content: %s", w.Body.String())
|
||||
}
|
||||
if got := w.Header().Get("Content-Type"); got != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want %q", got, "application/json")
|
||||
}
|
||||
if got := w.Header().Get("X-Request-Id"); got != "req-123" {
|
||||
t.Errorf("X-Request-Id = %q, want %q", got, "req-123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStream_UpstreamError_MarkFailure(t *testing.T) {
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(429)
|
||||
w.Write([]byte(`{"error":{"type":"rate_limit_error","message":"too many requests"}}`))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleNonStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != 429 {
|
||||
t.Errorf("status = %d, want 429", w.Code)
|
||||
}
|
||||
|
||||
// Verify MarkFailure was called — cred should now be on cooldown
|
||||
if !cred.IsOnCooldown() {
|
||||
t.Error("expected credential to be on cooldown after 429")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStream_UpstreamSuccess_DesanitizesResponse(t *testing.T) {
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"tool_use","name":"renamed_tool","id":"t1","input":{}}],"model":"claude-sonnet-4-6","stop_reason":"tool_use","usage":{"input_tokens":10,"output_tokens":5}}`))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
|
||||
})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleNonStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String())
|
||||
}
|
||||
// Should be desanitized back to original_tool
|
||||
if !strings.Contains(w.Body.String(), "original_tool") {
|
||||
t.Errorf("response should contain desanitized tool name 'original_tool', got %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStream_Upstream500_MarkFailure(t *testing.T) {
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(`{"error":{"type":"server_error","message":"internal error"}}`))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleNonStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != 500 {
|
||||
t.Errorf("status = %d, want 500", w.Code)
|
||||
}
|
||||
if !cred.IsOnCooldown() {
|
||||
t.Error("expected credential to be on cooldown after 500")
|
||||
}
|
||||
}
|
||||
|
||||
// --- handleStream behavior tests ---
|
||||
|
||||
func TestHandleStream_ConnectionError(t *testing.T) {
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: http.Client{Transport: &failingTransport{}},
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != http.StatusBadGateway {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusBadGateway)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "upstream stream request failed") {
|
||||
t.Errorf("body = %q, expected upstream stream error", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStream_UpstreamError(t *testing.T) {
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(429)
|
||||
w.Write([]byte(`{"error":{"type":"rate_limit_error","message":"rate limited"}}`))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != 429 {
|
||||
t.Errorf("status = %d, want 429", w.Code)
|
||||
}
|
||||
if !cred.IsOnCooldown() {
|
||||
t.Error("expected credential on cooldown after stream 429")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStream_SuccessForwardsEvents(t *testing.T) {
|
||||
events := "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":0}}}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":5}}\n\nevent: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
|
||||
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(events))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
respBody := w.Body.String()
|
||||
if !strings.Contains(respBody, "message_start") {
|
||||
t.Error("response missing message_start event")
|
||||
}
|
||||
if !strings.Contains(respBody, "hello") {
|
||||
t.Error("response missing text content 'hello'")
|
||||
}
|
||||
if !strings.Contains(respBody, "message_stop") {
|
||||
t.Error("response missing message_stop event")
|
||||
}
|
||||
|
||||
// Verify SSE headers
|
||||
if got := w.Header().Get("Content-Type"); got != "text/event-stream" {
|
||||
t.Errorf("Content-Type = %q, want %q", got, "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStream_DesanitizesEvents(t *testing.T) {
|
||||
events := "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"name\":\"renamed_tool\",\"id\":\"t1\"}}\n\nevent: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"},\"usage\":{\"output_tokens\":5}}\n\nevent: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
|
||||
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(events))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
|
||||
})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String())
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "original_tool") {
|
||||
t.Errorf("stream response should contain desanitized 'original_tool', got %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// --- HandleMessages full integration wiring test ---
|
||||
|
||||
func TestHandleMessages_WiresHandlerCorrectly(t *testing.T) {
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
// Verify the handler can be created without panic
|
||||
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
|
||||
if handler == nil {
|
||||
t.Fatal("HandleMessages returned nil handler")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleMessages_EmptyBody(t *testing.T) {
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
|
||||
|
||||
// Empty body — handler should still try to pick cred and call upstream
|
||||
// (which will fail with connection error to api.anthropic.com, not a panic)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(""))
|
||||
|
||||
handler(c)
|
||||
|
||||
// Should get a 502 because the upstream URL (api.anthropic.com) won't be reachable
|
||||
// in test environment, or it might complete. The key thing is no panic.
|
||||
// We mainly verify it doesn't panic.
|
||||
if w.Code == 0 {
|
||||
t.Error("expected non-zero status code")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Test helpers ---
|
||||
|
||||
// errReader is an io.Reader that always returns an error.
|
||||
type errReader struct{}
|
||||
|
||||
func (e *errReader) Read([]byte) (int, error) {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
// failingTransport is an http.RoundTripper that always returns an error.
|
||||
type failingTransport struct{}
|
||||
|
||||
func (f *failingTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||
return nil, fmt.Errorf("connection refused")
|
||||
}
|
||||
|
||||
// rewriteTransport intercepts HTTP requests and rewrites the URL to point
|
||||
// at a local test server. This allows testing with UpstreamClient's hardcoded
|
||||
// messagesURL by redirecting all requests to a mock server.
|
||||
type rewriteTransport struct {
|
||||
base http.RoundTripper
|
||||
destURL string
|
||||
}
|
||||
|
||||
func (t *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the request URL to point at our mock server
|
||||
newReq := req.Clone(req.Context())
|
||||
newReq.URL.Scheme = "http"
|
||||
newReq.URL.Host = strings.TrimPrefix(t.destURL, "http://")
|
||||
newReq.URL.Path = "/v1/messages"
|
||||
newReq.URL.RawQuery = ""
|
||||
if t.base == nil {
|
||||
return http.DefaultTransport.RoundTrip(newReq)
|
||||
}
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
@@ -12,10 +11,10 @@ import (
|
||||
)
|
||||
|
||||
type Sanitizer struct {
|
||||
toolsForward map[string]string
|
||||
toolsReverse map[string]string
|
||||
systemRules []config.ReplaceRule
|
||||
bodyRules []config.ReplaceRule
|
||||
toolsForward map[string]string
|
||||
toolsReverse map[string]string
|
||||
systemRules []config.ReplaceRule
|
||||
bodyRules []config.ReplaceRule
|
||||
}
|
||||
|
||||
func NewSanitizer(cfg config.SanitizeConfig) *Sanitizer {
|
||||
@@ -50,11 +49,7 @@ func (s *Sanitizer) DesanitizeResponse(body []byte) []byte {
|
||||
}
|
||||
name := block.Get("name").String()
|
||||
if orig, ok := s.toolsReverse[name]; ok {
|
||||
if b, err := sjson.SetBytes(body, "content."+strconv.Itoa(i)+".name", orig); err != nil {
|
||||
log.Warn().Err(err).Str("tool", name).Msg("desanitize response: set name failed")
|
||||
} else {
|
||||
body = b
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "content."+strconv.Itoa(i)+".name", orig)
|
||||
}
|
||||
}
|
||||
return body
|
||||
@@ -69,12 +64,8 @@ func (s *Sanitizer) DesanitizeStreamEvent(line string) string {
|
||||
for _, path := range []string{"content_block.name", "delta.name"} {
|
||||
name := gjson.GetBytes(data, path).String()
|
||||
if orig, ok := s.toolsReverse[name]; ok {
|
||||
if b, err := sjson.SetBytes(data, path, orig); err != nil {
|
||||
log.Warn().Err(err).Str("tool", name).Msg("desanitize stream event: set name failed")
|
||||
} else {
|
||||
data = b
|
||||
changed = true
|
||||
}
|
||||
data, _ = sjson.SetBytes(data, path, orig)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if changed {
|
||||
@@ -94,11 +85,7 @@ func (s *Sanitizer) renameTools(body []byte) []byte {
|
||||
for i, tool := range tools.Array() {
|
||||
name := tool.Get("name").String()
|
||||
if newName, ok := s.toolsForward[name]; ok {
|
||||
if b, err := sjson.SetBytes(body, "tools."+strconv.Itoa(i)+".name", newName); err != nil {
|
||||
log.Warn().Err(err).Str("tool", name).Msg("rename tool failed")
|
||||
} else {
|
||||
body = b
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "tools."+strconv.Itoa(i)+".name", newName)
|
||||
}
|
||||
}
|
||||
return body
|
||||
@@ -117,11 +104,7 @@ func (s *Sanitizer) replaceSystem(body []byte) []byte {
|
||||
for _, rule := range s.systemRules {
|
||||
text = strings.ReplaceAll(text, rule.Match, rule.Replace)
|
||||
}
|
||||
if b, err := sjson.SetBytes(body, "system."+strconv.Itoa(i)+".text", text); err != nil {
|
||||
log.Warn().Err(err).Int("block", i).Msg("replace system text failed")
|
||||
} else {
|
||||
body = b
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "system."+strconv.Itoa(i)+".text", text)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
@@ -1,476 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
)
|
||||
|
||||
func TestNewSanitizer_Empty(t *testing.T) {
|
||||
s := NewSanitizer(config.SanitizeConfig{})
|
||||
if len(s.toolsForward) != 0 {
|
||||
t.Errorf("expected empty toolsForward, got %d entries", len(s.toolsForward))
|
||||
}
|
||||
if len(s.toolsReverse) != 0 {
|
||||
t.Errorf("expected empty toolsReverse, got %d entries", len(s.toolsReverse))
|
||||
}
|
||||
if s.systemRules != nil {
|
||||
t.Errorf("expected nil systemRules")
|
||||
}
|
||||
if s.bodyRules != nil {
|
||||
t.Errorf("expected nil bodyRules")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSanitizer_WithTools(t *testing.T) {
|
||||
cfg := config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{
|
||||
{From: "old_tool", To: "new_tool"},
|
||||
{From: "another", To: "replaced"},
|
||||
},
|
||||
}
|
||||
s := NewSanitizer(cfg)
|
||||
if got := s.toolsForward["old_tool"]; got != "new_tool" {
|
||||
t.Errorf("toolsForward[old_tool] = %q, want %q", got, "new_tool")
|
||||
}
|
||||
if got := s.toolsReverse["new_tool"]; got != "old_tool" {
|
||||
t.Errorf("toolsReverse[new_tool] = %q, want %q", got, "old_tool")
|
||||
}
|
||||
if got := s.toolsForward["another"]; got != "replaced" {
|
||||
t.Errorf("toolsForward[another] = %q, want %q", got, "replaced")
|
||||
}
|
||||
if got := s.toolsReverse["replaced"]; got != "another" {
|
||||
t.Errorf("toolsReverse[replaced] = %q, want %q", got, "another")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSanitizer_WithSystemAndBodyRules(t *testing.T) {
|
||||
cfg := config.SanitizeConfig{
|
||||
System: []config.ReplaceRule{{Match: "foo", Replace: "bar"}},
|
||||
Body: []config.ReplaceRule{{Match: "baz", Replace: "qux"}},
|
||||
}
|
||||
s := NewSanitizer(cfg)
|
||||
if len(s.systemRules) != 1 || s.systemRules[0].Match != "foo" {
|
||||
t.Errorf("systemRules not set correctly")
|
||||
}
|
||||
if len(s.bodyRules) != 1 || s.bodyRules[0].Match != "baz" {
|
||||
t.Errorf("bodyRules not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenameTools(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
forward map[string]string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty map returns body unchanged",
|
||||
forward: map[string]string{},
|
||||
body: `{"tools":[{"name":"my_tool"}]}`,
|
||||
expected: `{"tools":[{"name":"my_tool"}]}`,
|
||||
},
|
||||
{
|
||||
name: "no tools array returns body unchanged",
|
||||
forward: map[string]string{"my_tool": "renamed"},
|
||||
body: `{"messages":[]}`,
|
||||
expected: `{"messages":[]}`,
|
||||
},
|
||||
{
|
||||
name: "tools is not array returns body unchanged",
|
||||
forward: map[string]string{"my_tool": "renamed"},
|
||||
body: `{"tools":"not_array"}`,
|
||||
expected: `{"tools":"not_array"}`,
|
||||
},
|
||||
{
|
||||
name: "matching tool gets renamed",
|
||||
forward: map[string]string{"my_tool": "renamed_tool"},
|
||||
body: `{"tools":[{"name":"my_tool","description":"desc"}]}`,
|
||||
expected: `renamed_tool`,
|
||||
},
|
||||
{
|
||||
name: "non-matching tool unchanged",
|
||||
forward: map[string]string{"other_tool": "renamed"},
|
||||
body: `{"tools":[{"name":"my_tool"}]}`,
|
||||
expected: `my_tool`,
|
||||
},
|
||||
{
|
||||
name: "partial match - only exact match renames",
|
||||
forward: map[string]string{"tool_a": "tool_x", "tool_b": "tool_y"},
|
||||
body: `{"tools":[{"name":"tool_a"},{"name":"tool_c"},{"name":"tool_b"}]}`,
|
||||
expected: `tool_x`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: tt.forward,
|
||||
toolsReverse: make(map[string]string),
|
||||
}
|
||||
result := string(s.renameTools([]byte(tt.body)))
|
||||
if !strings.Contains(result, tt.expected) {
|
||||
t.Errorf("result %q does not contain %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenameTools_MultipleTools(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: map[string]string{"tool_a": "tool_x", "tool_b": "tool_y"},
|
||||
toolsReverse: make(map[string]string),
|
||||
}
|
||||
body := `{"tools":[{"name":"tool_a"},{"name":"tool_c"},{"name":"tool_b"}]}`
|
||||
result := string(s.renameTools([]byte(body)))
|
||||
if !strings.Contains(result, `"tool_x"`) {
|
||||
t.Errorf("tool_a should be renamed to tool_x, got %s", result)
|
||||
}
|
||||
if !strings.Contains(result, `"tool_y"`) {
|
||||
t.Errorf("tool_b should be renamed to tool_y, got %s", result)
|
||||
}
|
||||
if !strings.Contains(result, `"tool_c"`) {
|
||||
t.Errorf("tool_c should remain unchanged, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceSystem(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rules []config.ReplaceRule
|
||||
body string
|
||||
contains string
|
||||
}{
|
||||
{
|
||||
name: "empty rules returns body unchanged",
|
||||
rules: nil,
|
||||
body: `{"system":[{"type":"text","text":"hello world"}]}`,
|
||||
contains: "hello world",
|
||||
},
|
||||
{
|
||||
name: "no system field returns body unchanged",
|
||||
rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}},
|
||||
body: `{"messages":[]}`,
|
||||
contains: `"messages":[]`,
|
||||
},
|
||||
{
|
||||
name: "system not array returns body unchanged",
|
||||
rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}},
|
||||
body: `{"system":"just a string"}`,
|
||||
contains: "just a string",
|
||||
},
|
||||
{
|
||||
name: "single block single rule",
|
||||
rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}},
|
||||
body: `{"system":[{"type":"text","text":"hello world"}]}`,
|
||||
contains: "goodbye world",
|
||||
},
|
||||
{
|
||||
name: "multiple blocks",
|
||||
rules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}},
|
||||
body: `{"system":[{"type":"text","text":"AAA first"},{"type":"text","text":"AAA second"}]}`,
|
||||
contains: "BBB first",
|
||||
},
|
||||
{
|
||||
name: "multiple rules applied in order",
|
||||
rules: []config.ReplaceRule{{Match: "cat", Replace: "dog"}, {Match: "dog", Replace: "fish"}},
|
||||
body: `{"system":[{"type":"text","text":"I have a cat"}]}`,
|
||||
contains: "I have a fish",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: make(map[string]string),
|
||||
systemRules: tt.rules,
|
||||
}
|
||||
result := string(s.replaceSystem([]byte(tt.body)))
|
||||
if !strings.Contains(result, tt.contains) {
|
||||
t.Errorf("result %q does not contain %q", result, tt.contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceSystem_MultipleBlocks(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: make(map[string]string),
|
||||
systemRules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}},
|
||||
}
|
||||
body := `{"system":[{"type":"text","text":"AAA first"},{"type":"text","text":"AAA second"}]}`
|
||||
result := string(s.replaceSystem([]byte(body)))
|
||||
if !strings.Contains(result, "BBB first") {
|
||||
t.Errorf("first block not replaced: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "BBB second") {
|
||||
t.Errorf("second block not replaced: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rules []config.ReplaceRule
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty rules returns body unchanged",
|
||||
rules: nil,
|
||||
body: `{"foo":"bar"}`,
|
||||
expected: `{"foo":"bar"}`,
|
||||
},
|
||||
{
|
||||
name: "single replacement across entire body",
|
||||
rules: []config.ReplaceRule{{Match: "SECRET", Replace: "REDACTED"}},
|
||||
body: `{"data":"SECRET value SECRET"}`,
|
||||
expected: `{"data":"REDACTED value REDACTED"}`,
|
||||
},
|
||||
{
|
||||
name: "multiple rules applied sequentially",
|
||||
rules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}, {Match: "BBB", Replace: "CCC"}},
|
||||
body: `{"text":"AAA"}`,
|
||||
expected: `{"text":"CCC"}`,
|
||||
},
|
||||
{
|
||||
name: "no match leaves body unchanged",
|
||||
rules: []config.ReplaceRule{{Match: "NOMATCH", Replace: "X"}},
|
||||
body: `{"text":"hello"}`,
|
||||
expected: `{"text":"hello"}`,
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
rules: []config.ReplaceRule{{Match: "a", Replace: "b"}},
|
||||
body: ``,
|
||||
expected: ``,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: make(map[string]string),
|
||||
bodyRules: tt.rules,
|
||||
}
|
||||
result := string(s.replaceBody([]byte(tt.body)))
|
||||
if result != tt.expected {
|
||||
t.Errorf("got %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRequest(t *testing.T) {
|
||||
cfg := config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "my_tool", To: "renamed_tool"}},
|
||||
System: []config.ReplaceRule{{Match: "INTERNAL", Replace: "PUBLIC"}},
|
||||
Body: []config.ReplaceRule{{Match: "secret_val", Replace: "safe_val"}},
|
||||
}
|
||||
s := NewSanitizer(cfg)
|
||||
|
||||
body := `{"tools":[{"name":"my_tool"}],"system":[{"type":"text","text":"INTERNAL info"}],"data":"secret_val here"}`
|
||||
result := string(s.SanitizeRequest([]byte(body)))
|
||||
|
||||
if !strings.Contains(result, `"renamed_tool"`) {
|
||||
t.Errorf("tool not renamed in result: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "PUBLIC info") {
|
||||
t.Errorf("system not replaced in result: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "safe_val here") {
|
||||
t.Errorf("body not replaced in result: %s", result)
|
||||
}
|
||||
if strings.Contains(result, "secret_val") {
|
||||
t.Errorf("secret_val should have been replaced: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRequest_EmptyConfig(t *testing.T) {
|
||||
s := NewSanitizer(config.SanitizeConfig{})
|
||||
body := `{"tools":[{"name":"my_tool"}],"system":[{"type":"text","text":"hello"}]}`
|
||||
result := string(s.SanitizeRequest([]byte(body)))
|
||||
if result != body {
|
||||
t.Errorf("empty config should not modify body.\ngot: %s\nwant: %s", result, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDesanitizeResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reverse map[string]string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no content field returns unchanged",
|
||||
reverse: map[string]string{"renamed": "original"},
|
||||
body: `{"id":"msg_1","role":"assistant"}`,
|
||||
expected: `{"id":"msg_1","role":"assistant"}`,
|
||||
},
|
||||
{
|
||||
name: "content not array returns unchanged",
|
||||
reverse: map[string]string{"renamed": "original"},
|
||||
body: `{"content":"just text"}`,
|
||||
expected: `{"content":"just text"}`,
|
||||
},
|
||||
{
|
||||
name: "non-tool_use block left unchanged",
|
||||
reverse: map[string]string{"renamed": "original"},
|
||||
body: `{"content":[{"type":"text","text":"hello"}]}`,
|
||||
expected: `{"content":[{"type":"text","text":"hello"}]}`,
|
||||
},
|
||||
{
|
||||
name: "tool_use block with matching name gets reversed",
|
||||
reverse: map[string]string{"renamed_tool": "original_tool"},
|
||||
body: `{"content":[{"type":"tool_use","name":"renamed_tool","id":"t1"}]}`,
|
||||
expected: `original_tool`,
|
||||
},
|
||||
{
|
||||
name: "tool_use block with no match unchanged",
|
||||
reverse: map[string]string{"other": "something"},
|
||||
body: `{"content":[{"type":"tool_use","name":"my_tool","id":"t1"}]}`,
|
||||
expected: `my_tool`,
|
||||
},
|
||||
{
|
||||
name: "mixed blocks only tool_use reversed",
|
||||
reverse: map[string]string{"renamed": "original"},
|
||||
body: `{"content":[{"type":"text","text":"hi"},{"type":"tool_use","name":"renamed","id":"t1"}]}`,
|
||||
expected: `original`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: tt.reverse,
|
||||
}
|
||||
result := string(s.DesanitizeResponse([]byte(tt.body)))
|
||||
if !strings.Contains(result, tt.expected) {
|
||||
t.Errorf("result %q does not contain %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDesanitizeResponse_MultipleToolUse(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: map[string]string{"r1": "o1", "r2": "o2"},
|
||||
}
|
||||
body := `{"content":[{"type":"tool_use","name":"r1","id":"t1"},{"type":"text","text":"x"},{"type":"tool_use","name":"r2","id":"t2"}]}`
|
||||
result := string(s.DesanitizeResponse([]byte(body)))
|
||||
if !strings.Contains(result, `"o1"`) {
|
||||
t.Errorf("r1 not reversed to o1: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, `"o2"`) {
|
||||
t.Errorf("r2 not reversed to o2: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDesanitizeStreamEvent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reverse map[string]string
|
||||
line string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "non-data line passed through",
|
||||
reverse: map[string]string{"r": "o"},
|
||||
line: "event: content_block_start",
|
||||
expected: "event: content_block_start",
|
||||
},
|
||||
{
|
||||
name: "data line without tool_use passed through",
|
||||
reverse: map[string]string{"r": "o"},
|
||||
line: `data: {"type":"text","text":"hello"}`,
|
||||
expected: `data: {"type":"text","text":"hello"}`,
|
||||
},
|
||||
{
|
||||
name: "data line with tool_use in content_block.name",
|
||||
reverse: map[string]string{"renamed_tool": "original_tool"},
|
||||
line: `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed_tool","id":"t1"}}`,
|
||||
expected: `original_tool`,
|
||||
},
|
||||
{
|
||||
name: "data line with tool_use in delta.name",
|
||||
reverse: map[string]string{"renamed_tool": "original_tool"},
|
||||
line: `data: {"type":"content_block_delta","delta":{"type":"tool_use","name":"renamed_tool"}}`,
|
||||
expected: `original_tool`,
|
||||
},
|
||||
{
|
||||
name: "data line with tool_use but no matching name",
|
||||
reverse: map[string]string{"other": "something"},
|
||||
line: `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"my_tool","id":"t1"}}`,
|
||||
expected: `my_tool`,
|
||||
},
|
||||
{
|
||||
name: "empty line passed through",
|
||||
reverse: map[string]string{"r": "o"},
|
||||
line: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "line contains tool_use but not data prefix - passed through",
|
||||
reverse: map[string]string{"r": "o"},
|
||||
line: "event: tool_use",
|
||||
expected: "event: tool_use",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: tt.reverse,
|
||||
}
|
||||
result := s.DesanitizeStreamEvent(tt.line)
|
||||
if !strings.Contains(result, tt.expected) {
|
||||
t.Errorf("result %q does not contain %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDesanitizeStreamEvent_DataPrefixPreserved(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: map[string]string{"renamed": "original"},
|
||||
}
|
||||
line := `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed","id":"t1"}}`
|
||||
result := s.DesanitizeStreamEvent(line)
|
||||
if !strings.HasPrefix(result, "data: ") {
|
||||
t.Errorf("result should start with 'data: ', got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRequest_MalformedJSON(t *testing.T) {
|
||||
s := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "a", To: "b"}},
|
||||
System: []config.ReplaceRule{{Match: "x", Replace: "y"}},
|
||||
})
|
||||
// Malformed JSON - renameTools and replaceSystem should handle gracefully
|
||||
body := `not valid json`
|
||||
result := string(s.SanitizeRequest([]byte(body)))
|
||||
// Should not panic; body rules still do string replacement
|
||||
if result != "not valid json" {
|
||||
t.Errorf("malformed JSON should pass through (no body rules match), got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRequest_EmptyBody(t *testing.T) {
|
||||
s := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "a", To: "b"}},
|
||||
})
|
||||
result := s.SanitizeRequest([]byte{})
|
||||
if len(result) != 0 {
|
||||
t.Errorf("empty body should return empty, got %q", string(result))
|
||||
}
|
||||
}
|
||||
+41
-58
@@ -36,21 +36,6 @@ var skipHeaders = map[string]bool{
|
||||
"connection": true,
|
||||
}
|
||||
|
||||
const fakeJSONResponse = `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`
|
||||
|
||||
const fakeStreamResponse = "event: message_start\n" +
|
||||
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_fake\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}\n\n" +
|
||||
"event: content_block_start\n" +
|
||||
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n" +
|
||||
"event: content_block_delta\n" +
|
||||
"data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"ok\"}}\n\n" +
|
||||
"event: content_block_stop\n" +
|
||||
"data: {\"type\":\"content_block_stop\",\"index\":0}\n\n" +
|
||||
"event: message_delta\n" +
|
||||
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":1}}\n\n" +
|
||||
"event: message_stop\n" +
|
||||
"data: {\"type\":\"message_stop\"}\n\n"
|
||||
|
||||
func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
@@ -63,7 +48,45 @@ func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) {
|
||||
captured := make(chan struct{}, 1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", sniffHandler(&mu, &profile, captured))
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "HEAD" {
|
||||
w.WriteHeader(200)
|
||||
return
|
||||
}
|
||||
if r.Method != "POST" || !strings.Contains(r.URL.Path, "/v1/messages") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`)
|
||||
return
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
|
||||
mu.Lock()
|
||||
if profile == nil {
|
||||
profile = extractProfile(r, body)
|
||||
select {
|
||||
case captured <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
if strings.Contains(string(body), `"stream":true`) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_fake\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}\n\n")
|
||||
fmt.Fprint(w, "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n")
|
||||
fmt.Fprint(w, "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"ok\"}}\n\n")
|
||||
fmt.Fprint(w, "event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n")
|
||||
fmt.Fprint(w, "event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":1}}\n\n")
|
||||
fmt.Fprint(w, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`)
|
||||
}
|
||||
})
|
||||
|
||||
srv := &http.Server{Handler: mux}
|
||||
go srv.Serve(listener)
|
||||
@@ -99,52 +122,11 @@ func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) {
|
||||
Int("headers", len(profile.Headers)).
|
||||
Int("body_size", len(profile.Body)).
|
||||
Msg("sniffed claude-code profile")
|
||||
|
||||
for _, h := range profile.Headers {
|
||||
log.Debug().Str("header", h[0]).Str("value", h[1]).Msg("sniffed header")
|
||||
}
|
||||
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
func sniffHandler(mu *sync.Mutex, profile **SniffedProfile, captured chan<- struct{}) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "HEAD" {
|
||||
w.WriteHeader(200)
|
||||
return
|
||||
}
|
||||
if r.Method != "POST" || !strings.Contains(r.URL.Path, "/v1/messages") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, fakeJSONResponse)
|
||||
return
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
|
||||
mu.Lock()
|
||||
if *profile == nil {
|
||||
*profile = extractProfile(r, body)
|
||||
select {
|
||||
case captured <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
if strings.Contains(string(body), `"stream":true`) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, fakeStreamResponse)
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprint(w, fakeJSONResponse)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func extractProfile(r *http.Request, body []byte) *SniffedProfile {
|
||||
// Capture raw headers preserving original casing.
|
||||
var headers [][2]string
|
||||
for name, vals := range r.Header {
|
||||
if skipHeaders[strings.ToLower(name)] {
|
||||
@@ -155,6 +137,7 @@ func extractProfile(r *http.Request, body []byte) *SniffedProfile {
|
||||
}
|
||||
}
|
||||
|
||||
// Deduplicate and strip subscription-specific betas.
|
||||
seen := map[string]bool{}
|
||||
var deduped [][2]string
|
||||
for _, h := range headers {
|
||||
|
||||
@@ -1,278 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newRequest(t *testing.T, headers map[string][]string) *http.Request {
|
||||
t.Helper()
|
||||
r := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
r.Header = http.Header{}
|
||||
for k, vals := range headers {
|
||||
for _, v := range vals {
|
||||
r.Header.Add(k, v)
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func TestExtractProfile_BasicHeaders(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"Content-Type": {"application/json"},
|
||||
"X-Custom-Header": {"custom-value"},
|
||||
"User-Agent": {"Claude/1.2.3 linux"},
|
||||
})
|
||||
body := []byte(`{"model":"claude-sonnet-4-6"}`)
|
||||
|
||||
p := extractProfile(r, body)
|
||||
|
||||
// Check version parsed
|
||||
if p.Version != "1.2.3" {
|
||||
t.Errorf("version = %q, want %q", p.Version, "1.2.3")
|
||||
}
|
||||
|
||||
// Check body preserved
|
||||
if string(p.Body) != string(body) {
|
||||
t.Errorf("body not preserved")
|
||||
}
|
||||
|
||||
// Check headers captured
|
||||
found := map[string]bool{}
|
||||
for _, h := range p.Headers {
|
||||
found[strings.ToLower(h[0])] = true
|
||||
}
|
||||
if !found["content-type"] {
|
||||
t.Error("Content-Type header should be captured")
|
||||
}
|
||||
if !found["x-custom-header"] {
|
||||
t.Error("X-Custom-Header should be captured")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_SkipHeaders(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"Host": {"example.com"},
|
||||
"Content-Length": {"42"},
|
||||
"Authorization": {"Bearer token123"},
|
||||
"X-Api-Key": {"key123"},
|
||||
"Connection": {"keep-alive"},
|
||||
"Content-Type": {"application/json"},
|
||||
"X-Custom": {"keep-me"},
|
||||
})
|
||||
|
||||
p := extractProfile(r, []byte(`{}`))
|
||||
|
||||
for _, h := range p.Headers {
|
||||
lower := strings.ToLower(h[0])
|
||||
if skipHeaders[lower] {
|
||||
t.Errorf("header %q should have been skipped", h[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Verify non-skipped headers are present
|
||||
found := map[string]bool{}
|
||||
for _, h := range p.Headers {
|
||||
found[strings.ToLower(h[0])] = true
|
||||
}
|
||||
if !found["content-type"] {
|
||||
t.Error("Content-Type should be kept")
|
||||
}
|
||||
if !found["x-custom"] {
|
||||
t.Error("X-Custom should be kept")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_HeaderDeduplication(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"Content-Type": {"application/json"},
|
||||
})
|
||||
// Add duplicate with different casing - Go's http.Header normalizes to canonical form
|
||||
// so we need to add the same canonical header with multiple values to test dedup
|
||||
r.Header.Add("Content-Type", "text/plain")
|
||||
|
||||
p := extractProfile(r, []byte(`{}`))
|
||||
|
||||
// After deduplication by lowercase key, only one entry per key
|
||||
seen := map[string]int{}
|
||||
for _, h := range p.Headers {
|
||||
seen[strings.ToLower(h[0])]++
|
||||
}
|
||||
for key, count := range seen {
|
||||
if count > 1 {
|
||||
t.Errorf("header %q appears %d times after dedup, want 1", key, count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_AnthropicBetaContextStripping(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"Anthropic-Beta": {"prompt-caching-2024-07-31,context-1m-2024-09-01,some-other-beta"},
|
||||
})
|
||||
|
||||
p := extractProfile(r, []byte(`{}`))
|
||||
|
||||
var betaValue string
|
||||
for _, h := range p.Headers {
|
||||
if strings.ToLower(h[0]) == "anthropic-beta" {
|
||||
betaValue = h[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(betaValue, "context-1m") {
|
||||
t.Errorf("context-1m should be stripped from anthropic-beta, got %q", betaValue)
|
||||
}
|
||||
if !strings.Contains(betaValue, "prompt-caching-2024-07-31") {
|
||||
t.Errorf("prompt-caching should be preserved, got %q", betaValue)
|
||||
}
|
||||
if !strings.Contains(betaValue, "some-other-beta") {
|
||||
t.Errorf("some-other-beta should be preserved, got %q", betaValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_AnthropicBetaAllContextRemoved(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"Anthropic-Beta": {"context-1m-2024-09-01"},
|
||||
})
|
||||
|
||||
p := extractProfile(r, []byte(`{}`))
|
||||
|
||||
for _, h := range p.Headers {
|
||||
if strings.ToLower(h[0]) == "anthropic-beta" {
|
||||
// All betas were context-1m, so after filtering the value should be empty
|
||||
if h[1] != "" {
|
||||
t.Errorf("all context-1m betas stripped should leave empty, got %q", h[1])
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
// It's also acceptable if the header is still present but empty
|
||||
}
|
||||
|
||||
func TestExtractProfile_VersionParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userAgent string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "standard Claude UA",
|
||||
userAgent: "Claude/1.2.3 linux x86_64",
|
||||
expected: "1.2.3",
|
||||
},
|
||||
{
|
||||
name: "version with no space after",
|
||||
userAgent: "Claude/4.5.6",
|
||||
expected: "4.5.6",
|
||||
},
|
||||
{
|
||||
name: "no slash in UA",
|
||||
userAgent: "Mozilla 5.0",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty UA",
|
||||
userAgent: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "slash at start",
|
||||
userAgent: "/1.0.0 rest",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "multiple slashes",
|
||||
userAgent: "App/1.0.0 (sub/2.0)",
|
||||
expected: "1.0.0",
|
||||
},
|
||||
{
|
||||
name: "version only after slash no space",
|
||||
userAgent: "Tool/9.8.7",
|
||||
expected: "9.8.7",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"User-Agent": {tt.userAgent},
|
||||
})
|
||||
p := extractProfile(r, []byte(`{}`))
|
||||
if p.Version != tt.expected {
|
||||
t.Errorf("version = %q, want %q", p.Version, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_EmptyHeaders(t *testing.T) {
|
||||
r := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
r.Header = http.Header{}
|
||||
|
||||
p := extractProfile(r, []byte(`{"test":true}`))
|
||||
|
||||
if len(p.Headers) != 0 {
|
||||
t.Errorf("expected no headers, got %d", len(p.Headers))
|
||||
}
|
||||
if p.Version != "" {
|
||||
t.Errorf("expected empty version with no UA, got %q", p.Version)
|
||||
}
|
||||
if string(p.Body) != `{"test":true}` {
|
||||
t.Errorf("body not preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_BodyPreserved(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"User-Agent": {"Claude/1.0.0 test"},
|
||||
})
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
|
||||
p := extractProfile(r, body)
|
||||
|
||||
if string(p.Body) != string(body) {
|
||||
t.Errorf("body not preserved.\ngot: %s\nwant: %s", p.Body, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipHeaders_Entries(t *testing.T) {
|
||||
expected := map[string]bool{
|
||||
"host": true,
|
||||
"content-length": true,
|
||||
"authorization": true,
|
||||
"x-api-key": true,
|
||||
"connection": true,
|
||||
}
|
||||
if len(skipHeaders) != len(expected) {
|
||||
t.Errorf("skipHeaders has %d entries, want %d", len(skipHeaders), len(expected))
|
||||
}
|
||||
for k, v := range expected {
|
||||
if skipHeaders[k] != v {
|
||||
t.Errorf("skipHeaders[%q] = %v, want %v", k, skipHeaders[k], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSniffedProfile_Fields(t *testing.T) {
|
||||
// Verify the struct can hold all expected data
|
||||
p := &SniffedProfile{
|
||||
Headers: [][2]string{{"Content-Type", "application/json"}},
|
||||
Body: []byte(`{}`),
|
||||
Version: "1.0.0",
|
||||
}
|
||||
if len(p.Headers) != 1 {
|
||||
t.Error("Headers should have 1 entry")
|
||||
}
|
||||
if p.Headers[0][0] != "Content-Type" || p.Headers[0][1] != "application/json" {
|
||||
t.Error("Header not stored correctly")
|
||||
}
|
||||
if string(p.Body) != `{}` {
|
||||
t.Error("Body not stored correctly")
|
||||
}
|
||||
if p.Version != "1.0.0" {
|
||||
t.Error("Version not stored correctly")
|
||||
}
|
||||
}
|
||||
@@ -1,47 +1,29 @@
|
||||
// Package transport provides a shared uTLS HTTP/2 round-tripper with Chrome
|
||||
// TLS fingerprinting and per-host connection pooling. Used by both the upstream
|
||||
// proxy client and the OAuth token refresh client.
|
||||
package transport
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
tls "github.com/refraction-networking/utls"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
// UTLS implements http.RoundTripper using uTLS (Chrome fingerprint) over HTTP/2.
|
||||
// It maintains a per-host connection pool with coordination for concurrent
|
||||
// requests to the same host.
|
||||
type UTLS struct {
|
||||
type utlsRoundTripper struct {
|
||||
mu sync.Mutex
|
||||
connections map[string]*http2.ClientConn
|
||||
pending map[string]*sync.Cond
|
||||
dialTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewUTLS creates a uTLS HTTP/2 round-tripper with a 10-second dial timeout.
|
||||
func NewUTLS() *UTLS {
|
||||
return &UTLS{
|
||||
func newUtlsRoundTripper() *utlsRoundTripper {
|
||||
return &utlsRoundTripper{
|
||||
connections: make(map[string]*http2.ClientConn),
|
||||
pending: make(map[string]*sync.Cond),
|
||||
dialTimeout: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHTTPClient returns an http.Client using uTLS transport with the given
|
||||
// request timeout. Pass 0 for no timeout (streaming).
|
||||
func NewHTTPClient(timeout time.Duration) *http.Client {
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: NewUTLS(),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *UTLS) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
t.mu.Lock()
|
||||
|
||||
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
|
||||
@@ -77,8 +59,8 @@ func (t *UTLS) getOrCreateConnection(host, addr string) (*http2.ClientConn, erro
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
func (t *UTLS) createConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
conn, err := net.DialTimeout("tcp", addr, t.dialTimeout)
|
||||
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -101,14 +83,14 @@ func (t *UTLS) createConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper with uTLS Chrome fingerprinting.
|
||||
func (t *UTLS) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
hostname := req.URL.Hostname()
|
||||
port := req.URL.Port()
|
||||
if port == "" {
|
||||
port = "443"
|
||||
}
|
||||
addr := net.JoinHostPort(hostname, port)
|
||||
log.Debug().Str("addr", addr).Msg("uTLS round trip")
|
||||
|
||||
h2Conn, err := t.getOrCreateConnection(hostname, addr)
|
||||
if err != nil {
|
||||
@@ -9,12 +9,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/auth"
|
||||
"github.com/fujin/anthropic-proxy/internal/logging"
|
||||
"github.com/fujin/anthropic-proxy/internal/transport"
|
||||
"github.com/fujin/anthropic-proxy/internal/version"
|
||||
)
|
||||
|
||||
const messagesURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
@@ -29,7 +25,7 @@ func NewUpstreamClient(profile *SniffedProfile) *UpstreamClient {
|
||||
return &UpstreamClient{
|
||||
client: http.Client{
|
||||
Timeout: 0,
|
||||
Transport: transport.NewUTLS(),
|
||||
Transport: newUtlsRoundTripper(),
|
||||
},
|
||||
sessionID: uuid.New().String(),
|
||||
profile: profile,
|
||||
@@ -40,7 +36,7 @@ func (u *UpstreamClient) version() string {
|
||||
if u.profile != nil && u.profile.Version != "" {
|
||||
return u.profile.Version
|
||||
}
|
||||
return version.ClaudeCodeFallback
|
||||
return "2.1.92"
|
||||
}
|
||||
|
||||
// applyHeaders replays sniffed headers, substituting auth + per-request IDs + accept.
|
||||
@@ -55,15 +51,6 @@ func (u *UpstreamClient) applyHeaders(req *http.Request, token string, streaming
|
||||
req.Header.Del("x-api-key")
|
||||
if strings.HasPrefix(token, "sk-ant-oat") {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
// OAuth tokens require this beta flag — without it the API rejects with 401
|
||||
existing := req.Header.Get("anthropic-beta")
|
||||
if !strings.Contains(existing, "oauth-2025-04-20") {
|
||||
if existing == "" {
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
} else {
|
||||
req.Header.Set("anthropic-beta", existing+",oauth-2025-04-20")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
req.Header.Set("x-api-key", token)
|
||||
}
|
||||
@@ -88,12 +75,6 @@ func (u *UpstreamClient) Execute(ctx context.Context, cred *auth.Credential, bod
|
||||
}
|
||||
u.applyHeaders(req, cred.Token(), false)
|
||||
|
||||
log.Debug().
|
||||
Str("url", messagesURL).
|
||||
Str("upstream_headers", logging.RedactHeaders(req.Header)).
|
||||
Int("body_size", len(body)).
|
||||
Msg("upstream request")
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, 0, fmt.Errorf("upstream request: %w", err)
|
||||
@@ -104,12 +85,6 @@ func (u *UpstreamClient) Execute(ctx context.Context, cred *auth.Credential, bod
|
||||
if err != nil {
|
||||
return nil, nil, resp.StatusCode, fmt.Errorf("read upstream response: %w", err)
|
||||
}
|
||||
log.Debug().
|
||||
Int("status", resp.StatusCode).
|
||||
Str("response_headers", logging.RedactHeaders(resp.Header)).
|
||||
Int("response_size", len(respBody)).
|
||||
Msg("upstream response")
|
||||
|
||||
return respBody, resp.Header, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
@@ -122,21 +97,9 @@ func (u *UpstreamClient) ExecuteStream(ctx context.Context, cred *auth.Credentia
|
||||
}
|
||||
u.applyHeaders(req, cred.Token(), true)
|
||||
|
||||
log.Debug().
|
||||
Str("url", messagesURL).
|
||||
Str("upstream_headers", logging.RedactHeaders(req.Header)).
|
||||
Int("body_size", len(body)).
|
||||
Msg("upstream stream request")
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream stream request: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("status", resp.StatusCode).
|
||||
Str("response_headers", logging.RedactHeaders(resp.Header)).
|
||||
Msg("upstream stream response")
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -1,334 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- NewUpstreamClient ---
|
||||
|
||||
func TestNewUpstreamClient_NilProfile(t *testing.T) {
|
||||
uc := NewUpstreamClient(nil)
|
||||
if uc == nil {
|
||||
t.Fatal("NewUpstreamClient returned nil")
|
||||
}
|
||||
if uc.sessionID == "" {
|
||||
t.Error("expected non-empty sessionID")
|
||||
}
|
||||
if uc.profile != nil {
|
||||
t.Error("expected nil profile")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUpstreamClient_WithProfile(t *testing.T) {
|
||||
profile := &SniffedProfile{
|
||||
Version: "1.2.3",
|
||||
Headers: [][2]string{{"User-Agent", "test/1.0"}},
|
||||
}
|
||||
uc := NewUpstreamClient(profile)
|
||||
if uc.profile != profile {
|
||||
t.Error("expected profile to be stored")
|
||||
}
|
||||
if uc.sessionID == "" {
|
||||
t.Error("expected non-empty sessionID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUpstreamClient_UniqueSessionIDs(t *testing.T) {
|
||||
uc1 := NewUpstreamClient(nil)
|
||||
uc2 := NewUpstreamClient(nil)
|
||||
if uc1.sessionID == uc2.sessionID {
|
||||
t.Errorf("expected different session IDs, both got %q", uc1.sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
// --- version() ---
|
||||
|
||||
func TestVersion_WithProfileVersion(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
profile: &SniffedProfile{Version: "3.5.7"},
|
||||
}
|
||||
if got := uc.version(); got != "3.5.7" {
|
||||
t.Errorf("version() = %q, want %q", got, "3.5.7")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersion_NilProfile_Fallback(t *testing.T) {
|
||||
uc := &UpstreamClient{profile: nil}
|
||||
if got := uc.version(); got != "2.1.92" {
|
||||
t.Errorf("version() = %q, want %q", got, "2.1.92")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersion_EmptyProfileVersion_Fallback(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
profile: &SniffedProfile{Version: ""},
|
||||
}
|
||||
if got := uc.version(); got != "2.1.92" {
|
||||
t.Errorf("version() = %q, want %q", got, "2.1.92")
|
||||
}
|
||||
}
|
||||
|
||||
// --- applyHeaders ---
|
||||
|
||||
func TestApplyHeaders_NilProfile_NonOAuth_NonStream(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "test-session-id",
|
||||
profile: nil,
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-api123", false)
|
||||
|
||||
// x-api-key for non-OAuth token
|
||||
if got := req.Header.Get("x-api-key"); got != "sk-ant-api123" {
|
||||
t.Errorf("x-api-key = %q, want %q", got, "sk-ant-api123")
|
||||
}
|
||||
// Should NOT have Authorization
|
||||
if got := req.Header.Get("Authorization"); got != "" {
|
||||
t.Errorf("Authorization = %q, want empty", got)
|
||||
}
|
||||
// Session ID
|
||||
if got := req.Header.Get("X-Claude-Code-Session-Id"); got != "test-session-id" {
|
||||
t.Errorf("X-Claude-Code-Session-Id = %q, want %q", got, "test-session-id")
|
||||
}
|
||||
// Request ID should be a UUID
|
||||
if got := req.Header.Get("x-client-request-id"); got == "" {
|
||||
t.Error("expected non-empty x-client-request-id")
|
||||
}
|
||||
// Non-stream: application/json
|
||||
if got := req.Header.Get("Accept"); got != "application/json" {
|
||||
t.Errorf("Accept = %q, want %q", got, "application/json")
|
||||
}
|
||||
// Accept-Encoding always identity
|
||||
if got := req.Header.Get("Accept-Encoding"); got != "identity" {
|
||||
t.Errorf("Accept-Encoding = %q, want %q", got, "identity")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_NilProfile_NonOAuth_Stream(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: nil,
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-api123", true)
|
||||
|
||||
if got := req.Header.Get("Accept"); got != "text/event-stream" {
|
||||
t.Errorf("Accept = %q, want %q", got, "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_OAuthToken_SetsBearerAndBetaFlag(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: nil,
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-oat-mytoken", false)
|
||||
|
||||
// OAuth: Authorization Bearer
|
||||
if got := req.Header.Get("Authorization"); got != "Bearer sk-ant-oat-mytoken" {
|
||||
t.Errorf("Authorization = %q, want %q", got, "Bearer sk-ant-oat-mytoken")
|
||||
}
|
||||
// Should NOT have x-api-key
|
||||
if got := req.Header.Get("x-api-key"); got != "" {
|
||||
t.Errorf("x-api-key = %q, want empty for OAuth", got)
|
||||
}
|
||||
// anthropic-beta should include oauth-2025-04-20
|
||||
if got := req.Header.Get("anthropic-beta"); got != "oauth-2025-04-20" {
|
||||
t.Errorf("anthropic-beta = %q, want %q", got, "oauth-2025-04-20")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_OAuthToken_AppendsToExistingBeta(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: &SniffedProfile{
|
||||
Headers: [][2]string{
|
||||
{"anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15"},
|
||||
},
|
||||
},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-oat-tok", false)
|
||||
|
||||
beta := req.Header.Get("anthropic-beta")
|
||||
if !strings.Contains(beta, "max-tokens-3-5-sonnet-2024-07-15") {
|
||||
t.Errorf("anthropic-beta %q should contain existing beta", beta)
|
||||
}
|
||||
if !strings.Contains(beta, "oauth-2025-04-20") {
|
||||
t.Errorf("anthropic-beta %q should contain oauth flag", beta)
|
||||
}
|
||||
// Should be appended with comma
|
||||
if beta != "max-tokens-3-5-sonnet-2024-07-15,oauth-2025-04-20" {
|
||||
t.Errorf("anthropic-beta = %q, want %q", beta, "max-tokens-3-5-sonnet-2024-07-15,oauth-2025-04-20")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_OAuthToken_ExistingBetaAlreadyHasOAuth(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: &SniffedProfile{
|
||||
Headers: [][2]string{
|
||||
{"anthropic-beta", "oauth-2025-04-20,something-else"},
|
||||
},
|
||||
},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-oat-tok", false)
|
||||
|
||||
beta := req.Header.Get("anthropic-beta")
|
||||
// Should NOT duplicate oauth flag
|
||||
count := strings.Count(beta, "oauth-2025-04-20")
|
||||
if count != 1 {
|
||||
t.Errorf("oauth flag appeared %d times in %q, want 1", count, beta)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_WithProfile_ReplaysHeaders(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: &SniffedProfile{
|
||||
Headers: [][2]string{
|
||||
{"User-Agent", "Claude/1.0"},
|
||||
{"anthropic-version", "2023-06-01"},
|
||||
{"Custom-Header", "custom-value"},
|
||||
},
|
||||
},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-api123", false)
|
||||
|
||||
if got := req.Header.Get("User-Agent"); got != "Claude/1.0" {
|
||||
t.Errorf("User-Agent = %q, want %q", got, "Claude/1.0")
|
||||
}
|
||||
if got := req.Header.Get("anthropic-version"); got != "2023-06-01" {
|
||||
t.Errorf("anthropic-version = %q, want %q", got, "2023-06-01")
|
||||
}
|
||||
if got := req.Header.Get("Custom-Header"); got != "custom-value" {
|
||||
t.Errorf("Custom-Header = %q, want %q", got, "custom-value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_ProfileAuthHeadersRemoved(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: &SniffedProfile{
|
||||
Headers: [][2]string{
|
||||
{"Authorization", "Bearer old-token"},
|
||||
{"x-api-key", "old-api-key"},
|
||||
{"User-Agent", "test"},
|
||||
},
|
||||
},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-api-new", false)
|
||||
|
||||
// Old auth headers from profile should be removed
|
||||
if got := req.Header.Get("Authorization"); got != "" {
|
||||
t.Errorf("Authorization should be empty for non-OAuth, got %q", got)
|
||||
}
|
||||
// New auth should be set via x-api-key
|
||||
if got := req.Header.Get("x-api-key"); got != "sk-ant-api-new" {
|
||||
t.Errorf("x-api-key = %q, want %q", got, "sk-ant-api-new")
|
||||
}
|
||||
// User-Agent from profile should remain
|
||||
if got := req.Header.Get("User-Agent"); got != "test" {
|
||||
t.Errorf("User-Agent = %q, want %q", got, "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_ProfileAuthHeadersRemovedForOAuth(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: &SniffedProfile{
|
||||
Headers: [][2]string{
|
||||
{"Authorization", "Bearer old-token"},
|
||||
{"x-api-key", "old-api-key"},
|
||||
},
|
||||
},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-oat-new", false)
|
||||
|
||||
// Old x-api-key removed
|
||||
if got := req.Header.Get("x-api-key"); got != "" {
|
||||
t.Errorf("x-api-key should be empty for OAuth, got %q", got)
|
||||
}
|
||||
// New auth set via Authorization
|
||||
if got := req.Header.Get("Authorization"); got != "Bearer sk-ant-oat-new" {
|
||||
t.Errorf("Authorization = %q, want %q", got, "Bearer sk-ant-oat-new")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_AcceptEncoding_AlwaysIdentity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
streaming bool
|
||||
}{
|
||||
{"non-stream", false},
|
||||
{"stream", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
uc := &UpstreamClient{sessionID: "s", profile: nil}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
uc.applyHeaders(req, "token", tt.streaming)
|
||||
|
||||
if got := req.Header.Get("Accept-Encoding"); got != "identity" {
|
||||
t.Errorf("Accept-Encoding = %q, want %q", got, "identity")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_UniqueRequestIDs(t *testing.T) {
|
||||
uc := &UpstreamClient{sessionID: "s", profile: nil}
|
||||
|
||||
req1, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
uc.applyHeaders(req1, "tok", false)
|
||||
|
||||
req2, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
uc.applyHeaders(req2, "tok", false)
|
||||
|
||||
id1 := req1.Header.Get("x-client-request-id")
|
||||
id2 := req2.Header.Get("x-client-request-id")
|
||||
if id1 == "" || id2 == "" {
|
||||
t.Fatal("expected non-empty request IDs")
|
||||
}
|
||||
if id1 == id2 {
|
||||
t.Errorf("expected unique request IDs, both got %q", id1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_NonOAuth_NoAnthroPicBetaSet(t *testing.T) {
|
||||
// Non-OAuth tokens should NOT set anthropic-beta oauth flag
|
||||
uc := &UpstreamClient{sessionID: "s", profile: nil}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
uc.applyHeaders(req, "sk-ant-api123", false)
|
||||
|
||||
beta := req.Header.Get("anthropic-beta")
|
||||
if strings.Contains(beta, "oauth-2025-04-20") {
|
||||
t.Errorf("non-OAuth token should not have oauth beta flag, got %q", beta)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_OAuthToken_FreshBeta(t *testing.T) {
|
||||
// No profile, no existing beta — should set fresh
|
||||
uc := &UpstreamClient{sessionID: "s", profile: nil}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
uc.applyHeaders(req, "sk-ant-oat-tok", false)
|
||||
|
||||
if got := req.Header.Get("anthropic-beta"); got != "oauth-2025-04-20" {
|
||||
t.Errorf("anthropic-beta = %q, want %q", got, "oauth-2025-04-20")
|
||||
}
|
||||
}
|
||||
@@ -1,166 +0,0 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Window holds per-window usage state.
|
||||
type Window struct {
|
||||
Utilization float64 // 0-100 from API
|
||||
ResetsAt time.Time // when window resets
|
||||
}
|
||||
|
||||
// Snapshot is a read-only copy of a Window's state.
|
||||
type Snapshot struct {
|
||||
Utilization float64
|
||||
ResetsAt time.Time
|
||||
}
|
||||
|
||||
// Tracker polls /api/oauth/usage and tracks per-window utilization.
|
||||
type Tracker struct {
|
||||
tokenFn func() string
|
||||
mu sync.RWMutex
|
||||
fiveHour Window
|
||||
sevenDay Window
|
||||
sonnet Window
|
||||
extra ExtraUsage
|
||||
}
|
||||
|
||||
// NewTracker creates a tracker. tokenFn should return the current access token.
|
||||
func NewTracker(tokenFn func() string) *Tracker {
|
||||
return &Tracker{tokenFn: tokenFn}
|
||||
}
|
||||
|
||||
// Start begins the background poll loop.
|
||||
func (t *Tracker) Start(ctx context.Context) {
|
||||
go func() {
|
||||
t.poll(ctx)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(5 * time.Minute):
|
||||
t.poll(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// UpdateFromHeaders extracts rate limit data from /v1/messages response headers.
|
||||
func (t *Tracker) UpdateFromHeaders(h http.Header) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if v := h.Get("Anthropic-Ratelimit-Unified-5h-Utilization"); v != "" {
|
||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
t.fiveHour.Utilization = f * 100
|
||||
}
|
||||
}
|
||||
if v := h.Get("Anthropic-Ratelimit-Unified-5h-Reset"); v != "" {
|
||||
if ts, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
t.fiveHour.ResetsAt = time.Unix(ts, 0).UTC().Truncate(time.Minute)
|
||||
}
|
||||
}
|
||||
if v := h.Get("Anthropic-Ratelimit-Unified-7d-Utilization"); v != "" {
|
||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
t.sevenDay.Utilization = f * 100
|
||||
}
|
||||
}
|
||||
if v := h.Get("Anthropic-Ratelimit-Unified-7d-Reset"); v != "" {
|
||||
if ts, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
t.sevenDay.ResetsAt = time.Unix(ts, 0).UTC().Truncate(time.Minute)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FiveHour returns a snapshot of the 5-hour window.
|
||||
func (t *Tracker) FiveHour() Snapshot {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return Snapshot{Utilization: t.fiveHour.Utilization, ResetsAt: t.fiveHour.ResetsAt}
|
||||
}
|
||||
|
||||
// SevenDay returns a snapshot of the 7-day window.
|
||||
func (t *Tracker) SevenDay() Snapshot {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return Snapshot{Utilization: t.sevenDay.Utilization, ResetsAt: t.sevenDay.ResetsAt}
|
||||
}
|
||||
|
||||
// Sonnet returns a snapshot of the 7-day sonnet window.
|
||||
func (t *Tracker) Sonnet() Snapshot {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return Snapshot{Utilization: t.sonnet.Utilization, ResetsAt: t.sonnet.ResetsAt}
|
||||
}
|
||||
|
||||
// Extra returns the current extra usage state.
|
||||
func (t *Tracker) Extra() ExtraUsage {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return t.extra
|
||||
}
|
||||
|
||||
func (t *Tracker) poll(ctx context.Context) {
|
||||
token := t.tokenFn()
|
||||
if token == "" {
|
||||
return
|
||||
}
|
||||
|
||||
usage, err := fetchUsage(ctx, token)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("usage poll failed")
|
||||
return
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if usage.FiveHour != nil {
|
||||
t.updateWindow(&t.fiveHour, usage.FiveHour)
|
||||
}
|
||||
if usage.SevenDay != nil {
|
||||
t.updateWindow(&t.sevenDay, usage.SevenDay)
|
||||
}
|
||||
if usage.SevenDaySonnet != nil {
|
||||
t.updateWindow(&t.sonnet, usage.SevenDaySonnet)
|
||||
}
|
||||
if usage.ExtraUsage != nil {
|
||||
t.extra = *usage.ExtraUsage
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Float64("5h_util", t.fiveHour.Utilization).
|
||||
Time("5h_resets", t.fiveHour.ResetsAt).
|
||||
Float64("7d_util", t.sevenDay.Utilization).
|
||||
Time("7d_resets", t.sevenDay.ResetsAt).
|
||||
Msg("usage poll")
|
||||
|
||||
if t.fiveHour.Utilization > 80 {
|
||||
log.Warn().Float64("utilization", t.fiveHour.Utilization).Time("resets_at", t.fiveHour.ResetsAt).Msg("5h window utilization high")
|
||||
}
|
||||
if t.sevenDay.Utilization > 80 {
|
||||
log.Warn().Float64("utilization", t.sevenDay.Utilization).Time("resets_at", t.sevenDay.ResetsAt).Msg("7d window utilization high")
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tracker) updateWindow(w *Window, rl *RateLimit) {
|
||||
if rl.Utilization != nil {
|
||||
w.Utilization = *rl.Utilization
|
||||
}
|
||||
if rl.ResetsAt != nil {
|
||||
parsed, err := time.Parse(time.RFC3339Nano, *rl.ResetsAt)
|
||||
if err != nil {
|
||||
parsed, err = time.Parse(time.RFC3339, *rl.ResetsAt)
|
||||
}
|
||||
if err == nil {
|
||||
w.ResetsAt = parsed.UTC().Truncate(time.Minute)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,278 +0,0 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewTracker(t *testing.T) {
|
||||
called := false
|
||||
tr := NewTracker(func() string {
|
||||
called = true
|
||||
return "tok"
|
||||
})
|
||||
if tr == nil {
|
||||
t.Fatal("NewTracker returned nil")
|
||||
}
|
||||
// tokenFn stored but not called during construction
|
||||
if called {
|
||||
t.Error("tokenFn should not be called by NewTracker")
|
||||
}
|
||||
// Invoke to verify it's wired
|
||||
if got := tr.tokenFn(); got != "tok" {
|
||||
t.Errorf("tokenFn() = %q, want tok", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFromHeaders_Full(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
h := http.Header{}
|
||||
h.Set("Anthropic-Ratelimit-Unified-5h-Utilization", "0.42")
|
||||
h.Set("Anthropic-Ratelimit-Unified-5h-Reset", "1700000000")
|
||||
h.Set("Anthropic-Ratelimit-Unified-7d-Utilization", "0.75")
|
||||
h.Set("Anthropic-Ratelimit-Unified-7d-Reset", "1700100000")
|
||||
|
||||
tr.UpdateFromHeaders(h)
|
||||
|
||||
fh := tr.FiveHour()
|
||||
if fh.Utilization != 42.0 {
|
||||
t.Errorf("FiveHour.Utilization = %f, want 42.0", fh.Utilization)
|
||||
}
|
||||
wantReset5h := time.Unix(1700000000, 0).UTC().Truncate(time.Minute)
|
||||
if !fh.ResetsAt.Equal(wantReset5h) {
|
||||
t.Errorf("FiveHour.ResetsAt = %v, want %v", fh.ResetsAt, wantReset5h)
|
||||
}
|
||||
|
||||
sd := tr.SevenDay()
|
||||
if sd.Utilization != 75.0 {
|
||||
t.Errorf("SevenDay.Utilization = %f, want 75.0", sd.Utilization)
|
||||
}
|
||||
wantReset7d := time.Unix(1700100000, 0).UTC().Truncate(time.Minute)
|
||||
if !sd.ResetsAt.Equal(wantReset7d) {
|
||||
t.Errorf("SevenDay.ResetsAt = %v, want %v", sd.ResetsAt, wantReset7d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFromHeaders_Partial(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
// Only set 5h utilization, no reset, no 7d
|
||||
h := http.Header{}
|
||||
h.Set("Anthropic-Ratelimit-Unified-5h-Utilization", "0.33")
|
||||
tr.UpdateFromHeaders(h)
|
||||
|
||||
fh := tr.FiveHour()
|
||||
if fh.Utilization != 33.0 {
|
||||
t.Errorf("FiveHour.Utilization = %f, want 33.0", fh.Utilization)
|
||||
}
|
||||
if !fh.ResetsAt.IsZero() {
|
||||
t.Errorf("FiveHour.ResetsAt should be zero, got %v", fh.ResetsAt)
|
||||
}
|
||||
|
||||
sd := tr.SevenDay()
|
||||
if sd.Utilization != 0 {
|
||||
t.Errorf("SevenDay.Utilization = %f, want 0", sd.Utilization)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFromHeaders_Missing(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
// Pre-set some state
|
||||
tr.mu.Lock()
|
||||
tr.fiveHour.Utilization = 50.0
|
||||
tr.mu.Unlock()
|
||||
|
||||
// Update with empty headers — should not change state
|
||||
tr.UpdateFromHeaders(http.Header{})
|
||||
|
||||
fh := tr.FiveHour()
|
||||
if fh.Utilization != 50.0 {
|
||||
t.Errorf("FiveHour.Utilization = %f, want 50.0 (unchanged)", fh.Utilization)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFromHeaders_InvalidValues(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
h := http.Header{}
|
||||
h.Set("Anthropic-Ratelimit-Unified-5h-Utilization", "not-a-number")
|
||||
h.Set("Anthropic-Ratelimit-Unified-5h-Reset", "not-a-timestamp")
|
||||
|
||||
tr.UpdateFromHeaders(h)
|
||||
|
||||
fh := tr.FiveHour()
|
||||
if fh.Utilization != 0 {
|
||||
t.Errorf("Utilization should stay 0 for invalid input, got %f", fh.Utilization)
|
||||
}
|
||||
if !fh.ResetsAt.IsZero() {
|
||||
t.Errorf("ResetsAt should stay zero for invalid input, got %v", fh.ResetsAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSonnet_Snapshot(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
// Sonnet is only set via poll/updateWindow, not UpdateFromHeaders
|
||||
// Verify it starts at zero
|
||||
s := tr.Sonnet()
|
||||
if s.Utilization != 0 {
|
||||
t.Errorf("Sonnet.Utilization = %f, want 0", s.Utilization)
|
||||
}
|
||||
if !s.ResetsAt.IsZero() {
|
||||
t.Errorf("Sonnet.ResetsAt should be zero, got %v", s.ResetsAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtra_Default(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
extra := tr.Extra()
|
||||
if extra.IsEnabled {
|
||||
t.Error("Extra.IsEnabled should be false by default")
|
||||
}
|
||||
if extra.MonthlyLimit != nil {
|
||||
t.Error("Extra.MonthlyLimit should be nil by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateWindow(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
util *float64
|
||||
resetsAt *string
|
||||
wantUtil float64
|
||||
wantResetOK bool
|
||||
}{
|
||||
{
|
||||
name: "both fields",
|
||||
util: float64Ptr(65.5),
|
||||
resetsAt: stringPtr("2024-01-15T10:30:45Z"),
|
||||
wantUtil: 65.5,
|
||||
wantResetOK: true,
|
||||
},
|
||||
{
|
||||
name: "utilization only",
|
||||
util: float64Ptr(30.0),
|
||||
resetsAt: nil,
|
||||
wantUtil: 30.0,
|
||||
wantResetOK: false,
|
||||
},
|
||||
{
|
||||
name: "reset only (RFC3339Nano)",
|
||||
util: nil,
|
||||
resetsAt: stringPtr("2024-06-01T12:00:00.123456789Z"),
|
||||
wantUtil: 0,
|
||||
wantResetOK: true,
|
||||
},
|
||||
{
|
||||
name: "nil both",
|
||||
util: nil,
|
||||
resetsAt: nil,
|
||||
wantUtil: 0,
|
||||
wantResetOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := &Window{}
|
||||
rl := &RateLimit{
|
||||
Utilization: tt.util,
|
||||
ResetsAt: tt.resetsAt,
|
||||
}
|
||||
tr.updateWindow(w, rl)
|
||||
|
||||
if w.Utilization != tt.wantUtil {
|
||||
t.Errorf("Utilization = %f, want %f", w.Utilization, tt.wantUtil)
|
||||
}
|
||||
if tt.wantResetOK {
|
||||
if w.ResetsAt.IsZero() {
|
||||
t.Error("ResetsAt should be set")
|
||||
}
|
||||
// Verify truncation to minute
|
||||
if w.ResetsAt.Second() != 0 || w.ResetsAt.Nanosecond() != 0 {
|
||||
t.Errorf("ResetsAt not truncated to minute: %v", w.ResetsAt)
|
||||
}
|
||||
if w.ResetsAt.Location() != time.UTC {
|
||||
t.Errorf("ResetsAt not in UTC: %v", w.ResetsAt.Location())
|
||||
}
|
||||
} else if tt.resetsAt == nil {
|
||||
if !w.ResetsAt.IsZero() {
|
||||
t.Errorf("ResetsAt should be zero when input is nil, got %v", w.ResetsAt)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateWindow_InvalidTime(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
w := &Window{}
|
||||
bad := "not-a-time"
|
||||
rl := &RateLimit{ResetsAt: &bad}
|
||||
tr.updateWindow(w, rl)
|
||||
if !w.ResetsAt.IsZero() {
|
||||
t.Errorf("ResetsAt should stay zero for invalid time, got %v", w.ResetsAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoll_SetsStateFromUsageResponse(t *testing.T) {
|
||||
// White-box: directly set fields that poll would set after fetchUsage
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
// Simulate what poll does after fetching usage
|
||||
tr.mu.Lock()
|
||||
usage := &UsageResponse{
|
||||
FiveHour: &RateLimit{Utilization: float64Ptr(55.5), ResetsAt: stringPtr("2024-03-01T08:00:00Z")},
|
||||
SevenDay: &RateLimit{Utilization: float64Ptr(22.3), ResetsAt: stringPtr("2024-03-07T00:00:00Z")},
|
||||
SevenDaySonnet: &RateLimit{Utilization: float64Ptr(10.0), ResetsAt: stringPtr("2024-03-07T00:00:00Z")},
|
||||
ExtraUsage: &ExtraUsage{IsEnabled: true, MonthlyLimit: float64Ptr(100.0), UsedCredits: float64Ptr(42.5)},
|
||||
}
|
||||
if usage.FiveHour != nil {
|
||||
tr.updateWindow(&tr.fiveHour, usage.FiveHour)
|
||||
}
|
||||
if usage.SevenDay != nil {
|
||||
tr.updateWindow(&tr.sevenDay, usage.SevenDay)
|
||||
}
|
||||
if usage.SevenDaySonnet != nil {
|
||||
tr.updateWindow(&tr.sonnet, usage.SevenDaySonnet)
|
||||
}
|
||||
if usage.ExtraUsage != nil {
|
||||
tr.extra = *usage.ExtraUsage
|
||||
}
|
||||
tr.mu.Unlock()
|
||||
|
||||
fh := tr.FiveHour()
|
||||
if fh.Utilization != 55.5 {
|
||||
t.Errorf("FiveHour.Utilization = %f, want 55.5", fh.Utilization)
|
||||
}
|
||||
|
||||
sd := tr.SevenDay()
|
||||
if sd.Utilization != 22.3 {
|
||||
t.Errorf("SevenDay.Utilization = %f, want 22.3", sd.Utilization)
|
||||
}
|
||||
|
||||
sn := tr.Sonnet()
|
||||
if sn.Utilization != 10.0 {
|
||||
t.Errorf("Sonnet.Utilization = %f, want 10.0", sn.Utilization)
|
||||
}
|
||||
|
||||
extra := tr.Extra()
|
||||
if !extra.IsEnabled {
|
||||
t.Error("Extra.IsEnabled = false, want true")
|
||||
}
|
||||
if extra.MonthlyLimit == nil || *extra.MonthlyLimit != 100.0 {
|
||||
t.Errorf("Extra.MonthlyLimit = %v, want 100.0", extra.MonthlyLimit)
|
||||
}
|
||||
if extra.UsedCredits == nil || *extra.UsedCredits != 42.5 {
|
||||
t.Errorf("Extra.UsedCredits = %v, want 42.5", extra.UsedCredits)
|
||||
}
|
||||
}
|
||||
|
||||
func float64Ptr(f float64) *float64 { return &f }
|
||||
func stringPtr(s string) *string { return &s }
|
||||
@@ -1,67 +0,0 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/transport"
|
||||
"github.com/fujin/anthropic-proxy/internal/version"
|
||||
)
|
||||
|
||||
var usageClient = transport.NewHTTPClient(10 * time.Second)
|
||||
|
||||
const usageURL = "https://api.anthropic.com/api/oauth/usage"
|
||||
|
||||
type RateLimit struct {
|
||||
Utilization *float64 `json:"utilization"` // 0-100
|
||||
ResetsAt *string `json:"resets_at"` // ISO 8601
|
||||
}
|
||||
|
||||
type ExtraUsage struct {
|
||||
IsEnabled bool `json:"is_enabled"`
|
||||
MonthlyLimit *float64 `json:"monthly_limit"`
|
||||
UsedCredits *float64 `json:"used_credits"`
|
||||
Utilization *float64 `json:"utilization"`
|
||||
}
|
||||
|
||||
type UsageResponse struct {
|
||||
FiveHour *RateLimit `json:"five_hour"`
|
||||
SevenDay *RateLimit `json:"seven_day"`
|
||||
SevenDaySonnet *RateLimit `json:"seven_day_sonnet"`
|
||||
ExtraUsage *ExtraUsage `json:"extra_usage"`
|
||||
}
|
||||
|
||||
func fetchUsage(ctx context.Context, token string) (*UsageResponse, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, usageURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
req.Header.Set("User-Agent", "claude-cli/"+version.ClaudeCodeFallback)
|
||||
|
||||
resp, err := usageClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("usage returned %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var usage UsageResponse
|
||||
if err := json.Unmarshal(body, &usage); err != nil {
|
||||
return nil, fmt.Errorf("decode: %w", err)
|
||||
}
|
||||
return &usage, nil
|
||||
}
|
||||
@@ -1,241 +0,0 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUsageResponse_FullJSON(t *testing.T) {
|
||||
raw := `{
|
||||
"five_hour": {"utilization": 42.5, "resets_at": "2024-01-15T10:30:00Z"},
|
||||
"seven_day": {"utilization": 75.0, "resets_at": "2024-01-20T00:00:00Z"},
|
||||
"seven_day_sonnet": {"utilization": 10.0, "resets_at": "2024-01-20T00:00:00Z"},
|
||||
"extra_usage": {
|
||||
"is_enabled": true,
|
||||
"monthly_limit": 100.0,
|
||||
"used_credits": 42.5,
|
||||
"utilization": 42.5
|
||||
}
|
||||
}`
|
||||
|
||||
var resp UsageResponse
|
||||
if err := json.Unmarshal([]byte(raw), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if resp.FiveHour == nil {
|
||||
t.Fatal("FiveHour is nil")
|
||||
}
|
||||
if resp.FiveHour.Utilization == nil || *resp.FiveHour.Utilization != 42.5 {
|
||||
t.Errorf("FiveHour.Utilization = %v, want 42.5", resp.FiveHour.Utilization)
|
||||
}
|
||||
if resp.FiveHour.ResetsAt == nil || *resp.FiveHour.ResetsAt != "2024-01-15T10:30:00Z" {
|
||||
t.Errorf("FiveHour.ResetsAt = %v", resp.FiveHour.ResetsAt)
|
||||
}
|
||||
|
||||
if resp.SevenDay == nil {
|
||||
t.Fatal("SevenDay is nil")
|
||||
}
|
||||
if resp.SevenDay.Utilization == nil || *resp.SevenDay.Utilization != 75.0 {
|
||||
t.Errorf("SevenDay.Utilization = %v, want 75.0", resp.SevenDay.Utilization)
|
||||
}
|
||||
|
||||
if resp.SevenDaySonnet == nil {
|
||||
t.Fatal("SevenDaySonnet is nil")
|
||||
}
|
||||
if resp.SevenDaySonnet.Utilization == nil || *resp.SevenDaySonnet.Utilization != 10.0 {
|
||||
t.Errorf("SevenDaySonnet.Utilization = %v", resp.SevenDaySonnet.Utilization)
|
||||
}
|
||||
|
||||
if resp.ExtraUsage == nil {
|
||||
t.Fatal("ExtraUsage is nil")
|
||||
}
|
||||
if !resp.ExtraUsage.IsEnabled {
|
||||
t.Error("ExtraUsage.IsEnabled = false, want true")
|
||||
}
|
||||
if resp.ExtraUsage.MonthlyLimit == nil || *resp.ExtraUsage.MonthlyLimit != 100.0 {
|
||||
t.Errorf("ExtraUsage.MonthlyLimit = %v, want 100.0", resp.ExtraUsage.MonthlyLimit)
|
||||
}
|
||||
if resp.ExtraUsage.UsedCredits == nil || *resp.ExtraUsage.UsedCredits != 42.5 {
|
||||
t.Errorf("ExtraUsage.UsedCredits = %v, want 42.5", resp.ExtraUsage.UsedCredits)
|
||||
}
|
||||
if resp.ExtraUsage.Utilization == nil || *resp.ExtraUsage.Utilization != 42.5 {
|
||||
t.Errorf("ExtraUsage.Utilization = %v, want 42.5", resp.ExtraUsage.Utilization)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageResponse_PartialJSON(t *testing.T) {
|
||||
raw := `{"five_hour": {"utilization": 10.0}}`
|
||||
|
||||
var resp UsageResponse
|
||||
if err := json.Unmarshal([]byte(raw), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if resp.FiveHour == nil {
|
||||
t.Fatal("FiveHour is nil")
|
||||
}
|
||||
if resp.FiveHour.Utilization == nil || *resp.FiveHour.Utilization != 10.0 {
|
||||
t.Errorf("FiveHour.Utilization = %v, want 10.0", resp.FiveHour.Utilization)
|
||||
}
|
||||
if resp.FiveHour.ResetsAt != nil {
|
||||
t.Errorf("FiveHour.ResetsAt should be nil, got %v", resp.FiveHour.ResetsAt)
|
||||
}
|
||||
if resp.SevenDay != nil {
|
||||
t.Errorf("SevenDay should be nil, got %v", resp.SevenDay)
|
||||
}
|
||||
if resp.SevenDaySonnet != nil {
|
||||
t.Errorf("SevenDaySonnet should be nil, got %v", resp.SevenDaySonnet)
|
||||
}
|
||||
if resp.ExtraUsage != nil {
|
||||
t.Errorf("ExtraUsage should be nil, got %v", resp.ExtraUsage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageResponse_EmptyJSON(t *testing.T) {
|
||||
var resp UsageResponse
|
||||
if err := json.Unmarshal([]byte(`{}`), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if resp.FiveHour != nil || resp.SevenDay != nil || resp.SevenDaySonnet != nil || resp.ExtraUsage != nil {
|
||||
t.Error("all fields should be nil for empty JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchUsage_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify request headers
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer test-token" {
|
||||
t.Errorf("Authorization = %q, want 'Bearer test-token'", got)
|
||||
}
|
||||
if got := r.Header.Get("Content-Type"); got != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want application/json", got)
|
||||
}
|
||||
if got := r.Header.Get("anthropic-beta"); got != "oauth-2025-04-20" {
|
||||
t.Errorf("anthropic-beta = %q, want oauth-2025-04-20", got)
|
||||
}
|
||||
if got := r.Header.Get("User-Agent"); got != "claude-cli/2.1.92" {
|
||||
t.Errorf("User-Agent = %q, want claude-cli/2.1.92", got)
|
||||
}
|
||||
if r.Method != http.MethodGet {
|
||||
t.Errorf("Method = %q, want GET", r.Method)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{
|
||||
"five_hour": {"utilization": 50.0, "resets_at": "2024-01-15T10:00:00Z"},
|
||||
"seven_day": {"utilization": 25.0, "resets_at": "2024-01-20T00:00:00Z"}
|
||||
}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// fetchUsage hardcodes usageURL, but we can test via the mock by temporarily
|
||||
// using http.DefaultClient's transport. Instead, we test the handler directly.
|
||||
// The httptest server validates our request expectations above.
|
||||
|
||||
// Make a real request to the test server to verify handler behavior
|
||||
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
req.Header.Set("User-Agent", "claude-cli/2.1.92")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var usage UsageResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&usage); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
|
||||
if usage.FiveHour == nil || *usage.FiveHour.Utilization != 50.0 {
|
||||
t.Errorf("FiveHour.Utilization = %v, want 50.0", usage.FiveHour)
|
||||
}
|
||||
if usage.SevenDay == nil || *usage.SevenDay.Utilization != 25.0 {
|
||||
t.Errorf("SevenDay.Utilization = %v, want 25.0", usage.SevenDay)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchUsage_Non200(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = w.Write([]byte(`{"error":"forbidden"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Simulate the error path: non-200 returns error with status and body
|
||||
resp, err := http.Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
t.Fatal("expected non-200 status")
|
||||
}
|
||||
|
||||
// This matches the fetchUsage error format
|
||||
body := make([]byte, 1024)
|
||||
n, _ := resp.Body.Read(body)
|
||||
bodyStr := string(body[:n])
|
||||
if !strings.Contains(bodyStr, "forbidden") {
|
||||
t.Errorf("body = %q, want it to contain 'forbidden'", bodyStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchUsage_MalformedJSON(t *testing.T) {
|
||||
raw := `{not valid json`
|
||||
var resp UsageResponse
|
||||
err := json.Unmarshal([]byte(raw), &resp)
|
||||
if err == nil {
|
||||
t.Fatal("expected decode error for malformed JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimit_NilFields(t *testing.T) {
|
||||
raw := `{}`
|
||||
var rl RateLimit
|
||||
if err := json.Unmarshal([]byte(raw), &rl); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if rl.Utilization != nil {
|
||||
t.Errorf("Utilization should be nil, got %v", rl.Utilization)
|
||||
}
|
||||
if rl.ResetsAt != nil {
|
||||
t.Errorf("ResetsAt should be nil, got %v", rl.ResetsAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtraUsage_JSON(t *testing.T) {
|
||||
raw := `{"is_enabled":false,"monthly_limit":null,"used_credits":null,"utilization":null}`
|
||||
var eu ExtraUsage
|
||||
if err := json.Unmarshal([]byte(raw), &eu); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if eu.IsEnabled {
|
||||
t.Error("IsEnabled should be false")
|
||||
}
|
||||
if eu.MonthlyLimit != nil {
|
||||
t.Error("MonthlyLimit should be nil")
|
||||
}
|
||||
if eu.UsedCredits != nil {
|
||||
t.Error("UsedCredits should be nil")
|
||||
}
|
||||
if eu.Utilization != nil {
|
||||
t.Error("Utilization should be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageURL_Constant(t *testing.T) {
|
||||
if usageURL != "https://api.anthropic.com/api/oauth/usage" {
|
||||
t.Errorf("usageURL = %q, want https://api.anthropic.com/api/oauth/usage", usageURL)
|
||||
}
|
||||
}
|
||||
@@ -14,8 +14,6 @@ import (
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
"github.com/fujin/anthropic-proxy/internal/logging"
|
||||
"github.com/fujin/anthropic-proxy/internal/proxy"
|
||||
"github.com/fujin/anthropic-proxy/internal/ratelimit"
|
||||
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
@@ -26,7 +24,7 @@ type Server struct {
|
||||
apiKeys atomic.Pointer[map[string]struct{}]
|
||||
}
|
||||
|
||||
func New(cfg *config.Config, pool *auth.Pool, profile *proxy.SniffedProfile, tracker *ratelimit.Tracker, metricsHandler http.Handler) *Server {
|
||||
func New(cfg *config.Config, pool *auth.Pool, profile *proxy.SniffedProfile) *Server {
|
||||
s := &Server{configPath: "config.yaml"}
|
||||
|
||||
san := proxy.NewSanitizer(cfg.Sanitize)
|
||||
@@ -39,22 +37,15 @@ func New(cfg *config.Config, pool *auth.Pool, profile *proxy.SniffedProfile, tra
|
||||
engine := gin.New()
|
||||
engine.Use(gin.Recovery())
|
||||
engine.Use(corsMiddleware())
|
||||
if cfg.Telemetry.Export.Enabled() {
|
||||
engine.Use(otelgin.Middleware(cfg.Telemetry.ServiceName))
|
||||
}
|
||||
engine.Use(s.authMiddleware())
|
||||
engine.Use(logging.GinRequestLogger())
|
||||
|
||||
handler := proxy.HandleMessages(pool, profile, func() *proxy.Sanitizer {
|
||||
return s.sanitizer.Load()
|
||||
}, tracker)
|
||||
})
|
||||
engine.POST("/v1/messages", handler)
|
||||
engine.POST("/messages", handler)
|
||||
|
||||
if metricsHandler != nil {
|
||||
engine.GET("/metrics", gin.WrapH(metricsHandler))
|
||||
}
|
||||
|
||||
engine.POST("/reload", s.handleReload())
|
||||
engine.POST("/debug/refresh", handleDebugRefresh(pool))
|
||||
engine.GET("/healthz", func(c *gin.Context) {
|
||||
@@ -138,16 +129,10 @@ func corsMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// authBypassPaths lists endpoints that do not require API key authentication.
|
||||
var authBypassPaths = map[string]bool{
|
||||
"/healthz": true,
|
||||
"/reload": true,
|
||||
"/metrics": true,
|
||||
}
|
||||
|
||||
func (s *Server) authMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if authBypassPaths[c.Request.URL.Path] {
|
||||
path := c.Request.URL.Path
|
||||
if path == "/healthz" || path == "/reload" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,529 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// --- makeKeySet ---
|
||||
|
||||
func TestMakeKeySet(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keys []string
|
||||
wantN int
|
||||
lookup string
|
||||
found bool
|
||||
}{
|
||||
{
|
||||
name: "nil slice returns empty map",
|
||||
keys: nil,
|
||||
wantN: 0,
|
||||
},
|
||||
{
|
||||
name: "empty slice returns empty map",
|
||||
keys: []string{},
|
||||
wantN: 0,
|
||||
},
|
||||
{
|
||||
name: "single key",
|
||||
keys: []string{"key1"},
|
||||
wantN: 1,
|
||||
lookup: "key1",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "multiple keys",
|
||||
keys: []string{"a", "b", "c"},
|
||||
wantN: 3,
|
||||
lookup: "b",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "missing key not found",
|
||||
keys: []string{"a", "b"},
|
||||
wantN: 2,
|
||||
lookup: "c",
|
||||
found: false,
|
||||
},
|
||||
{
|
||||
name: "duplicate keys deduped",
|
||||
keys: []string{"x", "x", "x"},
|
||||
wantN: 1,
|
||||
lookup: "x",
|
||||
found: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := makeKeySet(tt.keys)
|
||||
if len(got) != tt.wantN {
|
||||
t.Errorf("len(makeKeySet) = %d, want %d", len(got), tt.wantN)
|
||||
}
|
||||
if tt.lookup != "" {
|
||||
_, ok := got[tt.lookup]
|
||||
if ok != tt.found {
|
||||
t.Errorf("keySet[%q] found=%v, want %v", tt.lookup, ok, tt.found)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- corsMiddleware ---
|
||||
|
||||
func TestCorsMiddleware_SetsHeaders(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
handler := corsMiddleware()
|
||||
handler(c)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
|
||||
t.Errorf("Access-Control-Allow-Origin = %q, want %q", got, "*")
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, PUT, DELETE, OPTIONS" {
|
||||
t.Errorf("Access-Control-Allow-Methods = %q, want %q", got, "GET, POST, PUT, DELETE, OPTIONS")
|
||||
}
|
||||
allowHeaders := w.Header().Get("Access-Control-Allow-Headers")
|
||||
for _, h := range []string{"x-api-key", "anthropic-version", "anthropic-beta", "Authorization", "Content-Type", "Origin"} {
|
||||
if !containsSubstring(allowHeaders, h) {
|
||||
t.Errorf("Access-Control-Allow-Headers %q missing %q", allowHeaders, h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsMiddleware_OptionsReturns204(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodOptions, "/v1/messages", nil)
|
||||
|
||||
handler := corsMiddleware()
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Errorf("OPTIONS status = %d, want %d", w.Code, http.StatusNoContent)
|
||||
}
|
||||
if !c.IsAborted() {
|
||||
t.Error("expected context to be aborted on OPTIONS")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsMiddleware_NonOptionsDoesNotAbort(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
handler := corsMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("POST request should not be aborted")
|
||||
}
|
||||
}
|
||||
|
||||
// --- authMiddleware ---
|
||||
|
||||
func newServerWithKeys(keys []string) *Server {
|
||||
s := &Server{}
|
||||
keySet := makeKeySet(keys)
|
||||
s.apiKeys.Store(&keySet)
|
||||
return s
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_BypassPaths(t *testing.T) {
|
||||
paths := []string{"/healthz", "/reload", "/metrics"}
|
||||
s := newServerWithKeys(nil) // no keys — would reject if auth checked
|
||||
|
||||
for _, path := range paths {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, path, nil)
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Errorf("path %q should bypass auth but was aborted", path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_MissingToken_401(t *testing.T) {
|
||||
s := newServerWithKeys([]string{"valid-key"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusUnauthorized)
|
||||
}
|
||||
if !c.IsAborted() {
|
||||
t.Error("expected aborted on missing token")
|
||||
}
|
||||
|
||||
var body map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if body["error"] != "missing authentication" {
|
||||
t.Errorf("error = %q, want %q", body["error"], "missing authentication")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_InvalidKey_403(t *testing.T) {
|
||||
s := newServerWithKeys([]string{"valid-key"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("x-api-key", "wrong-key")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
if !c.IsAborted() {
|
||||
t.Error("expected aborted on invalid key")
|
||||
}
|
||||
|
||||
var body map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if body["error"] != "invalid api key" {
|
||||
t.Errorf("error = %q, want %q", body["error"], "invalid api key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_ValidKey_XApiKey(t *testing.T) {
|
||||
s := newServerWithKeys([]string{"valid-key"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("x-api-key", "valid-key")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("valid key should not abort")
|
||||
}
|
||||
if w.Code == http.StatusUnauthorized || w.Code == http.StatusForbidden {
|
||||
t.Errorf("unexpected status %d for valid key", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_ValidKey_BearerAuth(t *testing.T) {
|
||||
s := newServerWithKeys([]string{"my-token"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("Authorization", "Bearer my-token")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("valid Bearer token should not abort")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_BearerPrefix_Stripped(t *testing.T) {
|
||||
// The token is "my-token", sent as "Bearer my-token". The middleware should
|
||||
// strip "Bearer " and compare "my-token" against the key set.
|
||||
s := newServerWithKeys([]string{"my-token"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("Authorization", "Bearer my-token")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("expected auth to pass with Bearer-prefixed valid key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_AuthorizationWithoutBearer(t *testing.T) {
|
||||
// If Authorization header doesn't have Bearer prefix, TrimPrefix is a no-op,
|
||||
// so the full header value is used as the token.
|
||||
s := newServerWithKeys([]string{"raw-token-value"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("Authorization", "raw-token-value")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("raw Authorization value matching a key should pass")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_XApiKey_FallbackWhenNoAuthHeader(t *testing.T) {
|
||||
// If Authorization is empty, x-api-key is checked.
|
||||
s := newServerWithKeys([]string{"fallback-key"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("x-api-key", "fallback-key")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("x-api-key fallback should pass")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_AuthorizationPreferredOverXApiKey(t *testing.T) {
|
||||
// Both headers set; Authorization takes precedence.
|
||||
s := newServerWithKeys([]string{"auth-key"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("Authorization", "Bearer auth-key")
|
||||
c.Request.Header.Set("x-api-key", "wrong-key")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("Authorization should take precedence over x-api-key")
|
||||
}
|
||||
}
|
||||
|
||||
// --- handleReload ---
|
||||
|
||||
func TestHandleReload_Success(t *testing.T) {
|
||||
// Create a temp config file
|
||||
tmpFile, err := os.CreateTemp("", "config-*.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp file: %v", err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
configContent := `
|
||||
port: 9999
|
||||
api_keys:
|
||||
- reloaded-key-1
|
||||
- reloaded-key-2
|
||||
sanitize:
|
||||
tools:
|
||||
- from: old_tool
|
||||
to: new_tool
|
||||
system:
|
||||
- match: foo
|
||||
replace: bar
|
||||
body:
|
||||
- match: baz
|
||||
replace: qux
|
||||
`
|
||||
if _, err := tmpFile.WriteString(configContent); err != nil {
|
||||
t.Fatalf("failed to write config: %v", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
s := &Server{configPath: tmpFile.Name()}
|
||||
// Initialize with empty values
|
||||
emptyKeys := makeKeySet(nil)
|
||||
s.apiKeys.Store(&emptyKeys)
|
||||
|
||||
emptySan := &atomic.Pointer[interface{}]{}
|
||||
_ = emptySan // just to show we're aware
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/reload", nil)
|
||||
|
||||
handler := s.handleReload()
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if resp["status"] != "reloaded" {
|
||||
t.Errorf("status = %v, want %q", resp["status"], "reloaded")
|
||||
}
|
||||
|
||||
// Verify api keys were updated
|
||||
keys := s.apiKeys.Load()
|
||||
if _, ok := (*keys)["reloaded-key-1"]; !ok {
|
||||
t.Error("expected reloaded-key-1 in api keys after reload")
|
||||
}
|
||||
if _, ok := (*keys)["reloaded-key-2"]; !ok {
|
||||
t.Error("expected reloaded-key-2 in api keys after reload")
|
||||
}
|
||||
if len(*keys) != 2 {
|
||||
t.Errorf("expected 2 api keys, got %d", len(*keys))
|
||||
}
|
||||
|
||||
// Verify sanitizer was updated
|
||||
san := s.sanitizer.Load()
|
||||
if san == nil {
|
||||
t.Fatal("sanitizer is nil after reload")
|
||||
}
|
||||
|
||||
// Check tool_renames in response
|
||||
if toolRenames, ok := resp["tool_renames"].(float64); !ok || int(toolRenames) != 1 {
|
||||
t.Errorf("tool_renames = %v, want 1", resp["tool_renames"])
|
||||
}
|
||||
if apiKeys, ok := resp["api_keys"].(float64); !ok || int(apiKeys) != 2 {
|
||||
t.Errorf("api_keys = %v, want 2", resp["api_keys"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleReload_InvalidConfig(t *testing.T) {
|
||||
s := &Server{configPath: "/nonexistent/path/config.yaml"}
|
||||
emptyKeys := makeKeySet(nil)
|
||||
s.apiKeys.Store(&emptyKeys)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/reload", nil)
|
||||
|
||||
handler := s.handleReload()
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
if resp["error"] == "" {
|
||||
t.Error("expected non-empty error message")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Full route tests using httptest ---
|
||||
|
||||
func TestHealthzEndpoint(t *testing.T) {
|
||||
engine := gin.New()
|
||||
engine.Use(corsMiddleware())
|
||||
|
||||
s := newServerWithKeys(nil)
|
||||
engine.Use(s.authMiddleware())
|
||||
|
||||
engine.GET("/healthz", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if body["status"] != "ok" {
|
||||
t.Errorf("status = %q, want %q", body["status"], "ok")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_FullRoute_Rejected(t *testing.T) {
|
||||
engine := gin.New()
|
||||
s := newServerWithKeys([]string{"correct-key"})
|
||||
engine.Use(s.authMiddleware())
|
||||
engine.POST("/v1/messages", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
// No auth header
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_FullRoute_Accepted(t *testing.T) {
|
||||
engine := gin.New()
|
||||
s := newServerWithKeys([]string{"correct-key"})
|
||||
engine.Use(s.authMiddleware())
|
||||
engine.POST("/v1/messages", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
req.Header.Set("x-api-key", "correct-key")
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsMiddleware_FullRoute_OptionsRequest(t *testing.T) {
|
||||
engine := gin.New()
|
||||
engine.Use(corsMiddleware())
|
||||
|
||||
s := newServerWithKeys([]string{"key"})
|
||||
engine.Use(s.authMiddleware())
|
||||
|
||||
engine.POST("/v1/messages", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/v1/messages", nil)
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusNoContent)
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
|
||||
t.Errorf("ACAO = %q, want %q", got, "*")
|
||||
}
|
||||
}
|
||||
|
||||
// helper
|
||||
func containsSubstring(s, sub string) bool {
|
||||
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub))
|
||||
}
|
||||
|
||||
func containsStr(s, sub string) bool {
|
||||
for i := 0; i <= len(s)-len(sub); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
otellog "go.opentelemetry.io/otel/log"
|
||||
sdklog "go.opentelemetry.io/otel/sdk/log"
|
||||
)
|
||||
|
||||
// LogBridge implements io.Writer and forwards zerolog JSON lines to the
|
||||
// OTel LoggerProvider. It is used as an extra writer in zerolog's MultiWriter
|
||||
// so that logs go to both file and OTLP.
|
||||
type LogBridge struct {
|
||||
provider *sdklog.LoggerProvider
|
||||
}
|
||||
|
||||
func (b *LogBridge) Write(p []byte) (n int, err error) {
|
||||
var entry map[string]interface{}
|
||||
if err := json.Unmarshal(p, &entry); err != nil {
|
||||
return len(p), nil // skip malformed lines
|
||||
}
|
||||
|
||||
logger := b.provider.Logger("zerolog")
|
||||
|
||||
var rec otellog.Record
|
||||
rec.SetTimestamp(time.Now())
|
||||
|
||||
if msg, ok := entry["message"].(string); ok {
|
||||
rec.SetBody(otellog.StringValue(msg))
|
||||
}
|
||||
|
||||
if lvl, ok := entry["level"].(string); ok {
|
||||
rec.SetSeverity(mapSeverity(lvl))
|
||||
}
|
||||
|
||||
// Forward all fields as attributes
|
||||
attrs := make([]otellog.KeyValue, 0, len(entry))
|
||||
for k, v := range entry {
|
||||
if k == "message" || k == "level" || k == "time" {
|
||||
continue
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
attrs = append(attrs, otellog.String(k, val))
|
||||
case float64:
|
||||
attrs = append(attrs, otellog.Float64(k, val))
|
||||
case bool:
|
||||
attrs = append(attrs, otellog.Bool(k, val))
|
||||
default:
|
||||
b, _ := json.Marshal(val)
|
||||
attrs = append(attrs, otellog.String(k, string(b)))
|
||||
}
|
||||
}
|
||||
rec.AddAttributes(attrs...)
|
||||
|
||||
logger.Emit(context.Background(), rec)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func mapSeverity(level string) otellog.Severity {
|
||||
switch level {
|
||||
case "trace":
|
||||
return otellog.SeverityTrace
|
||||
case "debug":
|
||||
return otellog.SeverityDebug
|
||||
case "info":
|
||||
return otellog.SeverityInfo
|
||||
case "warn", "warning":
|
||||
return otellog.SeverityWarn
|
||||
case "error":
|
||||
return otellog.SeverityError
|
||||
case "fatal":
|
||||
return otellog.SeverityFatal
|
||||
case "panic":
|
||||
return otellog.SeverityFatal2
|
||||
default:
|
||||
return otellog.SeverityInfo
|
||||
}
|
||||
}
|
||||
@@ -1,178 +0,0 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
otellog "go.opentelemetry.io/otel/log"
|
||||
sdklog "go.opentelemetry.io/otel/sdk/log"
|
||||
)
|
||||
|
||||
func TestMapSeverity(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want otellog.Severity
|
||||
}{
|
||||
{"trace", otellog.SeverityTrace},
|
||||
{"debug", otellog.SeverityDebug},
|
||||
{"info", otellog.SeverityInfo},
|
||||
{"warn", otellog.SeverityWarn},
|
||||
{"warning", otellog.SeverityWarn},
|
||||
{"error", otellog.SeverityError},
|
||||
{"fatal", otellog.SeverityFatal},
|
||||
{"panic", otellog.SeverityFatal2},
|
||||
{"unknown", otellog.SeverityInfo},
|
||||
{"", otellog.SeverityInfo},
|
||||
{"INFO", otellog.SeverityInfo}, // uppercase falls to default
|
||||
{"PANIC", otellog.SeverityInfo}, // uppercase falls to default
|
||||
{"gibberish", otellog.SeverityInfo},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run("level_"+tc.input, func(t *testing.T) {
|
||||
got := mapSeverity(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("mapSeverity(%q) = %v, want %v", tc.input, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestBridge(t *testing.T) *LogBridge {
|
||||
t.Helper()
|
||||
provider := sdklog.NewLoggerProvider()
|
||||
t.Cleanup(func() {
|
||||
_ = provider.Shutdown(t.Context())
|
||||
})
|
||||
return &LogBridge{provider: provider}
|
||||
}
|
||||
|
||||
func TestLogBridgeWrite(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{} // will be marshaled to JSON; use string for raw input
|
||||
raw string // if non-empty, use this directly instead of marshaling input
|
||||
}{
|
||||
{
|
||||
name: "valid_json_with_message_level_and_extras",
|
||||
input: map[string]interface{}{
|
||||
"message": "request handled",
|
||||
"level": "info",
|
||||
"method": "GET",
|
||||
"status": float64(200),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "message_only_no_level",
|
||||
input: map[string]interface{}{
|
||||
"message": "hello world",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "level_only_no_message",
|
||||
input: map[string]interface{}{
|
||||
"level": "error",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty_json_object",
|
||||
input: map[string]interface{}{},
|
||||
},
|
||||
{
|
||||
name: "string_float64_bool_attributes",
|
||||
input: map[string]interface{}{
|
||||
"message": "test",
|
||||
"level": "debug",
|
||||
"str_val": "hello",
|
||||
"num_val": float64(3.14),
|
||||
"bool_val": true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complex_nested_object_attribute",
|
||||
input: map[string]interface{}{
|
||||
"message": "nested",
|
||||
"level": "warn",
|
||||
"nested": map[string]interface{}{"foo": "bar", "n": float64(1)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "time_field_skipped_in_attributes",
|
||||
input: map[string]interface{}{
|
||||
"message": "with time",
|
||||
"level": "info",
|
||||
"time": "2025-01-01T00:00:00Z",
|
||||
"extra": "kept",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "malformed_json",
|
||||
raw: "this is not json at all",
|
||||
},
|
||||
{
|
||||
name: "malformed_json_partial",
|
||||
raw: `{"broken":`,
|
||||
},
|
||||
{
|
||||
name: "array_attribute_marshaled_as_string",
|
||||
input: map[string]interface{}{
|
||||
"message": "arrays",
|
||||
"tags": []interface{}{"a", "b"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "null_value_attribute",
|
||||
input: map[string]interface{}{
|
||||
"message": "nulls",
|
||||
"val": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
bridge := newTestBridge(t)
|
||||
|
||||
var p []byte
|
||||
if tc.raw != "" {
|
||||
p = []byte(tc.raw)
|
||||
} else {
|
||||
var err error
|
||||
p, err = json.Marshal(tc.input)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal test input: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
n, err := bridge.Write(p)
|
||||
if n != len(p) {
|
||||
t.Errorf("Write() returned n=%d, want %d", n, len(p))
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Write() returned err=%v, want nil", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogBridgeWriteAlwaysReturnsLenAndNil(t *testing.T) {
|
||||
bridge := newTestBridge(t)
|
||||
|
||||
inputs := [][]byte{
|
||||
[]byte(`{"message":"ok","level":"info"}`),
|
||||
[]byte(`not json`),
|
||||
[]byte(`{}`),
|
||||
[]byte(``),
|
||||
[]byte(`[]`),
|
||||
}
|
||||
|
||||
for _, p := range inputs {
|
||||
n, err := bridge.Write(p)
|
||||
if n != len(p) {
|
||||
t.Errorf("Write(%q) n=%d, want %d", string(p), n, len(p))
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Write(%q) err=%v, want nil", string(p), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/ratelimit"
|
||||
)
|
||||
|
||||
var (
|
||||
RequestCounter metric.Int64Counter
|
||||
RequestDuration metric.Float64Histogram
|
||||
RequestBodySize metric.Int64Histogram
|
||||
UpstreamErrors metric.Int64Counter
|
||||
TokensInput metric.Int64Counter
|
||||
TokensOutput metric.Int64Counter
|
||||
CredentialCooldowns metric.Int64Counter
|
||||
ActiveCredentials metric.Int64UpDownCounter
|
||||
StreamRequests metric.Int64Counter
|
||||
)
|
||||
|
||||
// InitMetrics creates all metric instruments from the given meter.
|
||||
// If tracker is non-nil, registers observable gauges for per-window usage.
|
||||
func InitMetrics(meter metric.Meter, tracker *ratelimit.Tracker) {
|
||||
RequestCounter, _ = meter.Int64Counter("proxy.request.count",
|
||||
metric.WithDescription("Total proxy requests"),
|
||||
)
|
||||
RequestDuration, _ = meter.Float64Histogram("proxy.request.duration_ms",
|
||||
metric.WithDescription("Request latency in milliseconds"),
|
||||
metric.WithUnit("ms"),
|
||||
)
|
||||
RequestBodySize, _ = meter.Int64Histogram("proxy.request.body_size_bytes",
|
||||
metric.WithDescription("Request body size in bytes"),
|
||||
metric.WithUnit("By"),
|
||||
)
|
||||
UpstreamErrors, _ = meter.Int64Counter("proxy.upstream.errors",
|
||||
metric.WithDescription("Upstream error count"),
|
||||
)
|
||||
TokensInput, _ = meter.Int64Counter("proxy.tokens.input",
|
||||
metric.WithDescription("Input tokens consumed"),
|
||||
)
|
||||
TokensOutput, _ = meter.Int64Counter("proxy.tokens.output",
|
||||
metric.WithDescription("Output tokens consumed"),
|
||||
)
|
||||
CredentialCooldowns, _ = meter.Int64Counter("proxy.credential.cooldowns",
|
||||
metric.WithDescription("Credential cooldown activations"),
|
||||
)
|
||||
ActiveCredentials, _ = meter.Int64UpDownCounter("proxy.credential.active",
|
||||
metric.WithDescription("Currently active (non-cooldown) credentials"),
|
||||
)
|
||||
StreamRequests, _ = meter.Int64Counter("proxy.stream.requests",
|
||||
metric.WithDescription("Streaming request count"),
|
||||
)
|
||||
|
||||
if tracker == nil {
|
||||
return
|
||||
}
|
||||
|
||||
attr5h := attribute.String("window", "5h")
|
||||
attr7d := attribute.String("window", "7d")
|
||||
attrSonnet := attribute.String("window", "7d_sonnet")
|
||||
|
||||
meter.Float64ObservableGauge("proxy.usage.utilization",
|
||||
metric.WithDescription("Current utilization % from API"),
|
||||
metric.WithFloat64Callback(func(_ context.Context, o metric.Float64Observer) error {
|
||||
o.Observe(tracker.FiveHour().Utilization, metric.WithAttributes(attr5h))
|
||||
o.Observe(tracker.SevenDay().Utilization, metric.WithAttributes(attr7d))
|
||||
o.Observe(tracker.Sonnet().Utilization, metric.WithAttributes(attrSonnet))
|
||||
return nil
|
||||
}),
|
||||
)
|
||||
|
||||
meter.Int64ObservableGauge("proxy.usage.resets_at",
|
||||
metric.WithDescription("Unix seconds when window resets"),
|
||||
metric.WithInt64Callback(func(_ context.Context, o metric.Int64Observer) error {
|
||||
o.Observe(tracker.FiveHour().ResetsAt.Unix(), metric.WithAttributes(attr5h))
|
||||
o.Observe(tracker.SevenDay().ResetsAt.Unix(), metric.WithAttributes(attr7d))
|
||||
o.Observe(tracker.Sonnet().ResetsAt.Unix(), metric.WithAttributes(attrSonnet))
|
||||
return nil
|
||||
}),
|
||||
)
|
||||
|
||||
}
|
||||
@@ -1,112 +0,0 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
"github.com/fujin/anthropic-proxy/internal/ratelimit"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
|
||||
promexporter "go.opentelemetry.io/otel/exporters/prometheus"
|
||||
otellog "go.opentelemetry.io/otel/log/global"
|
||||
"go.opentelemetry.io/otel/sdk/log"
|
||||
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
|
||||
"go.opentelemetry.io/otel/sdk/resource"
|
||||
"go.opentelemetry.io/otel/sdk/trace"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
|
||||
)
|
||||
|
||||
func Setup(ctx context.Context, cfg config.TelemetryConfig, tracker *ratelimit.Tracker) (shutdown func(context.Context) error, logWriter io.Writer, metricsHandler http.Handler, err error) {
|
||||
res, err := resource.New(ctx,
|
||||
resource.WithAttributes(
|
||||
semconv.ServiceName(cfg.ServiceName),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
var readers []sdkmetric.Option
|
||||
readers = append(readers, sdkmetric.WithResource(res))
|
||||
|
||||
var promHandler http.Handler
|
||||
if cfg.Embedded.Enabled {
|
||||
exporter, pErr := promexporter.New()
|
||||
if pErr != nil {
|
||||
return nil, nil, nil, pErr
|
||||
}
|
||||
readers = append(readers, sdkmetric.WithReader(exporter))
|
||||
promHandler = promhttp.Handler()
|
||||
}
|
||||
|
||||
if !cfg.Export.Enabled() {
|
||||
mp := sdkmetric.NewMeterProvider(readers...)
|
||||
otel.SetMeterProvider(mp)
|
||||
InitMetrics(mp.Meter(cfg.ServiceName), tracker)
|
||||
return func(ctx context.Context) error { return mp.Shutdown(ctx) }, nil, promHandler, nil
|
||||
}
|
||||
|
||||
traceOpts := []otlptracegrpc.Option{otlptracegrpc.WithEndpoint(cfg.Export.Endpoint)}
|
||||
metricOpts := []otlpmetricgrpc.Option{
|
||||
otlpmetricgrpc.WithEndpoint(cfg.Export.Endpoint),
|
||||
otlpmetricgrpc.WithTemporalitySelector(sdkmetric.CumulativeTemporalitySelector),
|
||||
}
|
||||
logOpts := []otlploggrpc.Option{otlploggrpc.WithEndpoint(cfg.Export.Endpoint)}
|
||||
if cfg.Export.Insecure {
|
||||
traceOpts = append(traceOpts, otlptracegrpc.WithInsecure())
|
||||
metricOpts = append(metricOpts, otlpmetricgrpc.WithInsecure())
|
||||
logOpts = append(logOpts, otlploggrpc.WithInsecure())
|
||||
}
|
||||
|
||||
traceExp, err := otlptracegrpc.New(ctx, traceOpts...)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
tp := trace.NewTracerProvider(
|
||||
trace.WithBatcher(traceExp),
|
||||
trace.WithResource(res),
|
||||
)
|
||||
otel.SetTracerProvider(tp)
|
||||
|
||||
metricExp, err := otlpmetricgrpc.New(ctx, metricOpts...)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
readers = append(readers, sdkmetric.WithReader(sdkmetric.NewPeriodicReader(metricExp)))
|
||||
mp := sdkmetric.NewMeterProvider(readers...)
|
||||
otel.SetMeterProvider(mp)
|
||||
InitMetrics(mp.Meter(cfg.ServiceName), tracker)
|
||||
|
||||
logExp, err := otlploggrpc.New(ctx, logOpts...)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
lp := log.NewLoggerProvider(
|
||||
log.WithProcessor(log.NewBatchProcessor(logExp)),
|
||||
log.WithResource(res),
|
||||
)
|
||||
otellog.SetLoggerProvider(lp)
|
||||
|
||||
bridge := &LogBridge{provider: lp}
|
||||
|
||||
shutdownFn := func(ctx context.Context) error {
|
||||
var firstErr error
|
||||
if e := tp.Shutdown(ctx); e != nil && firstErr == nil {
|
||||
firstErr = e
|
||||
}
|
||||
if e := mp.Shutdown(ctx); e != nil && firstErr == nil {
|
||||
firstErr = e
|
||||
}
|
||||
if e := lp.Shutdown(ctx); e != nil && firstErr == nil {
|
||||
firstErr = e
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
return shutdownFn, bridge, promHandler, nil
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewUTLS(t *testing.T) {
|
||||
tr := NewUTLS()
|
||||
if tr == nil {
|
||||
t.Fatal("NewUTLS returned nil")
|
||||
}
|
||||
if tr.connections == nil {
|
||||
t.Error("connections map is nil")
|
||||
}
|
||||
if tr.pending == nil {
|
||||
t.Error("pending map is nil")
|
||||
}
|
||||
if tr.dialTimeout != 10*time.Second {
|
||||
t.Errorf("dialTimeout = %v, want 10s", tr.dialTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
timeout time.Duration
|
||||
}{
|
||||
{"zero timeout (streaming)", 0},
|
||||
{"15s timeout", 15 * time.Second},
|
||||
{"30s timeout", 30 * time.Second},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := NewHTTPClient(tt.timeout)
|
||||
if c == nil {
|
||||
t.Fatal("NewHTTPClient returned nil")
|
||||
}
|
||||
if c.Timeout != tt.timeout {
|
||||
t.Errorf("Timeout = %v, want %v", c.Timeout, tt.timeout)
|
||||
}
|
||||
if c.Transport == nil {
|
||||
t.Error("Transport is nil")
|
||||
}
|
||||
if _, ok := c.Transport.(*UTLS); !ok {
|
||||
t.Errorf("Transport type = %T, want *UTLS", c.Transport)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUTLS_ImplementsRoundTripper(t *testing.T) {
|
||||
var _ http.RoundTripper = (*UTLS)(nil)
|
||||
}
|
||||
|
||||
func TestUTLS_RoundTrip_InvalidHost(t *testing.T) {
|
||||
tr := NewUTLS()
|
||||
// Use a non-routable address to test dial timeout behavior
|
||||
req, err := http.NewRequest("GET", "https://192.0.2.1:443/test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
_, err = tr.RoundTrip(req)
|
||||
if err == nil {
|
||||
t.Error("expected error for non-routable address, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUTLS_ConnectionEviction(t *testing.T) {
|
||||
tr := NewUTLS()
|
||||
// Verify connections map starts empty
|
||||
tr.mu.Lock()
|
||||
if len(tr.connections) != 0 {
|
||||
t.Errorf("initial connections = %d, want 0", len(tr.connections))
|
||||
}
|
||||
tr.mu.Unlock()
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
// Package version provides the fallback Claude Code client version used when
|
||||
// no sniffed profile is available. This constant is shared between the upstream
|
||||
// proxy client and the rate limit usage poller.
|
||||
package version
|
||||
|
||||
// ClaudeCodeFallback is the Claude Code CLI version string used as a fallback
|
||||
// when no real version is obtained from sniffing.
|
||||
const ClaudeCodeFallback = "2.1.92"
|
||||
@@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -12,19 +11,31 @@ import (
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/auth"
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
"github.com/fujin/anthropic-proxy/internal/embedded"
|
||||
"github.com/fujin/anthropic-proxy/internal/logging"
|
||||
"github.com/fujin/anthropic-proxy/internal/proxy"
|
||||
"github.com/fujin/anthropic-proxy/internal/ratelimit"
|
||||
"github.com/fujin/anthropic-proxy/internal/server"
|
||||
"github.com/fujin/anthropic-proxy/internal/telemetry"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func initCredential() (*auth.Credential, error) {
|
||||
creds, err := auth.LoadDefaultCredentials()
|
||||
func run() error {
|
||||
cfg, err := config.Load("config.yaml")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load credentials: %w", err)
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
logging.Setup(logging.Config{
|
||||
Level: cfg.Logging.Level,
|
||||
File: cfg.Logging.File,
|
||||
MaxSizeMB: cfg.Logging.MaxSizeMB,
|
||||
MaxBackups: cfg.Logging.MaxBackups,
|
||||
MaxAgeDays: cfg.Logging.MaxAgeDays,
|
||||
Compress: cfg.Logging.Compress,
|
||||
})
|
||||
|
||||
// Load credentials from ~/.claude/.credentials.json
|
||||
creds, err := config.LoadDefaultCredentials()
|
||||
if err != nil {
|
||||
return fmt.Errorf("load credentials: %w", err)
|
||||
}
|
||||
|
||||
var cred *auth.Credential
|
||||
@@ -46,84 +57,19 @@ func initCredential() (*auth.Credential, error) {
|
||||
}
|
||||
|
||||
if cred == nil {
|
||||
// Non-TTY check: if stdin is not a terminal, can't do interactive login
|
||||
fi, statErr := os.Stdin.Stat()
|
||||
if statErr == nil && (fi.Mode()&os.ModeCharDevice) == 0 {
|
||||
return nil, fmt.Errorf("no valid credentials found; run the proxy interactively for initial login")
|
||||
return fmt.Errorf("no valid credentials found; run the proxy interactively for initial login")
|
||||
}
|
||||
log.Info().Msg("no credentials found, starting OAuth login")
|
||||
cred, err = auth.Login(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login failed: %w", err)
|
||||
return fmt.Errorf("login failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Str("credential", cred.Email).Msg("credential loaded")
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
func initEmbedded(cfg *config.Config) (cleanup func(), err error) {
|
||||
if !cfg.Telemetry.Embedded.Enabled {
|
||||
return func() {}, nil
|
||||
}
|
||||
|
||||
var cleanups []func()
|
||||
|
||||
vm := embedded.NewVM(cfg.Telemetry.Embedded, cfg.Port)
|
||||
if err := vm.Start(); err != nil {
|
||||
log.Error().Err(err).Msg("failed to start victoria-metrics")
|
||||
} else {
|
||||
cleanups = append(cleanups, vm.Stop)
|
||||
}
|
||||
|
||||
perses := embedded.NewPerses(cfg.Telemetry.Embedded, cfg.Port)
|
||||
if err := perses.Start(); err != nil {
|
||||
log.Error().Err(err).Msg("failed to start perses")
|
||||
} else {
|
||||
cleanups = append(cleanups, perses.Stop)
|
||||
}
|
||||
|
||||
return func() {
|
||||
for i := len(cleanups) - 1; i >= 0; i-- {
|
||||
cleanups[i]()
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func run() error {
|
||||
cfg, err := config.Load("config.yaml")
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
// Create usage tracker (started later once credential is loaded)
|
||||
var credForTracker *auth.Credential
|
||||
tracker := ratelimit.NewTracker(func() string {
|
||||
if credForTracker == nil {
|
||||
return ""
|
||||
}
|
||||
return credForTracker.Token()
|
||||
})
|
||||
|
||||
// Initialize telemetry (metrics always active; OTLP export when endpoint set)
|
||||
telemetryShutdown, logBridge, metricsHandler, err := telemetry.Setup(context.Background(), cfg.Telemetry, tracker)
|
||||
if err != nil {
|
||||
return fmt.Errorf("telemetry setup: %w", err)
|
||||
}
|
||||
defer telemetryShutdown(context.Background())
|
||||
|
||||
var extraWriters []io.Writer
|
||||
if logBridge != nil {
|
||||
extraWriters = append(extraWriters, logBridge)
|
||||
}
|
||||
|
||||
logging.Setup(cfg.Logging, extraWriters...)
|
||||
|
||||
cred, err := initCredential()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
credForTracker = cred
|
||||
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
|
||||
@@ -132,7 +78,6 @@ func run() error {
|
||||
|
||||
pool.RefreshExpiring(context.Background())
|
||||
auth.StartBackgroundRefresh(ctx, pool)
|
||||
tracker.Start(ctx)
|
||||
|
||||
var profile *proxy.SniffedProfile
|
||||
if cfg.ClaudeBinary != "" {
|
||||
@@ -143,14 +88,8 @@ func run() error {
|
||||
}
|
||||
}
|
||||
|
||||
embeddedCleanup, err := initEmbedded(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer embeddedCleanup()
|
||||
|
||||
log.Info().Int("port", cfg.Port).Msg("starting server")
|
||||
srv := server.New(cfg, pool, profile, tracker, metricsHandler)
|
||||
srv := server.New(cfg, pool, profile)
|
||||
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
+2
-2
@@ -7,11 +7,11 @@
|
||||
|
||||
buildGoModule rec {
|
||||
pname = "anthropic-proxy";
|
||||
version = "0.0.5";
|
||||
version = "0.0.4";
|
||||
|
||||
src = ./.;
|
||||
|
||||
vendorHash = "sha256-yXINNC+NEw+HbOQ5aBgSE5dYTWp+zEZ230rzXfwOoDY=";
|
||||
vendorHash = "sha256-xKztaGlelw7OI/6RJkkepHmLLH+dCCqYXE71C+y3PwI=";
|
||||
|
||||
meta = with lib; {
|
||||
description = "Reverse proxy that lets OpenCode (and similar tools) use a Claude subscription instead of an API key.";
|
||||
|
||||
+19
@@ -0,0 +1,19 @@
|
||||
Copyright (c) 2009, 2010, 2013-2016 by the Brotli Authors.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
+7
@@ -0,0 +1,7 @@
|
||||
This package is a brotli compressor and decompressor implemented in Go.
|
||||
It was translated from the reference implementation (https://github.com/google/brotli)
|
||||
with the `c2go` tool at https://github.com/andybalholm/c2go.
|
||||
|
||||
I am using it in production with https://github.com/andybalholm/redwood.
|
||||
|
||||
API documentation is found at https://pkg.go.dev/github.com/andybalholm/brotli?tab=doc.
|
||||
+185
@@ -0,0 +1,185 @@
|
||||
package brotli
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Function to find backward reference copies. */
|
||||
|
||||
func computeDistanceCode(distance uint, max_distance uint, dist_cache []int) uint {
|
||||
if distance <= max_distance {
|
||||
var distance_plus_3 uint = distance + 3
|
||||
var offset0 uint = distance_plus_3 - uint(dist_cache[0])
|
||||
var offset1 uint = distance_plus_3 - uint(dist_cache[1])
|
||||
if distance == uint(dist_cache[0]) {
|
||||
return 0
|
||||
} else if distance == uint(dist_cache[1]) {
|
||||
return 1
|
||||
} else if offset0 < 7 {
|
||||
return (0x9750468 >> (4 * offset0)) & 0xF
|
||||
} else if offset1 < 7 {
|
||||
return (0xFDB1ACE >> (4 * offset1)) & 0xF
|
||||
} else if distance == uint(dist_cache[2]) {
|
||||
return 2
|
||||
} else if distance == uint(dist_cache[3]) {
|
||||
return 3
|
||||
}
|
||||
}
|
||||
|
||||
return distance + numDistanceShortCodes - 1
|
||||
}
|
||||
|
||||
var hasherSearchResultPool sync.Pool
|
||||
|
||||
func createBackwardReferences(num_bytes uint, position uint, ringbuffer []byte, ringbuffer_mask uint, params *encoderParams, hasher hasherHandle, dist_cache []int, last_insert_len *uint, commands *[]command, num_literals *uint) {
|
||||
var max_backward_limit uint = maxBackwardLimit(params.lgwin)
|
||||
var insert_length uint = *last_insert_len
|
||||
var pos_end uint = position + num_bytes
|
||||
var store_end uint
|
||||
if num_bytes >= hasher.StoreLookahead() {
|
||||
store_end = position + num_bytes - hasher.StoreLookahead() + 1
|
||||
} else {
|
||||
store_end = position
|
||||
}
|
||||
var random_heuristics_window_size uint = literalSpreeLengthForSparseSearch(params)
|
||||
var apply_random_heuristics uint = position + random_heuristics_window_size
|
||||
var gap uint = 0
|
||||
/* Set maximum distance, see section 9.1. of the spec. */
|
||||
|
||||
const kMinScore uint = scoreBase + 100
|
||||
|
||||
/* For speed up heuristics for random data. */
|
||||
|
||||
/* Minimum score to accept a backward reference. */
|
||||
hasher.PrepareDistanceCache(dist_cache)
|
||||
sr2, _ := hasherSearchResultPool.Get().(*hasherSearchResult)
|
||||
if sr2 == nil {
|
||||
sr2 = &hasherSearchResult{}
|
||||
}
|
||||
sr, _ := hasherSearchResultPool.Get().(*hasherSearchResult)
|
||||
if sr == nil {
|
||||
sr = &hasherSearchResult{}
|
||||
}
|
||||
|
||||
for position+hasher.HashTypeLength() < pos_end {
|
||||
var max_length uint = pos_end - position
|
||||
var max_distance uint = brotli_min_size_t(position, max_backward_limit)
|
||||
sr.len = 0
|
||||
sr.len_code_delta = 0
|
||||
sr.distance = 0
|
||||
sr.score = kMinScore
|
||||
hasher.FindLongestMatch(¶ms.dictionary, ringbuffer, ringbuffer_mask, dist_cache, position, max_length, max_distance, gap, params.dist.max_distance, sr)
|
||||
if sr.score > kMinScore {
|
||||
/* Found a match. Let's look for something even better ahead. */
|
||||
var delayed_backward_references_in_row int = 0
|
||||
max_length--
|
||||
for ; ; max_length-- {
|
||||
var cost_diff_lazy uint = 175
|
||||
if params.quality < minQualityForExtensiveReferenceSearch {
|
||||
sr2.len = brotli_min_size_t(sr.len-1, max_length)
|
||||
} else {
|
||||
sr2.len = 0
|
||||
}
|
||||
sr2.len_code_delta = 0
|
||||
sr2.distance = 0
|
||||
sr2.score = kMinScore
|
||||
max_distance = brotli_min_size_t(position+1, max_backward_limit)
|
||||
hasher.FindLongestMatch(¶ms.dictionary, ringbuffer, ringbuffer_mask, dist_cache, position+1, max_length, max_distance, gap, params.dist.max_distance, sr2)
|
||||
if sr2.score >= sr.score+cost_diff_lazy {
|
||||
/* Ok, let's just write one byte for now and start a match from the
|
||||
next byte. */
|
||||
position++
|
||||
|
||||
insert_length++
|
||||
*sr = *sr2
|
||||
delayed_backward_references_in_row++
|
||||
if delayed_backward_references_in_row < 4 && position+hasher.HashTypeLength() < pos_end {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
apply_random_heuristics = position + 2*sr.len + random_heuristics_window_size
|
||||
max_distance = brotli_min_size_t(position, max_backward_limit)
|
||||
{
|
||||
/* The first 16 codes are special short-codes,
|
||||
and the minimum offset is 1. */
|
||||
var distance_code uint = computeDistanceCode(sr.distance, max_distance+gap, dist_cache)
|
||||
if (sr.distance <= (max_distance + gap)) && distance_code > 0 {
|
||||
dist_cache[3] = dist_cache[2]
|
||||
dist_cache[2] = dist_cache[1]
|
||||
dist_cache[1] = dist_cache[0]
|
||||
dist_cache[0] = int(sr.distance)
|
||||
hasher.PrepareDistanceCache(dist_cache)
|
||||
}
|
||||
|
||||
*commands = append(*commands, makeCommand(¶ms.dist, insert_length, sr.len, sr.len_code_delta, distance_code))
|
||||
}
|
||||
|
||||
*num_literals += insert_length
|
||||
insert_length = 0
|
||||
/* Put the hash keys into the table, if there are enough bytes left.
|
||||
Depending on the hasher implementation, it can push all positions
|
||||
in the given range or only a subset of them.
|
||||
Avoid hash poisoning with RLE data. */
|
||||
{
|
||||
var range_start uint = position + 2
|
||||
var range_end uint = brotli_min_size_t(position+sr.len, store_end)
|
||||
if sr.distance < sr.len>>2 {
|
||||
range_start = brotli_min_size_t(range_end, brotli_max_size_t(range_start, position+sr.len-(sr.distance<<2)))
|
||||
}
|
||||
|
||||
hasher.StoreRange(ringbuffer, ringbuffer_mask, range_start, range_end)
|
||||
}
|
||||
|
||||
position += sr.len
|
||||
} else {
|
||||
insert_length++
|
||||
position++
|
||||
|
||||
/* If we have not seen matches for a long time, we can skip some
|
||||
match lookups. Unsuccessful match lookups are very very expensive
|
||||
and this kind of a heuristic speeds up compression quite
|
||||
a lot. */
|
||||
if position > apply_random_heuristics {
|
||||
/* Going through uncompressible data, jump. */
|
||||
if position > apply_random_heuristics+4*random_heuristics_window_size {
|
||||
var kMargin uint = brotli_max_size_t(hasher.StoreLookahead()-1, 4)
|
||||
/* It is quite a long time since we saw a copy, so we assume
|
||||
that this data is not compressible, and store hashes less
|
||||
often. Hashes of non compressible data are less likely to
|
||||
turn out to be useful in the future, too, so we store less of
|
||||
them to not to flood out the hash table of good compressible
|
||||
data. */
|
||||
|
||||
var pos_jump uint = brotli_min_size_t(position+16, pos_end-kMargin)
|
||||
for ; position < pos_jump; position += 4 {
|
||||
hasher.Store(ringbuffer, ringbuffer_mask, position)
|
||||
insert_length += 4
|
||||
}
|
||||
} else {
|
||||
var kMargin uint = brotli_max_size_t(hasher.StoreLookahead()-1, 2)
|
||||
var pos_jump uint = brotli_min_size_t(position+8, pos_end-kMargin)
|
||||
for ; position < pos_jump; position += 2 {
|
||||
hasher.Store(ringbuffer, ringbuffer_mask, position)
|
||||
insert_length += 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
insert_length += pos_end - position
|
||||
*last_insert_len = insert_length
|
||||
|
||||
hasherSearchResultPool.Put(sr)
|
||||
hasherSearchResultPool.Put(sr2)
|
||||
}
|
||||
+796
@@ -0,0 +1,796 @@
|
||||
package brotli
|
||||
|
||||
import "math"
|
||||
|
||||
type zopfliNode struct {
|
||||
length uint32
|
||||
distance uint32
|
||||
dcode_insert_length uint32
|
||||
u struct {
|
||||
cost float32
|
||||
next uint32
|
||||
shortcut uint32
|
||||
}
|
||||
}
|
||||
|
||||
const maxEffectiveDistanceAlphabetSize = 544
|
||||
|
||||
const kInfinity float32 = 1.7e38 /* ~= 2 ^ 127 */
|
||||
|
||||
var kDistanceCacheIndex = []uint32{0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1}
|
||||
|
||||
var kDistanceCacheOffset = []int{0, 0, 0, 0, -1, 1, -2, 2, -3, 3, -1, 1, -2, 2, -3, 3}
|
||||
|
||||
func initZopfliNodes(array []zopfliNode, length uint) {
|
||||
var stub zopfliNode
|
||||
var i uint
|
||||
stub.length = 1
|
||||
stub.distance = 0
|
||||
stub.dcode_insert_length = 0
|
||||
stub.u.cost = kInfinity
|
||||
for i = 0; i < length; i++ {
|
||||
array[i] = stub
|
||||
}
|
||||
}
|
||||
|
||||
func zopfliNodeCopyLength(self *zopfliNode) uint32 {
|
||||
return self.length & 0x1FFFFFF
|
||||
}
|
||||
|
||||
func zopfliNodeLengthCode(self *zopfliNode) uint32 {
|
||||
var modifier uint32 = self.length >> 25
|
||||
return zopfliNodeCopyLength(self) + 9 - modifier
|
||||
}
|
||||
|
||||
func zopfliNodeCopyDistance(self *zopfliNode) uint32 {
|
||||
return self.distance
|
||||
}
|
||||
|
||||
func zopfliNodeDistanceCode(self *zopfliNode) uint32 {
|
||||
var short_code uint32 = self.dcode_insert_length >> 27
|
||||
if short_code == 0 {
|
||||
return zopfliNodeCopyDistance(self) + numDistanceShortCodes - 1
|
||||
} else {
|
||||
return short_code - 1
|
||||
}
|
||||
}
|
||||
|
||||
func zopfliNodeCommandLength(self *zopfliNode) uint32 {
|
||||
return zopfliNodeCopyLength(self) + (self.dcode_insert_length & 0x7FFFFFF)
|
||||
}
|
||||
|
||||
/* Histogram based cost model for zopflification. */
|
||||
type zopfliCostModel struct {
|
||||
cost_cmd_ [numCommandSymbols]float32
|
||||
cost_dist_ []float32
|
||||
distance_histogram_size uint32
|
||||
literal_costs_ []float32
|
||||
min_cost_cmd_ float32
|
||||
num_bytes_ uint
|
||||
}
|
||||
|
||||
func initZopfliCostModel(self *zopfliCostModel, dist *distanceParams, num_bytes uint) {
|
||||
var distance_histogram_size uint32 = dist.alphabet_size
|
||||
if distance_histogram_size > maxEffectiveDistanceAlphabetSize {
|
||||
distance_histogram_size = maxEffectiveDistanceAlphabetSize
|
||||
}
|
||||
|
||||
self.num_bytes_ = num_bytes
|
||||
self.literal_costs_ = make([]float32, (num_bytes + 2))
|
||||
self.cost_dist_ = make([]float32, (dist.alphabet_size))
|
||||
self.distance_histogram_size = distance_histogram_size
|
||||
}
|
||||
|
||||
func cleanupZopfliCostModel(self *zopfliCostModel) {
|
||||
self.literal_costs_ = nil
|
||||
self.cost_dist_ = nil
|
||||
}
|
||||
|
||||
func setCost(histogram []uint32, histogram_size uint, literal_histogram bool, cost []float32) {
|
||||
var sum uint = 0
|
||||
var missing_symbol_sum uint
|
||||
var log2sum float32
|
||||
var missing_symbol_cost float32
|
||||
var i uint
|
||||
for i = 0; i < histogram_size; i++ {
|
||||
sum += uint(histogram[i])
|
||||
}
|
||||
|
||||
log2sum = float32(fastLog2(sum))
|
||||
missing_symbol_sum = sum
|
||||
if !literal_histogram {
|
||||
for i = 0; i < histogram_size; i++ {
|
||||
if histogram[i] == 0 {
|
||||
missing_symbol_sum++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
missing_symbol_cost = float32(fastLog2(missing_symbol_sum)) + 2
|
||||
for i = 0; i < histogram_size; i++ {
|
||||
if histogram[i] == 0 {
|
||||
cost[i] = missing_symbol_cost
|
||||
continue
|
||||
}
|
||||
|
||||
/* Shannon bits for this symbol. */
|
||||
cost[i] = log2sum - float32(fastLog2(uint(histogram[i])))
|
||||
|
||||
/* Cannot be coded with less than 1 bit */
|
||||
if cost[i] < 1 {
|
||||
cost[i] = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func zopfliCostModelSetFromCommands(self *zopfliCostModel, position uint, ringbuffer []byte, ringbuffer_mask uint, commands []command, last_insert_len uint) {
|
||||
var histogram_literal [numLiteralSymbols]uint32
|
||||
var histogram_cmd [numCommandSymbols]uint32
|
||||
var histogram_dist [maxEffectiveDistanceAlphabetSize]uint32
|
||||
var cost_literal [numLiteralSymbols]float32
|
||||
var pos uint = position - last_insert_len
|
||||
var min_cost_cmd float32 = kInfinity
|
||||
var cost_cmd []float32 = self.cost_cmd_[:]
|
||||
var literal_costs []float32
|
||||
|
||||
histogram_literal = [numLiteralSymbols]uint32{}
|
||||
histogram_cmd = [numCommandSymbols]uint32{}
|
||||
histogram_dist = [maxEffectiveDistanceAlphabetSize]uint32{}
|
||||
|
||||
for i := range commands {
|
||||
var inslength uint = uint(commands[i].insert_len_)
|
||||
var copylength uint = uint(commandCopyLen(&commands[i]))
|
||||
var distcode uint = uint(commands[i].dist_prefix_) & 0x3FF
|
||||
var cmdcode uint = uint(commands[i].cmd_prefix_)
|
||||
var j uint
|
||||
|
||||
histogram_cmd[cmdcode]++
|
||||
if cmdcode >= 128 {
|
||||
histogram_dist[distcode]++
|
||||
}
|
||||
|
||||
for j = 0; j < inslength; j++ {
|
||||
histogram_literal[ringbuffer[(pos+j)&ringbuffer_mask]]++
|
||||
}
|
||||
|
||||
pos += inslength + copylength
|
||||
}
|
||||
|
||||
setCost(histogram_literal[:], numLiteralSymbols, true, cost_literal[:])
|
||||
setCost(histogram_cmd[:], numCommandSymbols, false, cost_cmd)
|
||||
setCost(histogram_dist[:], uint(self.distance_histogram_size), false, self.cost_dist_)
|
||||
|
||||
for i := 0; i < numCommandSymbols; i++ {
|
||||
min_cost_cmd = brotli_min_float(min_cost_cmd, cost_cmd[i])
|
||||
}
|
||||
|
||||
self.min_cost_cmd_ = min_cost_cmd
|
||||
{
|
||||
literal_costs = self.literal_costs_
|
||||
var literal_carry float32 = 0.0
|
||||
num_bytes := int(self.num_bytes_)
|
||||
literal_costs[0] = 0.0
|
||||
for i := 0; i < num_bytes; i++ {
|
||||
literal_carry += cost_literal[ringbuffer[(position+uint(i))&ringbuffer_mask]]
|
||||
literal_costs[i+1] = literal_costs[i] + literal_carry
|
||||
literal_carry -= literal_costs[i+1] - literal_costs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func zopfliCostModelSetFromLiteralCosts(self *zopfliCostModel, position uint, ringbuffer []byte, ringbuffer_mask uint) {
|
||||
var literal_costs []float32 = self.literal_costs_
|
||||
var literal_carry float32 = 0.0
|
||||
var cost_dist []float32 = self.cost_dist_
|
||||
var cost_cmd []float32 = self.cost_cmd_[:]
|
||||
var num_bytes uint = self.num_bytes_
|
||||
var i uint
|
||||
estimateBitCostsForLiterals(position, num_bytes, ringbuffer_mask, ringbuffer, literal_costs[1:])
|
||||
literal_costs[0] = 0.0
|
||||
for i = 0; i < num_bytes; i++ {
|
||||
literal_carry += literal_costs[i+1]
|
||||
literal_costs[i+1] = literal_costs[i] + literal_carry
|
||||
literal_carry -= literal_costs[i+1] - literal_costs[i]
|
||||
}
|
||||
|
||||
for i = 0; i < numCommandSymbols; i++ {
|
||||
cost_cmd[i] = float32(fastLog2(uint(11 + uint32(i))))
|
||||
}
|
||||
|
||||
for i = 0; uint32(i) < self.distance_histogram_size; i++ {
|
||||
cost_dist[i] = float32(fastLog2(uint(20 + uint32(i))))
|
||||
}
|
||||
|
||||
self.min_cost_cmd_ = float32(fastLog2(11))
|
||||
}
|
||||
|
||||
func zopfliCostModelGetCommandCost(self *zopfliCostModel, cmdcode uint16) float32 {
|
||||
return self.cost_cmd_[cmdcode]
|
||||
}
|
||||
|
||||
func zopfliCostModelGetDistanceCost(self *zopfliCostModel, distcode uint) float32 {
|
||||
return self.cost_dist_[distcode]
|
||||
}
|
||||
|
||||
func zopfliCostModelGetLiteralCosts(self *zopfliCostModel, from uint, to uint) float32 {
|
||||
return self.literal_costs_[to] - self.literal_costs_[from]
|
||||
}
|
||||
|
||||
func zopfliCostModelGetMinCostCmd(self *zopfliCostModel) float32 {
|
||||
return self.min_cost_cmd_
|
||||
}
|
||||
|
||||
/* REQUIRES: len >= 2, start_pos <= pos */
|
||||
/* REQUIRES: cost < kInfinity, nodes[start_pos].cost < kInfinity */
|
||||
/* Maintains the "ZopfliNode array invariant". */
|
||||
func updateZopfliNode(nodes []zopfliNode, pos uint, start_pos uint, len uint, len_code uint, dist uint, short_code uint, cost float32) {
|
||||
var next *zopfliNode = &nodes[pos+len]
|
||||
next.length = uint32(len | (len+9-len_code)<<25)
|
||||
next.distance = uint32(dist)
|
||||
next.dcode_insert_length = uint32(short_code<<27 | (pos - start_pos))
|
||||
next.u.cost = cost
|
||||
}
|
||||
|
||||
type posData struct {
|
||||
pos uint
|
||||
distance_cache [4]int
|
||||
costdiff float32
|
||||
cost float32
|
||||
}
|
||||
|
||||
/* Maintains the smallest 8 cost difference together with their positions */
|
||||
type startPosQueue struct {
|
||||
q_ [8]posData
|
||||
idx_ uint
|
||||
}
|
||||
|
||||
func initStartPosQueue(self *startPosQueue) {
|
||||
self.idx_ = 0
|
||||
}
|
||||
|
||||
func startPosQueueSize(self *startPosQueue) uint {
|
||||
return brotli_min_size_t(self.idx_, 8)
|
||||
}
|
||||
|
||||
func startPosQueuePush(self *startPosQueue, posdata *posData) {
|
||||
var offset uint = ^(self.idx_) & 7
|
||||
self.idx_++
|
||||
var len uint = startPosQueueSize(self)
|
||||
var i uint
|
||||
var q []posData = self.q_[:]
|
||||
q[offset] = *posdata
|
||||
|
||||
/* Restore the sorted order. In the list of |len| items at most |len - 1|
|
||||
adjacent element comparisons / swaps are required. */
|
||||
for i = 1; i < len; i++ {
|
||||
if q[offset&7].costdiff > q[(offset+1)&7].costdiff {
|
||||
var tmp posData = q[offset&7]
|
||||
q[offset&7] = q[(offset+1)&7]
|
||||
q[(offset+1)&7] = tmp
|
||||
}
|
||||
|
||||
offset++
|
||||
}
|
||||
}
|
||||
|
||||
func startPosQueueAt(self *startPosQueue, k uint) *posData {
|
||||
return &self.q_[(k-self.idx_)&7]
|
||||
}
|
||||
|
||||
/* Returns the minimum possible copy length that can improve the cost of any */
|
||||
/* future position. */
|
||||
func computeMinimumCopyLength(start_cost float32, nodes []zopfliNode, num_bytes uint, pos uint) uint {
|
||||
var min_cost float32 = start_cost
|
||||
var len uint = 2
|
||||
var next_len_bucket uint = 4
|
||||
/* Compute the minimum possible cost of reaching any future position. */
|
||||
|
||||
var next_len_offset uint = 10
|
||||
for pos+len <= num_bytes && nodes[pos+len].u.cost <= min_cost {
|
||||
/* We already reached (pos + len) with no more cost than the minimum
|
||||
possible cost of reaching anything from this pos, so there is no point in
|
||||
looking for lengths <= len. */
|
||||
len++
|
||||
|
||||
if len == next_len_offset {
|
||||
/* We reached the next copy length code bucket, so we add one more
|
||||
extra bit to the minimum cost. */
|
||||
min_cost += 1.0
|
||||
|
||||
next_len_offset += next_len_bucket
|
||||
next_len_bucket *= 2
|
||||
}
|
||||
}
|
||||
|
||||
return uint(len)
|
||||
}
|
||||
|
||||
/* REQUIRES: nodes[pos].cost < kInfinity
|
||||
REQUIRES: nodes[0..pos] satisfies that "ZopfliNode array invariant". */
|
||||
func computeDistanceShortcut(block_start uint, pos uint, max_backward_limit uint, gap uint, nodes []zopfliNode) uint32 {
|
||||
var clen uint = uint(zopfliNodeCopyLength(&nodes[pos]))
|
||||
var ilen uint = uint(nodes[pos].dcode_insert_length & 0x7FFFFFF)
|
||||
var dist uint = uint(zopfliNodeCopyDistance(&nodes[pos]))
|
||||
|
||||
/* Since |block_start + pos| is the end position of the command, the copy part
|
||||
starts from |block_start + pos - clen|. Distances that are greater than
|
||||
this or greater than |max_backward_limit| + |gap| are static dictionary
|
||||
references, and do not update the last distances.
|
||||
Also distance code 0 (last distance) does not update the last distances. */
|
||||
if pos == 0 {
|
||||
return 0
|
||||
} else if dist+clen <= block_start+pos+gap && dist <= max_backward_limit+gap && zopfliNodeDistanceCode(&nodes[pos]) > 0 {
|
||||
return uint32(pos)
|
||||
} else {
|
||||
return nodes[pos-clen-ilen].u.shortcut
|
||||
}
|
||||
}
|
||||
|
||||
/* Fills in dist_cache[0..3] with the last four distances (as defined by
|
||||
Section 4. of the Spec) that would be used at (block_start + pos) if we
|
||||
used the shortest path of commands from block_start, computed from
|
||||
nodes[0..pos]. The last four distances at block_start are in
|
||||
starting_dist_cache[0..3].
|
||||
REQUIRES: nodes[pos].cost < kInfinity
|
||||
REQUIRES: nodes[0..pos] satisfies that "ZopfliNode array invariant". */
|
||||
func computeDistanceCache(pos uint, starting_dist_cache []int, nodes []zopfliNode, dist_cache []int) {
|
||||
var idx int = 0
|
||||
var p uint = uint(nodes[pos].u.shortcut)
|
||||
for idx < 4 && p > 0 {
|
||||
var ilen uint = uint(nodes[p].dcode_insert_length & 0x7FFFFFF)
|
||||
var clen uint = uint(zopfliNodeCopyLength(&nodes[p]))
|
||||
var dist uint = uint(zopfliNodeCopyDistance(&nodes[p]))
|
||||
dist_cache[idx] = int(dist)
|
||||
idx++
|
||||
|
||||
/* Because of prerequisite, p >= clen + ilen >= 2. */
|
||||
p = uint(nodes[p-clen-ilen].u.shortcut)
|
||||
}
|
||||
|
||||
for ; idx < 4; idx++ {
|
||||
dist_cache[idx] = starting_dist_cache[0]
|
||||
starting_dist_cache = starting_dist_cache[1:]
|
||||
}
|
||||
}
|
||||
|
||||
/* Maintains "ZopfliNode array invariant" and pushes node to the queue, if it
|
||||
is eligible. */
|
||||
func evaluateNode(block_start uint, pos uint, max_backward_limit uint, gap uint, starting_dist_cache []int, model *zopfliCostModel, queue *startPosQueue, nodes []zopfliNode) {
|
||||
/* Save cost, because ComputeDistanceCache invalidates it. */
|
||||
var node_cost float32 = nodes[pos].u.cost
|
||||
nodes[pos].u.shortcut = computeDistanceShortcut(block_start, pos, max_backward_limit, gap, nodes)
|
||||
if node_cost <= zopfliCostModelGetLiteralCosts(model, 0, pos) {
|
||||
var posdata posData
|
||||
posdata.pos = pos
|
||||
posdata.cost = node_cost
|
||||
posdata.costdiff = node_cost - zopfliCostModelGetLiteralCosts(model, 0, pos)
|
||||
computeDistanceCache(pos, starting_dist_cache, nodes, posdata.distance_cache[:])
|
||||
startPosQueuePush(queue, &posdata)
|
||||
}
|
||||
}
|
||||
|
||||
/* Returns longest copy length. */
|
||||
func updateNodes(num_bytes uint, block_start uint, pos uint, ringbuffer []byte, ringbuffer_mask uint, params *encoderParams, max_backward_limit uint, starting_dist_cache []int, num_matches uint, matches []backwardMatch, model *zopfliCostModel, queue *startPosQueue, nodes []zopfliNode) uint {
|
||||
var cur_ix uint = block_start + pos
|
||||
var cur_ix_masked uint = cur_ix & ringbuffer_mask
|
||||
var max_distance uint = brotli_min_size_t(cur_ix, max_backward_limit)
|
||||
var max_len uint = num_bytes - pos
|
||||
var max_zopfli_len uint = maxZopfliLen(params)
|
||||
var max_iters uint = maxZopfliCandidates(params)
|
||||
var min_len uint
|
||||
var result uint = 0
|
||||
var k uint
|
||||
var gap uint = 0
|
||||
|
||||
evaluateNode(block_start, pos, max_backward_limit, gap, starting_dist_cache, model, queue, nodes)
|
||||
{
|
||||
var posdata *posData = startPosQueueAt(queue, 0)
|
||||
var min_cost float32 = (posdata.cost + zopfliCostModelGetMinCostCmd(model) + zopfliCostModelGetLiteralCosts(model, posdata.pos, pos))
|
||||
min_len = computeMinimumCopyLength(min_cost, nodes, num_bytes, pos)
|
||||
}
|
||||
|
||||
/* Go over the command starting positions in order of increasing cost
|
||||
difference. */
|
||||
for k = 0; k < max_iters && k < startPosQueueSize(queue); k++ {
|
||||
var posdata *posData = startPosQueueAt(queue, k)
|
||||
var start uint = posdata.pos
|
||||
var inscode uint16 = getInsertLengthCode(pos - start)
|
||||
var start_costdiff float32 = posdata.costdiff
|
||||
var base_cost float32 = start_costdiff + float32(getInsertExtra(inscode)) + zopfliCostModelGetLiteralCosts(model, 0, pos)
|
||||
var best_len uint = min_len - 1
|
||||
var j uint = 0
|
||||
/* Look for last distance matches using the distance cache from this
|
||||
starting position. */
|
||||
for ; j < numDistanceShortCodes && best_len < max_len; j++ {
|
||||
var idx uint = uint(kDistanceCacheIndex[j])
|
||||
var backward uint = uint(posdata.distance_cache[idx] + kDistanceCacheOffset[j])
|
||||
var prev_ix uint = cur_ix - backward
|
||||
var len uint = 0
|
||||
var continuation byte = ringbuffer[cur_ix_masked+best_len]
|
||||
if cur_ix_masked+best_len > ringbuffer_mask {
|
||||
break
|
||||
}
|
||||
|
||||
if backward > max_distance+gap {
|
||||
/* Word dictionary -> ignore. */
|
||||
continue
|
||||
}
|
||||
|
||||
if backward <= max_distance {
|
||||
/* Regular backward reference. */
|
||||
if prev_ix >= cur_ix {
|
||||
continue
|
||||
}
|
||||
|
||||
prev_ix &= ringbuffer_mask
|
||||
if prev_ix+best_len > ringbuffer_mask || continuation != ringbuffer[prev_ix+best_len] {
|
||||
continue
|
||||
}
|
||||
|
||||
len = findMatchLengthWithLimit(ringbuffer[prev_ix:], ringbuffer[cur_ix_masked:], max_len)
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
{
|
||||
var dist_cost float32 = base_cost + zopfliCostModelGetDistanceCost(model, j)
|
||||
var l uint
|
||||
for l = best_len + 1; l <= len; l++ {
|
||||
var copycode uint16 = getCopyLengthCode(l)
|
||||
var cmdcode uint16 = combineLengthCodes(inscode, copycode, j == 0)
|
||||
var tmp float32
|
||||
if cmdcode < 128 {
|
||||
tmp = base_cost
|
||||
} else {
|
||||
tmp = dist_cost
|
||||
}
|
||||
var cost float32 = tmp + float32(getCopyExtra(copycode)) + zopfliCostModelGetCommandCost(model, cmdcode)
|
||||
if cost < nodes[pos+l].u.cost {
|
||||
updateZopfliNode(nodes, pos, start, l, l, backward, j+1, cost)
|
||||
result = brotli_max_size_t(result, l)
|
||||
}
|
||||
|
||||
best_len = l
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* At higher iterations look only for new last distance matches, since
|
||||
looking only for new command start positions with the same distances
|
||||
does not help much. */
|
||||
if k >= 2 {
|
||||
continue
|
||||
}
|
||||
{
|
||||
/* Loop through all possible copy lengths at this position. */
|
||||
var len uint = min_len
|
||||
for j = 0; j < num_matches; j++ {
|
||||
var match backwardMatch = matches[j]
|
||||
var dist uint = uint(match.distance)
|
||||
var is_dictionary_match bool = (dist > max_distance+gap)
|
||||
var dist_code uint = dist + numDistanceShortCodes - 1
|
||||
var dist_symbol uint16
|
||||
var distextra uint32
|
||||
var distnumextra uint32
|
||||
var dist_cost float32
|
||||
var max_match_len uint
|
||||
/* We already tried all possible last distance matches, so we can use
|
||||
normal distance code here. */
|
||||
prefixEncodeCopyDistance(dist_code, uint(params.dist.num_direct_distance_codes), uint(params.dist.distance_postfix_bits), &dist_symbol, &distextra)
|
||||
|
||||
distnumextra = uint32(dist_symbol) >> 10
|
||||
dist_cost = base_cost + float32(distnumextra) + zopfliCostModelGetDistanceCost(model, uint(dist_symbol)&0x3FF)
|
||||
|
||||
/* Try all copy lengths up until the maximum copy length corresponding
|
||||
to this distance. If the distance refers to the static dictionary, or
|
||||
the maximum length is long enough, try only one maximum length. */
|
||||
max_match_len = backwardMatchLength(&match)
|
||||
|
||||
if len < max_match_len && (is_dictionary_match || max_match_len > max_zopfli_len) {
|
||||
len = max_match_len
|
||||
}
|
||||
|
||||
for ; len <= max_match_len; len++ {
|
||||
var len_code uint
|
||||
if is_dictionary_match {
|
||||
len_code = backwardMatchLengthCode(&match)
|
||||
} else {
|
||||
len_code = len
|
||||
}
|
||||
var copycode uint16 = getCopyLengthCode(len_code)
|
||||
var cmdcode uint16 = combineLengthCodes(inscode, copycode, false)
|
||||
var cost float32 = dist_cost + float32(getCopyExtra(copycode)) + zopfliCostModelGetCommandCost(model, cmdcode)
|
||||
if cost < nodes[pos+len].u.cost {
|
||||
updateZopfliNode(nodes, pos, start, uint(len), len_code, dist, 0, cost)
|
||||
if len > result {
|
||||
result = len
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func computeShortestPathFromNodes(num_bytes uint, nodes []zopfliNode) uint {
|
||||
var index uint = num_bytes
|
||||
var num_commands uint = 0
|
||||
for nodes[index].dcode_insert_length&0x7FFFFFF == 0 && nodes[index].length == 1 {
|
||||
index--
|
||||
}
|
||||
nodes[index].u.next = math.MaxUint32
|
||||
for index != 0 {
|
||||
var len uint = uint(zopfliNodeCommandLength(&nodes[index]))
|
||||
index -= uint(len)
|
||||
nodes[index].u.next = uint32(len)
|
||||
num_commands++
|
||||
}
|
||||
|
||||
return num_commands
|
||||
}
|
||||
|
||||
/* REQUIRES: nodes != NULL and len(nodes) >= num_bytes + 1 */
|
||||
func zopfliCreateCommands(num_bytes uint, block_start uint, nodes []zopfliNode, dist_cache []int, last_insert_len *uint, params *encoderParams, commands *[]command, num_literals *uint) {
|
||||
var max_backward_limit uint = maxBackwardLimit(params.lgwin)
|
||||
var pos uint = 0
|
||||
var offset uint32 = nodes[0].u.next
|
||||
var i uint
|
||||
var gap uint = 0
|
||||
for i = 0; offset != math.MaxUint32; i++ {
|
||||
var next *zopfliNode = &nodes[uint32(pos)+offset]
|
||||
var copy_length uint = uint(zopfliNodeCopyLength(next))
|
||||
var insert_length uint = uint(next.dcode_insert_length & 0x7FFFFFF)
|
||||
pos += insert_length
|
||||
offset = next.u.next
|
||||
if i == 0 {
|
||||
insert_length += *last_insert_len
|
||||
*last_insert_len = 0
|
||||
}
|
||||
{
|
||||
var distance uint = uint(zopfliNodeCopyDistance(next))
|
||||
var len_code uint = uint(zopfliNodeLengthCode(next))
|
||||
var max_distance uint = brotli_min_size_t(block_start+pos, max_backward_limit)
|
||||
var is_dictionary bool = (distance > max_distance+gap)
|
||||
var dist_code uint = uint(zopfliNodeDistanceCode(next))
|
||||
*commands = append(*commands, makeCommand(¶ms.dist, insert_length, copy_length, int(len_code)-int(copy_length), dist_code))
|
||||
|
||||
if !is_dictionary && dist_code > 0 {
|
||||
dist_cache[3] = dist_cache[2]
|
||||
dist_cache[2] = dist_cache[1]
|
||||
dist_cache[1] = dist_cache[0]
|
||||
dist_cache[0] = int(distance)
|
||||
}
|
||||
}
|
||||
|
||||
*num_literals += insert_length
|
||||
pos += copy_length
|
||||
}
|
||||
|
||||
*last_insert_len += num_bytes - pos
|
||||
}
|
||||
|
||||
func zopfliIterate(num_bytes uint, position uint, ringbuffer []byte, ringbuffer_mask uint, params *encoderParams, gap uint, dist_cache []int, model *zopfliCostModel, num_matches []uint32, matches []backwardMatch, nodes []zopfliNode) uint {
|
||||
var max_backward_limit uint = maxBackwardLimit(params.lgwin)
|
||||
var max_zopfli_len uint = maxZopfliLen(params)
|
||||
var queue startPosQueue
|
||||
var cur_match_pos uint = 0
|
||||
var i uint
|
||||
nodes[0].length = 0
|
||||
nodes[0].u.cost = 0
|
||||
initStartPosQueue(&queue)
|
||||
for i = 0; i+3 < num_bytes; i++ {
|
||||
var skip uint = updateNodes(num_bytes, position, i, ringbuffer, ringbuffer_mask, params, max_backward_limit, dist_cache, uint(num_matches[i]), matches[cur_match_pos:], model, &queue, nodes)
|
||||
if skip < longCopyQuickStep {
|
||||
skip = 0
|
||||
}
|
||||
cur_match_pos += uint(num_matches[i])
|
||||
if num_matches[i] == 1 && backwardMatchLength(&matches[cur_match_pos-1]) > max_zopfli_len {
|
||||
skip = brotli_max_size_t(backwardMatchLength(&matches[cur_match_pos-1]), skip)
|
||||
}
|
||||
|
||||
if skip > 1 {
|
||||
skip--
|
||||
for skip != 0 {
|
||||
i++
|
||||
if i+3 >= num_bytes {
|
||||
break
|
||||
}
|
||||
evaluateNode(position, i, max_backward_limit, gap, dist_cache, model, &queue, nodes)
|
||||
cur_match_pos += uint(num_matches[i])
|
||||
skip--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return computeShortestPathFromNodes(num_bytes, nodes)
|
||||
}
|
||||
|
||||
/* Computes the shortest path of commands from position to at most
|
||||
position + num_bytes.
|
||||
|
||||
On return, path->size() is the number of commands found and path[i] is the
|
||||
length of the i-th command (copy length plus insert length).
|
||||
Note that the sum of the lengths of all commands can be less than num_bytes.
|
||||
|
||||
On return, the nodes[0..num_bytes] array will have the following
|
||||
"ZopfliNode array invariant":
|
||||
For each i in [1..num_bytes], if nodes[i].cost < kInfinity, then
|
||||
(1) nodes[i].copy_length() >= 2
|
||||
(2) nodes[i].command_length() <= i and
|
||||
(3) nodes[i - nodes[i].command_length()].cost < kInfinity
|
||||
|
||||
REQUIRES: nodes != nil and len(nodes) >= num_bytes + 1 */
|
||||
func zopfliComputeShortestPath(num_bytes uint, position uint, ringbuffer []byte, ringbuffer_mask uint, params *encoderParams, dist_cache []int, hasher *h10, nodes []zopfliNode) uint {
|
||||
var max_backward_limit uint = maxBackwardLimit(params.lgwin)
|
||||
var max_zopfli_len uint = maxZopfliLen(params)
|
||||
var model zopfliCostModel
|
||||
var queue startPosQueue
|
||||
var matches [2 * (maxNumMatchesH10 + 64)]backwardMatch
|
||||
var store_end uint
|
||||
if num_bytes >= hasher.StoreLookahead() {
|
||||
store_end = position + num_bytes - hasher.StoreLookahead() + 1
|
||||
} else {
|
||||
store_end = position
|
||||
}
|
||||
var i uint
|
||||
var gap uint = 0
|
||||
var lz_matches_offset uint = 0
|
||||
nodes[0].length = 0
|
||||
nodes[0].u.cost = 0
|
||||
initZopfliCostModel(&model, ¶ms.dist, num_bytes)
|
||||
zopfliCostModelSetFromLiteralCosts(&model, position, ringbuffer, ringbuffer_mask)
|
||||
initStartPosQueue(&queue)
|
||||
for i = 0; i+hasher.HashTypeLength()-1 < num_bytes; i++ {
|
||||
var pos uint = position + i
|
||||
var max_distance uint = brotli_min_size_t(pos, max_backward_limit)
|
||||
var skip uint
|
||||
var num_matches uint
|
||||
num_matches = findAllMatchesH10(hasher, ¶ms.dictionary, ringbuffer, ringbuffer_mask, pos, num_bytes-i, max_distance, gap, params, matches[lz_matches_offset:])
|
||||
if num_matches > 0 && backwardMatchLength(&matches[num_matches-1]) > max_zopfli_len {
|
||||
matches[0] = matches[num_matches-1]
|
||||
num_matches = 1
|
||||
}
|
||||
|
||||
skip = updateNodes(num_bytes, position, i, ringbuffer, ringbuffer_mask, params, max_backward_limit, dist_cache, num_matches, matches[:], &model, &queue, nodes)
|
||||
if skip < longCopyQuickStep {
|
||||
skip = 0
|
||||
}
|
||||
if num_matches == 1 && backwardMatchLength(&matches[0]) > max_zopfli_len {
|
||||
skip = brotli_max_size_t(backwardMatchLength(&matches[0]), skip)
|
||||
}
|
||||
|
||||
if skip > 1 {
|
||||
/* Add the tail of the copy to the hasher. */
|
||||
hasher.StoreRange(ringbuffer, ringbuffer_mask, pos+1, brotli_min_size_t(pos+skip, store_end))
|
||||
|
||||
skip--
|
||||
for skip != 0 {
|
||||
i++
|
||||
if i+hasher.HashTypeLength()-1 >= num_bytes {
|
||||
break
|
||||
}
|
||||
evaluateNode(position, i, max_backward_limit, gap, dist_cache, &model, &queue, nodes)
|
||||
skip--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cleanupZopfliCostModel(&model)
|
||||
return computeShortestPathFromNodes(num_bytes, nodes)
|
||||
}
|
||||
|
||||
func createZopfliBackwardReferences(num_bytes uint, position uint, ringbuffer []byte, ringbuffer_mask uint, params *encoderParams, hasher *h10, dist_cache []int, last_insert_len *uint, commands *[]command, num_literals *uint) {
|
||||
var nodes []zopfliNode
|
||||
nodes = make([]zopfliNode, (num_bytes + 1))
|
||||
initZopfliNodes(nodes, num_bytes+1)
|
||||
zopfliComputeShortestPath(num_bytes, position, ringbuffer, ringbuffer_mask, params, dist_cache, hasher, nodes)
|
||||
zopfliCreateCommands(num_bytes, position, nodes, dist_cache, last_insert_len, params, commands, num_literals)
|
||||
nodes = nil
|
||||
}
|
||||
|
||||
func createHqZopfliBackwardReferences(num_bytes uint, position uint, ringbuffer []byte, ringbuffer_mask uint, params *encoderParams, hasher hasherHandle, dist_cache []int, last_insert_len *uint, commands *[]command, num_literals *uint) {
|
||||
var max_backward_limit uint = maxBackwardLimit(params.lgwin)
|
||||
var num_matches []uint32 = make([]uint32, num_bytes)
|
||||
var matches_size uint = 4 * num_bytes
|
||||
var store_end uint
|
||||
if num_bytes >= hasher.StoreLookahead() {
|
||||
store_end = position + num_bytes - hasher.StoreLookahead() + 1
|
||||
} else {
|
||||
store_end = position
|
||||
}
|
||||
var cur_match_pos uint = 0
|
||||
var i uint
|
||||
var orig_num_literals uint
|
||||
var orig_last_insert_len uint
|
||||
var orig_dist_cache [4]int
|
||||
var orig_num_commands int
|
||||
var model zopfliCostModel
|
||||
var nodes []zopfliNode
|
||||
var matches []backwardMatch = make([]backwardMatch, matches_size)
|
||||
var gap uint = 0
|
||||
var shadow_matches uint = 0
|
||||
var new_array []backwardMatch
|
||||
for i = 0; i+hasher.HashTypeLength()-1 < num_bytes; i++ {
|
||||
var pos uint = position + i
|
||||
var max_distance uint = brotli_min_size_t(pos, max_backward_limit)
|
||||
var max_length uint = num_bytes - i
|
||||
var num_found_matches uint
|
||||
var cur_match_end uint
|
||||
var j uint
|
||||
|
||||
/* Ensure that we have enough free slots. */
|
||||
if matches_size < cur_match_pos+maxNumMatchesH10+shadow_matches {
|
||||
var new_size uint = matches_size
|
||||
if new_size == 0 {
|
||||
new_size = cur_match_pos + maxNumMatchesH10 + shadow_matches
|
||||
}
|
||||
|
||||
for new_size < cur_match_pos+maxNumMatchesH10+shadow_matches {
|
||||
new_size *= 2
|
||||
}
|
||||
|
||||
new_array = make([]backwardMatch, new_size)
|
||||
if matches_size != 0 {
|
||||
copy(new_array, matches[:matches_size])
|
||||
}
|
||||
|
||||
matches = new_array
|
||||
matches_size = new_size
|
||||
}
|
||||
|
||||
num_found_matches = findAllMatchesH10(hasher.(*h10), ¶ms.dictionary, ringbuffer, ringbuffer_mask, pos, max_length, max_distance, gap, params, matches[cur_match_pos+shadow_matches:])
|
||||
cur_match_end = cur_match_pos + num_found_matches
|
||||
for j = cur_match_pos; j+1 < cur_match_end; j++ {
|
||||
assert(backwardMatchLength(&matches[j]) <= backwardMatchLength(&matches[j+1]))
|
||||
}
|
||||
|
||||
num_matches[i] = uint32(num_found_matches)
|
||||
if num_found_matches > 0 {
|
||||
var match_len uint = backwardMatchLength(&matches[cur_match_end-1])
|
||||
if match_len > maxZopfliLenQuality11 {
|
||||
var skip uint = match_len - 1
|
||||
matches[cur_match_pos] = matches[cur_match_end-1]
|
||||
cur_match_pos++
|
||||
num_matches[i] = 1
|
||||
|
||||
/* Add the tail of the copy to the hasher. */
|
||||
hasher.StoreRange(ringbuffer, ringbuffer_mask, pos+1, brotli_min_size_t(pos+match_len, store_end))
|
||||
var pos uint = i
|
||||
for i := 0; i < int(skip); i++ {
|
||||
num_matches[pos+1:][i] = 0
|
||||
}
|
||||
i += skip
|
||||
} else {
|
||||
cur_match_pos = cur_match_end
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
orig_num_literals = *num_literals
|
||||
orig_last_insert_len = *last_insert_len
|
||||
copy(orig_dist_cache[:], dist_cache[:4])
|
||||
orig_num_commands = len(*commands)
|
||||
nodes = make([]zopfliNode, (num_bytes + 1))
|
||||
initZopfliCostModel(&model, ¶ms.dist, num_bytes)
|
||||
for i = 0; i < 2; i++ {
|
||||
initZopfliNodes(nodes, num_bytes+1)
|
||||
if i == 0 {
|
||||
zopfliCostModelSetFromLiteralCosts(&model, position, ringbuffer, ringbuffer_mask)
|
||||
} else {
|
||||
zopfliCostModelSetFromCommands(&model, position, ringbuffer, ringbuffer_mask, (*commands)[orig_num_commands:], orig_last_insert_len)
|
||||
}
|
||||
|
||||
*commands = (*commands)[:orig_num_commands]
|
||||
*num_literals = orig_num_literals
|
||||
*last_insert_len = orig_last_insert_len
|
||||
copy(dist_cache, orig_dist_cache[:4])
|
||||
zopfliIterate(num_bytes, position, ringbuffer, ringbuffer_mask, params, gap, dist_cache, &model, num_matches, matches, nodes)
|
||||
zopfliCreateCommands(num_bytes, position, nodes, dist_cache, last_insert_len, params, commands, num_literals)
|
||||
}
|
||||
|
||||
cleanupZopfliCostModel(&model)
|
||||
nodes = nil
|
||||
matches = nil
|
||||
num_matches = nil
|
||||
}
|
||||
+436
@@ -0,0 +1,436 @@
|
||||
package brotli
|
||||
|
||||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Functions to estimate the bit cost of Huffman trees. */
|
||||
func shannonEntropy(population []uint32, size uint, total *uint) float64 {
|
||||
var sum uint = 0
|
||||
var retval float64 = 0
|
||||
var population_end []uint32 = population[size:]
|
||||
var p uint
|
||||
for -cap(population) < -cap(population_end) {
|
||||
p = uint(population[0])
|
||||
population = population[1:]
|
||||
sum += p
|
||||
retval -= float64(p) * fastLog2(p)
|
||||
}
|
||||
|
||||
if sum != 0 {
|
||||
retval += float64(sum) * fastLog2(sum)
|
||||
}
|
||||
*total = sum
|
||||
return retval
|
||||
}
|
||||
|
||||
func bitsEntropy(population []uint32, size uint) float64 {
|
||||
var sum uint
|
||||
var retval float64 = shannonEntropy(population, size, &sum)
|
||||
if retval < float64(sum) {
|
||||
/* At least one bit per literal is needed. */
|
||||
retval = float64(sum)
|
||||
}
|
||||
|
||||
return retval
|
||||
}
|
||||
|
||||
const kOneSymbolHistogramCost float64 = 12
|
||||
const kTwoSymbolHistogramCost float64 = 20
|
||||
const kThreeSymbolHistogramCost float64 = 28
|
||||
const kFourSymbolHistogramCost float64 = 37
|
||||
|
||||
func populationCostLiteral(histogram *histogramLiteral) float64 {
|
||||
var data_size uint = histogramDataSizeLiteral()
|
||||
var count int = 0
|
||||
var s [5]uint
|
||||
var bits float64 = 0.0
|
||||
var i uint
|
||||
if histogram.total_count_ == 0 {
|
||||
return kOneSymbolHistogramCost
|
||||
}
|
||||
|
||||
for i = 0; i < data_size; i++ {
|
||||
if histogram.data_[i] > 0 {
|
||||
s[count] = i
|
||||
count++
|
||||
if count > 4 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if count == 1 {
|
||||
return kOneSymbolHistogramCost
|
||||
}
|
||||
|
||||
if count == 2 {
|
||||
return kTwoSymbolHistogramCost + float64(histogram.total_count_)
|
||||
}
|
||||
|
||||
if count == 3 {
|
||||
var histo0 uint32 = histogram.data_[s[0]]
|
||||
var histo1 uint32 = histogram.data_[s[1]]
|
||||
var histo2 uint32 = histogram.data_[s[2]]
|
||||
var histomax uint32 = brotli_max_uint32_t(histo0, brotli_max_uint32_t(histo1, histo2))
|
||||
return kThreeSymbolHistogramCost + 2*(float64(histo0)+float64(histo1)+float64(histo2)) - float64(histomax)
|
||||
}
|
||||
|
||||
if count == 4 {
|
||||
var histo [4]uint32
|
||||
var h23 uint32
|
||||
var histomax uint32
|
||||
for i = 0; i < 4; i++ {
|
||||
histo[i] = histogram.data_[s[i]]
|
||||
}
|
||||
|
||||
/* Sort */
|
||||
for i = 0; i < 4; i++ {
|
||||
var j uint
|
||||
for j = i + 1; j < 4; j++ {
|
||||
if histo[j] > histo[i] {
|
||||
var tmp uint32 = histo[j]
|
||||
histo[j] = histo[i]
|
||||
histo[i] = tmp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
h23 = histo[2] + histo[3]
|
||||
histomax = brotli_max_uint32_t(h23, histo[0])
|
||||
return kFourSymbolHistogramCost + 3*float64(h23) + 2*(float64(histo[0])+float64(histo[1])) - float64(histomax)
|
||||
}
|
||||
{
|
||||
var max_depth uint = 1
|
||||
var depth_histo = [codeLengthCodes]uint32{0}
|
||||
/* In this loop we compute the entropy of the histogram and simultaneously
|
||||
build a simplified histogram of the code length codes where we use the
|
||||
zero repeat code 17, but we don't use the non-zero repeat code 16. */
|
||||
|
||||
var log2total float64 = fastLog2(histogram.total_count_)
|
||||
for i = 0; i < data_size; {
|
||||
if histogram.data_[i] > 0 {
|
||||
var log2p float64 = log2total - fastLog2(uint(histogram.data_[i]))
|
||||
/* Compute -log2(P(symbol)) = -log2(count(symbol)/total_count) =
|
||||
= log2(total_count) - log2(count(symbol)) */
|
||||
|
||||
var depth uint = uint(log2p + 0.5)
|
||||
/* Approximate the bit depth by round(-log2(P(symbol))) */
|
||||
bits += float64(histogram.data_[i]) * log2p
|
||||
|
||||
if depth > 15 {
|
||||
depth = 15
|
||||
}
|
||||
|
||||
if depth > max_depth {
|
||||
max_depth = depth
|
||||
}
|
||||
|
||||
depth_histo[depth]++
|
||||
i++
|
||||
} else {
|
||||
var reps uint32 = 1
|
||||
/* Compute the run length of zeros and add the appropriate number of 0
|
||||
and 17 code length codes to the code length code histogram. */
|
||||
|
||||
var k uint
|
||||
for k = i + 1; k < data_size && histogram.data_[k] == 0; k++ {
|
||||
reps++
|
||||
}
|
||||
|
||||
i += uint(reps)
|
||||
if i == data_size {
|
||||
/* Don't add any cost for the last zero run, since these are encoded
|
||||
only implicitly. */
|
||||
break
|
||||
}
|
||||
|
||||
if reps < 3 {
|
||||
depth_histo[0] += reps
|
||||
} else {
|
||||
reps -= 2
|
||||
for reps > 0 {
|
||||
depth_histo[repeatZeroCodeLength]++
|
||||
|
||||
/* Add the 3 extra bits for the 17 code length code. */
|
||||
bits += 3
|
||||
|
||||
reps >>= 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Add the estimated encoding cost of the code length code histogram. */
|
||||
bits += float64(18 + 2*max_depth)
|
||||
|
||||
/* Add the entropy of the code length code histogram. */
|
||||
bits += bitsEntropy(depth_histo[:], codeLengthCodes)
|
||||
}
|
||||
|
||||
return bits
|
||||
}
|
||||
|
||||
func populationCostCommand(histogram *histogramCommand) float64 {
|
||||
var data_size uint = histogramDataSizeCommand()
|
||||
var count int = 0
|
||||
var s [5]uint
|
||||
var bits float64 = 0.0
|
||||
var i uint
|
||||
if histogram.total_count_ == 0 {
|
||||
return kOneSymbolHistogramCost
|
||||
}
|
||||
|
||||
for i = 0; i < data_size; i++ {
|
||||
if histogram.data_[i] > 0 {
|
||||
s[count] = i
|
||||
count++
|
||||
if count > 4 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if count == 1 {
|
||||
return kOneSymbolHistogramCost
|
||||
}
|
||||
|
||||
if count == 2 {
|
||||
return kTwoSymbolHistogramCost + float64(histogram.total_count_)
|
||||
}
|
||||
|
||||
if count == 3 {
|
||||
var histo0 uint32 = histogram.data_[s[0]]
|
||||
var histo1 uint32 = histogram.data_[s[1]]
|
||||
var histo2 uint32 = histogram.data_[s[2]]
|
||||
var histomax uint32 = brotli_max_uint32_t(histo0, brotli_max_uint32_t(histo1, histo2))
|
||||
return kThreeSymbolHistogramCost + 2*(float64(histo0)+float64(histo1)+float64(histo2)) - float64(histomax)
|
||||
}
|
||||
|
||||
if count == 4 {
|
||||
var histo [4]uint32
|
||||
var h23 uint32
|
||||
var histomax uint32
|
||||
for i = 0; i < 4; i++ {
|
||||
histo[i] = histogram.data_[s[i]]
|
||||
}
|
||||
|
||||
/* Sort */
|
||||
for i = 0; i < 4; i++ {
|
||||
var j uint
|
||||
for j = i + 1; j < 4; j++ {
|
||||
if histo[j] > histo[i] {
|
||||
var tmp uint32 = histo[j]
|
||||
histo[j] = histo[i]
|
||||
histo[i] = tmp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
h23 = histo[2] + histo[3]
|
||||
histomax = brotli_max_uint32_t(h23, histo[0])
|
||||
return kFourSymbolHistogramCost + 3*float64(h23) + 2*(float64(histo[0])+float64(histo[1])) - float64(histomax)
|
||||
}
|
||||
{
|
||||
var max_depth uint = 1
|
||||
var depth_histo = [codeLengthCodes]uint32{0}
|
||||
/* In this loop we compute the entropy of the histogram and simultaneously
|
||||
build a simplified histogram of the code length codes where we use the
|
||||
zero repeat code 17, but we don't use the non-zero repeat code 16. */
|
||||
|
||||
var log2total float64 = fastLog2(histogram.total_count_)
|
||||
for i = 0; i < data_size; {
|
||||
if histogram.data_[i] > 0 {
|
||||
var log2p float64 = log2total - fastLog2(uint(histogram.data_[i]))
|
||||
/* Compute -log2(P(symbol)) = -log2(count(symbol)/total_count) =
|
||||
= log2(total_count) - log2(count(symbol)) */
|
||||
|
||||
var depth uint = uint(log2p + 0.5)
|
||||
/* Approximate the bit depth by round(-log2(P(symbol))) */
|
||||
bits += float64(histogram.data_[i]) * log2p
|
||||
|
||||
if depth > 15 {
|
||||
depth = 15
|
||||
}
|
||||
|
||||
if depth > max_depth {
|
||||
max_depth = depth
|
||||
}
|
||||
|
||||
depth_histo[depth]++
|
||||
i++
|
||||
} else {
|
||||
var reps uint32 = 1
|
||||
/* Compute the run length of zeros and add the appropriate number of 0
|
||||
and 17 code length codes to the code length code histogram. */
|
||||
|
||||
var k uint
|
||||
for k = i + 1; k < data_size && histogram.data_[k] == 0; k++ {
|
||||
reps++
|
||||
}
|
||||
|
||||
i += uint(reps)
|
||||
if i == data_size {
|
||||
/* Don't add any cost for the last zero run, since these are encoded
|
||||
only implicitly. */
|
||||
break
|
||||
}
|
||||
|
||||
if reps < 3 {
|
||||
depth_histo[0] += reps
|
||||
} else {
|
||||
reps -= 2
|
||||
for reps > 0 {
|
||||
depth_histo[repeatZeroCodeLength]++
|
||||
|
||||
/* Add the 3 extra bits for the 17 code length code. */
|
||||
bits += 3
|
||||
|
||||
reps >>= 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Add the estimated encoding cost of the code length code histogram. */
|
||||
bits += float64(18 + 2*max_depth)
|
||||
|
||||
/* Add the entropy of the code length code histogram. */
|
||||
bits += bitsEntropy(depth_histo[:], codeLengthCodes)
|
||||
}
|
||||
|
||||
return bits
|
||||
}
|
||||
|
||||
func populationCostDistance(histogram *histogramDistance) float64 {
|
||||
var data_size uint = histogramDataSizeDistance()
|
||||
var count int = 0
|
||||
var s [5]uint
|
||||
var bits float64 = 0.0
|
||||
var i uint
|
||||
if histogram.total_count_ == 0 {
|
||||
return kOneSymbolHistogramCost
|
||||
}
|
||||
|
||||
for i = 0; i < data_size; i++ {
|
||||
if histogram.data_[i] > 0 {
|
||||
s[count] = i
|
||||
count++
|
||||
if count > 4 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if count == 1 {
|
||||
return kOneSymbolHistogramCost
|
||||
}
|
||||
|
||||
if count == 2 {
|
||||
return kTwoSymbolHistogramCost + float64(histogram.total_count_)
|
||||
}
|
||||
|
||||
if count == 3 {
|
||||
var histo0 uint32 = histogram.data_[s[0]]
|
||||
var histo1 uint32 = histogram.data_[s[1]]
|
||||
var histo2 uint32 = histogram.data_[s[2]]
|
||||
var histomax uint32 = brotli_max_uint32_t(histo0, brotli_max_uint32_t(histo1, histo2))
|
||||
return kThreeSymbolHistogramCost + 2*(float64(histo0)+float64(histo1)+float64(histo2)) - float64(histomax)
|
||||
}
|
||||
|
||||
if count == 4 {
|
||||
var histo [4]uint32
|
||||
var h23 uint32
|
||||
var histomax uint32
|
||||
for i = 0; i < 4; i++ {
|
||||
histo[i] = histogram.data_[s[i]]
|
||||
}
|
||||
|
||||
/* Sort */
|
||||
for i = 0; i < 4; i++ {
|
||||
var j uint
|
||||
for j = i + 1; j < 4; j++ {
|
||||
if histo[j] > histo[i] {
|
||||
var tmp uint32 = histo[j]
|
||||
histo[j] = histo[i]
|
||||
histo[i] = tmp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
h23 = histo[2] + histo[3]
|
||||
histomax = brotli_max_uint32_t(h23, histo[0])
|
||||
return kFourSymbolHistogramCost + 3*float64(h23) + 2*(float64(histo[0])+float64(histo[1])) - float64(histomax)
|
||||
}
|
||||
{
|
||||
var max_depth uint = 1
|
||||
var depth_histo = [codeLengthCodes]uint32{0}
|
||||
/* In this loop we compute the entropy of the histogram and simultaneously
|
||||
build a simplified histogram of the code length codes where we use the
|
||||
zero repeat code 17, but we don't use the non-zero repeat code 16. */
|
||||
|
||||
var log2total float64 = fastLog2(histogram.total_count_)
|
||||
for i = 0; i < data_size; {
|
||||
if histogram.data_[i] > 0 {
|
||||
var log2p float64 = log2total - fastLog2(uint(histogram.data_[i]))
|
||||
/* Compute -log2(P(symbol)) = -log2(count(symbol)/total_count) =
|
||||
= log2(total_count) - log2(count(symbol)) */
|
||||
|
||||
var depth uint = uint(log2p + 0.5)
|
||||
/* Approximate the bit depth by round(-log2(P(symbol))) */
|
||||
bits += float64(histogram.data_[i]) * log2p
|
||||
|
||||
if depth > 15 {
|
||||
depth = 15
|
||||
}
|
||||
|
||||
if depth > max_depth {
|
||||
max_depth = depth
|
||||
}
|
||||
|
||||
depth_histo[depth]++
|
||||
i++
|
||||
} else {
|
||||
var reps uint32 = 1
|
||||
/* Compute the run length of zeros and add the appropriate number of 0
|
||||
and 17 code length codes to the code length code histogram. */
|
||||
|
||||
var k uint
|
||||
for k = i + 1; k < data_size && histogram.data_[k] == 0; k++ {
|
||||
reps++
|
||||
}
|
||||
|
||||
i += uint(reps)
|
||||
if i == data_size {
|
||||
/* Don't add any cost for the last zero run, since these are encoded
|
||||
only implicitly. */
|
||||
break
|
||||
}
|
||||
|
||||
if reps < 3 {
|
||||
depth_histo[0] += reps
|
||||
} else {
|
||||
reps -= 2
|
||||
for reps > 0 {
|
||||
depth_histo[repeatZeroCodeLength]++
|
||||
|
||||
/* Add the 3 extra bits for the 17 code length code. */
|
||||
bits += 3
|
||||
|
||||
reps >>= 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Add the estimated encoding cost of the code length code histogram. */
|
||||
bits += float64(18 + 2*max_depth)
|
||||
|
||||
/* Add the entropy of the code length code histogram. */
|
||||
bits += bitsEntropy(depth_histo[:], codeLengthCodes)
|
||||
}
|
||||
|
||||
return bits
|
||||
}
|
||||
+266
@@ -0,0 +1,266 @@
|
||||
package brotli
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Bit reading helpers */
|
||||
|
||||
const shortFillBitWindowRead = (8 >> 1)
|
||||
|
||||
var kBitMask = [33]uint32{
|
||||
0x00000000,
|
||||
0x00000001,
|
||||
0x00000003,
|
||||
0x00000007,
|
||||
0x0000000F,
|
||||
0x0000001F,
|
||||
0x0000003F,
|
||||
0x0000007F,
|
||||
0x000000FF,
|
||||
0x000001FF,
|
||||
0x000003FF,
|
||||
0x000007FF,
|
||||
0x00000FFF,
|
||||
0x00001FFF,
|
||||
0x00003FFF,
|
||||
0x00007FFF,
|
||||
0x0000FFFF,
|
||||
0x0001FFFF,
|
||||
0x0003FFFF,
|
||||
0x0007FFFF,
|
||||
0x000FFFFF,
|
||||
0x001FFFFF,
|
||||
0x003FFFFF,
|
||||
0x007FFFFF,
|
||||
0x00FFFFFF,
|
||||
0x01FFFFFF,
|
||||
0x03FFFFFF,
|
||||
0x07FFFFFF,
|
||||
0x0FFFFFFF,
|
||||
0x1FFFFFFF,
|
||||
0x3FFFFFFF,
|
||||
0x7FFFFFFF,
|
||||
0xFFFFFFFF,
|
||||
}
|
||||
|
||||
func bitMask(n uint32) uint32 {
|
||||
return kBitMask[n]
|
||||
}
|
||||
|
||||
type bitReader struct {
|
||||
val_ uint64
|
||||
bit_pos_ uint32
|
||||
input []byte
|
||||
input_len uint
|
||||
byte_pos uint
|
||||
}
|
||||
|
||||
type bitReaderState struct {
|
||||
val_ uint64
|
||||
bit_pos_ uint32
|
||||
input []byte
|
||||
input_len uint
|
||||
byte_pos uint
|
||||
}
|
||||
|
||||
/* Initializes the BrotliBitReader fields. */
|
||||
|
||||
/* Ensures that accumulator is not empty.
|
||||
May consume up to sizeof(brotli_reg_t) - 1 bytes of input.
|
||||
Returns false if data is required but there is no input available.
|
||||
For BROTLI_ALIGNED_READ this function also prepares bit reader for aligned
|
||||
reading. */
|
||||
func bitReaderSaveState(from *bitReader, to *bitReaderState) {
|
||||
to.val_ = from.val_
|
||||
to.bit_pos_ = from.bit_pos_
|
||||
to.input = from.input
|
||||
to.input_len = from.input_len
|
||||
to.byte_pos = from.byte_pos
|
||||
}
|
||||
|
||||
func bitReaderRestoreState(to *bitReader, from *bitReaderState) {
|
||||
to.val_ = from.val_
|
||||
to.bit_pos_ = from.bit_pos_
|
||||
to.input = from.input
|
||||
to.input_len = from.input_len
|
||||
to.byte_pos = from.byte_pos
|
||||
}
|
||||
|
||||
func getAvailableBits(br *bitReader) uint32 {
|
||||
return 64 - br.bit_pos_
|
||||
}
|
||||
|
||||
/* Returns amount of unread bytes the bit reader still has buffered from the
|
||||
BrotliInput, including whole bytes in br->val_. */
|
||||
func getRemainingBytes(br *bitReader) uint {
|
||||
return uint(uint32(br.input_len-br.byte_pos) + (getAvailableBits(br) >> 3))
|
||||
}
|
||||
|
||||
/* Checks if there is at least |num| bytes left in the input ring-buffer
|
||||
(excluding the bits remaining in br->val_). */
|
||||
func checkInputAmount(br *bitReader, num uint) bool {
|
||||
return br.input_len-br.byte_pos >= num
|
||||
}
|
||||
|
||||
/* Guarantees that there are at least |n_bits| + 1 bits in accumulator.
|
||||
Precondition: accumulator contains at least 1 bit.
|
||||
|n_bits| should be in the range [1..24] for regular build. For portable
|
||||
non-64-bit little-endian build only 16 bits are safe to request. */
|
||||
func fillBitWindow(br *bitReader, n_bits uint32) {
|
||||
if br.bit_pos_ >= 32 {
|
||||
br.val_ >>= 32
|
||||
br.bit_pos_ ^= 32 /* here same as -= 32 because of the if condition */
|
||||
br.val_ |= (uint64(binary.LittleEndian.Uint32(br.input[br.byte_pos:]))) << 32
|
||||
br.byte_pos += 4
|
||||
}
|
||||
}
|
||||
|
||||
/* Mostly like BrotliFillBitWindow, but guarantees only 16 bits and reads no
|
||||
more than BROTLI_SHORT_FILL_BIT_WINDOW_READ bytes of input. */
|
||||
func fillBitWindow16(br *bitReader) {
|
||||
fillBitWindow(br, 17)
|
||||
}
|
||||
|
||||
/* Tries to pull one byte of input to accumulator.
|
||||
Returns false if there is no input available. */
|
||||
func pullByte(br *bitReader) bool {
|
||||
if br.byte_pos == br.input_len {
|
||||
return false
|
||||
}
|
||||
|
||||
br.val_ >>= 8
|
||||
br.val_ |= (uint64(br.input[br.byte_pos])) << 56
|
||||
br.bit_pos_ -= 8
|
||||
br.byte_pos++
|
||||
return true
|
||||
}
|
||||
|
||||
/* Returns currently available bits.
|
||||
The number of valid bits could be calculated by BrotliGetAvailableBits. */
|
||||
func getBitsUnmasked(br *bitReader) uint64 {
|
||||
return br.val_ >> br.bit_pos_
|
||||
}
|
||||
|
||||
/* Like BrotliGetBits, but does not mask the result.
|
||||
The result contains at least 16 valid bits. */
|
||||
func get16BitsUnmasked(br *bitReader) uint32 {
|
||||
fillBitWindow(br, 16)
|
||||
return uint32(getBitsUnmasked(br))
|
||||
}
|
||||
|
||||
/* Returns the specified number of bits from |br| without advancing bit
|
||||
position. */
|
||||
func getBits(br *bitReader, n_bits uint32) uint32 {
|
||||
fillBitWindow(br, n_bits)
|
||||
return uint32(getBitsUnmasked(br)) & bitMask(n_bits)
|
||||
}
|
||||
|
||||
/* Tries to peek the specified amount of bits. Returns false, if there
|
||||
is not enough input. */
|
||||
func safeGetBits(br *bitReader, n_bits uint32, val *uint32) bool {
|
||||
for getAvailableBits(br) < n_bits {
|
||||
if !pullByte(br) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
*val = uint32(getBitsUnmasked(br)) & bitMask(n_bits)
|
||||
return true
|
||||
}
|
||||
|
||||
/* Advances the bit pos by |n_bits|. */
|
||||
func dropBits(br *bitReader, n_bits uint32) {
|
||||
br.bit_pos_ += n_bits
|
||||
}
|
||||
|
||||
func bitReaderUnload(br *bitReader) {
|
||||
var unused_bytes uint32 = getAvailableBits(br) >> 3
|
||||
var unused_bits uint32 = unused_bytes << 3
|
||||
br.byte_pos -= uint(unused_bytes)
|
||||
if unused_bits == 64 {
|
||||
br.val_ = 0
|
||||
} else {
|
||||
br.val_ <<= unused_bits
|
||||
}
|
||||
|
||||
br.bit_pos_ += unused_bits
|
||||
}
|
||||
|
||||
/* Reads the specified number of bits from |br| and advances the bit pos.
|
||||
Precondition: accumulator MUST contain at least |n_bits|. */
|
||||
func takeBits(br *bitReader, n_bits uint32, val *uint32) {
|
||||
*val = uint32(getBitsUnmasked(br)) & bitMask(n_bits)
|
||||
dropBits(br, n_bits)
|
||||
}
|
||||
|
||||
/* Reads the specified number of bits from |br| and advances the bit pos.
|
||||
Assumes that there is enough input to perform BrotliFillBitWindow. */
|
||||
func readBits(br *bitReader, n_bits uint32) uint32 {
|
||||
var val uint32
|
||||
fillBitWindow(br, n_bits)
|
||||
takeBits(br, n_bits, &val)
|
||||
return val
|
||||
}
|
||||
|
||||
/* Tries to read the specified amount of bits. Returns false, if there
|
||||
is not enough input. |n_bits| MUST be positive. */
|
||||
func safeReadBits(br *bitReader, n_bits uint32, val *uint32) bool {
|
||||
for getAvailableBits(br) < n_bits {
|
||||
if !pullByte(br) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
takeBits(br, n_bits, val)
|
||||
return true
|
||||
}
|
||||
|
||||
/* Advances the bit reader position to the next byte boundary and verifies
|
||||
that any skipped bits are set to zero. */
|
||||
func bitReaderJumpToByteBoundary(br *bitReader) bool {
|
||||
var pad_bits_count uint32 = getAvailableBits(br) & 0x7
|
||||
var pad_bits uint32 = 0
|
||||
if pad_bits_count != 0 {
|
||||
takeBits(br, pad_bits_count, &pad_bits)
|
||||
}
|
||||
|
||||
return pad_bits == 0
|
||||
}
|
||||
|
||||
/* Copies remaining input bytes stored in the bit reader to the output. Value
|
||||
|num| may not be larger than BrotliGetRemainingBytes. The bit reader must be
|
||||
warmed up again after this. */
|
||||
func copyBytes(dest []byte, br *bitReader, num uint) {
|
||||
for getAvailableBits(br) >= 8 && num > 0 {
|
||||
dest[0] = byte(getBitsUnmasked(br))
|
||||
dropBits(br, 8)
|
||||
dest = dest[1:]
|
||||
num--
|
||||
}
|
||||
|
||||
copy(dest, br.input[br.byte_pos:][:num])
|
||||
br.byte_pos += num
|
||||
}
|
||||
|
||||
func initBitReader(br *bitReader) {
|
||||
br.val_ = 0
|
||||
br.bit_pos_ = 64
|
||||
}
|
||||
|
||||
func warmupBitReader(br *bitReader) bool {
|
||||
/* Fixing alignment after unaligned BrotliFillWindow would result accumulator
|
||||
overflow. If unalignment is caused by BrotliSafeReadBits, then there is
|
||||
enough space in accumulator to fix alignment. */
|
||||
if getAvailableBits(br) == 0 {
|
||||
if !pullByte(br) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
+144
@@ -0,0 +1,144 @@
|
||||
package brotli
|
||||
|
||||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Block split point selection utilities. */
|
||||
|
||||
type blockSplit struct {
|
||||
num_types uint
|
||||
num_blocks uint
|
||||
types []byte
|
||||
lengths []uint32
|
||||
types_alloc_size uint
|
||||
lengths_alloc_size uint
|
||||
}
|
||||
|
||||
const (
|
||||
kMaxLiteralHistograms uint = 100
|
||||
kMaxCommandHistograms uint = 50
|
||||
kLiteralBlockSwitchCost float64 = 28.1
|
||||
kCommandBlockSwitchCost float64 = 13.5
|
||||
kDistanceBlockSwitchCost float64 = 14.6
|
||||
kLiteralStrideLength uint = 70
|
||||
kCommandStrideLength uint = 40
|
||||
kSymbolsPerLiteralHistogram uint = 544
|
||||
kSymbolsPerCommandHistogram uint = 530
|
||||
kSymbolsPerDistanceHistogram uint = 544
|
||||
kMinLengthForBlockSplitting uint = 128
|
||||
kIterMulForRefining uint = 2
|
||||
kMinItersForRefining uint = 100
|
||||
)
|
||||
|
||||
func countLiterals(cmds []command) uint {
|
||||
var total_length uint = 0
|
||||
/* Count how many we have. */
|
||||
|
||||
for i := range cmds {
|
||||
total_length += uint(cmds[i].insert_len_)
|
||||
}
|
||||
|
||||
return total_length
|
||||
}
|
||||
|
||||
func copyLiteralsToByteArray(cmds []command, data []byte, offset uint, mask uint, literals []byte) {
|
||||
var pos uint = 0
|
||||
var from_pos uint = offset & mask
|
||||
for i := range cmds {
|
||||
var insert_len uint = uint(cmds[i].insert_len_)
|
||||
if from_pos+insert_len > mask {
|
||||
var head_size uint = mask + 1 - from_pos
|
||||
copy(literals[pos:], data[from_pos:][:head_size])
|
||||
from_pos = 0
|
||||
pos += head_size
|
||||
insert_len -= head_size
|
||||
}
|
||||
|
||||
if insert_len > 0 {
|
||||
copy(literals[pos:], data[from_pos:][:insert_len])
|
||||
pos += insert_len
|
||||
}
|
||||
|
||||
from_pos = uint((uint32(from_pos+insert_len) + commandCopyLen(&cmds[i])) & uint32(mask))
|
||||
}
|
||||
}
|
||||
|
||||
func myRand(seed *uint32) uint32 {
|
||||
/* Initial seed should be 7. In this case, loop length is (1 << 29). */
|
||||
*seed *= 16807
|
||||
|
||||
return *seed
|
||||
}
|
||||
|
||||
func bitCost(count uint) float64 {
|
||||
if count == 0 {
|
||||
return -2.0
|
||||
} else {
|
||||
return fastLog2(count)
|
||||
}
|
||||
}
|
||||
|
||||
const histogramsPerBatch = 64
|
||||
|
||||
const clustersPerBatch = 16
|
||||
|
||||
func initBlockSplit(self *blockSplit) {
|
||||
self.num_types = 0
|
||||
self.num_blocks = 0
|
||||
self.types = self.types[:0]
|
||||
self.lengths = self.lengths[:0]
|
||||
self.types_alloc_size = 0
|
||||
self.lengths_alloc_size = 0
|
||||
}
|
||||
|
||||
func splitBlock(cmds []command, data []byte, pos uint, mask uint, params *encoderParams, literal_split *blockSplit, insert_and_copy_split *blockSplit, dist_split *blockSplit) {
|
||||
{
|
||||
var literals_count uint = countLiterals(cmds)
|
||||
var literals []byte = make([]byte, literals_count)
|
||||
|
||||
/* Create a continuous array of literals. */
|
||||
copyLiteralsToByteArray(cmds, data, pos, mask, literals)
|
||||
|
||||
/* Create the block split on the array of literals.
|
||||
Literal histograms have alphabet size 256. */
|
||||
splitByteVectorLiteral(literals, literals_count, kSymbolsPerLiteralHistogram, kMaxLiteralHistograms, kLiteralStrideLength, kLiteralBlockSwitchCost, params, literal_split)
|
||||
|
||||
literals = nil
|
||||
}
|
||||
{
|
||||
var insert_and_copy_codes []uint16 = make([]uint16, len(cmds))
|
||||
/* Compute prefix codes for commands. */
|
||||
|
||||
for i := range cmds {
|
||||
insert_and_copy_codes[i] = cmds[i].cmd_prefix_
|
||||
}
|
||||
|
||||
/* Create the block split on the array of command prefixes. */
|
||||
splitByteVectorCommand(insert_and_copy_codes, kSymbolsPerCommandHistogram, kMaxCommandHistograms, kCommandStrideLength, kCommandBlockSwitchCost, params, insert_and_copy_split)
|
||||
|
||||
/* TODO: reuse for distances? */
|
||||
|
||||
insert_and_copy_codes = nil
|
||||
}
|
||||
{
|
||||
var distance_prefixes []uint16 = make([]uint16, len(cmds))
|
||||
var j uint = 0
|
||||
/* Create a continuous array of distance prefixes. */
|
||||
|
||||
for i := range cmds {
|
||||
var cmd *command = &cmds[i]
|
||||
if commandCopyLen(cmd) != 0 && cmd.cmd_prefix_ >= 128 {
|
||||
distance_prefixes[j] = cmd.dist_prefix_ & 0x3FF
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
/* Create the block split on the array of distance prefixes. */
|
||||
splitByteVectorDistance(distance_prefixes, j, kSymbolsPerDistanceHistogram, kMaxCommandHistograms, kCommandStrideLength, kDistanceBlockSwitchCost, params, dist_split)
|
||||
|
||||
distance_prefixes = nil
|
||||
}
|
||||
}
|
||||
+434
@@ -0,0 +1,434 @@
|
||||
package brotli
|
||||
|
||||
import "math"
|
||||
|
||||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
func initialEntropyCodesCommand(data []uint16, length uint, stride uint, num_histograms uint, histograms []histogramCommand) {
|
||||
var seed uint32 = 7
|
||||
var block_length uint = length / num_histograms
|
||||
var i uint
|
||||
clearHistogramsCommand(histograms, num_histograms)
|
||||
for i = 0; i < num_histograms; i++ {
|
||||
var pos uint = length * i / num_histograms
|
||||
if i != 0 {
|
||||
pos += uint(myRand(&seed) % uint32(block_length))
|
||||
}
|
||||
|
||||
if pos+stride >= length {
|
||||
pos = length - stride - 1
|
||||
}
|
||||
|
||||
histogramAddVectorCommand(&histograms[i], data[pos:], stride)
|
||||
}
|
||||
}
|
||||
|
||||
func randomSampleCommand(seed *uint32, data []uint16, length uint, stride uint, sample *histogramCommand) {
|
||||
var pos uint = 0
|
||||
if stride >= length {
|
||||
stride = length
|
||||
} else {
|
||||
pos = uint(myRand(seed) % uint32(length-stride+1))
|
||||
}
|
||||
|
||||
histogramAddVectorCommand(sample, data[pos:], stride)
|
||||
}
|
||||
|
||||
func refineEntropyCodesCommand(data []uint16, length uint, stride uint, num_histograms uint, histograms []histogramCommand) {
|
||||
var iters uint = kIterMulForRefining*length/stride + kMinItersForRefining
|
||||
var seed uint32 = 7
|
||||
var iter uint
|
||||
iters = ((iters + num_histograms - 1) / num_histograms) * num_histograms
|
||||
for iter = 0; iter < iters; iter++ {
|
||||
var sample histogramCommand
|
||||
histogramClearCommand(&sample)
|
||||
randomSampleCommand(&seed, data, length, stride, &sample)
|
||||
histogramAddHistogramCommand(&histograms[iter%num_histograms], &sample)
|
||||
}
|
||||
}
|
||||
|
||||
/* Assigns a block id from the range [0, num_histograms) to each data element
|
||||
in data[0..length) and fills in block_id[0..length) with the assigned values.
|
||||
Returns the number of blocks, i.e. one plus the number of block switches. */
|
||||
func findBlocksCommand(data []uint16, length uint, block_switch_bitcost float64, num_histograms uint, histograms []histogramCommand, insert_cost []float64, cost []float64, switch_signal []byte, block_id []byte) uint {
|
||||
var data_size uint = histogramDataSizeCommand()
|
||||
var bitmaplen uint = (num_histograms + 7) >> 3
|
||||
var num_blocks uint = 1
|
||||
var i uint
|
||||
var j uint
|
||||
assert(num_histograms <= 256)
|
||||
if num_histograms <= 1 {
|
||||
for i = 0; i < length; i++ {
|
||||
block_id[i] = 0
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
for i := 0; i < int(data_size*num_histograms); i++ {
|
||||
insert_cost[i] = 0
|
||||
}
|
||||
for i = 0; i < num_histograms; i++ {
|
||||
insert_cost[i] = fastLog2(uint(uint32(histograms[i].total_count_)))
|
||||
}
|
||||
|
||||
for i = data_size; i != 0; {
|
||||
i--
|
||||
for j = 0; j < num_histograms; j++ {
|
||||
insert_cost[i*num_histograms+j] = insert_cost[j] - bitCost(uint(histograms[j].data_[i]))
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < int(num_histograms); i++ {
|
||||
cost[i] = 0
|
||||
}
|
||||
for i := 0; i < int(length*bitmaplen); i++ {
|
||||
switch_signal[i] = 0
|
||||
}
|
||||
|
||||
/* After each iteration of this loop, cost[k] will contain the difference
|
||||
between the minimum cost of arriving at the current byte position using
|
||||
entropy code k, and the minimum cost of arriving at the current byte
|
||||
position. This difference is capped at the block switch cost, and if it
|
||||
reaches block switch cost, it means that when we trace back from the last
|
||||
position, we need to switch here. */
|
||||
for i = 0; i < length; i++ {
|
||||
var byte_ix uint = i
|
||||
var ix uint = byte_ix * bitmaplen
|
||||
var insert_cost_ix uint = uint(data[byte_ix]) * num_histograms
|
||||
var min_cost float64 = 1e99
|
||||
var block_switch_cost float64 = block_switch_bitcost
|
||||
var k uint
|
||||
for k = 0; k < num_histograms; k++ {
|
||||
/* We are coding the symbol in data[byte_ix] with entropy code k. */
|
||||
cost[k] += insert_cost[insert_cost_ix+k]
|
||||
|
||||
if cost[k] < min_cost {
|
||||
min_cost = cost[k]
|
||||
block_id[byte_ix] = byte(k)
|
||||
}
|
||||
}
|
||||
|
||||
/* More blocks for the beginning. */
|
||||
if byte_ix < 2000 {
|
||||
block_switch_cost *= 0.77 + 0.07*float64(byte_ix)/2000
|
||||
}
|
||||
|
||||
for k = 0; k < num_histograms; k++ {
|
||||
cost[k] -= min_cost
|
||||
if cost[k] >= block_switch_cost {
|
||||
var mask byte = byte(1 << (k & 7))
|
||||
cost[k] = block_switch_cost
|
||||
assert(k>>3 < bitmaplen)
|
||||
switch_signal[ix+(k>>3)] |= mask
|
||||
/* Trace back from the last position and switch at the marked places. */
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
var byte_ix uint = length - 1
|
||||
var ix uint = byte_ix * bitmaplen
|
||||
var cur_id byte = block_id[byte_ix]
|
||||
for byte_ix > 0 {
|
||||
var mask byte = byte(1 << (cur_id & 7))
|
||||
assert(uint(cur_id)>>3 < bitmaplen)
|
||||
byte_ix--
|
||||
ix -= bitmaplen
|
||||
if switch_signal[ix+uint(cur_id>>3)]&mask != 0 {
|
||||
if cur_id != block_id[byte_ix] {
|
||||
cur_id = block_id[byte_ix]
|
||||
num_blocks++
|
||||
}
|
||||
}
|
||||
|
||||
block_id[byte_ix] = cur_id
|
||||
}
|
||||
}
|
||||
|
||||
return num_blocks
|
||||
}
|
||||
|
||||
var remapBlockIdsCommand_kInvalidId uint16 = 256
|
||||
|
||||
func remapBlockIdsCommand(block_ids []byte, length uint, new_id []uint16, num_histograms uint) uint {
|
||||
var next_id uint16 = 0
|
||||
var i uint
|
||||
for i = 0; i < num_histograms; i++ {
|
||||
new_id[i] = remapBlockIdsCommand_kInvalidId
|
||||
}
|
||||
|
||||
for i = 0; i < length; i++ {
|
||||
assert(uint(block_ids[i]) < num_histograms)
|
||||
if new_id[block_ids[i]] == remapBlockIdsCommand_kInvalidId {
|
||||
new_id[block_ids[i]] = next_id
|
||||
next_id++
|
||||
}
|
||||
}
|
||||
|
||||
for i = 0; i < length; i++ {
|
||||
block_ids[i] = byte(new_id[block_ids[i]])
|
||||
assert(uint(block_ids[i]) < num_histograms)
|
||||
}
|
||||
|
||||
assert(uint(next_id) <= num_histograms)
|
||||
return uint(next_id)
|
||||
}
|
||||
|
||||
func buildBlockHistogramsCommand(data []uint16, length uint, block_ids []byte, num_histograms uint, histograms []histogramCommand) {
|
||||
var i uint
|
||||
clearHistogramsCommand(histograms, num_histograms)
|
||||
for i = 0; i < length; i++ {
|
||||
histogramAddCommand(&histograms[block_ids[i]], uint(data[i]))
|
||||
}
|
||||
}
|
||||
|
||||
var clusterBlocksCommand_kInvalidIndex uint32 = math.MaxUint32
|
||||
|
||||
func clusterBlocksCommand(data []uint16, length uint, num_blocks uint, block_ids []byte, split *blockSplit) {
|
||||
var histogram_symbols []uint32 = make([]uint32, num_blocks)
|
||||
var block_lengths []uint32 = make([]uint32, num_blocks)
|
||||
var expected_num_clusters uint = clustersPerBatch * (num_blocks + histogramsPerBatch - 1) / histogramsPerBatch
|
||||
var all_histograms_size uint = 0
|
||||
var all_histograms_capacity uint = expected_num_clusters
|
||||
var all_histograms []histogramCommand = make([]histogramCommand, all_histograms_capacity)
|
||||
var cluster_size_size uint = 0
|
||||
var cluster_size_capacity uint = expected_num_clusters
|
||||
var cluster_size []uint32 = make([]uint32, cluster_size_capacity)
|
||||
var num_clusters uint = 0
|
||||
var histograms []histogramCommand = make([]histogramCommand, brotli_min_size_t(num_blocks, histogramsPerBatch))
|
||||
var max_num_pairs uint = histogramsPerBatch * histogramsPerBatch / 2
|
||||
var pairs_capacity uint = max_num_pairs + 1
|
||||
var pairs []histogramPair = make([]histogramPair, pairs_capacity)
|
||||
var pos uint = 0
|
||||
var clusters []uint32
|
||||
var num_final_clusters uint
|
||||
var new_index []uint32
|
||||
var i uint
|
||||
var sizes = [histogramsPerBatch]uint32{0}
|
||||
var new_clusters = [histogramsPerBatch]uint32{0}
|
||||
var symbols = [histogramsPerBatch]uint32{0}
|
||||
var remap = [histogramsPerBatch]uint32{0}
|
||||
|
||||
for i := 0; i < int(num_blocks); i++ {
|
||||
block_lengths[i] = 0
|
||||
}
|
||||
{
|
||||
var block_idx uint = 0
|
||||
for i = 0; i < length; i++ {
|
||||
assert(block_idx < num_blocks)
|
||||
block_lengths[block_idx]++
|
||||
if i+1 == length || block_ids[i] != block_ids[i+1] {
|
||||
block_idx++
|
||||
}
|
||||
}
|
||||
|
||||
assert(block_idx == num_blocks)
|
||||
}
|
||||
|
||||
for i = 0; i < num_blocks; i += histogramsPerBatch {
|
||||
var num_to_combine uint = brotli_min_size_t(num_blocks-i, histogramsPerBatch)
|
||||
var num_new_clusters uint
|
||||
var j uint
|
||||
for j = 0; j < num_to_combine; j++ {
|
||||
var k uint
|
||||
histogramClearCommand(&histograms[j])
|
||||
for k = 0; uint32(k) < block_lengths[i+j]; k++ {
|
||||
histogramAddCommand(&histograms[j], uint(data[pos]))
|
||||
pos++
|
||||
}
|
||||
|
||||
histograms[j].bit_cost_ = populationCostCommand(&histograms[j])
|
||||
new_clusters[j] = uint32(j)
|
||||
symbols[j] = uint32(j)
|
||||
sizes[j] = 1
|
||||
}
|
||||
|
||||
num_new_clusters = histogramCombineCommand(histograms, sizes[:], symbols[:], new_clusters[:], []histogramPair(pairs), num_to_combine, num_to_combine, histogramsPerBatch, max_num_pairs)
|
||||
if all_histograms_capacity < (all_histograms_size + num_new_clusters) {
|
||||
var _new_size uint
|
||||
if all_histograms_capacity == 0 {
|
||||
_new_size = all_histograms_size + num_new_clusters
|
||||
} else {
|
||||
_new_size = all_histograms_capacity
|
||||
}
|
||||
var new_array []histogramCommand
|
||||
for _new_size < (all_histograms_size + num_new_clusters) {
|
||||
_new_size *= 2
|
||||
}
|
||||
new_array = make([]histogramCommand, _new_size)
|
||||
if all_histograms_capacity != 0 {
|
||||
copy(new_array, all_histograms[:all_histograms_capacity])
|
||||
}
|
||||
|
||||
all_histograms = new_array
|
||||
all_histograms_capacity = _new_size
|
||||
}
|
||||
|
||||
brotli_ensure_capacity_uint32_t(&cluster_size, &cluster_size_capacity, cluster_size_size+num_new_clusters)
|
||||
for j = 0; j < num_new_clusters; j++ {
|
||||
all_histograms[all_histograms_size] = histograms[new_clusters[j]]
|
||||
all_histograms_size++
|
||||
cluster_size[cluster_size_size] = sizes[new_clusters[j]]
|
||||
cluster_size_size++
|
||||
remap[new_clusters[j]] = uint32(j)
|
||||
}
|
||||
|
||||
for j = 0; j < num_to_combine; j++ {
|
||||
histogram_symbols[i+j] = uint32(num_clusters) + remap[symbols[j]]
|
||||
}
|
||||
|
||||
num_clusters += num_new_clusters
|
||||
assert(num_clusters == cluster_size_size)
|
||||
assert(num_clusters == all_histograms_size)
|
||||
}
|
||||
|
||||
histograms = nil
|
||||
|
||||
max_num_pairs = brotli_min_size_t(64*num_clusters, (num_clusters/2)*num_clusters)
|
||||
if pairs_capacity < max_num_pairs+1 {
|
||||
pairs = nil
|
||||
pairs = make([]histogramPair, (max_num_pairs + 1))
|
||||
}
|
||||
|
||||
clusters = make([]uint32, num_clusters)
|
||||
for i = 0; i < num_clusters; i++ {
|
||||
clusters[i] = uint32(i)
|
||||
}
|
||||
|
||||
num_final_clusters = histogramCombineCommand(all_histograms, cluster_size, histogram_symbols, clusters, pairs, num_clusters, num_blocks, maxNumberOfBlockTypes, max_num_pairs)
|
||||
pairs = nil
|
||||
cluster_size = nil
|
||||
|
||||
new_index = make([]uint32, num_clusters)
|
||||
for i = 0; i < num_clusters; i++ {
|
||||
new_index[i] = clusterBlocksCommand_kInvalidIndex
|
||||
}
|
||||
pos = 0
|
||||
{
|
||||
var next_index uint32 = 0
|
||||
for i = 0; i < num_blocks; i++ {
|
||||
var histo histogramCommand
|
||||
var j uint
|
||||
var best_out uint32
|
||||
var best_bits float64
|
||||
histogramClearCommand(&histo)
|
||||
for j = 0; uint32(j) < block_lengths[i]; j++ {
|
||||
histogramAddCommand(&histo, uint(data[pos]))
|
||||
pos++
|
||||
}
|
||||
|
||||
if i == 0 {
|
||||
best_out = histogram_symbols[0]
|
||||
} else {
|
||||
best_out = histogram_symbols[i-1]
|
||||
}
|
||||
best_bits = histogramBitCostDistanceCommand(&histo, &all_histograms[best_out])
|
||||
for j = 0; j < num_final_clusters; j++ {
|
||||
var cur_bits float64 = histogramBitCostDistanceCommand(&histo, &all_histograms[clusters[j]])
|
||||
if cur_bits < best_bits {
|
||||
best_bits = cur_bits
|
||||
best_out = clusters[j]
|
||||
}
|
||||
}
|
||||
|
||||
histogram_symbols[i] = best_out
|
||||
if new_index[best_out] == clusterBlocksCommand_kInvalidIndex {
|
||||
new_index[best_out] = next_index
|
||||
next_index++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
clusters = nil
|
||||
all_histograms = nil
|
||||
brotli_ensure_capacity_uint8_t(&split.types, &split.types_alloc_size, num_blocks)
|
||||
brotli_ensure_capacity_uint32_t(&split.lengths, &split.lengths_alloc_size, num_blocks)
|
||||
{
|
||||
var cur_length uint32 = 0
|
||||
var block_idx uint = 0
|
||||
var max_type byte = 0
|
||||
for i = 0; i < num_blocks; i++ {
|
||||
cur_length += block_lengths[i]
|
||||
if i+1 == num_blocks || histogram_symbols[i] != histogram_symbols[i+1] {
|
||||
var id byte = byte(new_index[histogram_symbols[i]])
|
||||
split.types[block_idx] = id
|
||||
split.lengths[block_idx] = cur_length
|
||||
max_type = brotli_max_uint8_t(max_type, id)
|
||||
cur_length = 0
|
||||
block_idx++
|
||||
}
|
||||
}
|
||||
|
||||
split.num_blocks = block_idx
|
||||
split.num_types = uint(max_type) + 1
|
||||
}
|
||||
|
||||
new_index = nil
|
||||
block_lengths = nil
|
||||
histogram_symbols = nil
|
||||
}
|
||||
|
||||
func splitByteVectorCommand(data []uint16, literals_per_histogram uint, max_histograms uint, sampling_stride_length uint, block_switch_cost float64, params *encoderParams, split *blockSplit) {
|
||||
length := uint(len(data))
|
||||
var data_size uint = histogramDataSizeCommand()
|
||||
var num_histograms uint = length/literals_per_histogram + 1
|
||||
var histograms []histogramCommand
|
||||
if num_histograms > max_histograms {
|
||||
num_histograms = max_histograms
|
||||
}
|
||||
|
||||
if length == 0 {
|
||||
split.num_types = 1
|
||||
return
|
||||
} else if length < kMinLengthForBlockSplitting {
|
||||
brotli_ensure_capacity_uint8_t(&split.types, &split.types_alloc_size, split.num_blocks+1)
|
||||
brotli_ensure_capacity_uint32_t(&split.lengths, &split.lengths_alloc_size, split.num_blocks+1)
|
||||
split.num_types = 1
|
||||
split.types[split.num_blocks] = 0
|
||||
split.lengths[split.num_blocks] = uint32(length)
|
||||
split.num_blocks++
|
||||
return
|
||||
}
|
||||
|
||||
histograms = make([]histogramCommand, num_histograms)
|
||||
|
||||
/* Find good entropy codes. */
|
||||
initialEntropyCodesCommand(data, length, sampling_stride_length, num_histograms, histograms)
|
||||
|
||||
refineEntropyCodesCommand(data, length, sampling_stride_length, num_histograms, histograms)
|
||||
{
|
||||
var block_ids []byte = make([]byte, length)
|
||||
var num_blocks uint = 0
|
||||
var bitmaplen uint = (num_histograms + 7) >> 3
|
||||
var insert_cost []float64 = make([]float64, (data_size * num_histograms))
|
||||
var cost []float64 = make([]float64, num_histograms)
|
||||
var switch_signal []byte = make([]byte, (length * bitmaplen))
|
||||
var new_id []uint16 = make([]uint16, num_histograms)
|
||||
var iters uint
|
||||
if params.quality < hqZopflificationQuality {
|
||||
iters = 3
|
||||
} else {
|
||||
iters = 10
|
||||
}
|
||||
/* Find a good path through literals with the good entropy codes. */
|
||||
|
||||
var i uint
|
||||
for i = 0; i < iters; i++ {
|
||||
num_blocks = findBlocksCommand(data, length, block_switch_cost, num_histograms, histograms, insert_cost, cost, switch_signal, block_ids)
|
||||
num_histograms = remapBlockIdsCommand(block_ids, length, new_id, num_histograms)
|
||||
buildBlockHistogramsCommand(data, length, block_ids, num_histograms, histograms)
|
||||
}
|
||||
|
||||
insert_cost = nil
|
||||
cost = nil
|
||||
switch_signal = nil
|
||||
new_id = nil
|
||||
histograms = nil
|
||||
clusterBlocksCommand(data, length, num_blocks, block_ids, split)
|
||||
block_ids = nil
|
||||
}
|
||||
}
|
||||
+433
@@ -0,0 +1,433 @@
|
||||
package brotli
|
||||
|
||||
import "math"
|
||||
|
||||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
func initialEntropyCodesDistance(data []uint16, length uint, stride uint, num_histograms uint, histograms []histogramDistance) {
|
||||
var seed uint32 = 7
|
||||
var block_length uint = length / num_histograms
|
||||
var i uint
|
||||
clearHistogramsDistance(histograms, num_histograms)
|
||||
for i = 0; i < num_histograms; i++ {
|
||||
var pos uint = length * i / num_histograms
|
||||
if i != 0 {
|
||||
pos += uint(myRand(&seed) % uint32(block_length))
|
||||
}
|
||||
|
||||
if pos+stride >= length {
|
||||
pos = length - stride - 1
|
||||
}
|
||||
|
||||
histogramAddVectorDistance(&histograms[i], data[pos:], stride)
|
||||
}
|
||||
}
|
||||
|
||||
func randomSampleDistance(seed *uint32, data []uint16, length uint, stride uint, sample *histogramDistance) {
|
||||
var pos uint = 0
|
||||
if stride >= length {
|
||||
stride = length
|
||||
} else {
|
||||
pos = uint(myRand(seed) % uint32(length-stride+1))
|
||||
}
|
||||
|
||||
histogramAddVectorDistance(sample, data[pos:], stride)
|
||||
}
|
||||
|
||||
func refineEntropyCodesDistance(data []uint16, length uint, stride uint, num_histograms uint, histograms []histogramDistance) {
|
||||
var iters uint = kIterMulForRefining*length/stride + kMinItersForRefining
|
||||
var seed uint32 = 7
|
||||
var iter uint
|
||||
iters = ((iters + num_histograms - 1) / num_histograms) * num_histograms
|
||||
for iter = 0; iter < iters; iter++ {
|
||||
var sample histogramDistance
|
||||
histogramClearDistance(&sample)
|
||||
randomSampleDistance(&seed, data, length, stride, &sample)
|
||||
histogramAddHistogramDistance(&histograms[iter%num_histograms], &sample)
|
||||
}
|
||||
}
|
||||
|
||||
/* Assigns a block id from the range [0, num_histograms) to each data element
|
||||
in data[0..length) and fills in block_id[0..length) with the assigned values.
|
||||
Returns the number of blocks, i.e. one plus the number of block switches. */
|
||||
func findBlocksDistance(data []uint16, length uint, block_switch_bitcost float64, num_histograms uint, histograms []histogramDistance, insert_cost []float64, cost []float64, switch_signal []byte, block_id []byte) uint {
|
||||
var data_size uint = histogramDataSizeDistance()
|
||||
var bitmaplen uint = (num_histograms + 7) >> 3
|
||||
var num_blocks uint = 1
|
||||
var i uint
|
||||
var j uint
|
||||
assert(num_histograms <= 256)
|
||||
if num_histograms <= 1 {
|
||||
for i = 0; i < length; i++ {
|
||||
block_id[i] = 0
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
for i := 0; i < int(data_size*num_histograms); i++ {
|
||||
insert_cost[i] = 0
|
||||
}
|
||||
for i = 0; i < num_histograms; i++ {
|
||||
insert_cost[i] = fastLog2(uint(uint32(histograms[i].total_count_)))
|
||||
}
|
||||
|
||||
for i = data_size; i != 0; {
|
||||
i--
|
||||
for j = 0; j < num_histograms; j++ {
|
||||
insert_cost[i*num_histograms+j] = insert_cost[j] - bitCost(uint(histograms[j].data_[i]))
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < int(num_histograms); i++ {
|
||||
cost[i] = 0
|
||||
}
|
||||
for i := 0; i < int(length*bitmaplen); i++ {
|
||||
switch_signal[i] = 0
|
||||
}
|
||||
|
||||
/* After each iteration of this loop, cost[k] will contain the difference
|
||||
between the minimum cost of arriving at the current byte position using
|
||||
entropy code k, and the minimum cost of arriving at the current byte
|
||||
position. This difference is capped at the block switch cost, and if it
|
||||
reaches block switch cost, it means that when we trace back from the last
|
||||
position, we need to switch here. */
|
||||
for i = 0; i < length; i++ {
|
||||
var byte_ix uint = i
|
||||
var ix uint = byte_ix * bitmaplen
|
||||
var insert_cost_ix uint = uint(data[byte_ix]) * num_histograms
|
||||
var min_cost float64 = 1e99
|
||||
var block_switch_cost float64 = block_switch_bitcost
|
||||
var k uint
|
||||
for k = 0; k < num_histograms; k++ {
|
||||
/* We are coding the symbol in data[byte_ix] with entropy code k. */
|
||||
cost[k] += insert_cost[insert_cost_ix+k]
|
||||
|
||||
if cost[k] < min_cost {
|
||||
min_cost = cost[k]
|
||||
block_id[byte_ix] = byte(k)
|
||||
}
|
||||
}
|
||||
|
||||
/* More blocks for the beginning. */
|
||||
if byte_ix < 2000 {
|
||||
block_switch_cost *= 0.77 + 0.07*float64(byte_ix)/2000
|
||||
}
|
||||
|
||||
for k = 0; k < num_histograms; k++ {
|
||||
cost[k] -= min_cost
|
||||
if cost[k] >= block_switch_cost {
|
||||
var mask byte = byte(1 << (k & 7))
|
||||
cost[k] = block_switch_cost
|
||||
assert(k>>3 < bitmaplen)
|
||||
switch_signal[ix+(k>>3)] |= mask
|
||||
/* Trace back from the last position and switch at the marked places. */
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
var byte_ix uint = length - 1
|
||||
var ix uint = byte_ix * bitmaplen
|
||||
var cur_id byte = block_id[byte_ix]
|
||||
for byte_ix > 0 {
|
||||
var mask byte = byte(1 << (cur_id & 7))
|
||||
assert(uint(cur_id)>>3 < bitmaplen)
|
||||
byte_ix--
|
||||
ix -= bitmaplen
|
||||
if switch_signal[ix+uint(cur_id>>3)]&mask != 0 {
|
||||
if cur_id != block_id[byte_ix] {
|
||||
cur_id = block_id[byte_ix]
|
||||
num_blocks++
|
||||
}
|
||||
}
|
||||
|
||||
block_id[byte_ix] = cur_id
|
||||
}
|
||||
}
|
||||
|
||||
return num_blocks
|
||||
}
|
||||
|
||||
var remapBlockIdsDistance_kInvalidId uint16 = 256
|
||||
|
||||
func remapBlockIdsDistance(block_ids []byte, length uint, new_id []uint16, num_histograms uint) uint {
|
||||
var next_id uint16 = 0
|
||||
var i uint
|
||||
for i = 0; i < num_histograms; i++ {
|
||||
new_id[i] = remapBlockIdsDistance_kInvalidId
|
||||
}
|
||||
|
||||
for i = 0; i < length; i++ {
|
||||
assert(uint(block_ids[i]) < num_histograms)
|
||||
if new_id[block_ids[i]] == remapBlockIdsDistance_kInvalidId {
|
||||
new_id[block_ids[i]] = next_id
|
||||
next_id++
|
||||
}
|
||||
}
|
||||
|
||||
for i = 0; i < length; i++ {
|
||||
block_ids[i] = byte(new_id[block_ids[i]])
|
||||
assert(uint(block_ids[i]) < num_histograms)
|
||||
}
|
||||
|
||||
assert(uint(next_id) <= num_histograms)
|
||||
return uint(next_id)
|
||||
}
|
||||
|
||||
func buildBlockHistogramsDistance(data []uint16, length uint, block_ids []byte, num_histograms uint, histograms []histogramDistance) {
|
||||
var i uint
|
||||
clearHistogramsDistance(histograms, num_histograms)
|
||||
for i = 0; i < length; i++ {
|
||||
histogramAddDistance(&histograms[block_ids[i]], uint(data[i]))
|
||||
}
|
||||
}
|
||||
|
||||
var clusterBlocksDistance_kInvalidIndex uint32 = math.MaxUint32
|
||||
|
||||
func clusterBlocksDistance(data []uint16, length uint, num_blocks uint, block_ids []byte, split *blockSplit) {
|
||||
var histogram_symbols []uint32 = make([]uint32, num_blocks)
|
||||
var block_lengths []uint32 = make([]uint32, num_blocks)
|
||||
var expected_num_clusters uint = clustersPerBatch * (num_blocks + histogramsPerBatch - 1) / histogramsPerBatch
|
||||
var all_histograms_size uint = 0
|
||||
var all_histograms_capacity uint = expected_num_clusters
|
||||
var all_histograms []histogramDistance = make([]histogramDistance, all_histograms_capacity)
|
||||
var cluster_size_size uint = 0
|
||||
var cluster_size_capacity uint = expected_num_clusters
|
||||
var cluster_size []uint32 = make([]uint32, cluster_size_capacity)
|
||||
var num_clusters uint = 0
|
||||
var histograms []histogramDistance = make([]histogramDistance, brotli_min_size_t(num_blocks, histogramsPerBatch))
|
||||
var max_num_pairs uint = histogramsPerBatch * histogramsPerBatch / 2
|
||||
var pairs_capacity uint = max_num_pairs + 1
|
||||
var pairs []histogramPair = make([]histogramPair, pairs_capacity)
|
||||
var pos uint = 0
|
||||
var clusters []uint32
|
||||
var num_final_clusters uint
|
||||
var new_index []uint32
|
||||
var i uint
|
||||
var sizes = [histogramsPerBatch]uint32{0}
|
||||
var new_clusters = [histogramsPerBatch]uint32{0}
|
||||
var symbols = [histogramsPerBatch]uint32{0}
|
||||
var remap = [histogramsPerBatch]uint32{0}
|
||||
|
||||
for i := 0; i < int(num_blocks); i++ {
|
||||
block_lengths[i] = 0
|
||||
}
|
||||
{
|
||||
var block_idx uint = 0
|
||||
for i = 0; i < length; i++ {
|
||||
assert(block_idx < num_blocks)
|
||||
block_lengths[block_idx]++
|
||||
if i+1 == length || block_ids[i] != block_ids[i+1] {
|
||||
block_idx++
|
||||
}
|
||||
}
|
||||
|
||||
assert(block_idx == num_blocks)
|
||||
}
|
||||
|
||||
for i = 0; i < num_blocks; i += histogramsPerBatch {
|
||||
var num_to_combine uint = brotli_min_size_t(num_blocks-i, histogramsPerBatch)
|
||||
var num_new_clusters uint
|
||||
var j uint
|
||||
for j = 0; j < num_to_combine; j++ {
|
||||
var k uint
|
||||
histogramClearDistance(&histograms[j])
|
||||
for k = 0; uint32(k) < block_lengths[i+j]; k++ {
|
||||
histogramAddDistance(&histograms[j], uint(data[pos]))
|
||||
pos++
|
||||
}
|
||||
|
||||
histograms[j].bit_cost_ = populationCostDistance(&histograms[j])
|
||||
new_clusters[j] = uint32(j)
|
||||
symbols[j] = uint32(j)
|
||||
sizes[j] = 1
|
||||
}
|
||||
|
||||
num_new_clusters = histogramCombineDistance(histograms, sizes[:], symbols[:], new_clusters[:], []histogramPair(pairs), num_to_combine, num_to_combine, histogramsPerBatch, max_num_pairs)
|
||||
if all_histograms_capacity < (all_histograms_size + num_new_clusters) {
|
||||
var _new_size uint
|
||||
if all_histograms_capacity == 0 {
|
||||
_new_size = all_histograms_size + num_new_clusters
|
||||
} else {
|
||||
_new_size = all_histograms_capacity
|
||||
}
|
||||
var new_array []histogramDistance
|
||||
for _new_size < (all_histograms_size + num_new_clusters) {
|
||||
_new_size *= 2
|
||||
}
|
||||
new_array = make([]histogramDistance, _new_size)
|
||||
if all_histograms_capacity != 0 {
|
||||
copy(new_array, all_histograms[:all_histograms_capacity])
|
||||
}
|
||||
|
||||
all_histograms = new_array
|
||||
all_histograms_capacity = _new_size
|
||||
}
|
||||
|
||||
brotli_ensure_capacity_uint32_t(&cluster_size, &cluster_size_capacity, cluster_size_size+num_new_clusters)
|
||||
for j = 0; j < num_new_clusters; j++ {
|
||||
all_histograms[all_histograms_size] = histograms[new_clusters[j]]
|
||||
all_histograms_size++
|
||||
cluster_size[cluster_size_size] = sizes[new_clusters[j]]
|
||||
cluster_size_size++
|
||||
remap[new_clusters[j]] = uint32(j)
|
||||
}
|
||||
|
||||
for j = 0; j < num_to_combine; j++ {
|
||||
histogram_symbols[i+j] = uint32(num_clusters) + remap[symbols[j]]
|
||||
}
|
||||
|
||||
num_clusters += num_new_clusters
|
||||
assert(num_clusters == cluster_size_size)
|
||||
assert(num_clusters == all_histograms_size)
|
||||
}
|
||||
|
||||
histograms = nil
|
||||
|
||||
max_num_pairs = brotli_min_size_t(64*num_clusters, (num_clusters/2)*num_clusters)
|
||||
if pairs_capacity < max_num_pairs+1 {
|
||||
pairs = nil
|
||||
pairs = make([]histogramPair, (max_num_pairs + 1))
|
||||
}
|
||||
|
||||
clusters = make([]uint32, num_clusters)
|
||||
for i = 0; i < num_clusters; i++ {
|
||||
clusters[i] = uint32(i)
|
||||
}
|
||||
|
||||
num_final_clusters = histogramCombineDistance(all_histograms, cluster_size, histogram_symbols, clusters, pairs, num_clusters, num_blocks, maxNumberOfBlockTypes, max_num_pairs)
|
||||
pairs = nil
|
||||
cluster_size = nil
|
||||
|
||||
new_index = make([]uint32, num_clusters)
|
||||
for i = 0; i < num_clusters; i++ {
|
||||
new_index[i] = clusterBlocksDistance_kInvalidIndex
|
||||
}
|
||||
pos = 0
|
||||
{
|
||||
var next_index uint32 = 0
|
||||
for i = 0; i < num_blocks; i++ {
|
||||
var histo histogramDistance
|
||||
var j uint
|
||||
var best_out uint32
|
||||
var best_bits float64
|
||||
histogramClearDistance(&histo)
|
||||
for j = 0; uint32(j) < block_lengths[i]; j++ {
|
||||
histogramAddDistance(&histo, uint(data[pos]))
|
||||
pos++
|
||||
}
|
||||
|
||||
if i == 0 {
|
||||
best_out = histogram_symbols[0]
|
||||
} else {
|
||||
best_out = histogram_symbols[i-1]
|
||||
}
|
||||
best_bits = histogramBitCostDistanceDistance(&histo, &all_histograms[best_out])
|
||||
for j = 0; j < num_final_clusters; j++ {
|
||||
var cur_bits float64 = histogramBitCostDistanceDistance(&histo, &all_histograms[clusters[j]])
|
||||
if cur_bits < best_bits {
|
||||
best_bits = cur_bits
|
||||
best_out = clusters[j]
|
||||
}
|
||||
}
|
||||
|
||||
histogram_symbols[i] = best_out
|
||||
if new_index[best_out] == clusterBlocksDistance_kInvalidIndex {
|
||||
new_index[best_out] = next_index
|
||||
next_index++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
clusters = nil
|
||||
all_histograms = nil
|
||||
brotli_ensure_capacity_uint8_t(&split.types, &split.types_alloc_size, num_blocks)
|
||||
brotli_ensure_capacity_uint32_t(&split.lengths, &split.lengths_alloc_size, num_blocks)
|
||||
{
|
||||
var cur_length uint32 = 0
|
||||
var block_idx uint = 0
|
||||
var max_type byte = 0
|
||||
for i = 0; i < num_blocks; i++ {
|
||||
cur_length += block_lengths[i]
|
||||
if i+1 == num_blocks || histogram_symbols[i] != histogram_symbols[i+1] {
|
||||
var id byte = byte(new_index[histogram_symbols[i]])
|
||||
split.types[block_idx] = id
|
||||
split.lengths[block_idx] = cur_length
|
||||
max_type = brotli_max_uint8_t(max_type, id)
|
||||
cur_length = 0
|
||||
block_idx++
|
||||
}
|
||||
}
|
||||
|
||||
split.num_blocks = block_idx
|
||||
split.num_types = uint(max_type) + 1
|
||||
}
|
||||
|
||||
new_index = nil
|
||||
block_lengths = nil
|
||||
histogram_symbols = nil
|
||||
}
|
||||
|
||||
func splitByteVectorDistance(data []uint16, length uint, literals_per_histogram uint, max_histograms uint, sampling_stride_length uint, block_switch_cost float64, params *encoderParams, split *blockSplit) {
|
||||
var data_size uint = histogramDataSizeDistance()
|
||||
var num_histograms uint = length/literals_per_histogram + 1
|
||||
var histograms []histogramDistance
|
||||
if num_histograms > max_histograms {
|
||||
num_histograms = max_histograms
|
||||
}
|
||||
|
||||
if length == 0 {
|
||||
split.num_types = 1
|
||||
return
|
||||
} else if length < kMinLengthForBlockSplitting {
|
||||
brotli_ensure_capacity_uint8_t(&split.types, &split.types_alloc_size, split.num_blocks+1)
|
||||
brotli_ensure_capacity_uint32_t(&split.lengths, &split.lengths_alloc_size, split.num_blocks+1)
|
||||
split.num_types = 1
|
||||
split.types[split.num_blocks] = 0
|
||||
split.lengths[split.num_blocks] = uint32(length)
|
||||
split.num_blocks++
|
||||
return
|
||||
}
|
||||
|
||||
histograms = make([]histogramDistance, num_histograms)
|
||||
|
||||
/* Find good entropy codes. */
|
||||
initialEntropyCodesDistance(data, length, sampling_stride_length, num_histograms, histograms)
|
||||
|
||||
refineEntropyCodesDistance(data, length, sampling_stride_length, num_histograms, histograms)
|
||||
{
|
||||
var block_ids []byte = make([]byte, length)
|
||||
var num_blocks uint = 0
|
||||
var bitmaplen uint = (num_histograms + 7) >> 3
|
||||
var insert_cost []float64 = make([]float64, (data_size * num_histograms))
|
||||
var cost []float64 = make([]float64, num_histograms)
|
||||
var switch_signal []byte = make([]byte, (length * bitmaplen))
|
||||
var new_id []uint16 = make([]uint16, num_histograms)
|
||||
var iters uint
|
||||
if params.quality < hqZopflificationQuality {
|
||||
iters = 3
|
||||
} else {
|
||||
iters = 10
|
||||
}
|
||||
/* Find a good path through literals with the good entropy codes. */
|
||||
|
||||
var i uint
|
||||
for i = 0; i < iters; i++ {
|
||||
num_blocks = findBlocksDistance(data, length, block_switch_cost, num_histograms, histograms, insert_cost, cost, switch_signal, block_ids)
|
||||
num_histograms = remapBlockIdsDistance(block_ids, length, new_id, num_histograms)
|
||||
buildBlockHistogramsDistance(data, length, block_ids, num_histograms, histograms)
|
||||
}
|
||||
|
||||
insert_cost = nil
|
||||
cost = nil
|
||||
switch_signal = nil
|
||||
new_id = nil
|
||||
histograms = nil
|
||||
clusterBlocksDistance(data, length, num_blocks, block_ids, split)
|
||||
block_ids = nil
|
||||
}
|
||||
}
|
||||
+433
@@ -0,0 +1,433 @@
|
||||
package brotli
|
||||
|
||||
import "math"
|
||||
|
||||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
func initialEntropyCodesLiteral(data []byte, length uint, stride uint, num_histograms uint, histograms []histogramLiteral) {
|
||||
var seed uint32 = 7
|
||||
var block_length uint = length / num_histograms
|
||||
var i uint
|
||||
clearHistogramsLiteral(histograms, num_histograms)
|
||||
for i = 0; i < num_histograms; i++ {
|
||||
var pos uint = length * i / num_histograms
|
||||
if i != 0 {
|
||||
pos += uint(myRand(&seed) % uint32(block_length))
|
||||
}
|
||||
|
||||
if pos+stride >= length {
|
||||
pos = length - stride - 1
|
||||
}
|
||||
|
||||
histogramAddVectorLiteral(&histograms[i], data[pos:], stride)
|
||||
}
|
||||
}
|
||||
|
||||
func randomSampleLiteral(seed *uint32, data []byte, length uint, stride uint, sample *histogramLiteral) {
|
||||
var pos uint = 0
|
||||
if stride >= length {
|
||||
stride = length
|
||||
} else {
|
||||
pos = uint(myRand(seed) % uint32(length-stride+1))
|
||||
}
|
||||
|
||||
histogramAddVectorLiteral(sample, data[pos:], stride)
|
||||
}
|
||||
|
||||
func refineEntropyCodesLiteral(data []byte, length uint, stride uint, num_histograms uint, histograms []histogramLiteral) {
|
||||
var iters uint = kIterMulForRefining*length/stride + kMinItersForRefining
|
||||
var seed uint32 = 7
|
||||
var iter uint
|
||||
iters = ((iters + num_histograms - 1) / num_histograms) * num_histograms
|
||||
for iter = 0; iter < iters; iter++ {
|
||||
var sample histogramLiteral
|
||||
histogramClearLiteral(&sample)
|
||||
randomSampleLiteral(&seed, data, length, stride, &sample)
|
||||
histogramAddHistogramLiteral(&histograms[iter%num_histograms], &sample)
|
||||
}
|
||||
}
|
||||
|
||||
/* Assigns a block id from the range [0, num_histograms) to each data element
|
||||
in data[0..length) and fills in block_id[0..length) with the assigned values.
|
||||
Returns the number of blocks, i.e. one plus the number of block switches. */
|
||||
func findBlocksLiteral(data []byte, length uint, block_switch_bitcost float64, num_histograms uint, histograms []histogramLiteral, insert_cost []float64, cost []float64, switch_signal []byte, block_id []byte) uint {
|
||||
var data_size uint = histogramDataSizeLiteral()
|
||||
var bitmaplen uint = (num_histograms + 7) >> 3
|
||||
var num_blocks uint = 1
|
||||
var i uint
|
||||
var j uint
|
||||
assert(num_histograms <= 256)
|
||||
if num_histograms <= 1 {
|
||||
for i = 0; i < length; i++ {
|
||||
block_id[i] = 0
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
for i := 0; i < int(data_size*num_histograms); i++ {
|
||||
insert_cost[i] = 0
|
||||
}
|
||||
for i = 0; i < num_histograms; i++ {
|
||||
insert_cost[i] = fastLog2(uint(uint32(histograms[i].total_count_)))
|
||||
}
|
||||
|
||||
for i = data_size; i != 0; {
|
||||
i--
|
||||
for j = 0; j < num_histograms; j++ {
|
||||
insert_cost[i*num_histograms+j] = insert_cost[j] - bitCost(uint(histograms[j].data_[i]))
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < int(num_histograms); i++ {
|
||||
cost[i] = 0
|
||||
}
|
||||
for i := 0; i < int(length*bitmaplen); i++ {
|
||||
switch_signal[i] = 0
|
||||
}
|
||||
|
||||
/* After each iteration of this loop, cost[k] will contain the difference
|
||||
between the minimum cost of arriving at the current byte position using
|
||||
entropy code k, and the minimum cost of arriving at the current byte
|
||||
position. This difference is capped at the block switch cost, and if it
|
||||
reaches block switch cost, it means that when we trace back from the last
|
||||
position, we need to switch here. */
|
||||
for i = 0; i < length; i++ {
|
||||
var byte_ix uint = i
|
||||
var ix uint = byte_ix * bitmaplen
|
||||
var insert_cost_ix uint = uint(data[byte_ix]) * num_histograms
|
||||
var min_cost float64 = 1e99
|
||||
var block_switch_cost float64 = block_switch_bitcost
|
||||
var k uint
|
||||
for k = 0; k < num_histograms; k++ {
|
||||
/* We are coding the symbol in data[byte_ix] with entropy code k. */
|
||||
cost[k] += insert_cost[insert_cost_ix+k]
|
||||
|
||||
if cost[k] < min_cost {
|
||||
min_cost = cost[k]
|
||||
block_id[byte_ix] = byte(k)
|
||||
}
|
||||
}
|
||||
|
||||
/* More blocks for the beginning. */
|
||||
if byte_ix < 2000 {
|
||||
block_switch_cost *= 0.77 + 0.07*float64(byte_ix)/2000
|
||||
}
|
||||
|
||||
for k = 0; k < num_histograms; k++ {
|
||||
cost[k] -= min_cost
|
||||
if cost[k] >= block_switch_cost {
|
||||
var mask byte = byte(1 << (k & 7))
|
||||
cost[k] = block_switch_cost
|
||||
assert(k>>3 < bitmaplen)
|
||||
switch_signal[ix+(k>>3)] |= mask
|
||||
/* Trace back from the last position and switch at the marked places. */
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
var byte_ix uint = length - 1
|
||||
var ix uint = byte_ix * bitmaplen
|
||||
var cur_id byte = block_id[byte_ix]
|
||||
for byte_ix > 0 {
|
||||
var mask byte = byte(1 << (cur_id & 7))
|
||||
assert(uint(cur_id)>>3 < bitmaplen)
|
||||
byte_ix--
|
||||
ix -= bitmaplen
|
||||
if switch_signal[ix+uint(cur_id>>3)]&mask != 0 {
|
||||
if cur_id != block_id[byte_ix] {
|
||||
cur_id = block_id[byte_ix]
|
||||
num_blocks++
|
||||
}
|
||||
}
|
||||
|
||||
block_id[byte_ix] = cur_id
|
||||
}
|
||||
}
|
||||
|
||||
return num_blocks
|
||||
}
|
||||
|
||||
var remapBlockIdsLiteral_kInvalidId uint16 = 256
|
||||
|
||||
func remapBlockIdsLiteral(block_ids []byte, length uint, new_id []uint16, num_histograms uint) uint {
|
||||
var next_id uint16 = 0
|
||||
var i uint
|
||||
for i = 0; i < num_histograms; i++ {
|
||||
new_id[i] = remapBlockIdsLiteral_kInvalidId
|
||||
}
|
||||
|
||||
for i = 0; i < length; i++ {
|
||||
assert(uint(block_ids[i]) < num_histograms)
|
||||
if new_id[block_ids[i]] == remapBlockIdsLiteral_kInvalidId {
|
||||
new_id[block_ids[i]] = next_id
|
||||
next_id++
|
||||
}
|
||||
}
|
||||
|
||||
for i = 0; i < length; i++ {
|
||||
block_ids[i] = byte(new_id[block_ids[i]])
|
||||
assert(uint(block_ids[i]) < num_histograms)
|
||||
}
|
||||
|
||||
assert(uint(next_id) <= num_histograms)
|
||||
return uint(next_id)
|
||||
}
|
||||
|
||||
func buildBlockHistogramsLiteral(data []byte, length uint, block_ids []byte, num_histograms uint, histograms []histogramLiteral) {
|
||||
var i uint
|
||||
clearHistogramsLiteral(histograms, num_histograms)
|
||||
for i = 0; i < length; i++ {
|
||||
histogramAddLiteral(&histograms[block_ids[i]], uint(data[i]))
|
||||
}
|
||||
}
|
||||
|
||||
var clusterBlocksLiteral_kInvalidIndex uint32 = math.MaxUint32
|
||||
|
||||
func clusterBlocksLiteral(data []byte, length uint, num_blocks uint, block_ids []byte, split *blockSplit) {
|
||||
var histogram_symbols []uint32 = make([]uint32, num_blocks)
|
||||
var block_lengths []uint32 = make([]uint32, num_blocks)
|
||||
var expected_num_clusters uint = clustersPerBatch * (num_blocks + histogramsPerBatch - 1) / histogramsPerBatch
|
||||
var all_histograms_size uint = 0
|
||||
var all_histograms_capacity uint = expected_num_clusters
|
||||
var all_histograms []histogramLiteral = make([]histogramLiteral, all_histograms_capacity)
|
||||
var cluster_size_size uint = 0
|
||||
var cluster_size_capacity uint = expected_num_clusters
|
||||
var cluster_size []uint32 = make([]uint32, cluster_size_capacity)
|
||||
var num_clusters uint = 0
|
||||
var histograms []histogramLiteral = make([]histogramLiteral, brotli_min_size_t(num_blocks, histogramsPerBatch))
|
||||
var max_num_pairs uint = histogramsPerBatch * histogramsPerBatch / 2
|
||||
var pairs_capacity uint = max_num_pairs + 1
|
||||
var pairs []histogramPair = make([]histogramPair, pairs_capacity)
|
||||
var pos uint = 0
|
||||
var clusters []uint32
|
||||
var num_final_clusters uint
|
||||
var new_index []uint32
|
||||
var i uint
|
||||
var sizes = [histogramsPerBatch]uint32{0}
|
||||
var new_clusters = [histogramsPerBatch]uint32{0}
|
||||
var symbols = [histogramsPerBatch]uint32{0}
|
||||
var remap = [histogramsPerBatch]uint32{0}
|
||||
|
||||
for i := 0; i < int(num_blocks); i++ {
|
||||
block_lengths[i] = 0
|
||||
}
|
||||
{
|
||||
var block_idx uint = 0
|
||||
for i = 0; i < length; i++ {
|
||||
assert(block_idx < num_blocks)
|
||||
block_lengths[block_idx]++
|
||||
if i+1 == length || block_ids[i] != block_ids[i+1] {
|
||||
block_idx++
|
||||
}
|
||||
}
|
||||
|
||||
assert(block_idx == num_blocks)
|
||||
}
|
||||
|
||||
for i = 0; i < num_blocks; i += histogramsPerBatch {
|
||||
var num_to_combine uint = brotli_min_size_t(num_blocks-i, histogramsPerBatch)
|
||||
var num_new_clusters uint
|
||||
var j uint
|
||||
for j = 0; j < num_to_combine; j++ {
|
||||
var k uint
|
||||
histogramClearLiteral(&histograms[j])
|
||||
for k = 0; uint32(k) < block_lengths[i+j]; k++ {
|
||||
histogramAddLiteral(&histograms[j], uint(data[pos]))
|
||||
pos++
|
||||
}
|
||||
|
||||
histograms[j].bit_cost_ = populationCostLiteral(&histograms[j])
|
||||
new_clusters[j] = uint32(j)
|
||||
symbols[j] = uint32(j)
|
||||
sizes[j] = 1
|
||||
}
|
||||
|
||||
num_new_clusters = histogramCombineLiteral(histograms, sizes[:], symbols[:], new_clusters[:], []histogramPair(pairs), num_to_combine, num_to_combine, histogramsPerBatch, max_num_pairs)
|
||||
if all_histograms_capacity < (all_histograms_size + num_new_clusters) {
|
||||
var _new_size uint
|
||||
if all_histograms_capacity == 0 {
|
||||
_new_size = all_histograms_size + num_new_clusters
|
||||
} else {
|
||||
_new_size = all_histograms_capacity
|
||||
}
|
||||
var new_array []histogramLiteral
|
||||
for _new_size < (all_histograms_size + num_new_clusters) {
|
||||
_new_size *= 2
|
||||
}
|
||||
new_array = make([]histogramLiteral, _new_size)
|
||||
if all_histograms_capacity != 0 {
|
||||
copy(new_array, all_histograms[:all_histograms_capacity])
|
||||
}
|
||||
|
||||
all_histograms = new_array
|
||||
all_histograms_capacity = _new_size
|
||||
}
|
||||
|
||||
brotli_ensure_capacity_uint32_t(&cluster_size, &cluster_size_capacity, cluster_size_size+num_new_clusters)
|
||||
for j = 0; j < num_new_clusters; j++ {
|
||||
all_histograms[all_histograms_size] = histograms[new_clusters[j]]
|
||||
all_histograms_size++
|
||||
cluster_size[cluster_size_size] = sizes[new_clusters[j]]
|
||||
cluster_size_size++
|
||||
remap[new_clusters[j]] = uint32(j)
|
||||
}
|
||||
|
||||
for j = 0; j < num_to_combine; j++ {
|
||||
histogram_symbols[i+j] = uint32(num_clusters) + remap[symbols[j]]
|
||||
}
|
||||
|
||||
num_clusters += num_new_clusters
|
||||
assert(num_clusters == cluster_size_size)
|
||||
assert(num_clusters == all_histograms_size)
|
||||
}
|
||||
|
||||
histograms = nil
|
||||
|
||||
max_num_pairs = brotli_min_size_t(64*num_clusters, (num_clusters/2)*num_clusters)
|
||||
if pairs_capacity < max_num_pairs+1 {
|
||||
pairs = nil
|
||||
pairs = make([]histogramPair, (max_num_pairs + 1))
|
||||
}
|
||||
|
||||
clusters = make([]uint32, num_clusters)
|
||||
for i = 0; i < num_clusters; i++ {
|
||||
clusters[i] = uint32(i)
|
||||
}
|
||||
|
||||
num_final_clusters = histogramCombineLiteral(all_histograms, cluster_size, histogram_symbols, clusters, pairs, num_clusters, num_blocks, maxNumberOfBlockTypes, max_num_pairs)
|
||||
pairs = nil
|
||||
cluster_size = nil
|
||||
|
||||
new_index = make([]uint32, num_clusters)
|
||||
for i = 0; i < num_clusters; i++ {
|
||||
new_index[i] = clusterBlocksLiteral_kInvalidIndex
|
||||
}
|
||||
pos = 0
|
||||
{
|
||||
var next_index uint32 = 0
|
||||
for i = 0; i < num_blocks; i++ {
|
||||
var histo histogramLiteral
|
||||
var j uint
|
||||
var best_out uint32
|
||||
var best_bits float64
|
||||
histogramClearLiteral(&histo)
|
||||
for j = 0; uint32(j) < block_lengths[i]; j++ {
|
||||
histogramAddLiteral(&histo, uint(data[pos]))
|
||||
pos++
|
||||
}
|
||||
|
||||
if i == 0 {
|
||||
best_out = histogram_symbols[0]
|
||||
} else {
|
||||
best_out = histogram_symbols[i-1]
|
||||
}
|
||||
best_bits = histogramBitCostDistanceLiteral(&histo, &all_histograms[best_out])
|
||||
for j = 0; j < num_final_clusters; j++ {
|
||||
var cur_bits float64 = histogramBitCostDistanceLiteral(&histo, &all_histograms[clusters[j]])
|
||||
if cur_bits < best_bits {
|
||||
best_bits = cur_bits
|
||||
best_out = clusters[j]
|
||||
}
|
||||
}
|
||||
|
||||
histogram_symbols[i] = best_out
|
||||
if new_index[best_out] == clusterBlocksLiteral_kInvalidIndex {
|
||||
new_index[best_out] = next_index
|
||||
next_index++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
clusters = nil
|
||||
all_histograms = nil
|
||||
brotli_ensure_capacity_uint8_t(&split.types, &split.types_alloc_size, num_blocks)
|
||||
brotli_ensure_capacity_uint32_t(&split.lengths, &split.lengths_alloc_size, num_blocks)
|
||||
{
|
||||
var cur_length uint32 = 0
|
||||
var block_idx uint = 0
|
||||
var max_type byte = 0
|
||||
for i = 0; i < num_blocks; i++ {
|
||||
cur_length += block_lengths[i]
|
||||
if i+1 == num_blocks || histogram_symbols[i] != histogram_symbols[i+1] {
|
||||
var id byte = byte(new_index[histogram_symbols[i]])
|
||||
split.types[block_idx] = id
|
||||
split.lengths[block_idx] = cur_length
|
||||
max_type = brotli_max_uint8_t(max_type, id)
|
||||
cur_length = 0
|
||||
block_idx++
|
||||
}
|
||||
}
|
||||
|
||||
split.num_blocks = block_idx
|
||||
split.num_types = uint(max_type) + 1
|
||||
}
|
||||
|
||||
new_index = nil
|
||||
block_lengths = nil
|
||||
histogram_symbols = nil
|
||||
}
|
||||
|
||||
func splitByteVectorLiteral(data []byte, length uint, literals_per_histogram uint, max_histograms uint, sampling_stride_length uint, block_switch_cost float64, params *encoderParams, split *blockSplit) {
|
||||
var data_size uint = histogramDataSizeLiteral()
|
||||
var num_histograms uint = length/literals_per_histogram + 1
|
||||
var histograms []histogramLiteral
|
||||
if num_histograms > max_histograms {
|
||||
num_histograms = max_histograms
|
||||
}
|
||||
|
||||
if length == 0 {
|
||||
split.num_types = 1
|
||||
return
|
||||
} else if length < kMinLengthForBlockSplitting {
|
||||
brotli_ensure_capacity_uint8_t(&split.types, &split.types_alloc_size, split.num_blocks+1)
|
||||
brotli_ensure_capacity_uint32_t(&split.lengths, &split.lengths_alloc_size, split.num_blocks+1)
|
||||
split.num_types = 1
|
||||
split.types[split.num_blocks] = 0
|
||||
split.lengths[split.num_blocks] = uint32(length)
|
||||
split.num_blocks++
|
||||
return
|
||||
}
|
||||
|
||||
histograms = make([]histogramLiteral, num_histograms)
|
||||
|
||||
/* Find good entropy codes. */
|
||||
initialEntropyCodesLiteral(data, length, sampling_stride_length, num_histograms, histograms)
|
||||
|
||||
refineEntropyCodesLiteral(data, length, sampling_stride_length, num_histograms, histograms)
|
||||
{
|
||||
var block_ids []byte = make([]byte, length)
|
||||
var num_blocks uint = 0
|
||||
var bitmaplen uint = (num_histograms + 7) >> 3
|
||||
var insert_cost []float64 = make([]float64, (data_size * num_histograms))
|
||||
var cost []float64 = make([]float64, num_histograms)
|
||||
var switch_signal []byte = make([]byte, (length * bitmaplen))
|
||||
var new_id []uint16 = make([]uint16, num_histograms)
|
||||
var iters uint
|
||||
if params.quality < hqZopflificationQuality {
|
||||
iters = 3
|
||||
} else {
|
||||
iters = 10
|
||||
}
|
||||
/* Find a good path through literals with the good entropy codes. */
|
||||
|
||||
var i uint
|
||||
for i = 0; i < iters; i++ {
|
||||
num_blocks = findBlocksLiteral(data, length, block_switch_cost, num_histograms, histograms, insert_cost, cost, switch_signal, block_ids)
|
||||
num_histograms = remapBlockIdsLiteral(block_ids, length, new_id, num_histograms)
|
||||
buildBlockHistogramsLiteral(data, length, block_ids, num_histograms, histograms)
|
||||
}
|
||||
|
||||
insert_cost = nil
|
||||
cost = nil
|
||||
switch_signal = nil
|
||||
new_id = nil
|
||||
histograms = nil
|
||||
clusterBlocksLiteral(data, length, num_blocks, block_ids, split)
|
||||
block_ids = nil
|
||||
}
|
||||
}
|
||||
+1300
File diff suppressed because it is too large
Load Diff
+30
@@ -0,0 +1,30 @@
|
||||
package brotli
|
||||
|
||||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Functions for clustering similar histograms together. */
|
||||
|
||||
type histogramPair struct {
|
||||
idx1 uint32
|
||||
idx2 uint32
|
||||
cost_combo float64
|
||||
cost_diff float64
|
||||
}
|
||||
|
||||
func histogramPairIsLess(p1 *histogramPair, p2 *histogramPair) bool {
|
||||
if p1.cost_diff != p2.cost_diff {
|
||||
return p1.cost_diff > p2.cost_diff
|
||||
}
|
||||
|
||||
return (p1.idx2 - p1.idx1) > (p2.idx2 - p2.idx1)
|
||||
}
|
||||
|
||||
/* Returns entropy reduction of the context map when we combine two clusters. */
|
||||
func clusterCostDiff(size_a uint, size_b uint) float64 {
|
||||
var size_c uint = size_a + size_b
|
||||
return float64(size_a)*fastLog2(size_a) + float64(size_b)*fastLog2(size_b) - float64(size_c)*fastLog2(size_c)
|
||||
}
|
||||
+164
@@ -0,0 +1,164 @@
|
||||
package brotli
|
||||
|
||||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Computes the bit cost reduction by combining out[idx1] and out[idx2] and if
|
||||
it is below a threshold, stores the pair (idx1, idx2) in the *pairs queue. */
|
||||
func compareAndPushToQueueCommand(out []histogramCommand, cluster_size []uint32, idx1 uint32, idx2 uint32, max_num_pairs uint, pairs []histogramPair, num_pairs *uint) {
|
||||
var is_good_pair bool = false
|
||||
var p histogramPair
|
||||
p.idx2 = 0
|
||||
p.idx1 = p.idx2
|
||||
p.cost_combo = 0
|
||||
p.cost_diff = p.cost_combo
|
||||
if idx1 == idx2 {
|
||||
return
|
||||
}
|
||||
|
||||
if idx2 < idx1 {
|
||||
var t uint32 = idx2
|
||||
idx2 = idx1
|
||||
idx1 = t
|
||||
}
|
||||
|
||||
p.idx1 = idx1
|
||||
p.idx2 = idx2
|
||||
p.cost_diff = 0.5 * clusterCostDiff(uint(cluster_size[idx1]), uint(cluster_size[idx2]))
|
||||
p.cost_diff -= out[idx1].bit_cost_
|
||||
p.cost_diff -= out[idx2].bit_cost_
|
||||
|
||||
if out[idx1].total_count_ == 0 {
|
||||
p.cost_combo = out[idx2].bit_cost_
|
||||
is_good_pair = true
|
||||
} else if out[idx2].total_count_ == 0 {
|
||||
p.cost_combo = out[idx1].bit_cost_
|
||||
is_good_pair = true
|
||||
} else {
|
||||
var threshold float64
|
||||
if *num_pairs == 0 {
|
||||
threshold = 1e99
|
||||
} else {
|
||||
threshold = brotli_max_double(0.0, pairs[0].cost_diff)
|
||||
}
|
||||
var combo histogramCommand = out[idx1]
|
||||
var cost_combo float64
|
||||
histogramAddHistogramCommand(&combo, &out[idx2])
|
||||
cost_combo = populationCostCommand(&combo)
|
||||
if cost_combo < threshold-p.cost_diff {
|
||||
p.cost_combo = cost_combo
|
||||
is_good_pair = true
|
||||
}
|
||||
}
|
||||
|
||||
if is_good_pair {
|
||||
p.cost_diff += p.cost_combo
|
||||
if *num_pairs > 0 && histogramPairIsLess(&pairs[0], &p) {
|
||||
/* Replace the top of the queue if needed. */
|
||||
if *num_pairs < max_num_pairs {
|
||||
pairs[*num_pairs] = pairs[0]
|
||||
(*num_pairs)++
|
||||
}
|
||||
|
||||
pairs[0] = p
|
||||
} else if *num_pairs < max_num_pairs {
|
||||
pairs[*num_pairs] = p
|
||||
(*num_pairs)++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func histogramCombineCommand(out []histogramCommand, cluster_size []uint32, symbols []uint32, clusters []uint32, pairs []histogramPair, num_clusters uint, symbols_size uint, max_clusters uint, max_num_pairs uint) uint {
|
||||
var cost_diff_threshold float64 = 0.0
|
||||
var min_cluster_size uint = 1
|
||||
var num_pairs uint = 0
|
||||
{
|
||||
/* We maintain a vector of histogram pairs, with the property that the pair
|
||||
with the maximum bit cost reduction is the first. */
|
||||
var idx1 uint
|
||||
for idx1 = 0; idx1 < num_clusters; idx1++ {
|
||||
var idx2 uint
|
||||
for idx2 = idx1 + 1; idx2 < num_clusters; idx2++ {
|
||||
compareAndPushToQueueCommand(out, cluster_size, clusters[idx1], clusters[idx2], max_num_pairs, pairs[0:], &num_pairs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for num_clusters > min_cluster_size {
|
||||
var best_idx1 uint32
|
||||
var best_idx2 uint32
|
||||
var i uint
|
||||
if pairs[0].cost_diff >= cost_diff_threshold {
|
||||
cost_diff_threshold = 1e99
|
||||
min_cluster_size = max_clusters
|
||||
continue
|
||||
}
|
||||
|
||||
/* Take the best pair from the top of heap. */
|
||||
best_idx1 = pairs[0].idx1
|
||||
|
||||
best_idx2 = pairs[0].idx2
|
||||
histogramAddHistogramCommand(&out[best_idx1], &out[best_idx2])
|
||||
out[best_idx1].bit_cost_ = pairs[0].cost_combo
|
||||
cluster_size[best_idx1] += cluster_size[best_idx2]
|
||||
for i = 0; i < symbols_size; i++ {
|
||||
if symbols[i] == best_idx2 {
|
||||
symbols[i] = best_idx1
|
||||
}
|
||||
}
|
||||
|
||||
for i = 0; i < num_clusters; i++ {
|
||||
if clusters[i] == best_idx2 {
|
||||
copy(clusters[i:], clusters[i+1:][:num_clusters-i-1])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
num_clusters--
|
||||
{
|
||||
/* Remove pairs intersecting the just combined best pair. */
|
||||
var copy_to_idx uint = 0
|
||||
for i = 0; i < num_pairs; i++ {
|
||||
var p *histogramPair = &pairs[i]
|
||||
if p.idx1 == best_idx1 || p.idx2 == best_idx1 || p.idx1 == best_idx2 || p.idx2 == best_idx2 {
|
||||
/* Remove invalid pair from the queue. */
|
||||
continue
|
||||
}
|
||||
|
||||
if histogramPairIsLess(&pairs[0], p) {
|
||||
/* Replace the top of the queue if needed. */
|
||||
var front histogramPair = pairs[0]
|
||||
pairs[0] = *p
|
||||
pairs[copy_to_idx] = front
|
||||
} else {
|
||||
pairs[copy_to_idx] = *p
|
||||
}
|
||||
|
||||
copy_to_idx++
|
||||
}
|
||||
|
||||
num_pairs = copy_to_idx
|
||||
}
|
||||
|
||||
/* Push new pairs formed with the combined histogram to the heap. */
|
||||
for i = 0; i < num_clusters; i++ {
|
||||
compareAndPushToQueueCommand(out, cluster_size, best_idx1, clusters[i], max_num_pairs, pairs[0:], &num_pairs)
|
||||
}
|
||||
}
|
||||
|
||||
return num_clusters
|
||||
}
|
||||
|
||||
/* What is the bit cost of moving histogram from cur_symbol to candidate. */
|
||||
func histogramBitCostDistanceCommand(histogram *histogramCommand, candidate *histogramCommand) float64 {
|
||||
if histogram.total_count_ == 0 {
|
||||
return 0.0
|
||||
} else {
|
||||
var tmp histogramCommand = *histogram
|
||||
histogramAddHistogramCommand(&tmp, candidate)
|
||||
return populationCostCommand(&tmp) - candidate.bit_cost_
|
||||
}
|
||||
}
|
||||
+326
@@ -0,0 +1,326 @@
|
||||
package brotli
|
||||
|
||||
import "math"
|
||||
|
||||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Computes the bit cost reduction by combining out[idx1] and out[idx2] and if
|
||||
it is below a threshold, stores the pair (idx1, idx2) in the *pairs queue. */
|
||||
func compareAndPushToQueueDistance(out []histogramDistance, cluster_size []uint32, idx1 uint32, idx2 uint32, max_num_pairs uint, pairs []histogramPair, num_pairs *uint) {
|
||||
var is_good_pair bool = false
|
||||
var p histogramPair
|
||||
p.idx2 = 0
|
||||
p.idx1 = p.idx2
|
||||
p.cost_combo = 0
|
||||
p.cost_diff = p.cost_combo
|
||||
if idx1 == idx2 {
|
||||
return
|
||||
}
|
||||
|
||||
if idx2 < idx1 {
|
||||
var t uint32 = idx2
|
||||
idx2 = idx1
|
||||
idx1 = t
|
||||
}
|
||||
|
||||
p.idx1 = idx1
|
||||
p.idx2 = idx2
|
||||
p.cost_diff = 0.5 * clusterCostDiff(uint(cluster_size[idx1]), uint(cluster_size[idx2]))
|
||||
p.cost_diff -= out[idx1].bit_cost_
|
||||
p.cost_diff -= out[idx2].bit_cost_
|
||||
|
||||
if out[idx1].total_count_ == 0 {
|
||||
p.cost_combo = out[idx2].bit_cost_
|
||||
is_good_pair = true
|
||||
} else if out[idx2].total_count_ == 0 {
|
||||
p.cost_combo = out[idx1].bit_cost_
|
||||
is_good_pair = true
|
||||
} else {
|
||||
var threshold float64
|
||||
if *num_pairs == 0 {
|
||||
threshold = 1e99
|
||||
} else {
|
||||
threshold = brotli_max_double(0.0, pairs[0].cost_diff)
|
||||
}
|
||||
var combo histogramDistance = out[idx1]
|
||||
var cost_combo float64
|
||||
histogramAddHistogramDistance(&combo, &out[idx2])
|
||||
cost_combo = populationCostDistance(&combo)
|
||||
if cost_combo < threshold-p.cost_diff {
|
||||
p.cost_combo = cost_combo
|
||||
is_good_pair = true
|
||||
}
|
||||
}
|
||||
|
||||
if is_good_pair {
|
||||
p.cost_diff += p.cost_combo
|
||||
if *num_pairs > 0 && histogramPairIsLess(&pairs[0], &p) {
|
||||
/* Replace the top of the queue if needed. */
|
||||
if *num_pairs < max_num_pairs {
|
||||
pairs[*num_pairs] = pairs[0]
|
||||
(*num_pairs)++
|
||||
}
|
||||
|
||||
pairs[0] = p
|
||||
} else if *num_pairs < max_num_pairs {
|
||||
pairs[*num_pairs] = p
|
||||
(*num_pairs)++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func histogramCombineDistance(out []histogramDistance, cluster_size []uint32, symbols []uint32, clusters []uint32, pairs []histogramPair, num_clusters uint, symbols_size uint, max_clusters uint, max_num_pairs uint) uint {
|
||||
var cost_diff_threshold float64 = 0.0
|
||||
var min_cluster_size uint = 1
|
||||
var num_pairs uint = 0
|
||||
{
|
||||
/* We maintain a vector of histogram pairs, with the property that the pair
|
||||
with the maximum bit cost reduction is the first. */
|
||||
var idx1 uint
|
||||
for idx1 = 0; idx1 < num_clusters; idx1++ {
|
||||
var idx2 uint
|
||||
for idx2 = idx1 + 1; idx2 < num_clusters; idx2++ {
|
||||
compareAndPushToQueueDistance(out, cluster_size, clusters[idx1], clusters[idx2], max_num_pairs, pairs[0:], &num_pairs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for num_clusters > min_cluster_size {
|
||||
var best_idx1 uint32
|
||||
var best_idx2 uint32
|
||||
var i uint
|
||||
if pairs[0].cost_diff >= cost_diff_threshold {
|
||||
cost_diff_threshold = 1e99
|
||||
min_cluster_size = max_clusters
|
||||
continue
|
||||
}
|
||||
|
||||
/* Take the best pair from the top of heap. */
|
||||
best_idx1 = pairs[0].idx1
|
||||
|
||||
best_idx2 = pairs[0].idx2
|
||||
histogramAddHistogramDistance(&out[best_idx1], &out[best_idx2])
|
||||
out[best_idx1].bit_cost_ = pairs[0].cost_combo
|
||||
cluster_size[best_idx1] += cluster_size[best_idx2]
|
||||
for i = 0; i < symbols_size; i++ {
|
||||
if symbols[i] == best_idx2 {
|
||||
symbols[i] = best_idx1
|
||||
}
|
||||
}
|
||||
|
||||
for i = 0; i < num_clusters; i++ {
|
||||
if clusters[i] == best_idx2 {
|
||||
copy(clusters[i:], clusters[i+1:][:num_clusters-i-1])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
num_clusters--
|
||||
{
|
||||
/* Remove pairs intersecting the just combined best pair. */
|
||||
var copy_to_idx uint = 0
|
||||
for i = 0; i < num_pairs; i++ {
|
||||
var p *histogramPair = &pairs[i]
|
||||
if p.idx1 == best_idx1 || p.idx2 == best_idx1 || p.idx1 == best_idx2 || p.idx2 == best_idx2 {
|
||||
/* Remove invalid pair from the queue. */
|
||||
continue
|
||||
}
|
||||
|
||||
if histogramPairIsLess(&pairs[0], p) {
|
||||
/* Replace the top of the queue if needed. */
|
||||
var front histogramPair = pairs[0]
|
||||
pairs[0] = *p
|
||||
pairs[copy_to_idx] = front
|
||||
} else {
|
||||
pairs[copy_to_idx] = *p
|
||||
}
|
||||
|
||||
copy_to_idx++
|
||||
}
|
||||
|
||||
num_pairs = copy_to_idx
|
||||
}
|
||||
|
||||
/* Push new pairs formed with the combined histogram to the heap. */
|
||||
for i = 0; i < num_clusters; i++ {
|
||||
compareAndPushToQueueDistance(out, cluster_size, best_idx1, clusters[i], max_num_pairs, pairs[0:], &num_pairs)
|
||||
}
|
||||
}
|
||||
|
||||
return num_clusters
|
||||
}
|
||||
|
||||
/* What is the bit cost of moving histogram from cur_symbol to candidate. */
|
||||
func histogramBitCostDistanceDistance(histogram *histogramDistance, candidate *histogramDistance) float64 {
|
||||
if histogram.total_count_ == 0 {
|
||||
return 0.0
|
||||
} else {
|
||||
var tmp histogramDistance = *histogram
|
||||
histogramAddHistogramDistance(&tmp, candidate)
|
||||
return populationCostDistance(&tmp) - candidate.bit_cost_
|
||||
}
|
||||
}
|
||||
|
||||
/* Find the best 'out' histogram for each of the 'in' histograms.
|
||||
When called, clusters[0..num_clusters) contains the unique values from
|
||||
symbols[0..in_size), but this property is not preserved in this function.
|
||||
Note: we assume that out[]->bit_cost_ is already up-to-date. */
|
||||
func histogramRemapDistance(in []histogramDistance, in_size uint, clusters []uint32, num_clusters uint, out []histogramDistance, symbols []uint32) {
|
||||
var i uint
|
||||
for i = 0; i < in_size; i++ {
|
||||
var best_out uint32
|
||||
if i == 0 {
|
||||
best_out = symbols[0]
|
||||
} else {
|
||||
best_out = symbols[i-1]
|
||||
}
|
||||
var best_bits float64 = histogramBitCostDistanceDistance(&in[i], &out[best_out])
|
||||
var j uint
|
||||
for j = 0; j < num_clusters; j++ {
|
||||
var cur_bits float64 = histogramBitCostDistanceDistance(&in[i], &out[clusters[j]])
|
||||
if cur_bits < best_bits {
|
||||
best_bits = cur_bits
|
||||
best_out = clusters[j]
|
||||
}
|
||||
}
|
||||
|
||||
symbols[i] = best_out
|
||||
}
|
||||
|
||||
/* Recompute each out based on raw and symbols. */
|
||||
for i = 0; i < num_clusters; i++ {
|
||||
histogramClearDistance(&out[clusters[i]])
|
||||
}
|
||||
|
||||
for i = 0; i < in_size; i++ {
|
||||
histogramAddHistogramDistance(&out[symbols[i]], &in[i])
|
||||
}
|
||||
}
|
||||
|
||||
/* Reorders elements of the out[0..length) array and changes values in
|
||||
symbols[0..length) array in the following way:
|
||||
* when called, symbols[] contains indexes into out[], and has N unique
|
||||
values (possibly N < length)
|
||||
* on return, symbols'[i] = f(symbols[i]) and
|
||||
out'[symbols'[i]] = out[symbols[i]], for each 0 <= i < length,
|
||||
where f is a bijection between the range of symbols[] and [0..N), and
|
||||
the first occurrences of values in symbols'[i] come in consecutive
|
||||
increasing order.
|
||||
Returns N, the number of unique values in symbols[]. */
|
||||
|
||||
var histogramReindexDistance_kInvalidIndex uint32 = math.MaxUint32
|
||||
|
||||
func histogramReindexDistance(out []histogramDistance, symbols []uint32, length uint) uint {
|
||||
var new_index []uint32 = make([]uint32, length)
|
||||
var next_index uint32
|
||||
var tmp []histogramDistance
|
||||
var i uint
|
||||
for i = 0; i < length; i++ {
|
||||
new_index[i] = histogramReindexDistance_kInvalidIndex
|
||||
}
|
||||
|
||||
next_index = 0
|
||||
for i = 0; i < length; i++ {
|
||||
if new_index[symbols[i]] == histogramReindexDistance_kInvalidIndex {
|
||||
new_index[symbols[i]] = next_index
|
||||
next_index++
|
||||
}
|
||||
}
|
||||
|
||||
/* TODO: by using idea of "cycle-sort" we can avoid allocation of
|
||||
tmp and reduce the number of copying by the factor of 2. */
|
||||
tmp = make([]histogramDistance, next_index)
|
||||
|
||||
next_index = 0
|
||||
for i = 0; i < length; i++ {
|
||||
if new_index[symbols[i]] == next_index {
|
||||
tmp[next_index] = out[symbols[i]]
|
||||
next_index++
|
||||
}
|
||||
|
||||
symbols[i] = new_index[symbols[i]]
|
||||
}
|
||||
|
||||
new_index = nil
|
||||
for i = 0; uint32(i) < next_index; i++ {
|
||||
out[i] = tmp[i]
|
||||
}
|
||||
|
||||
tmp = nil
|
||||
return uint(next_index)
|
||||
}
|
||||
|
||||
func clusterHistogramsDistance(in []histogramDistance, in_size uint, max_histograms uint, out []histogramDistance, out_size *uint, histogram_symbols []uint32) {
|
||||
var cluster_size []uint32 = make([]uint32, in_size)
|
||||
var clusters []uint32 = make([]uint32, in_size)
|
||||
var num_clusters uint = 0
|
||||
var max_input_histograms uint = 64
|
||||
var pairs_capacity uint = max_input_histograms * max_input_histograms / 2
|
||||
var pairs []histogramPair = make([]histogramPair, (pairs_capacity + 1))
|
||||
var i uint
|
||||
|
||||
/* For the first pass of clustering, we allow all pairs. */
|
||||
for i = 0; i < in_size; i++ {
|
||||
cluster_size[i] = 1
|
||||
}
|
||||
|
||||
for i = 0; i < in_size; i++ {
|
||||
out[i] = in[i]
|
||||
out[i].bit_cost_ = populationCostDistance(&in[i])
|
||||
histogram_symbols[i] = uint32(i)
|
||||
}
|
||||
|
||||
for i = 0; i < in_size; i += max_input_histograms {
|
||||
var num_to_combine uint = brotli_min_size_t(in_size-i, max_input_histograms)
|
||||
var num_new_clusters uint
|
||||
var j uint
|
||||
for j = 0; j < num_to_combine; j++ {
|
||||
clusters[num_clusters+j] = uint32(i + j)
|
||||
}
|
||||
|
||||
num_new_clusters = histogramCombineDistance(out, cluster_size, histogram_symbols[i:], clusters[num_clusters:], pairs, num_to_combine, num_to_combine, max_histograms, pairs_capacity)
|
||||
num_clusters += num_new_clusters
|
||||
}
|
||||
{
|
||||
/* For the second pass, we limit the total number of histogram pairs.
|
||||
After this limit is reached, we only keep searching for the best pair. */
|
||||
var max_num_pairs uint = brotli_min_size_t(64*num_clusters, (num_clusters/2)*num_clusters)
|
||||
if pairs_capacity < (max_num_pairs + 1) {
|
||||
var _new_size uint
|
||||
if pairs_capacity == 0 {
|
||||
_new_size = max_num_pairs + 1
|
||||
} else {
|
||||
_new_size = pairs_capacity
|
||||
}
|
||||
var new_array []histogramPair
|
||||
for _new_size < (max_num_pairs + 1) {
|
||||
_new_size *= 2
|
||||
}
|
||||
new_array = make([]histogramPair, _new_size)
|
||||
if pairs_capacity != 0 {
|
||||
copy(new_array, pairs[:pairs_capacity])
|
||||
}
|
||||
|
||||
pairs = new_array
|
||||
pairs_capacity = _new_size
|
||||
}
|
||||
|
||||
/* Collapse similar histograms. */
|
||||
num_clusters = histogramCombineDistance(out, cluster_size, histogram_symbols, clusters, pairs, num_clusters, in_size, max_histograms, max_num_pairs)
|
||||
}
|
||||
|
||||
pairs = nil
|
||||
cluster_size = nil
|
||||
|
||||
/* Find the optimal map from original histograms to the final ones. */
|
||||
histogramRemapDistance(in, in_size, clusters, num_clusters, out, histogram_symbols)
|
||||
|
||||
clusters = nil
|
||||
|
||||
/* Convert the context map to a canonical form. */
|
||||
*out_size = histogramReindexDistance(out, histogram_symbols, in_size)
|
||||
}
|
||||
+326
@@ -0,0 +1,326 @@
|
||||
package brotli
|
||||
|
||||
import "math"
|
||||
|
||||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Computes the bit cost reduction by combining out[idx1] and out[idx2] and if
|
||||
it is below a threshold, stores the pair (idx1, idx2) in the *pairs queue. */
|
||||
func compareAndPushToQueueLiteral(out []histogramLiteral, cluster_size []uint32, idx1 uint32, idx2 uint32, max_num_pairs uint, pairs []histogramPair, num_pairs *uint) {
|
||||
var is_good_pair bool = false
|
||||
var p histogramPair
|
||||
p.idx2 = 0
|
||||
p.idx1 = p.idx2
|
||||
p.cost_combo = 0
|
||||
p.cost_diff = p.cost_combo
|
||||
if idx1 == idx2 {
|
||||
return
|
||||
}
|
||||
|
||||
if idx2 < idx1 {
|
||||
var t uint32 = idx2
|
||||
idx2 = idx1
|
||||
idx1 = t
|
||||
}
|
||||
|
||||
p.idx1 = idx1
|
||||
p.idx2 = idx2
|
||||
p.cost_diff = 0.5 * clusterCostDiff(uint(cluster_size[idx1]), uint(cluster_size[idx2]))
|
||||
p.cost_diff -= out[idx1].bit_cost_
|
||||
p.cost_diff -= out[idx2].bit_cost_
|
||||
|
||||
if out[idx1].total_count_ == 0 {
|
||||
p.cost_combo = out[idx2].bit_cost_
|
||||
is_good_pair = true
|
||||
} else if out[idx2].total_count_ == 0 {
|
||||
p.cost_combo = out[idx1].bit_cost_
|
||||
is_good_pair = true
|
||||
} else {
|
||||
var threshold float64
|
||||
if *num_pairs == 0 {
|
||||
threshold = 1e99
|
||||
} else {
|
||||
threshold = brotli_max_double(0.0, pairs[0].cost_diff)
|
||||
}
|
||||
var combo histogramLiteral = out[idx1]
|
||||
var cost_combo float64
|
||||
histogramAddHistogramLiteral(&combo, &out[idx2])
|
||||
cost_combo = populationCostLiteral(&combo)
|
||||
if cost_combo < threshold-p.cost_diff {
|
||||
p.cost_combo = cost_combo
|
||||
is_good_pair = true
|
||||
}
|
||||
}
|
||||
|
||||
if is_good_pair {
|
||||
p.cost_diff += p.cost_combo
|
||||
if *num_pairs > 0 && histogramPairIsLess(&pairs[0], &p) {
|
||||
/* Replace the top of the queue if needed. */
|
||||
if *num_pairs < max_num_pairs {
|
||||
pairs[*num_pairs] = pairs[0]
|
||||
(*num_pairs)++
|
||||
}
|
||||
|
||||
pairs[0] = p
|
||||
} else if *num_pairs < max_num_pairs {
|
||||
pairs[*num_pairs] = p
|
||||
(*num_pairs)++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func histogramCombineLiteral(out []histogramLiteral, cluster_size []uint32, symbols []uint32, clusters []uint32, pairs []histogramPair, num_clusters uint, symbols_size uint, max_clusters uint, max_num_pairs uint) uint {
|
||||
var cost_diff_threshold float64 = 0.0
|
||||
var min_cluster_size uint = 1
|
||||
var num_pairs uint = 0
|
||||
{
|
||||
/* We maintain a vector of histogram pairs, with the property that the pair
|
||||
with the maximum bit cost reduction is the first. */
|
||||
var idx1 uint
|
||||
for idx1 = 0; idx1 < num_clusters; idx1++ {
|
||||
var idx2 uint
|
||||
for idx2 = idx1 + 1; idx2 < num_clusters; idx2++ {
|
||||
compareAndPushToQueueLiteral(out, cluster_size, clusters[idx1], clusters[idx2], max_num_pairs, pairs[0:], &num_pairs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for num_clusters > min_cluster_size {
|
||||
var best_idx1 uint32
|
||||
var best_idx2 uint32
|
||||
var i uint
|
||||
if pairs[0].cost_diff >= cost_diff_threshold {
|
||||
cost_diff_threshold = 1e99
|
||||
min_cluster_size = max_clusters
|
||||
continue
|
||||
}
|
||||
|
||||
/* Take the best pair from the top of heap. */
|
||||
best_idx1 = pairs[0].idx1
|
||||
|
||||
best_idx2 = pairs[0].idx2
|
||||
histogramAddHistogramLiteral(&out[best_idx1], &out[best_idx2])
|
||||
out[best_idx1].bit_cost_ = pairs[0].cost_combo
|
||||
cluster_size[best_idx1] += cluster_size[best_idx2]
|
||||
for i = 0; i < symbols_size; i++ {
|
||||
if symbols[i] == best_idx2 {
|
||||
symbols[i] = best_idx1
|
||||
}
|
||||
}
|
||||
|
||||
for i = 0; i < num_clusters; i++ {
|
||||
if clusters[i] == best_idx2 {
|
||||
copy(clusters[i:], clusters[i+1:][:num_clusters-i-1])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
num_clusters--
|
||||
{
|
||||
/* Remove pairs intersecting the just combined best pair. */
|
||||
var copy_to_idx uint = 0
|
||||
for i = 0; i < num_pairs; i++ {
|
||||
var p *histogramPair = &pairs[i]
|
||||
if p.idx1 == best_idx1 || p.idx2 == best_idx1 || p.idx1 == best_idx2 || p.idx2 == best_idx2 {
|
||||
/* Remove invalid pair from the queue. */
|
||||
continue
|
||||
}
|
||||
|
||||
if histogramPairIsLess(&pairs[0], p) {
|
||||
/* Replace the top of the queue if needed. */
|
||||
var front histogramPair = pairs[0]
|
||||
pairs[0] = *p
|
||||
pairs[copy_to_idx] = front
|
||||
} else {
|
||||
pairs[copy_to_idx] = *p
|
||||
}
|
||||
|
||||
copy_to_idx++
|
||||
}
|
||||
|
||||
num_pairs = copy_to_idx
|
||||
}
|
||||
|
||||
/* Push new pairs formed with the combined histogram to the heap. */
|
||||
for i = 0; i < num_clusters; i++ {
|
||||
compareAndPushToQueueLiteral(out, cluster_size, best_idx1, clusters[i], max_num_pairs, pairs[0:], &num_pairs)
|
||||
}
|
||||
}
|
||||
|
||||
return num_clusters
|
||||
}
|
||||
|
||||
/* What is the bit cost of moving histogram from cur_symbol to candidate. */
|
||||
func histogramBitCostDistanceLiteral(histogram *histogramLiteral, candidate *histogramLiteral) float64 {
|
||||
if histogram.total_count_ == 0 {
|
||||
return 0.0
|
||||
} else {
|
||||
var tmp histogramLiteral = *histogram
|
||||
histogramAddHistogramLiteral(&tmp, candidate)
|
||||
return populationCostLiteral(&tmp) - candidate.bit_cost_
|
||||
}
|
||||
}
|
||||
|
||||
/* Find the best 'out' histogram for each of the 'in' histograms.
|
||||
When called, clusters[0..num_clusters) contains the unique values from
|
||||
symbols[0..in_size), but this property is not preserved in this function.
|
||||
Note: we assume that out[]->bit_cost_ is already up-to-date. */
|
||||
func histogramRemapLiteral(in []histogramLiteral, in_size uint, clusters []uint32, num_clusters uint, out []histogramLiteral, symbols []uint32) {
|
||||
var i uint
|
||||
for i = 0; i < in_size; i++ {
|
||||
var best_out uint32
|
||||
if i == 0 {
|
||||
best_out = symbols[0]
|
||||
} else {
|
||||
best_out = symbols[i-1]
|
||||
}
|
||||
var best_bits float64 = histogramBitCostDistanceLiteral(&in[i], &out[best_out])
|
||||
var j uint
|
||||
for j = 0; j < num_clusters; j++ {
|
||||
var cur_bits float64 = histogramBitCostDistanceLiteral(&in[i], &out[clusters[j]])
|
||||
if cur_bits < best_bits {
|
||||
best_bits = cur_bits
|
||||
best_out = clusters[j]
|
||||
}
|
||||
}
|
||||
|
||||
symbols[i] = best_out
|
||||
}
|
||||
|
||||
/* Recompute each out based on raw and symbols. */
|
||||
for i = 0; i < num_clusters; i++ {
|
||||
histogramClearLiteral(&out[clusters[i]])
|
||||
}
|
||||
|
||||
for i = 0; i < in_size; i++ {
|
||||
histogramAddHistogramLiteral(&out[symbols[i]], &in[i])
|
||||
}
|
||||
}
|
||||
|
||||
/* Reorders elements of the out[0..length) array and changes values in
|
||||
symbols[0..length) array in the following way:
|
||||
* when called, symbols[] contains indexes into out[], and has N unique
|
||||
values (possibly N < length)
|
||||
* on return, symbols'[i] = f(symbols[i]) and
|
||||
out'[symbols'[i]] = out[symbols[i]], for each 0 <= i < length,
|
||||
where f is a bijection between the range of symbols[] and [0..N), and
|
||||
the first occurrences of values in symbols'[i] come in consecutive
|
||||
increasing order.
|
||||
Returns N, the number of unique values in symbols[]. */
|
||||
|
||||
var histogramReindexLiteral_kInvalidIndex uint32 = math.MaxUint32
|
||||
|
||||
func histogramReindexLiteral(out []histogramLiteral, symbols []uint32, length uint) uint {
|
||||
var new_index []uint32 = make([]uint32, length)
|
||||
var next_index uint32
|
||||
var tmp []histogramLiteral
|
||||
var i uint
|
||||
for i = 0; i < length; i++ {
|
||||
new_index[i] = histogramReindexLiteral_kInvalidIndex
|
||||
}
|
||||
|
||||
next_index = 0
|
||||
for i = 0; i < length; i++ {
|
||||
if new_index[symbols[i]] == histogramReindexLiteral_kInvalidIndex {
|
||||
new_index[symbols[i]] = next_index
|
||||
next_index++
|
||||
}
|
||||
}
|
||||
|
||||
/* TODO: by using idea of "cycle-sort" we can avoid allocation of
|
||||
tmp and reduce the number of copying by the factor of 2. */
|
||||
tmp = make([]histogramLiteral, next_index)
|
||||
|
||||
next_index = 0
|
||||
for i = 0; i < length; i++ {
|
||||
if new_index[symbols[i]] == next_index {
|
||||
tmp[next_index] = out[symbols[i]]
|
||||
next_index++
|
||||
}
|
||||
|
||||
symbols[i] = new_index[symbols[i]]
|
||||
}
|
||||
|
||||
new_index = nil
|
||||
for i = 0; uint32(i) < next_index; i++ {
|
||||
out[i] = tmp[i]
|
||||
}
|
||||
|
||||
tmp = nil
|
||||
return uint(next_index)
|
||||
}
|
||||
|
||||
func clusterHistogramsLiteral(in []histogramLiteral, in_size uint, max_histograms uint, out []histogramLiteral, out_size *uint, histogram_symbols []uint32) {
|
||||
var cluster_size []uint32 = make([]uint32, in_size)
|
||||
var clusters []uint32 = make([]uint32, in_size)
|
||||
var num_clusters uint = 0
|
||||
var max_input_histograms uint = 64
|
||||
var pairs_capacity uint = max_input_histograms * max_input_histograms / 2
|
||||
var pairs []histogramPair = make([]histogramPair, (pairs_capacity + 1))
|
||||
var i uint
|
||||
|
||||
/* For the first pass of clustering, we allow all pairs. */
|
||||
for i = 0; i < in_size; i++ {
|
||||
cluster_size[i] = 1
|
||||
}
|
||||
|
||||
for i = 0; i < in_size; i++ {
|
||||
out[i] = in[i]
|
||||
out[i].bit_cost_ = populationCostLiteral(&in[i])
|
||||
histogram_symbols[i] = uint32(i)
|
||||
}
|
||||
|
||||
for i = 0; i < in_size; i += max_input_histograms {
|
||||
var num_to_combine uint = brotli_min_size_t(in_size-i, max_input_histograms)
|
||||
var num_new_clusters uint
|
||||
var j uint
|
||||
for j = 0; j < num_to_combine; j++ {
|
||||
clusters[num_clusters+j] = uint32(i + j)
|
||||
}
|
||||
|
||||
num_new_clusters = histogramCombineLiteral(out, cluster_size, histogram_symbols[i:], clusters[num_clusters:], pairs, num_to_combine, num_to_combine, max_histograms, pairs_capacity)
|
||||
num_clusters += num_new_clusters
|
||||
}
|
||||
{
|
||||
/* For the second pass, we limit the total number of histogram pairs.
|
||||
After this limit is reached, we only keep searching for the best pair. */
|
||||
var max_num_pairs uint = brotli_min_size_t(64*num_clusters, (num_clusters/2)*num_clusters)
|
||||
if pairs_capacity < (max_num_pairs + 1) {
|
||||
var _new_size uint
|
||||
if pairs_capacity == 0 {
|
||||
_new_size = max_num_pairs + 1
|
||||
} else {
|
||||
_new_size = pairs_capacity
|
||||
}
|
||||
var new_array []histogramPair
|
||||
for _new_size < (max_num_pairs + 1) {
|
||||
_new_size *= 2
|
||||
}
|
||||
new_array = make([]histogramPair, _new_size)
|
||||
if pairs_capacity != 0 {
|
||||
copy(new_array, pairs[:pairs_capacity])
|
||||
}
|
||||
|
||||
pairs = new_array
|
||||
pairs_capacity = _new_size
|
||||
}
|
||||
|
||||
/* Collapse similar histograms. */
|
||||
num_clusters = histogramCombineLiteral(out, cluster_size, histogram_symbols, clusters, pairs, num_clusters, in_size, max_histograms, max_num_pairs)
|
||||
}
|
||||
|
||||
pairs = nil
|
||||
cluster_size = nil
|
||||
|
||||
/* Find the optimal map from original histograms to the final ones. */
|
||||
histogramRemapLiteral(in, in_size, clusters, num_clusters, out, histogram_symbols)
|
||||
|
||||
clusters = nil
|
||||
|
||||
/* Convert the context map to a canonical form. */
|
||||
*out_size = histogramReindexLiteral(out, histogram_symbols, in_size)
|
||||
}
|
||||
+254
@@ -0,0 +1,254 @@
|
||||
package brotli
|
||||
|
||||
var kInsBase = []uint32{
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
8,
|
||||
10,
|
||||
14,
|
||||
18,
|
||||
26,
|
||||
34,
|
||||
50,
|
||||
66,
|
||||
98,
|
||||
130,
|
||||
194,
|
||||
322,
|
||||
578,
|
||||
1090,
|
||||
2114,
|
||||
6210,
|
||||
22594,
|
||||
}
|
||||
|
||||
var kInsExtra = []uint32{
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
3,
|
||||
3,
|
||||
4,
|
||||
4,
|
||||
5,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
12,
|
||||
14,
|
||||
24,
|
||||
}
|
||||
|
||||
var kCopyBase = []uint32{
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
12,
|
||||
14,
|
||||
18,
|
||||
22,
|
||||
30,
|
||||
38,
|
||||
54,
|
||||
70,
|
||||
102,
|
||||
134,
|
||||
198,
|
||||
326,
|
||||
582,
|
||||
1094,
|
||||
2118,
|
||||
}
|
||||
|
||||
var kCopyExtra = []uint32{
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
3,
|
||||
3,
|
||||
4,
|
||||
4,
|
||||
5,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
24,
|
||||
}
|
||||
|
||||
func getInsertLengthCode(insertlen uint) uint16 {
|
||||
if insertlen < 6 {
|
||||
return uint16(insertlen)
|
||||
} else if insertlen < 130 {
|
||||
var nbits uint32 = log2FloorNonZero(insertlen-2) - 1
|
||||
return uint16((nbits << 1) + uint32((insertlen-2)>>nbits) + 2)
|
||||
} else if insertlen < 2114 {
|
||||
return uint16(log2FloorNonZero(insertlen-66) + 10)
|
||||
} else if insertlen < 6210 {
|
||||
return 21
|
||||
} else if insertlen < 22594 {
|
||||
return 22
|
||||
} else {
|
||||
return 23
|
||||
}
|
||||
}
|
||||
|
||||
func getCopyLengthCode(copylen uint) uint16 {
|
||||
if copylen < 10 {
|
||||
return uint16(copylen - 2)
|
||||
} else if copylen < 134 {
|
||||
var nbits uint32 = log2FloorNonZero(copylen-6) - 1
|
||||
return uint16((nbits << 1) + uint32((copylen-6)>>nbits) + 4)
|
||||
} else if copylen < 2118 {
|
||||
return uint16(log2FloorNonZero(copylen-70) + 12)
|
||||
} else {
|
||||
return 23
|
||||
}
|
||||
}
|
||||
|
||||
func combineLengthCodes(inscode uint16, copycode uint16, use_last_distance bool) uint16 {
|
||||
var bits64 uint16 = uint16(copycode&0x7 | (inscode&0x7)<<3)
|
||||
if use_last_distance && inscode < 8 && copycode < 16 {
|
||||
if copycode < 8 {
|
||||
return bits64
|
||||
} else {
|
||||
return bits64 | 64
|
||||
}
|
||||
} else {
|
||||
/* Specification: 5 Encoding of ... (last table) */
|
||||
/* offset = 2 * index, where index is in range [0..8] */
|
||||
var offset uint32 = 2 * ((uint32(copycode) >> 3) + 3*(uint32(inscode)>>3))
|
||||
|
||||
/* All values in specification are K * 64,
|
||||
where K = [2, 3, 6, 4, 5, 8, 7, 9, 10],
|
||||
i + 1 = [1, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
K - i - 1 = [1, 1, 3, 0, 0, 2, 0, 1, 2] = D.
|
||||
All values in D require only 2 bits to encode.
|
||||
Magic constant is shifted 6 bits left, to avoid final multiplication. */
|
||||
offset = (offset << 5) + 0x40 + ((0x520D40 >> offset) & 0xC0)
|
||||
|
||||
return uint16(offset | uint32(bits64))
|
||||
}
|
||||
}
|
||||
|
||||
func getLengthCode(insertlen uint, copylen uint, use_last_distance bool, code *uint16) {
|
||||
var inscode uint16 = getInsertLengthCode(insertlen)
|
||||
var copycode uint16 = getCopyLengthCode(copylen)
|
||||
*code = combineLengthCodes(inscode, copycode, use_last_distance)
|
||||
}
|
||||
|
||||
func getInsertBase(inscode uint16) uint32 {
|
||||
return kInsBase[inscode]
|
||||
}
|
||||
|
||||
func getInsertExtra(inscode uint16) uint32 {
|
||||
return kInsExtra[inscode]
|
||||
}
|
||||
|
||||
func getCopyBase(copycode uint16) uint32 {
|
||||
return kCopyBase[copycode]
|
||||
}
|
||||
|
||||
func getCopyExtra(copycode uint16) uint32 {
|
||||
return kCopyExtra[copycode]
|
||||
}
|
||||
|
||||
type command struct {
|
||||
insert_len_ uint32
|
||||
copy_len_ uint32
|
||||
dist_extra_ uint32
|
||||
cmd_prefix_ uint16
|
||||
dist_prefix_ uint16
|
||||
}
|
||||
|
||||
/* distance_code is e.g. 0 for same-as-last short code, or 16 for offset 1. */
|
||||
func makeCommand(dist *distanceParams, insertlen uint, copylen uint, copylen_code_delta int, distance_code uint) (cmd command) {
|
||||
/* Don't rely on signed int representation, use honest casts. */
|
||||
var delta uint32 = uint32(byte(int8(copylen_code_delta)))
|
||||
cmd.insert_len_ = uint32(insertlen)
|
||||
cmd.copy_len_ = uint32(uint32(copylen) | delta<<25)
|
||||
|
||||
/* The distance prefix and extra bits are stored in this Command as if
|
||||
npostfix and ndirect were 0, they are only recomputed later after the
|
||||
clustering if needed. */
|
||||
prefixEncodeCopyDistance(distance_code, uint(dist.num_direct_distance_codes), uint(dist.distance_postfix_bits), &cmd.dist_prefix_, &cmd.dist_extra_)
|
||||
getLengthCode(insertlen, uint(int(copylen)+copylen_code_delta), (cmd.dist_prefix_&0x3FF == 0), &cmd.cmd_prefix_)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func makeInsertCommand(insertlen uint) (cmd command) {
|
||||
cmd.insert_len_ = uint32(insertlen)
|
||||
cmd.copy_len_ = 4 << 25
|
||||
cmd.dist_extra_ = 0
|
||||
cmd.dist_prefix_ = numDistanceShortCodes
|
||||
getLengthCode(insertlen, 4, false, &cmd.cmd_prefix_)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func commandRestoreDistanceCode(self *command, dist *distanceParams) uint32 {
|
||||
if uint32(self.dist_prefix_&0x3FF) < numDistanceShortCodes+dist.num_direct_distance_codes {
|
||||
return uint32(self.dist_prefix_) & 0x3FF
|
||||
} else {
|
||||
var dcode uint32 = uint32(self.dist_prefix_) & 0x3FF
|
||||
var nbits uint32 = uint32(self.dist_prefix_) >> 10
|
||||
var extra uint32 = self.dist_extra_
|
||||
var postfix_mask uint32 = (1 << dist.distance_postfix_bits) - 1
|
||||
var hcode uint32 = (dcode - dist.num_direct_distance_codes - numDistanceShortCodes) >> dist.distance_postfix_bits
|
||||
var lcode uint32 = (dcode - dist.num_direct_distance_codes - numDistanceShortCodes) & postfix_mask
|
||||
var offset uint32 = ((2 + (hcode & 1)) << nbits) - 4
|
||||
return ((offset + extra) << dist.distance_postfix_bits) + lcode + dist.num_direct_distance_codes + numDistanceShortCodes
|
||||
}
|
||||
}
|
||||
|
||||
func commandDistanceContext(self *command) uint32 {
|
||||
var r uint32 = uint32(self.cmd_prefix_) >> 6
|
||||
var c uint32 = uint32(self.cmd_prefix_) & 7
|
||||
if (r == 0 || r == 2 || r == 4 || r == 7) && (c <= 2) {
|
||||
return c
|
||||
}
|
||||
|
||||
return 3
|
||||
}
|
||||
|
||||
func commandCopyLen(self *command) uint32 {
|
||||
return self.copy_len_ & 0x1FFFFFF
|
||||
}
|
||||
|
||||
func commandCopyLenCode(self *command) uint32 {
|
||||
var modifier uint32 = self.copy_len_ >> 25
|
||||
var delta int32 = int32(int8(byte(modifier | (modifier&0x40)<<1)))
|
||||
return uint32(int32(self.copy_len_&0x1FFFFFF) + delta)
|
||||
}
|
||||
+834
@@ -0,0 +1,834 @@
|
||||
package brotli
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Function for fast encoding of an input fragment, independently from the input
|
||||
history. This function uses one-pass processing: when we find a backward
|
||||
match, we immediately emit the corresponding command and literal codes to
|
||||
the bit stream.
|
||||
|
||||
Adapted from the CompressFragment() function in
|
||||
https://github.com/google/snappy/blob/master/snappy.cc */
|
||||
|
||||
const maxDistance_compress_fragment = 262128
|
||||
|
||||
func hash5(p []byte, shift uint) uint32 {
|
||||
var h uint64 = (binary.LittleEndian.Uint64(p) << 24) * uint64(kHashMul32)
|
||||
return uint32(h >> shift)
|
||||
}
|
||||
|
||||
func hashBytesAtOffset5(v uint64, offset int, shift uint) uint32 {
|
||||
assert(offset >= 0)
|
||||
assert(offset <= 3)
|
||||
{
|
||||
var h uint64 = ((v >> uint(8*offset)) << 24) * uint64(kHashMul32)
|
||||
return uint32(h >> shift)
|
||||
}
|
||||
}
|
||||
|
||||
func isMatch5(p1 []byte, p2 []byte) bool {
|
||||
return binary.LittleEndian.Uint32(p1) == binary.LittleEndian.Uint32(p2) &&
|
||||
p1[4] == p2[4]
|
||||
}
|
||||
|
||||
/* Builds a literal prefix code into "depths" and "bits" based on the statistics
|
||||
of the "input" string and stores it into the bit stream.
|
||||
Note that the prefix code here is built from the pre-LZ77 input, therefore
|
||||
we can only approximate the statistics of the actual literal stream.
|
||||
Moreover, for long inputs we build a histogram from a sample of the input
|
||||
and thus have to assign a non-zero depth for each literal.
|
||||
Returns estimated compression ratio millibytes/char for encoding given input
|
||||
with generated code. */
|
||||
func buildAndStoreLiteralPrefixCode(input []byte, input_size uint, depths []byte, bits []uint16, storage_ix *uint, storage []byte) uint {
|
||||
var histogram = [256]uint32{0}
|
||||
var histogram_total uint
|
||||
var i uint
|
||||
if input_size < 1<<15 {
|
||||
for i = 0; i < input_size; i++ {
|
||||
histogram[input[i]]++
|
||||
}
|
||||
|
||||
histogram_total = input_size
|
||||
for i = 0; i < 256; i++ {
|
||||
/* We weigh the first 11 samples with weight 3 to account for the
|
||||
balancing effect of the LZ77 phase on the histogram. */
|
||||
var adjust uint32 = 2 * brotli_min_uint32_t(histogram[i], 11)
|
||||
histogram[i] += adjust
|
||||
histogram_total += uint(adjust)
|
||||
}
|
||||
} else {
|
||||
const kSampleRate uint = 29
|
||||
for i = 0; i < input_size; i += kSampleRate {
|
||||
histogram[input[i]]++
|
||||
}
|
||||
|
||||
histogram_total = (input_size + kSampleRate - 1) / kSampleRate
|
||||
for i = 0; i < 256; i++ {
|
||||
/* We add 1 to each population count to avoid 0 bit depths (since this is
|
||||
only a sample and we don't know if the symbol appears or not), and we
|
||||
weigh the first 11 samples with weight 3 to account for the balancing
|
||||
effect of the LZ77 phase on the histogram (more frequent symbols are
|
||||
more likely to be in backward references instead as literals). */
|
||||
var adjust uint32 = 1 + 2*brotli_min_uint32_t(histogram[i], 11)
|
||||
histogram[i] += adjust
|
||||
histogram_total += uint(adjust)
|
||||
}
|
||||
}
|
||||
|
||||
buildAndStoreHuffmanTreeFast(histogram[:], histogram_total, /* max_bits = */
|
||||
8, depths, bits, storage_ix, storage)
|
||||
{
|
||||
var literal_ratio uint = 0
|
||||
for i = 0; i < 256; i++ {
|
||||
if histogram[i] != 0 {
|
||||
literal_ratio += uint(histogram[i] * uint32(depths[i]))
|
||||
}
|
||||
}
|
||||
|
||||
/* Estimated encoding ratio, millibytes per symbol. */
|
||||
return (literal_ratio * 125) / histogram_total
|
||||
}
|
||||
}
|
||||
|
||||
/* Builds a command and distance prefix code (each 64 symbols) into "depth" and
|
||||
"bits" based on "histogram" and stores it into the bit stream. */
|
||||
func buildAndStoreCommandPrefixCode1(histogram []uint32, depth []byte, bits []uint16, storage_ix *uint, storage []byte) {
|
||||
var tree [129]huffmanTree
|
||||
var cmd_depth = [numCommandSymbols]byte{0}
|
||||
/* Tree size for building a tree over 64 symbols is 2 * 64 + 1. */
|
||||
|
||||
var cmd_bits [64]uint16
|
||||
|
||||
createHuffmanTree(histogram, 64, 15, tree[:], depth)
|
||||
createHuffmanTree(histogram[64:], 64, 14, tree[:], depth[64:])
|
||||
|
||||
/* We have to jump through a few hoops here in order to compute
|
||||
the command bits because the symbols are in a different order than in
|
||||
the full alphabet. This looks complicated, but having the symbols
|
||||
in this order in the command bits saves a few branches in the Emit*
|
||||
functions. */
|
||||
copy(cmd_depth[:], depth[:24])
|
||||
|
||||
copy(cmd_depth[24:][:], depth[40:][:8])
|
||||
copy(cmd_depth[32:][:], depth[24:][:8])
|
||||
copy(cmd_depth[40:][:], depth[48:][:8])
|
||||
copy(cmd_depth[48:][:], depth[32:][:8])
|
||||
copy(cmd_depth[56:][:], depth[56:][:8])
|
||||
convertBitDepthsToSymbols(cmd_depth[:], 64, cmd_bits[:])
|
||||
copy(bits, cmd_bits[:24])
|
||||
copy(bits[24:], cmd_bits[32:][:8])
|
||||
copy(bits[32:], cmd_bits[48:][:8])
|
||||
copy(bits[40:], cmd_bits[24:][:8])
|
||||
copy(bits[48:], cmd_bits[40:][:8])
|
||||
copy(bits[56:], cmd_bits[56:][:8])
|
||||
convertBitDepthsToSymbols(depth[64:], 64, bits[64:])
|
||||
{
|
||||
/* Create the bit length array for the full command alphabet. */
|
||||
var i uint
|
||||
for i := 0; i < int(64); i++ {
|
||||
cmd_depth[i] = 0
|
||||
} /* only 64 first values were used */
|
||||
copy(cmd_depth[:], depth[:8])
|
||||
copy(cmd_depth[64:][:], depth[8:][:8])
|
||||
copy(cmd_depth[128:][:], depth[16:][:8])
|
||||
copy(cmd_depth[192:][:], depth[24:][:8])
|
||||
copy(cmd_depth[384:][:], depth[32:][:8])
|
||||
for i = 0; i < 8; i++ {
|
||||
cmd_depth[128+8*i] = depth[40+i]
|
||||
cmd_depth[256+8*i] = depth[48+i]
|
||||
cmd_depth[448+8*i] = depth[56+i]
|
||||
}
|
||||
|
||||
storeHuffmanTree(cmd_depth[:], numCommandSymbols, tree[:], storage_ix, storage)
|
||||
}
|
||||
|
||||
storeHuffmanTree(depth[64:], 64, tree[:], storage_ix, storage)
|
||||
}
|
||||
|
||||
/* REQUIRES: insertlen < 6210 */
|
||||
func emitInsertLen1(insertlen uint, depth []byte, bits []uint16, histo []uint32, storage_ix *uint, storage []byte) {
|
||||
if insertlen < 6 {
|
||||
var code uint = insertlen + 40
|
||||
writeBits(uint(depth[code]), uint64(bits[code]), storage_ix, storage)
|
||||
histo[code]++
|
||||
} else if insertlen < 130 {
|
||||
var tail uint = insertlen - 2
|
||||
var nbits uint32 = log2FloorNonZero(tail) - 1
|
||||
var prefix uint = tail >> nbits
|
||||
var inscode uint = uint((nbits << 1) + uint32(prefix) + 42)
|
||||
writeBits(uint(depth[inscode]), uint64(bits[inscode]), storage_ix, storage)
|
||||
writeBits(uint(nbits), uint64(tail)-(uint64(prefix)<<nbits), storage_ix, storage)
|
||||
histo[inscode]++
|
||||
} else if insertlen < 2114 {
|
||||
var tail uint = insertlen - 66
|
||||
var nbits uint32 = log2FloorNonZero(tail)
|
||||
var code uint = uint(nbits + 50)
|
||||
writeBits(uint(depth[code]), uint64(bits[code]), storage_ix, storage)
|
||||
writeBits(uint(nbits), uint64(tail)-(uint64(uint(1))<<nbits), storage_ix, storage)
|
||||
histo[code]++
|
||||
} else {
|
||||
writeBits(uint(depth[61]), uint64(bits[61]), storage_ix, storage)
|
||||
writeBits(12, uint64(insertlen)-2114, storage_ix, storage)
|
||||
histo[61]++
|
||||
}
|
||||
}
|
||||
|
||||
func emitLongInsertLen(insertlen uint, depth []byte, bits []uint16, histo []uint32, storage_ix *uint, storage []byte) {
|
||||
if insertlen < 22594 {
|
||||
writeBits(uint(depth[62]), uint64(bits[62]), storage_ix, storage)
|
||||
writeBits(14, uint64(insertlen)-6210, storage_ix, storage)
|
||||
histo[62]++
|
||||
} else {
|
||||
writeBits(uint(depth[63]), uint64(bits[63]), storage_ix, storage)
|
||||
writeBits(24, uint64(insertlen)-22594, storage_ix, storage)
|
||||
histo[63]++
|
||||
}
|
||||
}
|
||||
|
||||
func emitCopyLen1(copylen uint, depth []byte, bits []uint16, histo []uint32, storage_ix *uint, storage []byte) {
|
||||
if copylen < 10 {
|
||||
writeBits(uint(depth[copylen+14]), uint64(bits[copylen+14]), storage_ix, storage)
|
||||
histo[copylen+14]++
|
||||
} else if copylen < 134 {
|
||||
var tail uint = copylen - 6
|
||||
var nbits uint32 = log2FloorNonZero(tail) - 1
|
||||
var prefix uint = tail >> nbits
|
||||
var code uint = uint((nbits << 1) + uint32(prefix) + 20)
|
||||
writeBits(uint(depth[code]), uint64(bits[code]), storage_ix, storage)
|
||||
writeBits(uint(nbits), uint64(tail)-(uint64(prefix)<<nbits), storage_ix, storage)
|
||||
histo[code]++
|
||||
} else if copylen < 2118 {
|
||||
var tail uint = copylen - 70
|
||||
var nbits uint32 = log2FloorNonZero(tail)
|
||||
var code uint = uint(nbits + 28)
|
||||
writeBits(uint(depth[code]), uint64(bits[code]), storage_ix, storage)
|
||||
writeBits(uint(nbits), uint64(tail)-(uint64(uint(1))<<nbits), storage_ix, storage)
|
||||
histo[code]++
|
||||
} else {
|
||||
writeBits(uint(depth[39]), uint64(bits[39]), storage_ix, storage)
|
||||
writeBits(24, uint64(copylen)-2118, storage_ix, storage)
|
||||
histo[39]++
|
||||
}
|
||||
}
|
||||
|
||||
func emitCopyLenLastDistance1(copylen uint, depth []byte, bits []uint16, histo []uint32, storage_ix *uint, storage []byte) {
|
||||
if copylen < 12 {
|
||||
writeBits(uint(depth[copylen-4]), uint64(bits[copylen-4]), storage_ix, storage)
|
||||
histo[copylen-4]++
|
||||
} else if copylen < 72 {
|
||||
var tail uint = copylen - 8
|
||||
var nbits uint32 = log2FloorNonZero(tail) - 1
|
||||
var prefix uint = tail >> nbits
|
||||
var code uint = uint((nbits << 1) + uint32(prefix) + 4)
|
||||
writeBits(uint(depth[code]), uint64(bits[code]), storage_ix, storage)
|
||||
writeBits(uint(nbits), uint64(tail)-(uint64(prefix)<<nbits), storage_ix, storage)
|
||||
histo[code]++
|
||||
} else if copylen < 136 {
|
||||
var tail uint = copylen - 8
|
||||
var code uint = (tail >> 5) + 30
|
||||
writeBits(uint(depth[code]), uint64(bits[code]), storage_ix, storage)
|
||||
writeBits(5, uint64(tail)&31, storage_ix, storage)
|
||||
writeBits(uint(depth[64]), uint64(bits[64]), storage_ix, storage)
|
||||
histo[code]++
|
||||
histo[64]++
|
||||
} else if copylen < 2120 {
|
||||
var tail uint = copylen - 72
|
||||
var nbits uint32 = log2FloorNonZero(tail)
|
||||
var code uint = uint(nbits + 28)
|
||||
writeBits(uint(depth[code]), uint64(bits[code]), storage_ix, storage)
|
||||
writeBits(uint(nbits), uint64(tail)-(uint64(uint(1))<<nbits), storage_ix, storage)
|
||||
writeBits(uint(depth[64]), uint64(bits[64]), storage_ix, storage)
|
||||
histo[code]++
|
||||
histo[64]++
|
||||
} else {
|
||||
writeBits(uint(depth[39]), uint64(bits[39]), storage_ix, storage)
|
||||
writeBits(24, uint64(copylen)-2120, storage_ix, storage)
|
||||
writeBits(uint(depth[64]), uint64(bits[64]), storage_ix, storage)
|
||||
histo[39]++
|
||||
histo[64]++
|
||||
}
|
||||
}
|
||||
|
||||
func emitDistance1(distance uint, depth []byte, bits []uint16, histo []uint32, storage_ix *uint, storage []byte) {
|
||||
var d uint = distance + 3
|
||||
var nbits uint32 = log2FloorNonZero(d) - 1
|
||||
var prefix uint = (d >> nbits) & 1
|
||||
var offset uint = (2 + prefix) << nbits
|
||||
var distcode uint = uint(2*(nbits-1) + uint32(prefix) + 80)
|
||||
writeBits(uint(depth[distcode]), uint64(bits[distcode]), storage_ix, storage)
|
||||
writeBits(uint(nbits), uint64(d)-uint64(offset), storage_ix, storage)
|
||||
histo[distcode]++
|
||||
}
|
||||
|
||||
func emitLiterals(input []byte, len uint, depth []byte, bits []uint16, storage_ix *uint, storage []byte) {
|
||||
var j uint
|
||||
for j = 0; j < len; j++ {
|
||||
var lit byte = input[j]
|
||||
writeBits(uint(depth[lit]), uint64(bits[lit]), storage_ix, storage)
|
||||
}
|
||||
}
|
||||
|
||||
/* REQUIRES: len <= 1 << 24. */
|
||||
func storeMetaBlockHeader1(len uint, is_uncompressed bool, storage_ix *uint, storage []byte) {
|
||||
var nibbles uint = 6
|
||||
|
||||
/* ISLAST */
|
||||
writeBits(1, 0, storage_ix, storage)
|
||||
|
||||
if len <= 1<<16 {
|
||||
nibbles = 4
|
||||
} else if len <= 1<<20 {
|
||||
nibbles = 5
|
||||
}
|
||||
|
||||
writeBits(2, uint64(nibbles)-4, storage_ix, storage)
|
||||
writeBits(nibbles*4, uint64(len)-1, storage_ix, storage)
|
||||
|
||||
/* ISUNCOMPRESSED */
|
||||
writeSingleBit(is_uncompressed, storage_ix, storage)
|
||||
}
|
||||
|
||||
func updateBits(n_bits uint, bits uint32, pos uint, array []byte) {
|
||||
for n_bits > 0 {
|
||||
var byte_pos uint = pos >> 3
|
||||
var n_unchanged_bits uint = pos & 7
|
||||
var n_changed_bits uint = brotli_min_size_t(n_bits, 8-n_unchanged_bits)
|
||||
var total_bits uint = n_unchanged_bits + n_changed_bits
|
||||
var mask uint32 = (^((1 << total_bits) - 1)) | ((1 << n_unchanged_bits) - 1)
|
||||
var unchanged_bits uint32 = uint32(array[byte_pos]) & mask
|
||||
var changed_bits uint32 = bits & ((1 << n_changed_bits) - 1)
|
||||
array[byte_pos] = byte(changed_bits<<n_unchanged_bits | unchanged_bits)
|
||||
n_bits -= n_changed_bits
|
||||
bits >>= n_changed_bits
|
||||
pos += n_changed_bits
|
||||
}
|
||||
}
|
||||
|
||||
func rewindBitPosition1(new_storage_ix uint, storage_ix *uint, storage []byte) {
|
||||
var bitpos uint = new_storage_ix & 7
|
||||
var mask uint = (1 << bitpos) - 1
|
||||
storage[new_storage_ix>>3] &= byte(mask)
|
||||
*storage_ix = new_storage_ix
|
||||
}
|
||||
|
||||
var shouldMergeBlock_kSampleRate uint = 43
|
||||
|
||||
func shouldMergeBlock(data []byte, len uint, depths []byte) bool {
|
||||
var histo = [256]uint{0}
|
||||
var i uint
|
||||
for i = 0; i < len; i += shouldMergeBlock_kSampleRate {
|
||||
histo[data[i]]++
|
||||
}
|
||||
{
|
||||
var total uint = (len + shouldMergeBlock_kSampleRate - 1) / shouldMergeBlock_kSampleRate
|
||||
var r float64 = (fastLog2(total)+0.5)*float64(total) + 200
|
||||
for i = 0; i < 256; i++ {
|
||||
r -= float64(histo[i]) * (float64(depths[i]) + fastLog2(histo[i]))
|
||||
}
|
||||
|
||||
return r >= 0.0
|
||||
}
|
||||
}
|
||||
|
||||
func shouldUseUncompressedMode(metablock_start []byte, next_emit []byte, insertlen uint, literal_ratio uint) bool {
|
||||
var compressed uint = uint(-cap(next_emit) + cap(metablock_start))
|
||||
if compressed*50 > insertlen {
|
||||
return false
|
||||
} else {
|
||||
return literal_ratio > 980
|
||||
}
|
||||
}
|
||||
|
||||
func emitUncompressedMetaBlock1(begin []byte, end []byte, storage_ix_start uint, storage_ix *uint, storage []byte) {
|
||||
var len uint = uint(-cap(end) + cap(begin))
|
||||
rewindBitPosition1(storage_ix_start, storage_ix, storage)
|
||||
storeMetaBlockHeader1(uint(len), true, storage_ix, storage)
|
||||
*storage_ix = (*storage_ix + 7) &^ 7
|
||||
copy(storage[*storage_ix>>3:], begin[:len])
|
||||
*storage_ix += uint(len << 3)
|
||||
storage[*storage_ix>>3] = 0
|
||||
}
|
||||
|
||||
var kCmdHistoSeed = [128]uint32{
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
}
|
||||
|
||||
var compressFragmentFastImpl_kFirstBlockSize uint = 3 << 15
|
||||
var compressFragmentFastImpl_kMergeBlockSize uint = 1 << 16
|
||||
|
||||
func compressFragmentFastImpl(in []byte, input_size uint, is_last bool, table []int, table_bits uint, cmd_depth []byte, cmd_bits []uint16, cmd_code_numbits *uint, cmd_code []byte, storage_ix *uint, storage []byte) {
|
||||
var cmd_histo [128]uint32
|
||||
var ip_end int
|
||||
var next_emit int = 0
|
||||
var base_ip int = 0
|
||||
var input int = 0
|
||||
const kInputMarginBytes uint = windowGap
|
||||
const kMinMatchLen uint = 5
|
||||
var metablock_start int = input
|
||||
var block_size uint = brotli_min_size_t(input_size, compressFragmentFastImpl_kFirstBlockSize)
|
||||
var total_block_size uint = block_size
|
||||
var mlen_storage_ix uint = *storage_ix + 3
|
||||
var lit_depth [256]byte
|
||||
var lit_bits [256]uint16
|
||||
var literal_ratio uint
|
||||
var ip int
|
||||
var last_distance int
|
||||
var shift uint = 64 - table_bits
|
||||
|
||||
/* "next_emit" is a pointer to the first byte that is not covered by a
|
||||
previous copy. Bytes between "next_emit" and the start of the next copy or
|
||||
the end of the input will be emitted as literal bytes. */
|
||||
|
||||
/* Save the start of the first block for position and distance computations.
|
||||
*/
|
||||
|
||||
/* Save the bit position of the MLEN field of the meta-block header, so that
|
||||
we can update it later if we decide to extend this meta-block. */
|
||||
storeMetaBlockHeader1(block_size, false, storage_ix, storage)
|
||||
|
||||
/* No block splits, no contexts. */
|
||||
writeBits(13, 0, storage_ix, storage)
|
||||
|
||||
literal_ratio = buildAndStoreLiteralPrefixCode(in[input:], block_size, lit_depth[:], lit_bits[:], storage_ix, storage)
|
||||
{
|
||||
/* Store the pre-compressed command and distance prefix codes. */
|
||||
var i uint
|
||||
for i = 0; i+7 < *cmd_code_numbits; i += 8 {
|
||||
writeBits(8, uint64(cmd_code[i>>3]), storage_ix, storage)
|
||||
}
|
||||
}
|
||||
|
||||
writeBits(*cmd_code_numbits&7, uint64(cmd_code[*cmd_code_numbits>>3]), storage_ix, storage)
|
||||
|
||||
/* Initialize the command and distance histograms. We will gather
|
||||
statistics of command and distance codes during the processing
|
||||
of this block and use it to update the command and distance
|
||||
prefix codes for the next block. */
|
||||
emit_commands:
|
||||
copy(cmd_histo[:], kCmdHistoSeed[:])
|
||||
|
||||
/* "ip" is the input pointer. */
|
||||
ip = input
|
||||
|
||||
last_distance = -1
|
||||
ip_end = int(uint(input) + block_size)
|
||||
|
||||
if block_size >= kInputMarginBytes {
|
||||
var len_limit uint = brotli_min_size_t(block_size-kMinMatchLen, input_size-kInputMarginBytes)
|
||||
var ip_limit int = int(uint(input) + len_limit)
|
||||
/* For the last block, we need to keep a 16 bytes margin so that we can be
|
||||
sure that all distances are at most window size - 16.
|
||||
For all other blocks, we only need to keep a margin of 5 bytes so that
|
||||
we don't go over the block size with a copy. */
|
||||
|
||||
var next_hash uint32
|
||||
ip++
|
||||
for next_hash = hash5(in[ip:], shift); ; {
|
||||
var skip uint32 = 32
|
||||
var next_ip int = ip
|
||||
/* Step 1: Scan forward in the input looking for a 5-byte-long match.
|
||||
If we get close to exhausting the input then goto emit_remainder.
|
||||
|
||||
Heuristic match skipping: If 32 bytes are scanned with no matches
|
||||
found, start looking only at every other byte. If 32 more bytes are
|
||||
scanned, look at every third byte, etc.. When a match is found,
|
||||
immediately go back to looking at every byte. This is a small loss
|
||||
(~5% performance, ~0.1% density) for compressible data due to more
|
||||
bookkeeping, but for non-compressible data (such as JPEG) it's a huge
|
||||
win since the compressor quickly "realizes" the data is incompressible
|
||||
and doesn't bother looking for matches everywhere.
|
||||
|
||||
The "skip" variable keeps track of how many bytes there are since the
|
||||
last match; dividing it by 32 (i.e. right-shifting by five) gives the
|
||||
number of bytes to move ahead for each iteration. */
|
||||
|
||||
var candidate int
|
||||
assert(next_emit < ip)
|
||||
|
||||
trawl:
|
||||
for {
|
||||
var hash uint32 = next_hash
|
||||
var bytes_between_hash_lookups uint32 = skip >> 5
|
||||
skip++
|
||||
assert(hash == hash5(in[next_ip:], shift))
|
||||
ip = next_ip
|
||||
next_ip = int(uint32(ip) + bytes_between_hash_lookups)
|
||||
if next_ip > ip_limit {
|
||||
goto emit_remainder
|
||||
}
|
||||
|
||||
next_hash = hash5(in[next_ip:], shift)
|
||||
candidate = ip - last_distance
|
||||
if isMatch5(in[ip:], in[candidate:]) {
|
||||
if candidate < ip {
|
||||
table[hash] = int(ip - base_ip)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
candidate = base_ip + table[hash]
|
||||
assert(candidate >= base_ip)
|
||||
assert(candidate < ip)
|
||||
|
||||
table[hash] = int(ip - base_ip)
|
||||
if isMatch5(in[ip:], in[candidate:]) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
/* Check copy distance. If candidate is not feasible, continue search.
|
||||
Checking is done outside of hot loop to reduce overhead. */
|
||||
if ip-candidate > maxDistance_compress_fragment {
|
||||
goto trawl
|
||||
}
|
||||
|
||||
/* Step 2: Emit the found match together with the literal bytes from
|
||||
"next_emit" to the bit stream, and then see if we can find a next match
|
||||
immediately afterwards. Repeat until we find no match for the input
|
||||
without emitting some literal bytes. */
|
||||
{
|
||||
var base int = ip
|
||||
/* > 0 */
|
||||
var matched uint = 5 + findMatchLengthWithLimit(in[candidate+5:], in[ip+5:], uint(ip_end-ip)-5)
|
||||
var distance int = int(base - candidate)
|
||||
/* We have a 5-byte match at ip, and we need to emit bytes in
|
||||
[next_emit, ip). */
|
||||
|
||||
var insert uint = uint(base - next_emit)
|
||||
ip += int(matched)
|
||||
if insert < 6210 {
|
||||
emitInsertLen1(insert, cmd_depth, cmd_bits, cmd_histo[:], storage_ix, storage)
|
||||
} else if shouldUseUncompressedMode(in[metablock_start:], in[next_emit:], insert, literal_ratio) {
|
||||
emitUncompressedMetaBlock1(in[metablock_start:], in[base:], mlen_storage_ix-3, storage_ix, storage)
|
||||
input_size -= uint(base - input)
|
||||
input = base
|
||||
next_emit = input
|
||||
goto next_block
|
||||
} else {
|
||||
emitLongInsertLen(insert, cmd_depth, cmd_bits, cmd_histo[:], storage_ix, storage)
|
||||
}
|
||||
|
||||
emitLiterals(in[next_emit:], insert, lit_depth[:], lit_bits[:], storage_ix, storage)
|
||||
if distance == last_distance {
|
||||
writeBits(uint(cmd_depth[64]), uint64(cmd_bits[64]), storage_ix, storage)
|
||||
cmd_histo[64]++
|
||||
} else {
|
||||
emitDistance1(uint(distance), cmd_depth, cmd_bits, cmd_histo[:], storage_ix, storage)
|
||||
last_distance = distance
|
||||
}
|
||||
|
||||
emitCopyLenLastDistance1(matched, cmd_depth, cmd_bits, cmd_histo[:], storage_ix, storage)
|
||||
|
||||
next_emit = ip
|
||||
if ip >= ip_limit {
|
||||
goto emit_remainder
|
||||
}
|
||||
|
||||
/* We could immediately start working at ip now, but to improve
|
||||
compression we first update "table" with the hashes of some positions
|
||||
within the last copy. */
|
||||
{
|
||||
var input_bytes uint64 = binary.LittleEndian.Uint64(in[ip-3:])
|
||||
var prev_hash uint32 = hashBytesAtOffset5(input_bytes, 0, shift)
|
||||
var cur_hash uint32 = hashBytesAtOffset5(input_bytes, 3, shift)
|
||||
table[prev_hash] = int(ip - base_ip - 3)
|
||||
prev_hash = hashBytesAtOffset5(input_bytes, 1, shift)
|
||||
table[prev_hash] = int(ip - base_ip - 2)
|
||||
prev_hash = hashBytesAtOffset5(input_bytes, 2, shift)
|
||||
table[prev_hash] = int(ip - base_ip - 1)
|
||||
|
||||
candidate = base_ip + table[cur_hash]
|
||||
table[cur_hash] = int(ip - base_ip)
|
||||
}
|
||||
}
|
||||
|
||||
for isMatch5(in[ip:], in[candidate:]) {
|
||||
var base int = ip
|
||||
/* We have a 5-byte match at ip, and no need to emit any literal bytes
|
||||
prior to ip. */
|
||||
|
||||
var matched uint = 5 + findMatchLengthWithLimit(in[candidate+5:], in[ip+5:], uint(ip_end-ip)-5)
|
||||
if ip-candidate > maxDistance_compress_fragment {
|
||||
break
|
||||
}
|
||||
ip += int(matched)
|
||||
last_distance = int(base - candidate) /* > 0 */
|
||||
emitCopyLen1(matched, cmd_depth, cmd_bits, cmd_histo[:], storage_ix, storage)
|
||||
emitDistance1(uint(last_distance), cmd_depth, cmd_bits, cmd_histo[:], storage_ix, storage)
|
||||
|
||||
next_emit = ip
|
||||
if ip >= ip_limit {
|
||||
goto emit_remainder
|
||||
}
|
||||
|
||||
/* We could immediately start working at ip now, but to improve
|
||||
compression we first update "table" with the hashes of some positions
|
||||
within the last copy. */
|
||||
{
|
||||
var input_bytes uint64 = binary.LittleEndian.Uint64(in[ip-3:])
|
||||
var prev_hash uint32 = hashBytesAtOffset5(input_bytes, 0, shift)
|
||||
var cur_hash uint32 = hashBytesAtOffset5(input_bytes, 3, shift)
|
||||
table[prev_hash] = int(ip - base_ip - 3)
|
||||
prev_hash = hashBytesAtOffset5(input_bytes, 1, shift)
|
||||
table[prev_hash] = int(ip - base_ip - 2)
|
||||
prev_hash = hashBytesAtOffset5(input_bytes, 2, shift)
|
||||
table[prev_hash] = int(ip - base_ip - 1)
|
||||
|
||||
candidate = base_ip + table[cur_hash]
|
||||
table[cur_hash] = int(ip - base_ip)
|
||||
}
|
||||
}
|
||||
|
||||
ip++
|
||||
next_hash = hash5(in[ip:], shift)
|
||||
}
|
||||
}
|
||||
|
||||
emit_remainder:
|
||||
assert(next_emit <= ip_end)
|
||||
input += int(block_size)
|
||||
input_size -= block_size
|
||||
block_size = brotli_min_size_t(input_size, compressFragmentFastImpl_kMergeBlockSize)
|
||||
|
||||
/* Decide if we want to continue this meta-block instead of emitting the
|
||||
last insert-only command. */
|
||||
if input_size > 0 && total_block_size+block_size <= 1<<20 && shouldMergeBlock(in[input:], block_size, lit_depth[:]) {
|
||||
assert(total_block_size > 1<<16)
|
||||
|
||||
/* Update the size of the current meta-block and continue emitting commands.
|
||||
We can do this because the current size and the new size both have 5
|
||||
nibbles. */
|
||||
total_block_size += block_size
|
||||
|
||||
updateBits(20, uint32(total_block_size-1), mlen_storage_ix, storage)
|
||||
goto emit_commands
|
||||
}
|
||||
|
||||
/* Emit the remaining bytes as literals. */
|
||||
if next_emit < ip_end {
|
||||
var insert uint = uint(ip_end - next_emit)
|
||||
if insert < 6210 {
|
||||
emitInsertLen1(insert, cmd_depth, cmd_bits, cmd_histo[:], storage_ix, storage)
|
||||
emitLiterals(in[next_emit:], insert, lit_depth[:], lit_bits[:], storage_ix, storage)
|
||||
} else if shouldUseUncompressedMode(in[metablock_start:], in[next_emit:], insert, literal_ratio) {
|
||||
emitUncompressedMetaBlock1(in[metablock_start:], in[ip_end:], mlen_storage_ix-3, storage_ix, storage)
|
||||
} else {
|
||||
emitLongInsertLen(insert, cmd_depth, cmd_bits, cmd_histo[:], storage_ix, storage)
|
||||
emitLiterals(in[next_emit:], insert, lit_depth[:], lit_bits[:], storage_ix, storage)
|
||||
}
|
||||
}
|
||||
|
||||
next_emit = ip_end
|
||||
|
||||
/* If we have more data, write a new meta-block header and prefix codes and
|
||||
then continue emitting commands. */
|
||||
next_block:
|
||||
if input_size > 0 {
|
||||
metablock_start = input
|
||||
block_size = brotli_min_size_t(input_size, compressFragmentFastImpl_kFirstBlockSize)
|
||||
total_block_size = block_size
|
||||
|
||||
/* Save the bit position of the MLEN field of the meta-block header, so that
|
||||
we can update it later if we decide to extend this meta-block. */
|
||||
mlen_storage_ix = *storage_ix + 3
|
||||
|
||||
storeMetaBlockHeader1(block_size, false, storage_ix, storage)
|
||||
|
||||
/* No block splits, no contexts. */
|
||||
writeBits(13, 0, storage_ix, storage)
|
||||
|
||||
literal_ratio = buildAndStoreLiteralPrefixCode(in[input:], block_size, lit_depth[:], lit_bits[:], storage_ix, storage)
|
||||
buildAndStoreCommandPrefixCode1(cmd_histo[:], cmd_depth, cmd_bits, storage_ix, storage)
|
||||
goto emit_commands
|
||||
}
|
||||
|
||||
if !is_last {
|
||||
/* If this is not the last block, update the command and distance prefix
|
||||
codes for the next block and store the compressed forms. */
|
||||
cmd_code[0] = 0
|
||||
|
||||
*cmd_code_numbits = 0
|
||||
buildAndStoreCommandPrefixCode1(cmd_histo[:], cmd_depth, cmd_bits, cmd_code_numbits, cmd_code)
|
||||
}
|
||||
}
|
||||
|
||||
/* Compresses "input" string to the "*storage" buffer as one or more complete
|
||||
meta-blocks, and updates the "*storage_ix" bit position.
|
||||
|
||||
If "is_last" is 1, emits an additional empty last meta-block.
|
||||
|
||||
"cmd_depth" and "cmd_bits" contain the command and distance prefix codes
|
||||
(see comment in encode.h) used for the encoding of this input fragment.
|
||||
If "is_last" is 0, they are updated to reflect the statistics
|
||||
of this input fragment, to be used for the encoding of the next fragment.
|
||||
|
||||
"*cmd_code_numbits" is the number of bits of the compressed representation
|
||||
of the command and distance prefix codes, and "cmd_code" is an array of
|
||||
at least "(*cmd_code_numbits + 7) >> 3" size that contains the compressed
|
||||
command and distance prefix codes. If "is_last" is 0, these are also
|
||||
updated to represent the updated "cmd_depth" and "cmd_bits".
|
||||
|
||||
REQUIRES: "input_size" is greater than zero, or "is_last" is 1.
|
||||
REQUIRES: "input_size" is less or equal to maximal metablock size (1 << 24).
|
||||
REQUIRES: All elements in "table[0..table_size-1]" are initialized to zero.
|
||||
REQUIRES: "table_size" is an odd (9, 11, 13, 15) power of two
|
||||
OUTPUT: maximal copy distance <= |input_size|
|
||||
OUTPUT: maximal copy distance <= BROTLI_MAX_BACKWARD_LIMIT(18) */
|
||||
func compressFragmentFast(input []byte, input_size uint, is_last bool, table []int, table_size uint, cmd_depth []byte, cmd_bits []uint16, cmd_code_numbits *uint, cmd_code []byte, storage_ix *uint, storage []byte) {
|
||||
var initial_storage_ix uint = *storage_ix
|
||||
var table_bits uint = uint(log2FloorNonZero(table_size))
|
||||
|
||||
if input_size == 0 {
|
||||
assert(is_last)
|
||||
writeBits(1, 1, storage_ix, storage) /* islast */
|
||||
writeBits(1, 1, storage_ix, storage) /* isempty */
|
||||
*storage_ix = (*storage_ix + 7) &^ 7
|
||||
return
|
||||
}
|
||||
|
||||
compressFragmentFastImpl(input, input_size, is_last, table, table_bits, cmd_depth, cmd_bits, cmd_code_numbits, cmd_code, storage_ix, storage)
|
||||
|
||||
/* If output is larger than single uncompressed block, rewrite it. */
|
||||
if *storage_ix-initial_storage_ix > 31+(input_size<<3) {
|
||||
emitUncompressedMetaBlock1(input, input[input_size:], initial_storage_ix, storage_ix, storage)
|
||||
}
|
||||
|
||||
if is_last {
|
||||
writeBits(1, 1, storage_ix, storage) /* islast */
|
||||
writeBits(1, 1, storage_ix, storage) /* isempty */
|
||||
*storage_ix = (*storage_ix + 7) &^ 7
|
||||
}
|
||||
}
|
||||
+748
@@ -0,0 +1,748 @@
|
||||
package brotli
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Function for fast encoding of an input fragment, independently from the input
|
||||
history. This function uses two-pass processing: in the first pass we save
|
||||
the found backward matches and literal bytes into a buffer, and in the
|
||||
second pass we emit them into the bit stream using prefix codes built based
|
||||
on the actual command and literal byte histograms. */
|
||||
|
||||
const kCompressFragmentTwoPassBlockSize uint = 1 << 17
|
||||
|
||||
func hash1(p []byte, shift uint, length uint) uint32 {
|
||||
var h uint64 = (binary.LittleEndian.Uint64(p) << ((8 - length) * 8)) * uint64(kHashMul32)
|
||||
return uint32(h >> shift)
|
||||
}
|
||||
|
||||
func hashBytesAtOffset(v uint64, offset uint, shift uint, length uint) uint32 {
|
||||
assert(offset <= 8-length)
|
||||
{
|
||||
var h uint64 = ((v >> (8 * offset)) << ((8 - length) * 8)) * uint64(kHashMul32)
|
||||
return uint32(h >> shift)
|
||||
}
|
||||
}
|
||||
|
||||
func isMatch1(p1 []byte, p2 []byte, length uint) bool {
|
||||
if binary.LittleEndian.Uint32(p1) != binary.LittleEndian.Uint32(p2) {
|
||||
return false
|
||||
}
|
||||
if length == 4 {
|
||||
return true
|
||||
}
|
||||
return p1[4] == p2[4] && p1[5] == p2[5]
|
||||
}
|
||||
|
||||
/* Builds a command and distance prefix code (each 64 symbols) into "depth" and
|
||||
"bits" based on "histogram" and stores it into the bit stream. */
|
||||
func buildAndStoreCommandPrefixCode(histogram []uint32, depth []byte, bits []uint16, storage_ix *uint, storage []byte) {
|
||||
var tree [129]huffmanTree
|
||||
var cmd_depth = [numCommandSymbols]byte{0}
|
||||
/* Tree size for building a tree over 64 symbols is 2 * 64 + 1. */
|
||||
|
||||
var cmd_bits [64]uint16
|
||||
createHuffmanTree(histogram, 64, 15, tree[:], depth)
|
||||
createHuffmanTree(histogram[64:], 64, 14, tree[:], depth[64:])
|
||||
|
||||
/* We have to jump through a few hoops here in order to compute
|
||||
the command bits because the symbols are in a different order than in
|
||||
the full alphabet. This looks complicated, but having the symbols
|
||||
in this order in the command bits saves a few branches in the Emit*
|
||||
functions. */
|
||||
copy(cmd_depth[:], depth[24:][:24])
|
||||
|
||||
copy(cmd_depth[24:][:], depth[:8])
|
||||
copy(cmd_depth[32:][:], depth[48:][:8])
|
||||
copy(cmd_depth[40:][:], depth[8:][:8])
|
||||
copy(cmd_depth[48:][:], depth[56:][:8])
|
||||
copy(cmd_depth[56:][:], depth[16:][:8])
|
||||
convertBitDepthsToSymbols(cmd_depth[:], 64, cmd_bits[:])
|
||||
copy(bits, cmd_bits[24:][:8])
|
||||
copy(bits[8:], cmd_bits[40:][:8])
|
||||
copy(bits[16:], cmd_bits[56:][:8])
|
||||
copy(bits[24:], cmd_bits[:24])
|
||||
copy(bits[48:], cmd_bits[32:][:8])
|
||||
copy(bits[56:], cmd_bits[48:][:8])
|
||||
convertBitDepthsToSymbols(depth[64:], 64, bits[64:])
|
||||
{
|
||||
/* Create the bit length array for the full command alphabet. */
|
||||
var i uint
|
||||
for i := 0; i < int(64); i++ {
|
||||
cmd_depth[i] = 0
|
||||
} /* only 64 first values were used */
|
||||
copy(cmd_depth[:], depth[24:][:8])
|
||||
copy(cmd_depth[64:][:], depth[32:][:8])
|
||||
copy(cmd_depth[128:][:], depth[40:][:8])
|
||||
copy(cmd_depth[192:][:], depth[48:][:8])
|
||||
copy(cmd_depth[384:][:], depth[56:][:8])
|
||||
for i = 0; i < 8; i++ {
|
||||
cmd_depth[128+8*i] = depth[i]
|
||||
cmd_depth[256+8*i] = depth[8+i]
|
||||
cmd_depth[448+8*i] = depth[16+i]
|
||||
}
|
||||
|
||||
storeHuffmanTree(cmd_depth[:], numCommandSymbols, tree[:], storage_ix, storage)
|
||||
}
|
||||
|
||||
storeHuffmanTree(depth[64:], 64, tree[:], storage_ix, storage)
|
||||
}
|
||||
|
||||
func emitInsertLen(insertlen uint32, commands *[]uint32) {
|
||||
if insertlen < 6 {
|
||||
(*commands)[0] = insertlen
|
||||
} else if insertlen < 130 {
|
||||
var tail uint32 = insertlen - 2
|
||||
var nbits uint32 = log2FloorNonZero(uint(tail)) - 1
|
||||
var prefix uint32 = tail >> nbits
|
||||
var inscode uint32 = (nbits << 1) + prefix + 2
|
||||
var extra uint32 = tail - (prefix << nbits)
|
||||
(*commands)[0] = inscode | extra<<8
|
||||
} else if insertlen < 2114 {
|
||||
var tail uint32 = insertlen - 66
|
||||
var nbits uint32 = log2FloorNonZero(uint(tail))
|
||||
var code uint32 = nbits + 10
|
||||
var extra uint32 = tail - (1 << nbits)
|
||||
(*commands)[0] = code | extra<<8
|
||||
} else if insertlen < 6210 {
|
||||
var extra uint32 = insertlen - 2114
|
||||
(*commands)[0] = 21 | extra<<8
|
||||
} else if insertlen < 22594 {
|
||||
var extra uint32 = insertlen - 6210
|
||||
(*commands)[0] = 22 | extra<<8
|
||||
} else {
|
||||
var extra uint32 = insertlen - 22594
|
||||
(*commands)[0] = 23 | extra<<8
|
||||
}
|
||||
|
||||
*commands = (*commands)[1:]
|
||||
}
|
||||
|
||||
func emitCopyLen(copylen uint, commands *[]uint32) {
|
||||
if copylen < 10 {
|
||||
(*commands)[0] = uint32(copylen + 38)
|
||||
} else if copylen < 134 {
|
||||
var tail uint = copylen - 6
|
||||
var nbits uint = uint(log2FloorNonZero(tail) - 1)
|
||||
var prefix uint = tail >> nbits
|
||||
var code uint = (nbits << 1) + prefix + 44
|
||||
var extra uint = tail - (prefix << nbits)
|
||||
(*commands)[0] = uint32(code | extra<<8)
|
||||
} else if copylen < 2118 {
|
||||
var tail uint = copylen - 70
|
||||
var nbits uint = uint(log2FloorNonZero(tail))
|
||||
var code uint = nbits + 52
|
||||
var extra uint = tail - (uint(1) << nbits)
|
||||
(*commands)[0] = uint32(code | extra<<8)
|
||||
} else {
|
||||
var extra uint = copylen - 2118
|
||||
(*commands)[0] = uint32(63 | extra<<8)
|
||||
}
|
||||
|
||||
*commands = (*commands)[1:]
|
||||
}
|
||||
|
||||
func emitCopyLenLastDistance(copylen uint, commands *[]uint32) {
|
||||
if copylen < 12 {
|
||||
(*commands)[0] = uint32(copylen + 20)
|
||||
*commands = (*commands)[1:]
|
||||
} else if copylen < 72 {
|
||||
var tail uint = copylen - 8
|
||||
var nbits uint = uint(log2FloorNonZero(tail) - 1)
|
||||
var prefix uint = tail >> nbits
|
||||
var code uint = (nbits << 1) + prefix + 28
|
||||
var extra uint = tail - (prefix << nbits)
|
||||
(*commands)[0] = uint32(code | extra<<8)
|
||||
*commands = (*commands)[1:]
|
||||
} else if copylen < 136 {
|
||||
var tail uint = copylen - 8
|
||||
var code uint = (tail >> 5) + 54
|
||||
var extra uint = tail & 31
|
||||
(*commands)[0] = uint32(code | extra<<8)
|
||||
*commands = (*commands)[1:]
|
||||
(*commands)[0] = 64
|
||||
*commands = (*commands)[1:]
|
||||
} else if copylen < 2120 {
|
||||
var tail uint = copylen - 72
|
||||
var nbits uint = uint(log2FloorNonZero(tail))
|
||||
var code uint = nbits + 52
|
||||
var extra uint = tail - (uint(1) << nbits)
|
||||
(*commands)[0] = uint32(code | extra<<8)
|
||||
*commands = (*commands)[1:]
|
||||
(*commands)[0] = 64
|
||||
*commands = (*commands)[1:]
|
||||
} else {
|
||||
var extra uint = copylen - 2120
|
||||
(*commands)[0] = uint32(63 | extra<<8)
|
||||
*commands = (*commands)[1:]
|
||||
(*commands)[0] = 64
|
||||
*commands = (*commands)[1:]
|
||||
}
|
||||
}
|
||||
|
||||
func emitDistance(distance uint32, commands *[]uint32) {
|
||||
var d uint32 = distance + 3
|
||||
var nbits uint32 = log2FloorNonZero(uint(d)) - 1
|
||||
var prefix uint32 = (d >> nbits) & 1
|
||||
var offset uint32 = (2 + prefix) << nbits
|
||||
var distcode uint32 = 2*(nbits-1) + prefix + 80
|
||||
var extra uint32 = d - offset
|
||||
(*commands)[0] = distcode | extra<<8
|
||||
*commands = (*commands)[1:]
|
||||
}
|
||||
|
||||
/* REQUIRES: len <= 1 << 24. */
|
||||
func storeMetaBlockHeader(len uint, is_uncompressed bool, storage_ix *uint, storage []byte) {
|
||||
var nibbles uint = 6
|
||||
|
||||
/* ISLAST */
|
||||
writeBits(1, 0, storage_ix, storage)
|
||||
|
||||
if len <= 1<<16 {
|
||||
nibbles = 4
|
||||
} else if len <= 1<<20 {
|
||||
nibbles = 5
|
||||
}
|
||||
|
||||
writeBits(2, uint64(nibbles)-4, storage_ix, storage)
|
||||
writeBits(nibbles*4, uint64(len)-1, storage_ix, storage)
|
||||
|
||||
/* ISUNCOMPRESSED */
|
||||
writeSingleBit(is_uncompressed, storage_ix, storage)
|
||||
}
|
||||
|
||||
func createCommands(input []byte, block_size uint, input_size uint, base_ip_ptr []byte, table []int, table_bits uint, min_match uint, literals *[]byte, commands *[]uint32) {
|
||||
var ip int = 0
|
||||
var shift uint = 64 - table_bits
|
||||
var ip_end int = int(block_size)
|
||||
var base_ip int = -cap(base_ip_ptr) + cap(input)
|
||||
var next_emit int = 0
|
||||
var last_distance int = -1
|
||||
/* "ip" is the input pointer. */
|
||||
|
||||
const kInputMarginBytes uint = windowGap
|
||||
|
||||
/* "next_emit" is a pointer to the first byte that is not covered by a
|
||||
previous copy. Bytes between "next_emit" and the start of the next copy or
|
||||
the end of the input will be emitted as literal bytes. */
|
||||
if block_size >= kInputMarginBytes {
|
||||
var len_limit uint = brotli_min_size_t(block_size-min_match, input_size-kInputMarginBytes)
|
||||
var ip_limit int = int(len_limit)
|
||||
/* For the last block, we need to keep a 16 bytes margin so that we can be
|
||||
sure that all distances are at most window size - 16.
|
||||
For all other blocks, we only need to keep a margin of 5 bytes so that
|
||||
we don't go over the block size with a copy. */
|
||||
|
||||
var next_hash uint32
|
||||
ip++
|
||||
for next_hash = hash1(input[ip:], shift, min_match); ; {
|
||||
var skip uint32 = 32
|
||||
var next_ip int = ip
|
||||
/* Step 1: Scan forward in the input looking for a 6-byte-long match.
|
||||
If we get close to exhausting the input then goto emit_remainder.
|
||||
|
||||
Heuristic match skipping: If 32 bytes are scanned with no matches
|
||||
found, start looking only at every other byte. If 32 more bytes are
|
||||
scanned, look at every third byte, etc.. When a match is found,
|
||||
immediately go back to looking at every byte. This is a small loss
|
||||
(~5% performance, ~0.1% density) for compressible data due to more
|
||||
bookkeeping, but for non-compressible data (such as JPEG) it's a huge
|
||||
win since the compressor quickly "realizes" the data is incompressible
|
||||
and doesn't bother looking for matches everywhere.
|
||||
|
||||
The "skip" variable keeps track of how many bytes there are since the
|
||||
last match; dividing it by 32 (ie. right-shifting by five) gives the
|
||||
number of bytes to move ahead for each iteration. */
|
||||
|
||||
var candidate int
|
||||
|
||||
assert(next_emit < ip)
|
||||
|
||||
trawl:
|
||||
for {
|
||||
var hash uint32 = next_hash
|
||||
var bytes_between_hash_lookups uint32 = skip >> 5
|
||||
skip++
|
||||
ip = next_ip
|
||||
assert(hash == hash1(input[ip:], shift, min_match))
|
||||
next_ip = int(uint32(ip) + bytes_between_hash_lookups)
|
||||
if next_ip > ip_limit {
|
||||
goto emit_remainder
|
||||
}
|
||||
|
||||
next_hash = hash1(input[next_ip:], shift, min_match)
|
||||
candidate = ip - last_distance
|
||||
if isMatch1(input[ip:], base_ip_ptr[candidate-base_ip:], min_match) {
|
||||
if candidate < ip {
|
||||
table[hash] = int(ip - base_ip)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
candidate = base_ip + table[hash]
|
||||
assert(candidate >= base_ip)
|
||||
assert(candidate < ip)
|
||||
|
||||
table[hash] = int(ip - base_ip)
|
||||
if isMatch1(input[ip:], base_ip_ptr[candidate-base_ip:], min_match) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
/* Check copy distance. If candidate is not feasible, continue search.
|
||||
Checking is done outside of hot loop to reduce overhead. */
|
||||
if ip-candidate > maxDistance_compress_fragment {
|
||||
goto trawl
|
||||
}
|
||||
|
||||
/* Step 2: Emit the found match together with the literal bytes from
|
||||
"next_emit", and then see if we can find a next match immediately
|
||||
afterwards. Repeat until we find no match for the input
|
||||
without emitting some literal bytes. */
|
||||
{
|
||||
var base int = ip
|
||||
/* > 0 */
|
||||
var matched uint = min_match + findMatchLengthWithLimit(base_ip_ptr[uint(candidate-base_ip)+min_match:], input[uint(ip)+min_match:], uint(ip_end-ip)-min_match)
|
||||
var distance int = int(base - candidate)
|
||||
/* We have a 6-byte match at ip, and we need to emit bytes in
|
||||
[next_emit, ip). */
|
||||
|
||||
var insert int = int(base - next_emit)
|
||||
ip += int(matched)
|
||||
emitInsertLen(uint32(insert), commands)
|
||||
copy(*literals, input[next_emit:][:uint(insert)])
|
||||
*literals = (*literals)[insert:]
|
||||
if distance == last_distance {
|
||||
(*commands)[0] = 64
|
||||
*commands = (*commands)[1:]
|
||||
} else {
|
||||
emitDistance(uint32(distance), commands)
|
||||
last_distance = distance
|
||||
}
|
||||
|
||||
emitCopyLenLastDistance(matched, commands)
|
||||
|
||||
next_emit = ip
|
||||
if ip >= ip_limit {
|
||||
goto emit_remainder
|
||||
}
|
||||
{
|
||||
var input_bytes uint64
|
||||
var cur_hash uint32
|
||||
/* We could immediately start working at ip now, but to improve
|
||||
compression we first update "table" with the hashes of some
|
||||
positions within the last copy. */
|
||||
|
||||
var prev_hash uint32
|
||||
if min_match == 4 {
|
||||
input_bytes = binary.LittleEndian.Uint64(input[ip-3:])
|
||||
cur_hash = hashBytesAtOffset(input_bytes, 3, shift, min_match)
|
||||
prev_hash = hashBytesAtOffset(input_bytes, 0, shift, min_match)
|
||||
table[prev_hash] = int(ip - base_ip - 3)
|
||||
prev_hash = hashBytesAtOffset(input_bytes, 1, shift, min_match)
|
||||
table[prev_hash] = int(ip - base_ip - 2)
|
||||
prev_hash = hashBytesAtOffset(input_bytes, 0, shift, min_match)
|
||||
table[prev_hash] = int(ip - base_ip - 1)
|
||||
} else {
|
||||
input_bytes = binary.LittleEndian.Uint64(input[ip-5:])
|
||||
prev_hash = hashBytesAtOffset(input_bytes, 0, shift, min_match)
|
||||
table[prev_hash] = int(ip - base_ip - 5)
|
||||
prev_hash = hashBytesAtOffset(input_bytes, 1, shift, min_match)
|
||||
table[prev_hash] = int(ip - base_ip - 4)
|
||||
prev_hash = hashBytesAtOffset(input_bytes, 2, shift, min_match)
|
||||
table[prev_hash] = int(ip - base_ip - 3)
|
||||
input_bytes = binary.LittleEndian.Uint64(input[ip-2:])
|
||||
cur_hash = hashBytesAtOffset(input_bytes, 2, shift, min_match)
|
||||
prev_hash = hashBytesAtOffset(input_bytes, 0, shift, min_match)
|
||||
table[prev_hash] = int(ip - base_ip - 2)
|
||||
prev_hash = hashBytesAtOffset(input_bytes, 1, shift, min_match)
|
||||
table[prev_hash] = int(ip - base_ip - 1)
|
||||
}
|
||||
|
||||
candidate = base_ip + table[cur_hash]
|
||||
table[cur_hash] = int(ip - base_ip)
|
||||
}
|
||||
}
|
||||
|
||||
for ip-candidate <= maxDistance_compress_fragment && isMatch1(input[ip:], base_ip_ptr[candidate-base_ip:], min_match) {
|
||||
var base int = ip
|
||||
/* We have a 6-byte match at ip, and no need to emit any
|
||||
literal bytes prior to ip. */
|
||||
|
||||
var matched uint = min_match + findMatchLengthWithLimit(base_ip_ptr[uint(candidate-base_ip)+min_match:], input[uint(ip)+min_match:], uint(ip_end-ip)-min_match)
|
||||
ip += int(matched)
|
||||
last_distance = int(base - candidate) /* > 0 */
|
||||
emitCopyLen(matched, commands)
|
||||
emitDistance(uint32(last_distance), commands)
|
||||
|
||||
next_emit = ip
|
||||
if ip >= ip_limit {
|
||||
goto emit_remainder
|
||||
}
|
||||
{
|
||||
var input_bytes uint64
|
||||
var cur_hash uint32
|
||||
/* We could immediately start working at ip now, but to improve
|
||||
compression we first update "table" with the hashes of some
|
||||
positions within the last copy. */
|
||||
|
||||
var prev_hash uint32
|
||||
if min_match == 4 {
|
||||
input_bytes = binary.LittleEndian.Uint64(input[ip-3:])
|
||||
cur_hash = hashBytesAtOffset(input_bytes, 3, shift, min_match)
|
||||
prev_hash = hashBytesAtOffset(input_bytes, 0, shift, min_match)
|
||||
table[prev_hash] = int(ip - base_ip - 3)
|
||||
prev_hash = hashBytesAtOffset(input_bytes, 1, shift, min_match)
|
||||
table[prev_hash] = int(ip - base_ip - 2)
|
||||
prev_hash = hashBytesAtOffset(input_bytes, 2, shift, min_match)
|
||||
table[prev_hash] = int(ip - base_ip - 1)
|
||||
} else {
|
||||
input_bytes = binary.LittleEndian.Uint64(input[ip-5:])
|
||||
prev_hash = hashBytesAtOffset(input_bytes, 0, shift, min_match)
|
||||
table[prev_hash] = int(ip - base_ip - 5)
|
||||
prev_hash = hashBytesAtOffset(input_bytes, 1, shift, min_match)
|
||||
table[prev_hash] = int(ip - base_ip - 4)
|
||||
prev_hash = hashBytesAtOffset(input_bytes, 2, shift, min_match)
|
||||
table[prev_hash] = int(ip - base_ip - 3)
|
||||
input_bytes = binary.LittleEndian.Uint64(input[ip-2:])
|
||||
cur_hash = hashBytesAtOffset(input_bytes, 2, shift, min_match)
|
||||
prev_hash = hashBytesAtOffset(input_bytes, 0, shift, min_match)
|
||||
table[prev_hash] = int(ip - base_ip - 2)
|
||||
prev_hash = hashBytesAtOffset(input_bytes, 1, shift, min_match)
|
||||
table[prev_hash] = int(ip - base_ip - 1)
|
||||
}
|
||||
|
||||
candidate = base_ip + table[cur_hash]
|
||||
table[cur_hash] = int(ip - base_ip)
|
||||
}
|
||||
}
|
||||
|
||||
ip++
|
||||
next_hash = hash1(input[ip:], shift, min_match)
|
||||
}
|
||||
}
|
||||
|
||||
emit_remainder:
|
||||
assert(next_emit <= ip_end)
|
||||
|
||||
/* Emit the remaining bytes as literals. */
|
||||
if next_emit < ip_end {
|
||||
var insert uint32 = uint32(ip_end - next_emit)
|
||||
emitInsertLen(insert, commands)
|
||||
copy(*literals, input[next_emit:][:insert])
|
||||
*literals = (*literals)[insert:]
|
||||
}
|
||||
}
|
||||
|
||||
var storeCommands_kNumExtraBits = [128]uint32{
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
3,
|
||||
3,
|
||||
4,
|
||||
4,
|
||||
5,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
12,
|
||||
14,
|
||||
24,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
3,
|
||||
3,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
3,
|
||||
3,
|
||||
4,
|
||||
4,
|
||||
5,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
24,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
3,
|
||||
3,
|
||||
4,
|
||||
4,
|
||||
5,
|
||||
5,
|
||||
6,
|
||||
6,
|
||||
7,
|
||||
7,
|
||||
8,
|
||||
8,
|
||||
9,
|
||||
9,
|
||||
10,
|
||||
10,
|
||||
11,
|
||||
11,
|
||||
12,
|
||||
12,
|
||||
13,
|
||||
13,
|
||||
14,
|
||||
14,
|
||||
15,
|
||||
15,
|
||||
16,
|
||||
16,
|
||||
17,
|
||||
17,
|
||||
18,
|
||||
18,
|
||||
19,
|
||||
19,
|
||||
20,
|
||||
20,
|
||||
21,
|
||||
21,
|
||||
22,
|
||||
22,
|
||||
23,
|
||||
23,
|
||||
24,
|
||||
24,
|
||||
}
|
||||
var storeCommands_kInsertOffset = [24]uint32{
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
8,
|
||||
10,
|
||||
14,
|
||||
18,
|
||||
26,
|
||||
34,
|
||||
50,
|
||||
66,
|
||||
98,
|
||||
130,
|
||||
194,
|
||||
322,
|
||||
578,
|
||||
1090,
|
||||
2114,
|
||||
6210,
|
||||
22594,
|
||||
}
|
||||
|
||||
func storeCommands(literals []byte, num_literals uint, commands []uint32, num_commands uint, storage_ix *uint, storage []byte) {
|
||||
var lit_depths [256]byte
|
||||
var lit_bits [256]uint16
|
||||
var lit_histo = [256]uint32{0}
|
||||
var cmd_depths = [128]byte{0}
|
||||
var cmd_bits = [128]uint16{0}
|
||||
var cmd_histo = [128]uint32{0}
|
||||
var i uint
|
||||
for i = 0; i < num_literals; i++ {
|
||||
lit_histo[literals[i]]++
|
||||
}
|
||||
|
||||
buildAndStoreHuffmanTreeFast(lit_histo[:], num_literals, /* max_bits = */
|
||||
8, lit_depths[:], lit_bits[:], storage_ix, storage)
|
||||
|
||||
for i = 0; i < num_commands; i++ {
|
||||
var code uint32 = commands[i] & 0xFF
|
||||
assert(code < 128)
|
||||
cmd_histo[code]++
|
||||
}
|
||||
|
||||
cmd_histo[1] += 1
|
||||
cmd_histo[2] += 1
|
||||
cmd_histo[64] += 1
|
||||
cmd_histo[84] += 1
|
||||
buildAndStoreCommandPrefixCode(cmd_histo[:], cmd_depths[:], cmd_bits[:], storage_ix, storage)
|
||||
|
||||
for i = 0; i < num_commands; i++ {
|
||||
var cmd uint32 = commands[i]
|
||||
var code uint32 = cmd & 0xFF
|
||||
var extra uint32 = cmd >> 8
|
||||
assert(code < 128)
|
||||
writeBits(uint(cmd_depths[code]), uint64(cmd_bits[code]), storage_ix, storage)
|
||||
writeBits(uint(storeCommands_kNumExtraBits[code]), uint64(extra), storage_ix, storage)
|
||||
if code < 24 {
|
||||
var insert uint32 = storeCommands_kInsertOffset[code] + extra
|
||||
var j uint32
|
||||
for j = 0; j < insert; j++ {
|
||||
var lit byte = literals[0]
|
||||
writeBits(uint(lit_depths[lit]), uint64(lit_bits[lit]), storage_ix, storage)
|
||||
literals = literals[1:]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Acceptable loss for uncompressible speedup is 2% */
|
||||
const minRatio = 0.98
|
||||
|
||||
const sampleRate = 43
|
||||
|
||||
func shouldCompress(input []byte, input_size uint, num_literals uint) bool {
|
||||
var corpus_size float64 = float64(input_size)
|
||||
if float64(num_literals) < minRatio*corpus_size {
|
||||
return true
|
||||
} else {
|
||||
var literal_histo = [256]uint32{0}
|
||||
var max_total_bit_cost float64 = corpus_size * 8 * minRatio / sampleRate
|
||||
var i uint
|
||||
for i = 0; i < input_size; i += sampleRate {
|
||||
literal_histo[input[i]]++
|
||||
}
|
||||
|
||||
return bitsEntropy(literal_histo[:], 256) < max_total_bit_cost
|
||||
}
|
||||
}
|
||||
|
||||
func rewindBitPosition(new_storage_ix uint, storage_ix *uint, storage []byte) {
|
||||
var bitpos uint = new_storage_ix & 7
|
||||
var mask uint = (1 << bitpos) - 1
|
||||
storage[new_storage_ix>>3] &= byte(mask)
|
||||
*storage_ix = new_storage_ix
|
||||
}
|
||||
|
||||
func emitUncompressedMetaBlock(input []byte, input_size uint, storage_ix *uint, storage []byte) {
|
||||
storeMetaBlockHeader(input_size, true, storage_ix, storage)
|
||||
*storage_ix = (*storage_ix + 7) &^ 7
|
||||
copy(storage[*storage_ix>>3:], input[:input_size])
|
||||
*storage_ix += input_size << 3
|
||||
storage[*storage_ix>>3] = 0
|
||||
}
|
||||
|
||||
func compressFragmentTwoPassImpl(input []byte, input_size uint, is_last bool, command_buf []uint32, literal_buf []byte, table []int, table_bits uint, min_match uint, storage_ix *uint, storage []byte) {
|
||||
/* Save the start of the first block for position and distance computations.
|
||||
*/
|
||||
var base_ip []byte = input
|
||||
|
||||
for input_size > 0 {
|
||||
var block_size uint = brotli_min_size_t(input_size, kCompressFragmentTwoPassBlockSize)
|
||||
var commands []uint32 = command_buf
|
||||
var literals []byte = literal_buf
|
||||
var num_literals uint
|
||||
createCommands(input, block_size, input_size, base_ip, table, table_bits, min_match, &literals, &commands)
|
||||
num_literals = uint(-cap(literals) + cap(literal_buf))
|
||||
if shouldCompress(input, block_size, num_literals) {
|
||||
var num_commands uint = uint(-cap(commands) + cap(command_buf))
|
||||
storeMetaBlockHeader(block_size, false, storage_ix, storage)
|
||||
|
||||
/* No block splits, no contexts. */
|
||||
writeBits(13, 0, storage_ix, storage)
|
||||
|
||||
storeCommands(literal_buf, num_literals, command_buf, num_commands, storage_ix, storage)
|
||||
} else {
|
||||
/* Since we did not find many backward references and the entropy of
|
||||
the data is close to 8 bits, we can simply emit an uncompressed block.
|
||||
This makes compression speed of uncompressible data about 3x faster. */
|
||||
emitUncompressedMetaBlock(input, block_size, storage_ix, storage)
|
||||
}
|
||||
|
||||
input = input[block_size:]
|
||||
input_size -= block_size
|
||||
}
|
||||
}
|
||||
|
||||
/* Compresses "input" string to the "*storage" buffer as one or more complete
|
||||
meta-blocks, and updates the "*storage_ix" bit position.
|
||||
|
||||
If "is_last" is 1, emits an additional empty last meta-block.
|
||||
|
||||
REQUIRES: "input_size" is greater than zero, or "is_last" is 1.
|
||||
REQUIRES: "input_size" is less or equal to maximal metablock size (1 << 24).
|
||||
REQUIRES: "command_buf" and "literal_buf" point to at least
|
||||
kCompressFragmentTwoPassBlockSize long arrays.
|
||||
REQUIRES: All elements in "table[0..table_size-1]" are initialized to zero.
|
||||
REQUIRES: "table_size" is a power of two
|
||||
OUTPUT: maximal copy distance <= |input_size|
|
||||
OUTPUT: maximal copy distance <= BROTLI_MAX_BACKWARD_LIMIT(18) */
|
||||
func compressFragmentTwoPass(input []byte, input_size uint, is_last bool, command_buf []uint32, literal_buf []byte, table []int, table_size uint, storage_ix *uint, storage []byte) {
|
||||
var initial_storage_ix uint = *storage_ix
|
||||
var table_bits uint = uint(log2FloorNonZero(table_size))
|
||||
var min_match uint
|
||||
if table_bits <= 15 {
|
||||
min_match = 4
|
||||
} else {
|
||||
min_match = 6
|
||||
}
|
||||
compressFragmentTwoPassImpl(input, input_size, is_last, command_buf, literal_buf, table, table_bits, min_match, storage_ix, storage)
|
||||
|
||||
/* If output is larger than single uncompressed block, rewrite it. */
|
||||
if *storage_ix-initial_storage_ix > 31+(input_size<<3) {
|
||||
rewindBitPosition(initial_storage_ix, storage_ix, storage)
|
||||
emitUncompressedMetaBlock(input, input_size, storage_ix, storage)
|
||||
}
|
||||
|
||||
if is_last {
|
||||
writeBits(1, 1, storage_ix, storage) /* islast */
|
||||
writeBits(1, 1, storage_ix, storage) /* isempty */
|
||||
*storage_ix = (*storage_ix + 7) &^ 7
|
||||
}
|
||||
}
|
||||
+77
@@ -0,0 +1,77 @@
|
||||
package brotli
|
||||
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Specification: 7.3. Encoding of the context map */
|
||||
const contextMapMaxRle = 16
|
||||
|
||||
/* Specification: 2. Compressed representation overview */
|
||||
const maxNumberOfBlockTypes = 256
|
||||
|
||||
/* Specification: 3.3. Alphabet sizes: insert-and-copy length */
|
||||
const numLiteralSymbols = 256
|
||||
|
||||
const numCommandSymbols = 704
|
||||
|
||||
const numBlockLenSymbols = 26
|
||||
|
||||
const maxContextMapSymbols = (maxNumberOfBlockTypes + contextMapMaxRle)
|
||||
|
||||
const maxBlockTypeSymbols = (maxNumberOfBlockTypes + 2)
|
||||
|
||||
/* Specification: 3.5. Complex prefix codes */
|
||||
const repeatPreviousCodeLength = 16
|
||||
|
||||
const repeatZeroCodeLength = 17
|
||||
|
||||
const codeLengthCodes = (repeatZeroCodeLength + 1)
|
||||
|
||||
/* "code length of 8 is repeated" */
|
||||
const initialRepeatedCodeLength = 8
|
||||
|
||||
/* "Large Window Brotli" */
|
||||
const largeMaxDistanceBits = 62
|
||||
|
||||
const largeMinWbits = 10
|
||||
|
||||
const largeMaxWbits = 30
|
||||
|
||||
/* Specification: 4. Encoding of distances */
|
||||
const numDistanceShortCodes = 16
|
||||
|
||||
const maxNpostfix = 3
|
||||
|
||||
const maxNdirect = 120
|
||||
|
||||
const maxDistanceBits = 24
|
||||
|
||||
func distanceAlphabetSize(NPOSTFIX uint, NDIRECT uint, MAXNBITS uint) uint {
|
||||
return numDistanceShortCodes + NDIRECT + uint(MAXNBITS<<(NPOSTFIX+1))
|
||||
}
|
||||
|
||||
/* numDistanceSymbols == 1128 */
|
||||
const numDistanceSymbols = 1128
|
||||
|
||||
const maxDistance = 0x3FFFFFC
|
||||
|
||||
const maxAllowedDistance = 0x7FFFFFFC
|
||||
|
||||
/* 7.1. Context modes and context ID lookup for literals */
|
||||
/* "context IDs for literals are in the range of 0..63" */
|
||||
const literalContextBits = 6
|
||||
|
||||
/* 7.2. Context ID for distances */
|
||||
const distanceContextBits = 2
|
||||
|
||||
/* 9.1. Format of the Stream Header */
|
||||
/* Number of slack bytes for window size. Don't confuse
|
||||
with BROTLI_NUM_DISTANCE_SHORT_CODES. */
|
||||
const windowGap = 16
|
||||
|
||||
func maxBackwardLimit(W uint) uint {
|
||||
return (uint(1) << W) - windowGap
|
||||
}
|
||||
+2176
File diff suppressed because it is too large
Load Diff
+2581
File diff suppressed because it is too large
Load Diff
+122890
File diff suppressed because it is too large
Load Diff
+32779
File diff suppressed because it is too large
Load Diff
+1220
File diff suppressed because it is too large
Load Diff
+22
@@ -0,0 +1,22 @@
|
||||
package brotli
|
||||
|
||||
/* Dictionary data (words and transforms) for 1 possible context */
|
||||
type encoderDictionary struct {
|
||||
words *dictionary
|
||||
cutoffTransformsCount uint32
|
||||
cutoffTransforms uint64
|
||||
hash_table []uint16
|
||||
buckets []uint16
|
||||
dict_words []dictWord
|
||||
}
|
||||
|
||||
func initEncoderDictionary(dict *encoderDictionary) {
|
||||
dict.words = getDictionary()
|
||||
|
||||
dict.hash_table = kStaticDictionaryHash[:]
|
||||
dict.buckets = kStaticDictionaryBuckets[:]
|
||||
dict.dict_words = kStaticDictionaryWords[:]
|
||||
|
||||
dict.cutoffTransformsCount = kCutoffTransformsCount
|
||||
dict.cutoffTransforms = kCutoffTransforms
|
||||
}
|
||||
+592
@@ -0,0 +1,592 @@
|
||||
package brotli
|
||||
|
||||
import "math"
|
||||
|
||||
/* Copyright 2010 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Entropy encoding (Huffman) utilities. */
|
||||
|
||||
/* A node of a Huffman tree. */
|
||||
type huffmanTree struct {
|
||||
total_count_ uint32
|
||||
index_left_ int16
|
||||
index_right_or_value_ int16
|
||||
}
|
||||
|
||||
func initHuffmanTree(self *huffmanTree, count uint32, left int16, right int16) {
|
||||
self.total_count_ = count
|
||||
self.index_left_ = left
|
||||
self.index_right_or_value_ = right
|
||||
}
|
||||
|
||||
/* Input size optimized Shell sort. */
|
||||
type huffmanTreeComparator func(huffmanTree, huffmanTree) bool
|
||||
|
||||
var sortHuffmanTreeItems_gaps = []uint{132, 57, 23, 10, 4, 1}
|
||||
|
||||
func sortHuffmanTreeItems(items []huffmanTree, n uint, comparator huffmanTreeComparator) {
|
||||
if n < 13 {
|
||||
/* Insertion sort. */
|
||||
var i uint
|
||||
for i = 1; i < n; i++ {
|
||||
var tmp huffmanTree = items[i]
|
||||
var k uint = i
|
||||
var j uint = i - 1
|
||||
for comparator(tmp, items[j]) {
|
||||
items[k] = items[j]
|
||||
k = j
|
||||
if j == 0 {
|
||||
break
|
||||
}
|
||||
j--
|
||||
}
|
||||
|
||||
items[k] = tmp
|
||||
}
|
||||
|
||||
return
|
||||
} else {
|
||||
var g int
|
||||
if n < 57 {
|
||||
g = 2
|
||||
} else {
|
||||
g = 0
|
||||
}
|
||||
for ; g < 6; g++ {
|
||||
var gap uint = sortHuffmanTreeItems_gaps[g]
|
||||
var i uint
|
||||
for i = gap; i < n; i++ {
|
||||
var j uint = i
|
||||
var tmp huffmanTree = items[i]
|
||||
for ; j >= gap && comparator(tmp, items[j-gap]); j -= gap {
|
||||
items[j] = items[j-gap]
|
||||
}
|
||||
|
||||
items[j] = tmp
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Returns 1 if assignment of depths succeeded, otherwise 0. */
|
||||
func setDepth(p0 int, pool []huffmanTree, depth []byte, max_depth int) bool {
|
||||
var stack [16]int
|
||||
var level int = 0
|
||||
var p int = p0
|
||||
assert(max_depth <= 15)
|
||||
stack[0] = -1
|
||||
for {
|
||||
if pool[p].index_left_ >= 0 {
|
||||
level++
|
||||
if level > max_depth {
|
||||
return false
|
||||
}
|
||||
stack[level] = int(pool[p].index_right_or_value_)
|
||||
p = int(pool[p].index_left_)
|
||||
continue
|
||||
} else {
|
||||
depth[pool[p].index_right_or_value_] = byte(level)
|
||||
}
|
||||
|
||||
for level >= 0 && stack[level] == -1 {
|
||||
level--
|
||||
}
|
||||
if level < 0 {
|
||||
return true
|
||||
}
|
||||
p = stack[level]
|
||||
stack[level] = -1
|
||||
}
|
||||
}
|
||||
|
||||
/* Sort the root nodes, least popular first. */
|
||||
func sortHuffmanTree(v0 huffmanTree, v1 huffmanTree) bool {
|
||||
if v0.total_count_ != v1.total_count_ {
|
||||
return v0.total_count_ < v1.total_count_
|
||||
}
|
||||
|
||||
return v0.index_right_or_value_ > v1.index_right_or_value_
|
||||
}
|
||||
|
||||
/* This function will create a Huffman tree.
|
||||
|
||||
The catch here is that the tree cannot be arbitrarily deep.
|
||||
Brotli specifies a maximum depth of 15 bits for "code trees"
|
||||
and 7 bits for "code length code trees."
|
||||
|
||||
count_limit is the value that is to be faked as the minimum value
|
||||
and this minimum value is raised until the tree matches the
|
||||
maximum length requirement.
|
||||
|
||||
This algorithm is not of excellent performance for very long data blocks,
|
||||
especially when population counts are longer than 2**tree_limit, but
|
||||
we are not planning to use this with extremely long blocks.
|
||||
|
||||
See http://en.wikipedia.org/wiki/Huffman_coding */
|
||||
func createHuffmanTree(data []uint32, length uint, tree_limit int, tree []huffmanTree, depth []byte) {
|
||||
var count_limit uint32
|
||||
var sentinel huffmanTree
|
||||
initHuffmanTree(&sentinel, math.MaxUint32, -1, -1)
|
||||
|
||||
/* For block sizes below 64 kB, we never need to do a second iteration
|
||||
of this loop. Probably all of our block sizes will be smaller than
|
||||
that, so this loop is mostly of academic interest. If we actually
|
||||
would need this, we would be better off with the Katajainen algorithm. */
|
||||
for count_limit = 1; ; count_limit *= 2 {
|
||||
var n uint = 0
|
||||
var i uint
|
||||
var j uint
|
||||
var k uint
|
||||
for i = length; i != 0; {
|
||||
i--
|
||||
if data[i] != 0 {
|
||||
var count uint32 = brotli_max_uint32_t(data[i], count_limit)
|
||||
initHuffmanTree(&tree[n], count, -1, int16(i))
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
if n == 1 {
|
||||
depth[tree[0].index_right_or_value_] = 1 /* Only one element. */
|
||||
break
|
||||
}
|
||||
|
||||
sortHuffmanTreeItems(tree, n, huffmanTreeComparator(sortHuffmanTree))
|
||||
|
||||
/* The nodes are:
|
||||
[0, n): the sorted leaf nodes that we start with.
|
||||
[n]: we add a sentinel here.
|
||||
[n + 1, 2n): new parent nodes are added here, starting from
|
||||
(n+1). These are naturally in ascending order.
|
||||
[2n]: we add a sentinel at the end as well.
|
||||
There will be (2n+1) elements at the end. */
|
||||
tree[n] = sentinel
|
||||
|
||||
tree[n+1] = sentinel
|
||||
|
||||
i = 0 /* Points to the next leaf node. */
|
||||
j = n + 1 /* Points to the next non-leaf node. */
|
||||
for k = n - 1; k != 0; k-- {
|
||||
var left uint
|
||||
var right uint
|
||||
if tree[i].total_count_ <= tree[j].total_count_ {
|
||||
left = i
|
||||
i++
|
||||
} else {
|
||||
left = j
|
||||
j++
|
||||
}
|
||||
|
||||
if tree[i].total_count_ <= tree[j].total_count_ {
|
||||
right = i
|
||||
i++
|
||||
} else {
|
||||
right = j
|
||||
j++
|
||||
}
|
||||
{
|
||||
/* The sentinel node becomes the parent node. */
|
||||
var j_end uint = 2*n - k
|
||||
tree[j_end].total_count_ = tree[left].total_count_ + tree[right].total_count_
|
||||
tree[j_end].index_left_ = int16(left)
|
||||
tree[j_end].index_right_or_value_ = int16(right)
|
||||
|
||||
/* Add back the last sentinel node. */
|
||||
tree[j_end+1] = sentinel
|
||||
}
|
||||
}
|
||||
|
||||
if setDepth(int(2*n-1), tree[0:], depth, tree_limit) {
|
||||
/* We need to pack the Huffman tree in tree_limit bits. If this was not
|
||||
successful, add fake entities to the lowest values and retry. */
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func reverse(v []byte, start uint, end uint) {
|
||||
end--
|
||||
for start < end {
|
||||
var tmp byte = v[start]
|
||||
v[start] = v[end]
|
||||
v[end] = tmp
|
||||
start++
|
||||
end--
|
||||
}
|
||||
}
|
||||
|
||||
func writeHuffmanTreeRepetitions(previous_value byte, value byte, repetitions uint, tree_size *uint, tree []byte, extra_bits_data []byte) {
|
||||
assert(repetitions > 0)
|
||||
if previous_value != value {
|
||||
tree[*tree_size] = value
|
||||
extra_bits_data[*tree_size] = 0
|
||||
(*tree_size)++
|
||||
repetitions--
|
||||
}
|
||||
|
||||
if repetitions == 7 {
|
||||
tree[*tree_size] = value
|
||||
extra_bits_data[*tree_size] = 0
|
||||
(*tree_size)++
|
||||
repetitions--
|
||||
}
|
||||
|
||||
if repetitions < 3 {
|
||||
var i uint
|
||||
for i = 0; i < repetitions; i++ {
|
||||
tree[*tree_size] = value
|
||||
extra_bits_data[*tree_size] = 0
|
||||
(*tree_size)++
|
||||
}
|
||||
} else {
|
||||
var start uint = *tree_size
|
||||
repetitions -= 3
|
||||
for {
|
||||
tree[*tree_size] = repeatPreviousCodeLength
|
||||
extra_bits_data[*tree_size] = byte(repetitions & 0x3)
|
||||
(*tree_size)++
|
||||
repetitions >>= 2
|
||||
if repetitions == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
repetitions--
|
||||
}
|
||||
|
||||
reverse(tree, start, *tree_size)
|
||||
reverse(extra_bits_data, start, *tree_size)
|
||||
}
|
||||
}
|
||||
|
||||
func writeHuffmanTreeRepetitionsZeros(repetitions uint, tree_size *uint, tree []byte, extra_bits_data []byte) {
|
||||
if repetitions == 11 {
|
||||
tree[*tree_size] = 0
|
||||
extra_bits_data[*tree_size] = 0
|
||||
(*tree_size)++
|
||||
repetitions--
|
||||
}
|
||||
|
||||
if repetitions < 3 {
|
||||
var i uint
|
||||
for i = 0; i < repetitions; i++ {
|
||||
tree[*tree_size] = 0
|
||||
extra_bits_data[*tree_size] = 0
|
||||
(*tree_size)++
|
||||
}
|
||||
} else {
|
||||
var start uint = *tree_size
|
||||
repetitions -= 3
|
||||
for {
|
||||
tree[*tree_size] = repeatZeroCodeLength
|
||||
extra_bits_data[*tree_size] = byte(repetitions & 0x7)
|
||||
(*tree_size)++
|
||||
repetitions >>= 3
|
||||
if repetitions == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
repetitions--
|
||||
}
|
||||
|
||||
reverse(tree, start, *tree_size)
|
||||
reverse(extra_bits_data, start, *tree_size)
|
||||
}
|
||||
}
|
||||
|
||||
/* Change the population counts in a way that the consequent
|
||||
Huffman tree compression, especially its RLE-part will be more
|
||||
likely to compress this data more efficiently.
|
||||
|
||||
length contains the size of the histogram.
|
||||
counts contains the population counts.
|
||||
good_for_rle is a buffer of at least length size */
|
||||
func optimizeHuffmanCountsForRLE(length uint, counts []uint32, good_for_rle []byte) {
|
||||
var nonzero_count uint = 0
|
||||
var stride uint
|
||||
var limit uint
|
||||
var sum uint
|
||||
var streak_limit uint = 1240
|
||||
var i uint
|
||||
/* Let's make the Huffman code more compatible with RLE encoding. */
|
||||
for i = 0; i < length; i++ {
|
||||
if counts[i] != 0 {
|
||||
nonzero_count++
|
||||
}
|
||||
}
|
||||
|
||||
if nonzero_count < 16 {
|
||||
return
|
||||
}
|
||||
|
||||
for length != 0 && counts[length-1] == 0 {
|
||||
length--
|
||||
}
|
||||
|
||||
if length == 0 {
|
||||
return /* All zeros. */
|
||||
}
|
||||
|
||||
/* Now counts[0..length - 1] does not have trailing zeros. */
|
||||
{
|
||||
var nonzeros uint = 0
|
||||
var smallest_nonzero uint32 = 1 << 30
|
||||
for i = 0; i < length; i++ {
|
||||
if counts[i] != 0 {
|
||||
nonzeros++
|
||||
if smallest_nonzero > counts[i] {
|
||||
smallest_nonzero = counts[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if nonzeros < 5 {
|
||||
/* Small histogram will model it well. */
|
||||
return
|
||||
}
|
||||
|
||||
if smallest_nonzero < 4 {
|
||||
var zeros uint = length - nonzeros
|
||||
if zeros < 6 {
|
||||
for i = 1; i < length-1; i++ {
|
||||
if counts[i-1] != 0 && counts[i] == 0 && counts[i+1] != 0 {
|
||||
counts[i] = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if nonzeros < 28 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
/* 2) Let's mark all population counts that already can be encoded
|
||||
with an RLE code. */
|
||||
for i := 0; i < int(length); i++ {
|
||||
good_for_rle[i] = 0
|
||||
}
|
||||
{
|
||||
var symbol uint32 = counts[0]
|
||||
/* Let's not spoil any of the existing good RLE codes.
|
||||
Mark any seq of 0's that is longer as 5 as a good_for_rle.
|
||||
Mark any seq of non-0's that is longer as 7 as a good_for_rle. */
|
||||
|
||||
var step uint = 0
|
||||
for i = 0; i <= length; i++ {
|
||||
if i == length || counts[i] != symbol {
|
||||
if (symbol == 0 && step >= 5) || (symbol != 0 && step >= 7) {
|
||||
var k uint
|
||||
for k = 0; k < step; k++ {
|
||||
good_for_rle[i-k-1] = 1
|
||||
}
|
||||
}
|
||||
|
||||
step = 1
|
||||
if i != length {
|
||||
symbol = counts[i]
|
||||
}
|
||||
} else {
|
||||
step++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* 3) Let's replace those population counts that lead to more RLE codes.
|
||||
Math here is in 24.8 fixed point representation. */
|
||||
stride = 0
|
||||
|
||||
limit = uint(256*(counts[0]+counts[1]+counts[2])/3 + 420)
|
||||
sum = 0
|
||||
for i = 0; i <= length; i++ {
|
||||
if i == length || good_for_rle[i] != 0 || (i != 0 && good_for_rle[i-1] != 0) || (256*counts[i]-uint32(limit)+uint32(streak_limit)) >= uint32(2*streak_limit) {
|
||||
if stride >= 4 || (stride >= 3 && sum == 0) {
|
||||
var k uint
|
||||
var count uint = (sum + stride/2) / stride
|
||||
/* The stride must end, collapse what we have, if we have enough (4). */
|
||||
if count == 0 {
|
||||
count = 1
|
||||
}
|
||||
|
||||
if sum == 0 {
|
||||
/* Don't make an all zeros stride to be upgraded to ones. */
|
||||
count = 0
|
||||
}
|
||||
|
||||
for k = 0; k < stride; k++ {
|
||||
/* We don't want to change value at counts[i],
|
||||
that is already belonging to the next stride. Thus - 1. */
|
||||
counts[i-k-1] = uint32(count)
|
||||
}
|
||||
}
|
||||
|
||||
stride = 0
|
||||
sum = 0
|
||||
if i < length-2 {
|
||||
/* All interesting strides have a count of at least 4, */
|
||||
/* at least when non-zeros. */
|
||||
limit = uint(256*(counts[i]+counts[i+1]+counts[i+2])/3 + 420)
|
||||
} else if i < length {
|
||||
limit = uint(256 * counts[i])
|
||||
} else {
|
||||
limit = 0
|
||||
}
|
||||
}
|
||||
|
||||
stride++
|
||||
if i != length {
|
||||
sum += uint(counts[i])
|
||||
if stride >= 4 {
|
||||
limit = (256*sum + stride/2) / stride
|
||||
}
|
||||
|
||||
if stride == 4 {
|
||||
limit += 120
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decideOverRLEUse(depth []byte, length uint, use_rle_for_non_zero *bool, use_rle_for_zero *bool) {
|
||||
var total_reps_zero uint = 0
|
||||
var total_reps_non_zero uint = 0
|
||||
var count_reps_zero uint = 1
|
||||
var count_reps_non_zero uint = 1
|
||||
var i uint
|
||||
for i = 0; i < length; {
|
||||
var value byte = depth[i]
|
||||
var reps uint = 1
|
||||
var k uint
|
||||
for k = i + 1; k < length && depth[k] == value; k++ {
|
||||
reps++
|
||||
}
|
||||
|
||||
if reps >= 3 && value == 0 {
|
||||
total_reps_zero += reps
|
||||
count_reps_zero++
|
||||
}
|
||||
|
||||
if reps >= 4 && value != 0 {
|
||||
total_reps_non_zero += reps
|
||||
count_reps_non_zero++
|
||||
}
|
||||
|
||||
i += reps
|
||||
}
|
||||
|
||||
*use_rle_for_non_zero = total_reps_non_zero > count_reps_non_zero*2
|
||||
*use_rle_for_zero = total_reps_zero > count_reps_zero*2
|
||||
}
|
||||
|
||||
/* Write a Huffman tree from bit depths into the bit-stream representation
|
||||
of a Huffman tree. The generated Huffman tree is to be compressed once
|
||||
more using a Huffman tree */
|
||||
func writeHuffmanTree(depth []byte, length uint, tree_size *uint, tree []byte, extra_bits_data []byte) {
|
||||
var previous_value byte = initialRepeatedCodeLength
|
||||
var i uint
|
||||
var use_rle_for_non_zero bool = false
|
||||
var use_rle_for_zero bool = false
|
||||
var new_length uint = length
|
||||
/* Throw away trailing zeros. */
|
||||
for i = 0; i < length; i++ {
|
||||
if depth[length-i-1] == 0 {
|
||||
new_length--
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
/* First gather statistics on if it is a good idea to do RLE. */
|
||||
if length > 50 {
|
||||
/* Find RLE coding for longer codes.
|
||||
Shorter codes seem not to benefit from RLE. */
|
||||
decideOverRLEUse(depth, new_length, &use_rle_for_non_zero, &use_rle_for_zero)
|
||||
}
|
||||
|
||||
/* Actual RLE coding. */
|
||||
for i = 0; i < new_length; {
|
||||
var value byte = depth[i]
|
||||
var reps uint = 1
|
||||
if (value != 0 && use_rle_for_non_zero) || (value == 0 && use_rle_for_zero) {
|
||||
var k uint
|
||||
for k = i + 1; k < new_length && depth[k] == value; k++ {
|
||||
reps++
|
||||
}
|
||||
}
|
||||
|
||||
if value == 0 {
|
||||
writeHuffmanTreeRepetitionsZeros(reps, tree_size, tree, extra_bits_data)
|
||||
} else {
|
||||
writeHuffmanTreeRepetitions(previous_value, value, reps, tree_size, tree, extra_bits_data)
|
||||
previous_value = value
|
||||
}
|
||||
|
||||
i += reps
|
||||
}
|
||||
}
|
||||
|
||||
var reverseBits_kLut = [16]uint{
|
||||
0x00,
|
||||
0x08,
|
||||
0x04,
|
||||
0x0C,
|
||||
0x02,
|
||||
0x0A,
|
||||
0x06,
|
||||
0x0E,
|
||||
0x01,
|
||||
0x09,
|
||||
0x05,
|
||||
0x0D,
|
||||
0x03,
|
||||
0x0B,
|
||||
0x07,
|
||||
0x0F,
|
||||
}
|
||||
|
||||
func reverseBits(num_bits uint, bits uint16) uint16 {
|
||||
var retval uint = reverseBits_kLut[bits&0x0F]
|
||||
var i uint
|
||||
for i = 4; i < num_bits; i += 4 {
|
||||
retval <<= 4
|
||||
bits = uint16(bits >> 4)
|
||||
retval |= reverseBits_kLut[bits&0x0F]
|
||||
}
|
||||
|
||||
retval >>= ((0 - num_bits) & 0x03)
|
||||
return uint16(retval)
|
||||
}
|
||||
|
||||
/* 0..15 are values for bits */
|
||||
const maxHuffmanBits = 16
|
||||
|
||||
/* Get the actual bit values for a tree of bit depths. */
|
||||
func convertBitDepthsToSymbols(depth []byte, len uint, bits []uint16) {
|
||||
var bl_count = [maxHuffmanBits]uint16{0}
|
||||
var next_code [maxHuffmanBits]uint16
|
||||
var i uint
|
||||
/* In Brotli, all bit depths are [1..15]
|
||||
0 bit depth means that the symbol does not exist. */
|
||||
|
||||
var code int = 0
|
||||
for i = 0; i < len; i++ {
|
||||
bl_count[depth[i]]++
|
||||
}
|
||||
|
||||
bl_count[0] = 0
|
||||
next_code[0] = 0
|
||||
for i = 1; i < maxHuffmanBits; i++ {
|
||||
code = (code + int(bl_count[i-1])) << 1
|
||||
next_code[i] = uint16(code)
|
||||
}
|
||||
|
||||
for i = 0; i < len; i++ {
|
||||
if depth[i] != 0 {
|
||||
bits[i] = reverseBits(uint(depth[i]), next_code[depth[i]])
|
||||
next_code[depth[i]]++
|
||||
}
|
||||
}
|
||||
}
|
||||
+4394
File diff suppressed because it is too large
Load Diff
+290
@@ -0,0 +1,290 @@
|
||||
package brotli
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Utilities for fast computation of logarithms. */
|
||||
|
||||
func log2FloorNonZero(n uint) uint32 {
|
||||
return uint32(bits.Len(n)) - 1
|
||||
}
|
||||
|
||||
/* A lookup table for small values of log2(int) to be used in entropy
|
||||
computation.
|
||||
|
||||
", ".join(["%.16ff" % x for x in [0.0]+[log2(x) for x in range(1, 256)]]) */
|
||||
var kLog2Table = []float32{
|
||||
0.0000000000000000,
|
||||
0.0000000000000000,
|
||||
1.0000000000000000,
|
||||
1.5849625007211563,
|
||||
2.0000000000000000,
|
||||
2.3219280948873622,
|
||||
2.5849625007211561,
|
||||
2.8073549220576042,
|
||||
3.0000000000000000,
|
||||
3.1699250014423126,
|
||||
3.3219280948873626,
|
||||
3.4594316186372978,
|
||||
3.5849625007211565,
|
||||
3.7004397181410922,
|
||||
3.8073549220576037,
|
||||
3.9068905956085187,
|
||||
4.0000000000000000,
|
||||
4.0874628412503400,
|
||||
4.1699250014423122,
|
||||
4.2479275134435852,
|
||||
4.3219280948873626,
|
||||
4.3923174227787607,
|
||||
4.4594316186372973,
|
||||
4.5235619560570131,
|
||||
4.5849625007211570,
|
||||
4.6438561897747244,
|
||||
4.7004397181410926,
|
||||
4.7548875021634691,
|
||||
4.8073549220576037,
|
||||
4.8579809951275728,
|
||||
4.9068905956085187,
|
||||
4.9541963103868758,
|
||||
5.0000000000000000,
|
||||
5.0443941193584534,
|
||||
5.0874628412503400,
|
||||
5.1292830169449664,
|
||||
5.1699250014423122,
|
||||
5.2094533656289501,
|
||||
5.2479275134435852,
|
||||
5.2854022188622487,
|
||||
5.3219280948873626,
|
||||
5.3575520046180838,
|
||||
5.3923174227787607,
|
||||
5.4262647547020979,
|
||||
5.4594316186372973,
|
||||
5.4918530963296748,
|
||||
5.5235619560570131,
|
||||
5.5545888516776376,
|
||||
5.5849625007211570,
|
||||
5.6147098441152083,
|
||||
5.6438561897747244,
|
||||
5.6724253419714961,
|
||||
5.7004397181410926,
|
||||
5.7279204545631996,
|
||||
5.7548875021634691,
|
||||
5.7813597135246599,
|
||||
5.8073549220576046,
|
||||
5.8328900141647422,
|
||||
5.8579809951275719,
|
||||
5.8826430493618416,
|
||||
5.9068905956085187,
|
||||
5.9307373375628867,
|
||||
5.9541963103868758,
|
||||
5.9772799234999168,
|
||||
6.0000000000000000,
|
||||
6.0223678130284544,
|
||||
6.0443941193584534,
|
||||
6.0660891904577721,
|
||||
6.0874628412503400,
|
||||
6.1085244567781700,
|
||||
6.1292830169449672,
|
||||
6.1497471195046822,
|
||||
6.1699250014423122,
|
||||
6.1898245588800176,
|
||||
6.2094533656289510,
|
||||
6.2288186904958804,
|
||||
6.2479275134435861,
|
||||
6.2667865406949019,
|
||||
6.2854022188622487,
|
||||
6.3037807481771031,
|
||||
6.3219280948873617,
|
||||
6.3398500028846252,
|
||||
6.3575520046180847,
|
||||
6.3750394313469254,
|
||||
6.3923174227787598,
|
||||
6.4093909361377026,
|
||||
6.4262647547020979,
|
||||
6.4429434958487288,
|
||||
6.4594316186372982,
|
||||
6.4757334309663976,
|
||||
6.4918530963296748,
|
||||
6.5077946401986964,
|
||||
6.5235619560570131,
|
||||
6.5391588111080319,
|
||||
6.5545888516776376,
|
||||
6.5698556083309478,
|
||||
6.5849625007211561,
|
||||
6.5999128421871278,
|
||||
6.6147098441152092,
|
||||
6.6293566200796095,
|
||||
6.6438561897747253,
|
||||
6.6582114827517955,
|
||||
6.6724253419714952,
|
||||
6.6865005271832185,
|
||||
6.7004397181410917,
|
||||
6.7142455176661224,
|
||||
6.7279204545631988,
|
||||
6.7414669864011465,
|
||||
6.7548875021634691,
|
||||
6.7681843247769260,
|
||||
6.7813597135246599,
|
||||
6.7944158663501062,
|
||||
6.8073549220576037,
|
||||
6.8201789624151887,
|
||||
6.8328900141647422,
|
||||
6.8454900509443757,
|
||||
6.8579809951275719,
|
||||
6.8703647195834048,
|
||||
6.8826430493618416,
|
||||
6.8948177633079437,
|
||||
6.9068905956085187,
|
||||
6.9188632372745955,
|
||||
6.9307373375628867,
|
||||
6.9425145053392399,
|
||||
6.9541963103868758,
|
||||
6.9657842846620879,
|
||||
6.9772799234999168,
|
||||
6.9886846867721664,
|
||||
7.0000000000000000,
|
||||
7.0112272554232540,
|
||||
7.0223678130284544,
|
||||
7.0334230015374501,
|
||||
7.0443941193584534,
|
||||
7.0552824355011898,
|
||||
7.0660891904577721,
|
||||
7.0768155970508317,
|
||||
7.0874628412503400,
|
||||
7.0980320829605272,
|
||||
7.1085244567781700,
|
||||
7.1189410727235076,
|
||||
7.1292830169449664,
|
||||
7.1395513523987937,
|
||||
7.1497471195046822,
|
||||
7.1598713367783891,
|
||||
7.1699250014423130,
|
||||
7.1799090900149345,
|
||||
7.1898245588800176,
|
||||
7.1996723448363644,
|
||||
7.2094533656289492,
|
||||
7.2191685204621621,
|
||||
7.2288186904958804,
|
||||
7.2384047393250794,
|
||||
7.2479275134435861,
|
||||
7.2573878426926521,
|
||||
7.2667865406949019,
|
||||
7.2761244052742384,
|
||||
7.2854022188622487,
|
||||
7.2946207488916270,
|
||||
7.3037807481771031,
|
||||
7.3128829552843557,
|
||||
7.3219280948873617,
|
||||
7.3309168781146177,
|
||||
7.3398500028846243,
|
||||
7.3487281542310781,
|
||||
7.3575520046180847,
|
||||
7.3663222142458151,
|
||||
7.3750394313469254,
|
||||
7.3837042924740528,
|
||||
7.3923174227787607,
|
||||
7.4008794362821844,
|
||||
7.4093909361377026,
|
||||
7.4178525148858991,
|
||||
7.4262647547020979,
|
||||
7.4346282276367255,
|
||||
7.4429434958487288,
|
||||
7.4512111118323299,
|
||||
7.4594316186372973,
|
||||
7.4676055500829976,
|
||||
7.4757334309663976,
|
||||
7.4838157772642564,
|
||||
7.4918530963296748,
|
||||
7.4998458870832057,
|
||||
7.5077946401986964,
|
||||
7.5156998382840436,
|
||||
7.5235619560570131,
|
||||
7.5313814605163119,
|
||||
7.5391588111080319,
|
||||
7.5468944598876373,
|
||||
7.5545888516776376,
|
||||
7.5622424242210728,
|
||||
7.5698556083309478,
|
||||
7.5774288280357487,
|
||||
7.5849625007211561,
|
||||
7.5924570372680806,
|
||||
7.5999128421871278,
|
||||
7.6073303137496113,
|
||||
7.6147098441152075,
|
||||
7.6220518194563764,
|
||||
7.6293566200796095,
|
||||
7.6366246205436488,
|
||||
7.6438561897747244,
|
||||
7.6510516911789290,
|
||||
7.6582114827517955,
|
||||
7.6653359171851765,
|
||||
7.6724253419714952,
|
||||
7.6794800995054464,
|
||||
7.6865005271832185,
|
||||
7.6934869574993252,
|
||||
7.7004397181410926,
|
||||
7.7073591320808825,
|
||||
7.7142455176661224,
|
||||
7.7210991887071856,
|
||||
7.7279204545631996,
|
||||
7.7347096202258392,
|
||||
7.7414669864011465,
|
||||
7.7481928495894596,
|
||||
7.7548875021634691,
|
||||
7.7615512324444795,
|
||||
7.7681843247769260,
|
||||
7.7747870596011737,
|
||||
7.7813597135246608,
|
||||
7.7879025593914317,
|
||||
7.7944158663501062,
|
||||
7.8008998999203047,
|
||||
7.8073549220576037,
|
||||
7.8137811912170374,
|
||||
7.8201789624151887,
|
||||
7.8265484872909159,
|
||||
7.8328900141647422,
|
||||
7.8392037880969445,
|
||||
7.8454900509443757,
|
||||
7.8517490414160571,
|
||||
7.8579809951275719,
|
||||
7.8641861446542798,
|
||||
7.8703647195834048,
|
||||
7.8765169465650002,
|
||||
7.8826430493618425,
|
||||
7.8887432488982601,
|
||||
7.8948177633079446,
|
||||
7.9008668079807496,
|
||||
7.9068905956085187,
|
||||
7.9128893362299619,
|
||||
7.9188632372745955,
|
||||
7.9248125036057813,
|
||||
7.9307373375628867,
|
||||
7.9366379390025719,
|
||||
7.9425145053392399,
|
||||
7.9483672315846778,
|
||||
7.9541963103868758,
|
||||
7.9600019320680806,
|
||||
7.9657842846620870,
|
||||
7.9715435539507720,
|
||||
7.9772799234999168,
|
||||
7.9829935746943104,
|
||||
7.9886846867721664,
|
||||
7.9943534368588578,
|
||||
}
|
||||
|
||||
/* Faster logarithm for small integers, with the property of log2(0) == 0. */
|
||||
func fastLog2(v uint) float64 {
|
||||
if v < uint(len(kLog2Table)) {
|
||||
return float64(kLog2Table[v])
|
||||
}
|
||||
|
||||
return math.Log2(float64(v))
|
||||
}
|
||||
+45
@@ -0,0 +1,45 @@
|
||||
package brotli
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/bits"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
/* Copyright 2010 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Function to find maximal matching prefixes of strings. */
|
||||
func findMatchLengthWithLimit(s1 []byte, s2 []byte, limit uint) uint {
|
||||
var matched uint = 0
|
||||
_, _ = s1[limit-1], s2[limit-1] // bounds check
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
// Compare 8 bytes at at time.
|
||||
for matched+8 <= limit {
|
||||
w1 := binary.LittleEndian.Uint64(s1[matched:])
|
||||
w2 := binary.LittleEndian.Uint64(s2[matched:])
|
||||
if w1 != w2 {
|
||||
return matched + uint(bits.TrailingZeros64(w1^w2)>>3)
|
||||
}
|
||||
matched += 8
|
||||
}
|
||||
case "386":
|
||||
// Compare 4 bytes at at time.
|
||||
for matched+4 <= limit {
|
||||
w1 := binary.LittleEndian.Uint32(s1[matched:])
|
||||
w2 := binary.LittleEndian.Uint32(s2[matched:])
|
||||
if w1 != w2 {
|
||||
return matched + uint(bits.TrailingZeros32(w1^w2)>>3)
|
||||
}
|
||||
matched += 4
|
||||
}
|
||||
}
|
||||
for matched < limit && s1[matched] == s2[matched] {
|
||||
matched++
|
||||
}
|
||||
return matched
|
||||
}
|
||||
+287
@@ -0,0 +1,287 @@
|
||||
package brotli
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
func (*h10) HashTypeLength() uint {
|
||||
return 4
|
||||
}
|
||||
|
||||
func (*h10) StoreLookahead() uint {
|
||||
return 128
|
||||
}
|
||||
|
||||
func hashBytesH10(data []byte) uint32 {
|
||||
var h uint32 = binary.LittleEndian.Uint32(data) * kHashMul32
|
||||
|
||||
/* The higher bits contain more mixture from the multiplication,
|
||||
so we take our results from there. */
|
||||
return h >> (32 - 17)
|
||||
}
|
||||
|
||||
/* A (forgetful) hash table where each hash bucket contains a binary tree of
|
||||
sequences whose first 4 bytes share the same hash code.
|
||||
Each sequence is 128 long and is identified by its starting
|
||||
position in the input data. The binary tree is sorted by the lexicographic
|
||||
order of the sequences, and it is also a max-heap with respect to the
|
||||
starting positions. */
|
||||
type h10 struct {
|
||||
hasherCommon
|
||||
window_mask_ uint
|
||||
buckets_ [1 << 17]uint32
|
||||
invalid_pos_ uint32
|
||||
forest []uint32
|
||||
}
|
||||
|
||||
func (h *h10) Initialize(params *encoderParams) {
|
||||
h.window_mask_ = (1 << params.lgwin) - 1
|
||||
h.invalid_pos_ = uint32(0 - h.window_mask_)
|
||||
var num_nodes uint = uint(1) << params.lgwin
|
||||
h.forest = make([]uint32, 2*num_nodes)
|
||||
}
|
||||
|
||||
func (h *h10) Prepare(one_shot bool, input_size uint, data []byte) {
|
||||
var invalid_pos uint32 = h.invalid_pos_
|
||||
var i uint32
|
||||
for i = 0; i < 1<<17; i++ {
|
||||
h.buckets_[i] = invalid_pos
|
||||
}
|
||||
}
|
||||
|
||||
func leftChildIndexH10(self *h10, pos uint) uint {
|
||||
return 2 * (pos & self.window_mask_)
|
||||
}
|
||||
|
||||
func rightChildIndexH10(self *h10, pos uint) uint {
|
||||
return 2*(pos&self.window_mask_) + 1
|
||||
}
|
||||
|
||||
/* Stores the hash of the next 4 bytes and in a single tree-traversal, the
|
||||
hash bucket's binary tree is searched for matches and is re-rooted at the
|
||||
current position.
|
||||
|
||||
If less than 128 data is available, the hash bucket of the
|
||||
current position is searched for matches, but the state of the hash table
|
||||
is not changed, since we can not know the final sorting order of the
|
||||
current (incomplete) sequence.
|
||||
|
||||
This function must be called with increasing cur_ix positions. */
|
||||
func storeAndFindMatchesH10(self *h10, data []byte, cur_ix uint, ring_buffer_mask uint, max_length uint, max_backward uint, best_len *uint, matches []backwardMatch) []backwardMatch {
|
||||
var cur_ix_masked uint = cur_ix & ring_buffer_mask
|
||||
var max_comp_len uint = brotli_min_size_t(max_length, 128)
|
||||
var should_reroot_tree bool = (max_length >= 128)
|
||||
var key uint32 = hashBytesH10(data[cur_ix_masked:])
|
||||
var forest []uint32 = self.forest
|
||||
var prev_ix uint = uint(self.buckets_[key])
|
||||
var node_left uint = leftChildIndexH10(self, cur_ix)
|
||||
var node_right uint = rightChildIndexH10(self, cur_ix)
|
||||
var best_len_left uint = 0
|
||||
var best_len_right uint = 0
|
||||
var depth_remaining uint
|
||||
/* The forest index of the rightmost node of the left subtree of the new
|
||||
root, updated as we traverse and re-root the tree of the hash bucket. */
|
||||
|
||||
/* The forest index of the leftmost node of the right subtree of the new
|
||||
root, updated as we traverse and re-root the tree of the hash bucket. */
|
||||
|
||||
/* The match length of the rightmost node of the left subtree of the new
|
||||
root, updated as we traverse and re-root the tree of the hash bucket. */
|
||||
|
||||
/* The match length of the leftmost node of the right subtree of the new
|
||||
root, updated as we traverse and re-root the tree of the hash bucket. */
|
||||
if should_reroot_tree {
|
||||
self.buckets_[key] = uint32(cur_ix)
|
||||
}
|
||||
|
||||
for depth_remaining = 64; ; depth_remaining-- {
|
||||
var backward uint = cur_ix - prev_ix
|
||||
var prev_ix_masked uint = prev_ix & ring_buffer_mask
|
||||
if backward == 0 || backward > max_backward || depth_remaining == 0 {
|
||||
if should_reroot_tree {
|
||||
forest[node_left] = self.invalid_pos_
|
||||
forest[node_right] = self.invalid_pos_
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
{
|
||||
var cur_len uint = brotli_min_size_t(best_len_left, best_len_right)
|
||||
var len uint
|
||||
assert(cur_len <= 128)
|
||||
len = cur_len + findMatchLengthWithLimit(data[cur_ix_masked+cur_len:], data[prev_ix_masked+cur_len:], max_length-cur_len)
|
||||
if matches != nil && len > *best_len {
|
||||
*best_len = uint(len)
|
||||
initBackwardMatch(&matches[0], backward, uint(len))
|
||||
matches = matches[1:]
|
||||
}
|
||||
|
||||
if len >= max_comp_len {
|
||||
if should_reroot_tree {
|
||||
forest[node_left] = forest[leftChildIndexH10(self, prev_ix)]
|
||||
forest[node_right] = forest[rightChildIndexH10(self, prev_ix)]
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if data[cur_ix_masked+len] > data[prev_ix_masked+len] {
|
||||
best_len_left = uint(len)
|
||||
if should_reroot_tree {
|
||||
forest[node_left] = uint32(prev_ix)
|
||||
}
|
||||
|
||||
node_left = rightChildIndexH10(self, prev_ix)
|
||||
prev_ix = uint(forest[node_left])
|
||||
} else {
|
||||
best_len_right = uint(len)
|
||||
if should_reroot_tree {
|
||||
forest[node_right] = uint32(prev_ix)
|
||||
}
|
||||
|
||||
node_right = leftChildIndexH10(self, prev_ix)
|
||||
prev_ix = uint(forest[node_right])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return matches
|
||||
}
|
||||
|
||||
/* Finds all backward matches of &data[cur_ix & ring_buffer_mask] up to the
|
||||
length of max_length and stores the position cur_ix in the hash table.
|
||||
|
||||
Sets *num_matches to the number of matches found, and stores the found
|
||||
matches in matches[0] to matches[*num_matches - 1]. The matches will be
|
||||
sorted by strictly increasing length and (non-strictly) increasing
|
||||
distance. */
|
||||
func findAllMatchesH10(handle *h10, dictionary *encoderDictionary, data []byte, ring_buffer_mask uint, cur_ix uint, max_length uint, max_backward uint, gap uint, params *encoderParams, matches []backwardMatch) uint {
|
||||
var orig_matches []backwardMatch = matches
|
||||
var cur_ix_masked uint = cur_ix & ring_buffer_mask
|
||||
var best_len uint = 1
|
||||
var short_match_max_backward uint
|
||||
if params.quality != hqZopflificationQuality {
|
||||
short_match_max_backward = 16
|
||||
} else {
|
||||
short_match_max_backward = 64
|
||||
}
|
||||
var stop uint = cur_ix - short_match_max_backward
|
||||
var dict_matches [maxStaticDictionaryMatchLen + 1]uint32
|
||||
var i uint
|
||||
if cur_ix < short_match_max_backward {
|
||||
stop = 0
|
||||
}
|
||||
for i = cur_ix - 1; i > stop && best_len <= 2; i-- {
|
||||
var prev_ix uint = i
|
||||
var backward uint = cur_ix - prev_ix
|
||||
if backward > max_backward {
|
||||
break
|
||||
}
|
||||
|
||||
prev_ix &= ring_buffer_mask
|
||||
if data[cur_ix_masked] != data[prev_ix] || data[cur_ix_masked+1] != data[prev_ix+1] {
|
||||
continue
|
||||
}
|
||||
{
|
||||
var len uint = findMatchLengthWithLimit(data[prev_ix:], data[cur_ix_masked:], max_length)
|
||||
if len > best_len {
|
||||
best_len = uint(len)
|
||||
initBackwardMatch(&matches[0], backward, uint(len))
|
||||
matches = matches[1:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if best_len < max_length {
|
||||
matches = storeAndFindMatchesH10(handle, data, cur_ix, ring_buffer_mask, max_length, max_backward, &best_len, matches)
|
||||
}
|
||||
|
||||
for i = 0; i <= maxStaticDictionaryMatchLen; i++ {
|
||||
dict_matches[i] = kInvalidMatch
|
||||
}
|
||||
{
|
||||
var minlen uint = brotli_max_size_t(4, best_len+1)
|
||||
if findAllStaticDictionaryMatches(dictionary, data[cur_ix_masked:], minlen, max_length, dict_matches[0:]) {
|
||||
var maxlen uint = brotli_min_size_t(maxStaticDictionaryMatchLen, max_length)
|
||||
var l uint
|
||||
for l = minlen; l <= maxlen; l++ {
|
||||
var dict_id uint32 = dict_matches[l]
|
||||
if dict_id < kInvalidMatch {
|
||||
var distance uint = max_backward + gap + uint(dict_id>>5) + 1
|
||||
if distance <= params.dist.max_distance {
|
||||
initDictionaryBackwardMatch(&matches[0], distance, l, uint(dict_id&31))
|
||||
matches = matches[1:]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return uint(-cap(matches) + cap(orig_matches))
|
||||
}
|
||||
|
||||
/* Stores the hash of the next 4 bytes and re-roots the binary tree at the
|
||||
current sequence, without returning any matches.
|
||||
REQUIRES: ix + 128 <= end-of-current-block */
|
||||
func (h *h10) Store(data []byte, mask uint, ix uint) {
|
||||
var max_backward uint = h.window_mask_ - windowGap + 1
|
||||
/* Maximum distance is window size - 16, see section 9.1. of the spec. */
|
||||
storeAndFindMatchesH10(h, data, ix, mask, 128, max_backward, nil, nil)
|
||||
}
|
||||
|
||||
func (h *h10) StoreRange(data []byte, mask uint, ix_start uint, ix_end uint) {
|
||||
var i uint = ix_start
|
||||
var j uint = ix_start
|
||||
if ix_start+63 <= ix_end {
|
||||
i = ix_end - 63
|
||||
}
|
||||
|
||||
if ix_start+512 <= i {
|
||||
for ; j < i; j += 8 {
|
||||
h.Store(data, mask, j)
|
||||
}
|
||||
}
|
||||
|
||||
for ; i < ix_end; i++ {
|
||||
h.Store(data, mask, i)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *h10) StitchToPreviousBlock(num_bytes uint, position uint, ringbuffer []byte, ringbuffer_mask uint) {
|
||||
if num_bytes >= h.HashTypeLength()-1 && position >= 128 {
|
||||
var i_start uint = position - 128 + 1
|
||||
var i_end uint = brotli_min_size_t(position, i_start+num_bytes)
|
||||
/* Store the last `128 - 1` positions in the hasher.
|
||||
These could not be calculated before, since they require knowledge
|
||||
of both the previous and the current block. */
|
||||
|
||||
var i uint
|
||||
for i = i_start; i < i_end; i++ {
|
||||
/* Maximum distance is window size - 16, see section 9.1. of the spec.
|
||||
Furthermore, we have to make sure that we don't look further back
|
||||
from the start of the next block than the window size, otherwise we
|
||||
could access already overwritten areas of the ring-buffer. */
|
||||
var max_backward uint = h.window_mask_ - brotli_max_size_t(windowGap-1, position-i)
|
||||
|
||||
/* We know that i + 128 <= position + num_bytes, i.e. the
|
||||
end of the current block and that we have at least
|
||||
128 tail in the ring-buffer. */
|
||||
storeAndFindMatchesH10(h, ringbuffer, i, ringbuffer_mask, 128, max_backward, nil, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* MAX_NUM_MATCHES == 64 + MAX_TREE_SEARCH_DEPTH */
|
||||
const maxNumMatchesH10 = 128
|
||||
|
||||
func (*h10) FindLongestMatch(dictionary *encoderDictionary, data []byte, ring_buffer_mask uint, distance_cache []int, cur_ix uint, max_length uint, max_backward uint, gap uint, max_distance uint, out *hasherSearchResult) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (*h10) PrepareDistanceCache(distance_cache []int) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
+214
@@ -0,0 +1,214 @@
|
||||
package brotli
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
/* Copyright 2010 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* A (forgetful) hash table to the data seen by the compressor, to
|
||||
help create backward references to previous data.
|
||||
|
||||
This is a hash map of fixed size (bucket_size_) to a ring buffer of
|
||||
fixed size (block_size_). The ring buffer contains the last block_size_
|
||||
index positions of the given hash key in the compressed data. */
|
||||
func (*h5) HashTypeLength() uint {
|
||||
return 4
|
||||
}
|
||||
|
||||
func (*h5) StoreLookahead() uint {
|
||||
return 4
|
||||
}
|
||||
|
||||
/* HashBytes is the function that chooses the bucket to place the address in. */
|
||||
func hashBytesH5(data []byte, shift int) uint32 {
|
||||
var h uint32 = binary.LittleEndian.Uint32(data) * kHashMul32
|
||||
|
||||
/* The higher bits contain more mixture from the multiplication,
|
||||
so we take our results from there. */
|
||||
return uint32(h >> uint(shift))
|
||||
}
|
||||
|
||||
type h5 struct {
|
||||
hasherCommon
|
||||
bucket_size_ uint
|
||||
block_size_ uint
|
||||
hash_shift_ int
|
||||
block_mask_ uint32
|
||||
num []uint16
|
||||
buckets []uint32
|
||||
}
|
||||
|
||||
func (h *h5) Initialize(params *encoderParams) {
|
||||
h.hash_shift_ = 32 - h.params.bucket_bits
|
||||
h.bucket_size_ = uint(1) << uint(h.params.bucket_bits)
|
||||
h.block_size_ = uint(1) << uint(h.params.block_bits)
|
||||
h.block_mask_ = uint32(h.block_size_ - 1)
|
||||
h.num = make([]uint16, h.bucket_size_)
|
||||
h.buckets = make([]uint32, h.block_size_*h.bucket_size_)
|
||||
}
|
||||
|
||||
func (h *h5) Prepare(one_shot bool, input_size uint, data []byte) {
|
||||
var num []uint16 = h.num
|
||||
var partial_prepare_threshold uint = h.bucket_size_ >> 6
|
||||
/* Partial preparation is 100 times slower (per socket). */
|
||||
if one_shot && input_size <= partial_prepare_threshold {
|
||||
var i uint
|
||||
for i = 0; i < input_size; i++ {
|
||||
var key uint32 = hashBytesH5(data[i:], h.hash_shift_)
|
||||
num[key] = 0
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < int(h.bucket_size_); i++ {
|
||||
num[i] = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Look at 4 bytes at &data[ix & mask].
|
||||
Compute a hash from these, and store the value of ix at that position. */
|
||||
func (h *h5) Store(data []byte, mask uint, ix uint) {
|
||||
var num []uint16 = h.num
|
||||
var key uint32 = hashBytesH5(data[ix&mask:], h.hash_shift_)
|
||||
var minor_ix uint = uint(num[key]) & uint(h.block_mask_)
|
||||
var offset uint = minor_ix + uint(key<<uint(h.params.block_bits))
|
||||
h.buckets[offset] = uint32(ix)
|
||||
num[key]++
|
||||
}
|
||||
|
||||
func (h *h5) StoreRange(data []byte, mask uint, ix_start uint, ix_end uint) {
|
||||
var i uint
|
||||
for i = ix_start; i < ix_end; i++ {
|
||||
h.Store(data, mask, i)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *h5) StitchToPreviousBlock(num_bytes uint, position uint, ringbuffer []byte, ringbuffer_mask uint) {
|
||||
if num_bytes >= h.HashTypeLength()-1 && position >= 3 {
|
||||
/* Prepare the hashes for three last bytes of the last write.
|
||||
These could not be calculated before, since they require knowledge
|
||||
of both the previous and the current block. */
|
||||
h.Store(ringbuffer, ringbuffer_mask, position-3)
|
||||
h.Store(ringbuffer, ringbuffer_mask, position-2)
|
||||
h.Store(ringbuffer, ringbuffer_mask, position-1)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *h5) PrepareDistanceCache(distance_cache []int) {
|
||||
prepareDistanceCache(distance_cache, h.params.num_last_distances_to_check)
|
||||
}
|
||||
|
||||
/* Find a longest backward match of &data[cur_ix] up to the length of
|
||||
max_length and stores the position cur_ix in the hash table.
|
||||
|
||||
REQUIRES: PrepareDistanceCacheH5 must be invoked for current distance cache
|
||||
values; if this method is invoked repeatedly with the same distance
|
||||
cache values, it is enough to invoke PrepareDistanceCacheH5 once.
|
||||
|
||||
Does not look for matches longer than max_length.
|
||||
Does not look for matches further away than max_backward.
|
||||
Writes the best match into |out|.
|
||||
|out|->score is updated only if a better match is found. */
|
||||
func (h *h5) FindLongestMatch(dictionary *encoderDictionary, data []byte, ring_buffer_mask uint, distance_cache []int, cur_ix uint, max_length uint, max_backward uint, gap uint, max_distance uint, out *hasherSearchResult) {
|
||||
var num []uint16 = h.num
|
||||
var buckets []uint32 = h.buckets
|
||||
var cur_ix_masked uint = cur_ix & ring_buffer_mask
|
||||
var min_score uint = out.score
|
||||
var best_score uint = out.score
|
||||
var best_len uint = out.len
|
||||
var i uint
|
||||
var bucket []uint32
|
||||
/* Don't accept a short copy from far away. */
|
||||
out.len = 0
|
||||
|
||||
out.len_code_delta = 0
|
||||
|
||||
/* Try last distance first. */
|
||||
for i = 0; i < uint(h.params.num_last_distances_to_check); i++ {
|
||||
var backward uint = uint(distance_cache[i])
|
||||
var prev_ix uint = uint(cur_ix - backward)
|
||||
if prev_ix >= cur_ix {
|
||||
continue
|
||||
}
|
||||
|
||||
if backward > max_backward {
|
||||
continue
|
||||
}
|
||||
|
||||
prev_ix &= ring_buffer_mask
|
||||
|
||||
if cur_ix_masked+best_len > ring_buffer_mask || prev_ix+best_len > ring_buffer_mask || data[cur_ix_masked+best_len] != data[prev_ix+best_len] {
|
||||
continue
|
||||
}
|
||||
{
|
||||
var len uint = findMatchLengthWithLimit(data[prev_ix:], data[cur_ix_masked:], max_length)
|
||||
if len >= 3 || (len == 2 && i < 2) {
|
||||
/* Comparing for >= 2 does not change the semantics, but just saves for
|
||||
a few unnecessary binary logarithms in backward reference score,
|
||||
since we are not interested in such short matches. */
|
||||
var score uint = backwardReferenceScoreUsingLastDistance(uint(len))
|
||||
if best_score < score {
|
||||
if i != 0 {
|
||||
score -= backwardReferencePenaltyUsingLastDistance(i)
|
||||
}
|
||||
if best_score < score {
|
||||
best_score = score
|
||||
best_len = uint(len)
|
||||
out.len = best_len
|
||||
out.distance = backward
|
||||
out.score = best_score
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
var key uint32 = hashBytesH5(data[cur_ix_masked:], h.hash_shift_)
|
||||
bucket = buckets[key<<uint(h.params.block_bits):]
|
||||
var down uint
|
||||
if uint(num[key]) > h.block_size_ {
|
||||
down = uint(num[key]) - h.block_size_
|
||||
} else {
|
||||
down = 0
|
||||
}
|
||||
for i = uint(num[key]); i > down; {
|
||||
var prev_ix uint
|
||||
i--
|
||||
prev_ix = uint(bucket[uint32(i)&h.block_mask_])
|
||||
var backward uint = cur_ix - prev_ix
|
||||
if backward > max_backward {
|
||||
break
|
||||
}
|
||||
|
||||
prev_ix &= ring_buffer_mask
|
||||
if cur_ix_masked+best_len > ring_buffer_mask || prev_ix+best_len > ring_buffer_mask || data[cur_ix_masked+best_len] != data[prev_ix+best_len] {
|
||||
continue
|
||||
}
|
||||
{
|
||||
var len uint = findMatchLengthWithLimit(data[prev_ix:], data[cur_ix_masked:], max_length)
|
||||
if len >= 4 {
|
||||
/* Comparing for >= 3 does not change the semantics, but just saves
|
||||
for a few unnecessary binary logarithms in backward reference
|
||||
score, since we are not interested in such short matches. */
|
||||
var score uint = backwardReferenceScore(uint(len), backward)
|
||||
if best_score < score {
|
||||
best_score = score
|
||||
best_len = uint(len)
|
||||
out.len = best_len
|
||||
out.distance = backward
|
||||
out.score = best_score
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bucket[uint32(num[key])&h.block_mask_] = uint32(cur_ix)
|
||||
num[key]++
|
||||
}
|
||||
|
||||
if min_score == out.score {
|
||||
searchInStaticDictionary(dictionary, h, data[cur_ix_masked:], max_length, max_backward+gap, max_distance, out, false)
|
||||
}
|
||||
}
|
||||
+216
@@ -0,0 +1,216 @@
|
||||
package brotli
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
/* Copyright 2010 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* A (forgetful) hash table to the data seen by the compressor, to
|
||||
help create backward references to previous data.
|
||||
|
||||
This is a hash map of fixed size (bucket_size_) to a ring buffer of
|
||||
fixed size (block_size_). The ring buffer contains the last block_size_
|
||||
index positions of the given hash key in the compressed data. */
|
||||
func (*h6) HashTypeLength() uint {
|
||||
return 8
|
||||
}
|
||||
|
||||
func (*h6) StoreLookahead() uint {
|
||||
return 8
|
||||
}
|
||||
|
||||
/* HashBytes is the function that chooses the bucket to place the address in. */
|
||||
func hashBytesH6(data []byte, mask uint64, shift int) uint32 {
|
||||
var h uint64 = (binary.LittleEndian.Uint64(data) & mask) * kHashMul64Long
|
||||
|
||||
/* The higher bits contain more mixture from the multiplication,
|
||||
so we take our results from there. */
|
||||
return uint32(h >> uint(shift))
|
||||
}
|
||||
|
||||
type h6 struct {
|
||||
hasherCommon
|
||||
bucket_size_ uint
|
||||
block_size_ uint
|
||||
hash_shift_ int
|
||||
hash_mask_ uint64
|
||||
block_mask_ uint32
|
||||
num []uint16
|
||||
buckets []uint32
|
||||
}
|
||||
|
||||
func (h *h6) Initialize(params *encoderParams) {
|
||||
h.hash_shift_ = 64 - h.params.bucket_bits
|
||||
h.hash_mask_ = (^(uint64(0))) >> uint(64-8*h.params.hash_len)
|
||||
h.bucket_size_ = uint(1) << uint(h.params.bucket_bits)
|
||||
h.block_size_ = uint(1) << uint(h.params.block_bits)
|
||||
h.block_mask_ = uint32(h.block_size_ - 1)
|
||||
h.num = make([]uint16, h.bucket_size_)
|
||||
h.buckets = make([]uint32, h.block_size_*h.bucket_size_)
|
||||
}
|
||||
|
||||
func (h *h6) Prepare(one_shot bool, input_size uint, data []byte) {
|
||||
var num []uint16 = h.num
|
||||
var partial_prepare_threshold uint = h.bucket_size_ >> 6
|
||||
/* Partial preparation is 100 times slower (per socket). */
|
||||
if one_shot && input_size <= partial_prepare_threshold {
|
||||
var i uint
|
||||
for i = 0; i < input_size; i++ {
|
||||
var key uint32 = hashBytesH6(data[i:], h.hash_mask_, h.hash_shift_)
|
||||
num[key] = 0
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < int(h.bucket_size_); i++ {
|
||||
num[i] = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Look at 4 bytes at &data[ix & mask].
|
||||
Compute a hash from these, and store the value of ix at that position. */
|
||||
func (h *h6) Store(data []byte, mask uint, ix uint) {
|
||||
var num []uint16 = h.num
|
||||
var key uint32 = hashBytesH6(data[ix&mask:], h.hash_mask_, h.hash_shift_)
|
||||
var minor_ix uint = uint(num[key]) & uint(h.block_mask_)
|
||||
var offset uint = minor_ix + uint(key<<uint(h.params.block_bits))
|
||||
h.buckets[offset] = uint32(ix)
|
||||
num[key]++
|
||||
}
|
||||
|
||||
func (h *h6) StoreRange(data []byte, mask uint, ix_start uint, ix_end uint) {
|
||||
var i uint
|
||||
for i = ix_start; i < ix_end; i++ {
|
||||
h.Store(data, mask, i)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *h6) StitchToPreviousBlock(num_bytes uint, position uint, ringbuffer []byte, ringbuffer_mask uint) {
|
||||
if num_bytes >= h.HashTypeLength()-1 && position >= 3 {
|
||||
/* Prepare the hashes for three last bytes of the last write.
|
||||
These could not be calculated before, since they require knowledge
|
||||
of both the previous and the current block. */
|
||||
h.Store(ringbuffer, ringbuffer_mask, position-3)
|
||||
h.Store(ringbuffer, ringbuffer_mask, position-2)
|
||||
h.Store(ringbuffer, ringbuffer_mask, position-1)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *h6) PrepareDistanceCache(distance_cache []int) {
|
||||
prepareDistanceCache(distance_cache, h.params.num_last_distances_to_check)
|
||||
}
|
||||
|
||||
/* Find a longest backward match of &data[cur_ix] up to the length of
|
||||
max_length and stores the position cur_ix in the hash table.
|
||||
|
||||
REQUIRES: PrepareDistanceCacheH6 must be invoked for current distance cache
|
||||
values; if this method is invoked repeatedly with the same distance
|
||||
cache values, it is enough to invoke PrepareDistanceCacheH6 once.
|
||||
|
||||
Does not look for matches longer than max_length.
|
||||
Does not look for matches further away than max_backward.
|
||||
Writes the best match into |out|.
|
||||
|out|->score is updated only if a better match is found. */
|
||||
func (h *h6) FindLongestMatch(dictionary *encoderDictionary, data []byte, ring_buffer_mask uint, distance_cache []int, cur_ix uint, max_length uint, max_backward uint, gap uint, max_distance uint, out *hasherSearchResult) {
|
||||
var num []uint16 = h.num
|
||||
var buckets []uint32 = h.buckets
|
||||
var cur_ix_masked uint = cur_ix & ring_buffer_mask
|
||||
var min_score uint = out.score
|
||||
var best_score uint = out.score
|
||||
var best_len uint = out.len
|
||||
var i uint
|
||||
var bucket []uint32
|
||||
/* Don't accept a short copy from far away. */
|
||||
out.len = 0
|
||||
|
||||
out.len_code_delta = 0
|
||||
|
||||
/* Try last distance first. */
|
||||
for i = 0; i < uint(h.params.num_last_distances_to_check); i++ {
|
||||
var backward uint = uint(distance_cache[i])
|
||||
var prev_ix uint = uint(cur_ix - backward)
|
||||
if prev_ix >= cur_ix {
|
||||
continue
|
||||
}
|
||||
|
||||
if backward > max_backward {
|
||||
continue
|
||||
}
|
||||
|
||||
prev_ix &= ring_buffer_mask
|
||||
|
||||
if cur_ix_masked+best_len > ring_buffer_mask || prev_ix+best_len > ring_buffer_mask || data[cur_ix_masked+best_len] != data[prev_ix+best_len] {
|
||||
continue
|
||||
}
|
||||
{
|
||||
var len uint = findMatchLengthWithLimit(data[prev_ix:], data[cur_ix_masked:], max_length)
|
||||
if len >= 3 || (len == 2 && i < 2) {
|
||||
/* Comparing for >= 2 does not change the semantics, but just saves for
|
||||
a few unnecessary binary logarithms in backward reference score,
|
||||
since we are not interested in such short matches. */
|
||||
var score uint = backwardReferenceScoreUsingLastDistance(uint(len))
|
||||
if best_score < score {
|
||||
if i != 0 {
|
||||
score -= backwardReferencePenaltyUsingLastDistance(i)
|
||||
}
|
||||
if best_score < score {
|
||||
best_score = score
|
||||
best_len = uint(len)
|
||||
out.len = best_len
|
||||
out.distance = backward
|
||||
out.score = best_score
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
var key uint32 = hashBytesH6(data[cur_ix_masked:], h.hash_mask_, h.hash_shift_)
|
||||
bucket = buckets[key<<uint(h.params.block_bits):]
|
||||
var down uint
|
||||
if uint(num[key]) > h.block_size_ {
|
||||
down = uint(num[key]) - h.block_size_
|
||||
} else {
|
||||
down = 0
|
||||
}
|
||||
for i = uint(num[key]); i > down; {
|
||||
var prev_ix uint
|
||||
i--
|
||||
prev_ix = uint(bucket[uint32(i)&h.block_mask_])
|
||||
var backward uint = cur_ix - prev_ix
|
||||
if backward > max_backward {
|
||||
break
|
||||
}
|
||||
|
||||
prev_ix &= ring_buffer_mask
|
||||
if cur_ix_masked+best_len > ring_buffer_mask || prev_ix+best_len > ring_buffer_mask || data[cur_ix_masked+best_len] != data[prev_ix+best_len] {
|
||||
continue
|
||||
}
|
||||
{
|
||||
var len uint = findMatchLengthWithLimit(data[prev_ix:], data[cur_ix_masked:], max_length)
|
||||
if len >= 4 {
|
||||
/* Comparing for >= 3 does not change the semantics, but just saves
|
||||
for a few unnecessary binary logarithms in backward reference
|
||||
score, since we are not interested in such short matches. */
|
||||
var score uint = backwardReferenceScore(uint(len), backward)
|
||||
if best_score < score {
|
||||
best_score = score
|
||||
best_len = uint(len)
|
||||
out.len = best_len
|
||||
out.distance = backward
|
||||
out.score = best_score
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bucket[uint32(num[key])&h.block_mask_] = uint32(cur_ix)
|
||||
num[key]++
|
||||
}
|
||||
|
||||
if min_score == out.score {
|
||||
searchInStaticDictionary(dictionary, h, data[cur_ix_masked:], max_length, max_backward+gap, max_distance, out, false)
|
||||
}
|
||||
}
|
||||
+342
@@ -0,0 +1,342 @@
|
||||
package brotli
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type hasherCommon struct {
|
||||
params hasherParams
|
||||
is_prepared_ bool
|
||||
dict_num_lookups uint
|
||||
dict_num_matches uint
|
||||
}
|
||||
|
||||
func (h *hasherCommon) Common() *hasherCommon {
|
||||
return h
|
||||
}
|
||||
|
||||
type hasherHandle interface {
|
||||
Common() *hasherCommon
|
||||
Initialize(params *encoderParams)
|
||||
Prepare(one_shot bool, input_size uint, data []byte)
|
||||
StitchToPreviousBlock(num_bytes uint, position uint, ringbuffer []byte, ringbuffer_mask uint)
|
||||
HashTypeLength() uint
|
||||
StoreLookahead() uint
|
||||
PrepareDistanceCache(distance_cache []int)
|
||||
FindLongestMatch(dictionary *encoderDictionary, data []byte, ring_buffer_mask uint, distance_cache []int, cur_ix uint, max_length uint, max_backward uint, gap uint, max_distance uint, out *hasherSearchResult)
|
||||
StoreRange(data []byte, mask uint, ix_start uint, ix_end uint)
|
||||
Store(data []byte, mask uint, ix uint)
|
||||
}
|
||||
|
||||
const kCutoffTransformsCount uint32 = 10
|
||||
|
||||
/* 0, 12, 27, 23, 42, 63, 56, 48, 59, 64 */
|
||||
/* 0+0, 4+8, 8+19, 12+11, 16+26, 20+43, 24+32, 28+20, 32+27, 36+28 */
|
||||
const kCutoffTransforms uint64 = 0x071B520ADA2D3200
|
||||
|
||||
type hasherSearchResult struct {
|
||||
len uint
|
||||
distance uint
|
||||
score uint
|
||||
len_code_delta int
|
||||
}
|
||||
|
||||
/* kHashMul32 multiplier has these properties:
|
||||
* The multiplier must be odd. Otherwise we may lose the highest bit.
|
||||
* No long streaks of ones or zeros.
|
||||
* There is no effort to ensure that it is a prime, the oddity is enough
|
||||
for this use.
|
||||
* The number has been tuned heuristically against compression benchmarks. */
|
||||
const kHashMul32 uint32 = 0x1E35A7BD
|
||||
|
||||
const kHashMul64 uint64 = 0x1E35A7BD1E35A7BD
|
||||
|
||||
const kHashMul64Long uint64 = 0x1FE35A7BD3579BD3
|
||||
|
||||
func hash14(data []byte) uint32 {
|
||||
var h uint32 = binary.LittleEndian.Uint32(data) * kHashMul32
|
||||
|
||||
/* The higher bits contain more mixture from the multiplication,
|
||||
so we take our results from there. */
|
||||
return h >> (32 - 14)
|
||||
}
|
||||
|
||||
func prepareDistanceCache(distance_cache []int, num_distances int) {
|
||||
if num_distances > 4 {
|
||||
var last_distance int = distance_cache[0]
|
||||
distance_cache[4] = last_distance - 1
|
||||
distance_cache[5] = last_distance + 1
|
||||
distance_cache[6] = last_distance - 2
|
||||
distance_cache[7] = last_distance + 2
|
||||
distance_cache[8] = last_distance - 3
|
||||
distance_cache[9] = last_distance + 3
|
||||
if num_distances > 10 {
|
||||
var next_last_distance int = distance_cache[1]
|
||||
distance_cache[10] = next_last_distance - 1
|
||||
distance_cache[11] = next_last_distance + 1
|
||||
distance_cache[12] = next_last_distance - 2
|
||||
distance_cache[13] = next_last_distance + 2
|
||||
distance_cache[14] = next_last_distance - 3
|
||||
distance_cache[15] = next_last_distance + 3
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const literalByteScore = 135
|
||||
|
||||
const distanceBitPenalty = 30
|
||||
|
||||
/* Score must be positive after applying maximal penalty. */
|
||||
const scoreBase = (distanceBitPenalty * 8 * 8)
|
||||
|
||||
/* Usually, we always choose the longest backward reference. This function
|
||||
allows for the exception of that rule.
|
||||
|
||||
If we choose a backward reference that is further away, it will
|
||||
usually be coded with more bits. We approximate this by assuming
|
||||
log2(distance). If the distance can be expressed in terms of the
|
||||
last four distances, we use some heuristic constants to estimate
|
||||
the bits cost. For the first up to four literals we use the bit
|
||||
cost of the literals from the literal cost model, after that we
|
||||
use the average bit cost of the cost model.
|
||||
|
||||
This function is used to sometimes discard a longer backward reference
|
||||
when it is not much longer and the bit cost for encoding it is more
|
||||
than the saved literals.
|
||||
|
||||
backward_reference_offset MUST be positive. */
|
||||
func backwardReferenceScore(copy_length uint, backward_reference_offset uint) uint {
|
||||
return scoreBase + literalByteScore*uint(copy_length) - distanceBitPenalty*uint(log2FloorNonZero(backward_reference_offset))
|
||||
}
|
||||
|
||||
func backwardReferenceScoreUsingLastDistance(copy_length uint) uint {
|
||||
return literalByteScore*uint(copy_length) + scoreBase + 15
|
||||
}
|
||||
|
||||
func backwardReferencePenaltyUsingLastDistance(distance_short_code uint) uint {
|
||||
return uint(39) + ((0x1CA10 >> (distance_short_code & 0xE)) & 0xE)
|
||||
}
|
||||
|
||||
func testStaticDictionaryItem(dictionary *encoderDictionary, item uint, data []byte, max_length uint, max_backward uint, max_distance uint, out *hasherSearchResult) bool {
|
||||
var len uint
|
||||
var word_idx uint
|
||||
var offset uint
|
||||
var matchlen uint
|
||||
var backward uint
|
||||
var score uint
|
||||
len = item & 0x1F
|
||||
word_idx = item >> 5
|
||||
offset = uint(dictionary.words.offsets_by_length[len]) + len*word_idx
|
||||
if len > max_length {
|
||||
return false
|
||||
}
|
||||
|
||||
matchlen = findMatchLengthWithLimit(data, dictionary.words.data[offset:], uint(len))
|
||||
if matchlen+uint(dictionary.cutoffTransformsCount) <= len || matchlen == 0 {
|
||||
return false
|
||||
}
|
||||
{
|
||||
var cut uint = len - matchlen
|
||||
var transform_id uint = (cut << 2) + uint((dictionary.cutoffTransforms>>(cut*6))&0x3F)
|
||||
backward = max_backward + 1 + word_idx + (transform_id << dictionary.words.size_bits_by_length[len])
|
||||
}
|
||||
|
||||
if backward > max_distance {
|
||||
return false
|
||||
}
|
||||
|
||||
score = backwardReferenceScore(matchlen, backward)
|
||||
if score < out.score {
|
||||
return false
|
||||
}
|
||||
|
||||
out.len = matchlen
|
||||
out.len_code_delta = int(len) - int(matchlen)
|
||||
out.distance = backward
|
||||
out.score = score
|
||||
return true
|
||||
}
|
||||
|
||||
func searchInStaticDictionary(dictionary *encoderDictionary, handle hasherHandle, data []byte, max_length uint, max_backward uint, max_distance uint, out *hasherSearchResult, shallow bool) {
|
||||
var key uint
|
||||
var i uint
|
||||
var self *hasherCommon = handle.Common()
|
||||
if self.dict_num_matches < self.dict_num_lookups>>7 {
|
||||
return
|
||||
}
|
||||
|
||||
key = uint(hash14(data) << 1)
|
||||
for i = 0; ; (func() { i++; key++ })() {
|
||||
var tmp uint
|
||||
if shallow {
|
||||
tmp = 1
|
||||
} else {
|
||||
tmp = 2
|
||||
}
|
||||
if i >= tmp {
|
||||
break
|
||||
}
|
||||
var item uint = uint(dictionary.hash_table[key])
|
||||
self.dict_num_lookups++
|
||||
if item != 0 {
|
||||
var item_matches bool = testStaticDictionaryItem(dictionary, item, data, max_length, max_backward, max_distance, out)
|
||||
if item_matches {
|
||||
self.dict_num_matches++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type backwardMatch struct {
|
||||
distance uint32
|
||||
length_and_code uint32
|
||||
}
|
||||
|
||||
func initBackwardMatch(self *backwardMatch, dist uint, len uint) {
|
||||
self.distance = uint32(dist)
|
||||
self.length_and_code = uint32(len << 5)
|
||||
}
|
||||
|
||||
func initDictionaryBackwardMatch(self *backwardMatch, dist uint, len uint, len_code uint) {
|
||||
self.distance = uint32(dist)
|
||||
var tmp uint
|
||||
if len == len_code {
|
||||
tmp = 0
|
||||
} else {
|
||||
tmp = len_code
|
||||
}
|
||||
self.length_and_code = uint32(len<<5 | tmp)
|
||||
}
|
||||
|
||||
func backwardMatchLength(self *backwardMatch) uint {
|
||||
return uint(self.length_and_code >> 5)
|
||||
}
|
||||
|
||||
func backwardMatchLengthCode(self *backwardMatch) uint {
|
||||
var code uint = uint(self.length_and_code) & 31
|
||||
if code != 0 {
|
||||
return code
|
||||
} else {
|
||||
return backwardMatchLength(self)
|
||||
}
|
||||
}
|
||||
|
||||
func hasherReset(handle hasherHandle) {
|
||||
if handle == nil {
|
||||
return
|
||||
}
|
||||
handle.Common().is_prepared_ = false
|
||||
}
|
||||
|
||||
func newHasher(typ int) hasherHandle {
|
||||
switch typ {
|
||||
case 2:
|
||||
return &hashLongestMatchQuickly{
|
||||
bucketBits: 16,
|
||||
bucketSweep: 1,
|
||||
hashLen: 5,
|
||||
useDictionary: true,
|
||||
}
|
||||
case 3:
|
||||
return &hashLongestMatchQuickly{
|
||||
bucketBits: 16,
|
||||
bucketSweep: 2,
|
||||
hashLen: 5,
|
||||
useDictionary: false,
|
||||
}
|
||||
case 4:
|
||||
return &hashLongestMatchQuickly{
|
||||
bucketBits: 17,
|
||||
bucketSweep: 4,
|
||||
hashLen: 5,
|
||||
useDictionary: true,
|
||||
}
|
||||
case 5:
|
||||
return new(h5)
|
||||
case 6:
|
||||
return new(h6)
|
||||
case 10:
|
||||
return new(h10)
|
||||
case 35:
|
||||
return &hashComposite{
|
||||
ha: newHasher(3),
|
||||
hb: &hashRolling{jump: 4},
|
||||
}
|
||||
case 40:
|
||||
return &hashForgetfulChain{
|
||||
bucketBits: 15,
|
||||
numBanks: 1,
|
||||
bankBits: 16,
|
||||
numLastDistancesToCheck: 4,
|
||||
}
|
||||
case 41:
|
||||
return &hashForgetfulChain{
|
||||
bucketBits: 15,
|
||||
numBanks: 1,
|
||||
bankBits: 16,
|
||||
numLastDistancesToCheck: 10,
|
||||
}
|
||||
case 42:
|
||||
return &hashForgetfulChain{
|
||||
bucketBits: 15,
|
||||
numBanks: 512,
|
||||
bankBits: 9,
|
||||
numLastDistancesToCheck: 16,
|
||||
}
|
||||
case 54:
|
||||
return &hashLongestMatchQuickly{
|
||||
bucketBits: 20,
|
||||
bucketSweep: 4,
|
||||
hashLen: 7,
|
||||
useDictionary: false,
|
||||
}
|
||||
case 55:
|
||||
return &hashComposite{
|
||||
ha: newHasher(54),
|
||||
hb: &hashRolling{jump: 4},
|
||||
}
|
||||
case 65:
|
||||
return &hashComposite{
|
||||
ha: newHasher(6),
|
||||
hb: &hashRolling{jump: 1},
|
||||
}
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("unknown hasher type: %d", typ))
|
||||
}
|
||||
|
||||
func hasherSetup(handle *hasherHandle, params *encoderParams, data []byte, position uint, input_size uint, is_last bool) {
|
||||
var self hasherHandle = nil
|
||||
var common *hasherCommon = nil
|
||||
var one_shot bool = (position == 0 && is_last)
|
||||
if *handle == nil {
|
||||
chooseHasher(params, ¶ms.hasher)
|
||||
self = newHasher(params.hasher.type_)
|
||||
|
||||
*handle = self
|
||||
common = self.Common()
|
||||
common.params = params.hasher
|
||||
self.Initialize(params)
|
||||
}
|
||||
|
||||
self = *handle
|
||||
common = self.Common()
|
||||
if !common.is_prepared_ {
|
||||
self.Prepare(one_shot, input_size, data)
|
||||
|
||||
if position == 0 {
|
||||
common.dict_num_lookups = 0
|
||||
common.dict_num_matches = 0
|
||||
}
|
||||
|
||||
common.is_prepared_ = true
|
||||
}
|
||||
}
|
||||
|
||||
func initOrStitchToPreviousBlock(handle *hasherHandle, data []byte, mask uint, params *encoderParams, position uint, input_size uint, is_last bool) {
|
||||
var self hasherHandle
|
||||
hasherSetup(handle, params, data, position, input_size, is_last)
|
||||
self = *handle
|
||||
self.StitchToPreviousBlock(input_size, position, data, mask)
|
||||
}
|
||||
+93
@@ -0,0 +1,93 @@
|
||||
package brotli
|
||||
|
||||
/* Copyright 2018 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
func (h *hashComposite) HashTypeLength() uint {
|
||||
var a uint = h.ha.HashTypeLength()
|
||||
var b uint = h.hb.HashTypeLength()
|
||||
if a > b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
func (h *hashComposite) StoreLookahead() uint {
|
||||
var a uint = h.ha.StoreLookahead()
|
||||
var b uint = h.hb.StoreLookahead()
|
||||
if a > b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
/* Composite hasher: This hasher allows to combine two other hashers, HASHER_A
|
||||
and HASHER_B. */
|
||||
type hashComposite struct {
|
||||
hasherCommon
|
||||
ha hasherHandle
|
||||
hb hasherHandle
|
||||
params *encoderParams
|
||||
}
|
||||
|
||||
func (h *hashComposite) Initialize(params *encoderParams) {
|
||||
h.params = params
|
||||
}
|
||||
|
||||
/* TODO: Initialize of the hashers is defered to Prepare (and params
|
||||
remembered here) because we don't get the one_shot and input_size params
|
||||
here that are needed to know the memory size of them. Instead provide
|
||||
those params to all hashers InitializehashComposite */
|
||||
func (h *hashComposite) Prepare(one_shot bool, input_size uint, data []byte) {
|
||||
if h.ha == nil {
|
||||
var common_a *hasherCommon
|
||||
var common_b *hasherCommon
|
||||
|
||||
common_a = h.ha.Common()
|
||||
common_a.params = h.params.hasher
|
||||
common_a.is_prepared_ = false
|
||||
common_a.dict_num_lookups = 0
|
||||
common_a.dict_num_matches = 0
|
||||
h.ha.Initialize(h.params)
|
||||
|
||||
common_b = h.hb.Common()
|
||||
common_b.params = h.params.hasher
|
||||
common_b.is_prepared_ = false
|
||||
common_b.dict_num_lookups = 0
|
||||
common_b.dict_num_matches = 0
|
||||
h.hb.Initialize(h.params)
|
||||
}
|
||||
|
||||
h.ha.Prepare(one_shot, input_size, data)
|
||||
h.hb.Prepare(one_shot, input_size, data)
|
||||
}
|
||||
|
||||
func (h *hashComposite) Store(data []byte, mask uint, ix uint) {
|
||||
h.ha.Store(data, mask, ix)
|
||||
h.hb.Store(data, mask, ix)
|
||||
}
|
||||
|
||||
func (h *hashComposite) StoreRange(data []byte, mask uint, ix_start uint, ix_end uint) {
|
||||
h.ha.StoreRange(data, mask, ix_start, ix_end)
|
||||
h.hb.StoreRange(data, mask, ix_start, ix_end)
|
||||
}
|
||||
|
||||
func (h *hashComposite) StitchToPreviousBlock(num_bytes uint, position uint, ringbuffer []byte, ring_buffer_mask uint) {
|
||||
h.ha.StitchToPreviousBlock(num_bytes, position, ringbuffer, ring_buffer_mask)
|
||||
h.hb.StitchToPreviousBlock(num_bytes, position, ringbuffer, ring_buffer_mask)
|
||||
}
|
||||
|
||||
func (h *hashComposite) PrepareDistanceCache(distance_cache []int) {
|
||||
h.ha.PrepareDistanceCache(distance_cache)
|
||||
h.hb.PrepareDistanceCache(distance_cache)
|
||||
}
|
||||
|
||||
func (h *hashComposite) FindLongestMatch(dictionary *encoderDictionary, data []byte, ring_buffer_mask uint, distance_cache []int, cur_ix uint, max_length uint, max_backward uint, gap uint, max_distance uint, out *hasherSearchResult) {
|
||||
h.ha.FindLongestMatch(dictionary, data, ring_buffer_mask, distance_cache, cur_ix, max_length, max_backward, gap, max_distance, out)
|
||||
h.hb.FindLongestMatch(dictionary, data, ring_buffer_mask, distance_cache, cur_ix, max_length, max_backward, gap, max_distance, out)
|
||||
}
|
||||
+252
@@ -0,0 +1,252 @@
|
||||
package brotli
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
func (*hashForgetfulChain) HashTypeLength() uint {
|
||||
return 4
|
||||
}
|
||||
|
||||
func (*hashForgetfulChain) StoreLookahead() uint {
|
||||
return 4
|
||||
}
|
||||
|
||||
/* HashBytes is the function that chooses the bucket to place the address in.*/
|
||||
func (h *hashForgetfulChain) HashBytes(data []byte) uint {
|
||||
var hash uint32 = binary.LittleEndian.Uint32(data) * kHashMul32
|
||||
|
||||
/* The higher bits contain more mixture from the multiplication,
|
||||
so we take our results from there. */
|
||||
return uint(hash >> (32 - h.bucketBits))
|
||||
}
|
||||
|
||||
type slot struct {
|
||||
delta uint16
|
||||
next uint16
|
||||
}
|
||||
|
||||
/* A (forgetful) hash table to the data seen by the compressor, to
|
||||
help create backward references to previous data.
|
||||
|
||||
Hashes are stored in chains which are bucketed to groups. Group of chains
|
||||
share a storage "bank". When more than "bank size" chain nodes are added,
|
||||
oldest nodes are replaced; this way several chains may share a tail. */
|
||||
type hashForgetfulChain struct {
|
||||
hasherCommon
|
||||
|
||||
bucketBits uint
|
||||
numBanks uint
|
||||
bankBits uint
|
||||
numLastDistancesToCheck int
|
||||
|
||||
addr []uint32
|
||||
head []uint16
|
||||
tiny_hash [65536]byte
|
||||
banks [][]slot
|
||||
free_slot_idx []uint16
|
||||
max_hops uint
|
||||
}
|
||||
|
||||
func (h *hashForgetfulChain) Initialize(params *encoderParams) {
|
||||
var q uint
|
||||
if params.quality > 6 {
|
||||
q = 7
|
||||
} else {
|
||||
q = 8
|
||||
}
|
||||
h.max_hops = q << uint(params.quality-4)
|
||||
|
||||
bankSize := 1 << h.bankBits
|
||||
bucketSize := 1 << h.bucketBits
|
||||
|
||||
h.addr = make([]uint32, bucketSize)
|
||||
h.head = make([]uint16, bucketSize)
|
||||
h.banks = make([][]slot, h.numBanks)
|
||||
for i := range h.banks {
|
||||
h.banks[i] = make([]slot, bankSize)
|
||||
}
|
||||
h.free_slot_idx = make([]uint16, h.numBanks)
|
||||
}
|
||||
|
||||
func (h *hashForgetfulChain) Prepare(one_shot bool, input_size uint, data []byte) {
|
||||
var partial_prepare_threshold uint = (1 << h.bucketBits) >> 6
|
||||
/* Partial preparation is 100 times slower (per socket). */
|
||||
if one_shot && input_size <= partial_prepare_threshold {
|
||||
var i uint
|
||||
for i = 0; i < input_size; i++ {
|
||||
var bucket uint = h.HashBytes(data[i:])
|
||||
|
||||
/* See InitEmpty comment. */
|
||||
h.addr[bucket] = 0xCCCCCCCC
|
||||
|
||||
h.head[bucket] = 0xCCCC
|
||||
}
|
||||
} else {
|
||||
/* Fill |addr| array with 0xCCCCCCCC value. Because of wrapping, position
|
||||
processed by hasher never reaches 3GB + 64M; this makes all new chains
|
||||
to be terminated after the first node. */
|
||||
for i := range h.addr {
|
||||
h.addr[i] = 0xCCCCCCCC
|
||||
}
|
||||
|
||||
for i := range h.head {
|
||||
h.head[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
h.tiny_hash = [65536]byte{}
|
||||
for i := range h.free_slot_idx {
|
||||
h.free_slot_idx[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
/* Look at 4 bytes at &data[ix & mask]. Compute a hash from these, and prepend
|
||||
node to corresponding chain; also update tiny_hash for current position. */
|
||||
func (h *hashForgetfulChain) Store(data []byte, mask uint, ix uint) {
|
||||
var key uint = h.HashBytes(data[ix&mask:])
|
||||
var bank uint = key & (h.numBanks - 1)
|
||||
idx := uint(h.free_slot_idx[bank]) & ((1 << h.bankBits) - 1)
|
||||
h.free_slot_idx[bank]++
|
||||
var delta uint = ix - uint(h.addr[key])
|
||||
h.tiny_hash[uint16(ix)] = byte(key)
|
||||
if delta > 0xFFFF {
|
||||
delta = 0xFFFF
|
||||
}
|
||||
h.banks[bank][idx].delta = uint16(delta)
|
||||
h.banks[bank][idx].next = h.head[key]
|
||||
h.addr[key] = uint32(ix)
|
||||
h.head[key] = uint16(idx)
|
||||
}
|
||||
|
||||
func (h *hashForgetfulChain) StoreRange(data []byte, mask uint, ix_start uint, ix_end uint) {
|
||||
var i uint
|
||||
for i = ix_start; i < ix_end; i++ {
|
||||
h.Store(data, mask, i)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *hashForgetfulChain) StitchToPreviousBlock(num_bytes uint, position uint, ringbuffer []byte, ring_buffer_mask uint) {
|
||||
if num_bytes >= h.HashTypeLength()-1 && position >= 3 {
|
||||
/* Prepare the hashes for three last bytes of the last write.
|
||||
These could not be calculated before, since they require knowledge
|
||||
of both the previous and the current block. */
|
||||
h.Store(ringbuffer, ring_buffer_mask, position-3)
|
||||
h.Store(ringbuffer, ring_buffer_mask, position-2)
|
||||
h.Store(ringbuffer, ring_buffer_mask, position-1)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *hashForgetfulChain) PrepareDistanceCache(distance_cache []int) {
|
||||
prepareDistanceCache(distance_cache, h.numLastDistancesToCheck)
|
||||
}
|
||||
|
||||
/* Find a longest backward match of &data[cur_ix] up to the length of
|
||||
max_length and stores the position cur_ix in the hash table.
|
||||
|
||||
REQUIRES: PrepareDistanceCachehashForgetfulChain must be invoked for current distance cache
|
||||
values; if this method is invoked repeatedly with the same distance
|
||||
cache values, it is enough to invoke PrepareDistanceCachehashForgetfulChain once.
|
||||
|
||||
Does not look for matches longer than max_length.
|
||||
Does not look for matches further away than max_backward.
|
||||
Writes the best match into |out|.
|
||||
|out|->score is updated only if a better match is found. */
|
||||
func (h *hashForgetfulChain) FindLongestMatch(dictionary *encoderDictionary, data []byte, ring_buffer_mask uint, distance_cache []int, cur_ix uint, max_length uint, max_backward uint, gap uint, max_distance uint, out *hasherSearchResult) {
|
||||
var cur_ix_masked uint = cur_ix & ring_buffer_mask
|
||||
var min_score uint = out.score
|
||||
var best_score uint = out.score
|
||||
var best_len uint = out.len
|
||||
var key uint = h.HashBytes(data[cur_ix_masked:])
|
||||
var tiny_hash byte = byte(key)
|
||||
/* Don't accept a short copy from far away. */
|
||||
out.len = 0
|
||||
|
||||
out.len_code_delta = 0
|
||||
|
||||
/* Try last distance first. */
|
||||
for i := 0; i < h.numLastDistancesToCheck; i++ {
|
||||
var backward uint = uint(distance_cache[i])
|
||||
var prev_ix uint = (cur_ix - backward)
|
||||
|
||||
/* For distance code 0 we want to consider 2-byte matches. */
|
||||
if i > 0 && h.tiny_hash[uint16(prev_ix)] != tiny_hash {
|
||||
continue
|
||||
}
|
||||
if prev_ix >= cur_ix || backward > max_backward {
|
||||
continue
|
||||
}
|
||||
|
||||
prev_ix &= ring_buffer_mask
|
||||
{
|
||||
var len uint = findMatchLengthWithLimit(data[prev_ix:], data[cur_ix_masked:], max_length)
|
||||
if len >= 2 {
|
||||
var score uint = backwardReferenceScoreUsingLastDistance(uint(len))
|
||||
if best_score < score {
|
||||
if i != 0 {
|
||||
score -= backwardReferencePenaltyUsingLastDistance(uint(i))
|
||||
}
|
||||
if best_score < score {
|
||||
best_score = score
|
||||
best_len = uint(len)
|
||||
out.len = best_len
|
||||
out.distance = backward
|
||||
out.score = best_score
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
var bank uint = key & (h.numBanks - 1)
|
||||
var backward uint = 0
|
||||
var hops uint = h.max_hops
|
||||
var delta uint = cur_ix - uint(h.addr[key])
|
||||
var slot uint = uint(h.head[key])
|
||||
for {
|
||||
tmp6 := hops
|
||||
hops--
|
||||
if tmp6 == 0 {
|
||||
break
|
||||
}
|
||||
var prev_ix uint
|
||||
var last uint = slot
|
||||
backward += delta
|
||||
if backward > max_backward {
|
||||
break
|
||||
}
|
||||
prev_ix = (cur_ix - backward) & ring_buffer_mask
|
||||
slot = uint(h.banks[bank][last].next)
|
||||
delta = uint(h.banks[bank][last].delta)
|
||||
if cur_ix_masked+best_len > ring_buffer_mask || prev_ix+best_len > ring_buffer_mask || data[cur_ix_masked+best_len] != data[prev_ix+best_len] {
|
||||
continue
|
||||
}
|
||||
{
|
||||
var len uint = findMatchLengthWithLimit(data[prev_ix:], data[cur_ix_masked:], max_length)
|
||||
if len >= 4 {
|
||||
/* Comparing for >= 3 does not change the semantics, but just saves
|
||||
for a few unnecessary binary logarithms in backward reference
|
||||
score, since we are not interested in such short matches. */
|
||||
var score uint = backwardReferenceScore(uint(len), backward)
|
||||
if best_score < score {
|
||||
best_score = score
|
||||
best_len = uint(len)
|
||||
out.len = best_len
|
||||
out.distance = backward
|
||||
out.score = best_score
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
h.Store(data, ring_buffer_mask, cur_ix)
|
||||
}
|
||||
|
||||
if out.score == min_score {
|
||||
searchInStaticDictionary(dictionary, h, data[cur_ix_masked:], max_length, max_backward+gap, max_distance, out, false)
|
||||
}
|
||||
}
|
||||
+214
@@ -0,0 +1,214 @@
|
||||
package brotli
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
/* Copyright 2010 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* For BUCKET_SWEEP == 1, enabling the dictionary lookup makes compression
|
||||
a little faster (0.5% - 1%) and it compresses 0.15% better on small text
|
||||
and HTML inputs. */
|
||||
|
||||
func (*hashLongestMatchQuickly) HashTypeLength() uint {
|
||||
return 8
|
||||
}
|
||||
|
||||
func (*hashLongestMatchQuickly) StoreLookahead() uint {
|
||||
return 8
|
||||
}
|
||||
|
||||
/* HashBytes is the function that chooses the bucket to place
|
||||
the address in. The HashLongestMatch and hashLongestMatchQuickly
|
||||
classes have separate, different implementations of hashing. */
|
||||
func (h *hashLongestMatchQuickly) HashBytes(data []byte) uint32 {
|
||||
var hash uint64 = ((binary.LittleEndian.Uint64(data) << (64 - 8*h.hashLen)) * kHashMul64)
|
||||
|
||||
/* The higher bits contain more mixture from the multiplication,
|
||||
so we take our results from there. */
|
||||
return uint32(hash >> (64 - h.bucketBits))
|
||||
}
|
||||
|
||||
/* A (forgetful) hash table to the data seen by the compressor, to
|
||||
help create backward references to previous data.
|
||||
|
||||
This is a hash map of fixed size (1 << 16). Starting from the
|
||||
given index, 1 buckets are used to store values of a key. */
|
||||
type hashLongestMatchQuickly struct {
|
||||
hasherCommon
|
||||
|
||||
bucketBits uint
|
||||
bucketSweep int
|
||||
hashLen uint
|
||||
useDictionary bool
|
||||
|
||||
buckets []uint32
|
||||
}
|
||||
|
||||
func (h *hashLongestMatchQuickly) Initialize(params *encoderParams) {
|
||||
h.buckets = make([]uint32, 1<<h.bucketBits+h.bucketSweep)
|
||||
}
|
||||
|
||||
func (h *hashLongestMatchQuickly) Prepare(one_shot bool, input_size uint, data []byte) {
|
||||
var partial_prepare_threshold uint = (4 << h.bucketBits) >> 7
|
||||
/* Partial preparation is 100 times slower (per socket). */
|
||||
if one_shot && input_size <= partial_prepare_threshold {
|
||||
var i uint
|
||||
for i = 0; i < input_size; i++ {
|
||||
var key uint32 = h.HashBytes(data[i:])
|
||||
for j := 0; j < h.bucketSweep; j++ {
|
||||
h.buckets[key+uint32(j)] = 0
|
||||
}
|
||||
}
|
||||
} else {
|
||||
/* It is not strictly necessary to fill this buffer here, but
|
||||
not filling will make the results of the compression stochastic
|
||||
(but correct). This is because random data would cause the
|
||||
system to find accidentally good backward references here and there. */
|
||||
for i := range h.buckets {
|
||||
h.buckets[i] = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Look at 5 bytes at &data[ix & mask].
|
||||
Compute a hash from these, and store the value somewhere within
|
||||
[ix .. ix+3]. */
|
||||
func (h *hashLongestMatchQuickly) Store(data []byte, mask uint, ix uint) {
|
||||
var key uint32 = h.HashBytes(data[ix&mask:])
|
||||
var off uint32 = uint32(ix>>3) % uint32(h.bucketSweep)
|
||||
/* Wiggle the value with the bucket sweep range. */
|
||||
h.buckets[key+off] = uint32(ix)
|
||||
}
|
||||
|
||||
func (h *hashLongestMatchQuickly) StoreRange(data []byte, mask uint, ix_start uint, ix_end uint) {
|
||||
var i uint
|
||||
for i = ix_start; i < ix_end; i++ {
|
||||
h.Store(data, mask, i)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *hashLongestMatchQuickly) StitchToPreviousBlock(num_bytes uint, position uint, ringbuffer []byte, ringbuffer_mask uint) {
|
||||
if num_bytes >= h.HashTypeLength()-1 && position >= 3 {
|
||||
/* Prepare the hashes for three last bytes of the last write.
|
||||
These could not be calculated before, since they require knowledge
|
||||
of both the previous and the current block. */
|
||||
h.Store(ringbuffer, ringbuffer_mask, position-3)
|
||||
h.Store(ringbuffer, ringbuffer_mask, position-2)
|
||||
h.Store(ringbuffer, ringbuffer_mask, position-1)
|
||||
}
|
||||
}
|
||||
|
||||
func (*hashLongestMatchQuickly) PrepareDistanceCache(distance_cache []int) {
|
||||
}
|
||||
|
||||
/* Find a longest backward match of &data[cur_ix & ring_buffer_mask]
|
||||
up to the length of max_length and stores the position cur_ix in the
|
||||
hash table.
|
||||
|
||||
Does not look for matches longer than max_length.
|
||||
Does not look for matches further away than max_backward.
|
||||
Writes the best match into |out|.
|
||||
|out|->score is updated only if a better match is found. */
|
||||
func (h *hashLongestMatchQuickly) FindLongestMatch(dictionary *encoderDictionary, data []byte, ring_buffer_mask uint, distance_cache []int, cur_ix uint, max_length uint, max_backward uint, gap uint, max_distance uint, out *hasherSearchResult) {
|
||||
var best_len_in uint = out.len
|
||||
var cur_ix_masked uint = cur_ix & ring_buffer_mask
|
||||
var key uint32 = h.HashBytes(data[cur_ix_masked:])
|
||||
var compare_char int = int(data[cur_ix_masked+best_len_in])
|
||||
var min_score uint = out.score
|
||||
var best_score uint = out.score
|
||||
var best_len uint = best_len_in
|
||||
var cached_backward uint = uint(distance_cache[0])
|
||||
var prev_ix uint = cur_ix - cached_backward
|
||||
var bucket []uint32
|
||||
out.len_code_delta = 0
|
||||
if prev_ix < cur_ix {
|
||||
prev_ix &= uint(uint32(ring_buffer_mask))
|
||||
if compare_char == int(data[prev_ix+best_len]) {
|
||||
var len uint = findMatchLengthWithLimit(data[prev_ix:], data[cur_ix_masked:], max_length)
|
||||
if len >= 4 {
|
||||
var score uint = backwardReferenceScoreUsingLastDistance(uint(len))
|
||||
if best_score < score {
|
||||
best_score = score
|
||||
best_len = uint(len)
|
||||
out.len = uint(len)
|
||||
out.distance = cached_backward
|
||||
out.score = best_score
|
||||
compare_char = int(data[cur_ix_masked+best_len])
|
||||
if h.bucketSweep == 1 {
|
||||
h.buckets[key] = uint32(cur_ix)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if h.bucketSweep == 1 {
|
||||
var backward uint
|
||||
var len uint
|
||||
|
||||
/* Only one to look for, don't bother to prepare for a loop. */
|
||||
prev_ix = uint(h.buckets[key])
|
||||
|
||||
h.buckets[key] = uint32(cur_ix)
|
||||
backward = cur_ix - prev_ix
|
||||
prev_ix &= uint(uint32(ring_buffer_mask))
|
||||
if compare_char != int(data[prev_ix+best_len_in]) {
|
||||
return
|
||||
}
|
||||
|
||||
if backward == 0 || backward > max_backward {
|
||||
return
|
||||
}
|
||||
|
||||
len = findMatchLengthWithLimit(data[prev_ix:], data[cur_ix_masked:], max_length)
|
||||
if len >= 4 {
|
||||
var score uint = backwardReferenceScore(uint(len), backward)
|
||||
if best_score < score {
|
||||
out.len = uint(len)
|
||||
out.distance = backward
|
||||
out.score = score
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
bucket = h.buckets[key:]
|
||||
var i int
|
||||
prev_ix = uint(bucket[0])
|
||||
bucket = bucket[1:]
|
||||
for i = 0; i < h.bucketSweep; (func() { i++; tmp3 := bucket; bucket = bucket[1:]; prev_ix = uint(tmp3[0]) })() {
|
||||
var backward uint = cur_ix - prev_ix
|
||||
var len uint
|
||||
prev_ix &= uint(uint32(ring_buffer_mask))
|
||||
if compare_char != int(data[prev_ix+best_len]) {
|
||||
continue
|
||||
}
|
||||
|
||||
if backward == 0 || backward > max_backward {
|
||||
continue
|
||||
}
|
||||
|
||||
len = findMatchLengthWithLimit(data[prev_ix:], data[cur_ix_masked:], max_length)
|
||||
if len >= 4 {
|
||||
var score uint = backwardReferenceScore(uint(len), backward)
|
||||
if best_score < score {
|
||||
best_score = score
|
||||
best_len = uint(len)
|
||||
out.len = best_len
|
||||
out.distance = backward
|
||||
out.score = score
|
||||
compare_char = int(data[cur_ix_masked+best_len])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if h.useDictionary && min_score == out.score {
|
||||
searchInStaticDictionary(dictionary, h, data[cur_ix_masked:], max_length, max_backward+gap, max_distance, out, true)
|
||||
}
|
||||
|
||||
h.buckets[key+uint32((cur_ix>>3)%uint(h.bucketSweep))] = uint32(cur_ix)
|
||||
}
|
||||
+168
@@ -0,0 +1,168 @@
|
||||
package brotli
|
||||
|
||||
/* Copyright 2018 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* NOTE: this hasher does not search in the dictionary. It is used as
|
||||
backup-hasher, the main hasher already searches in it. */
|
||||
|
||||
const kRollingHashMul32 uint32 = 69069
|
||||
|
||||
const kInvalidPosHashRolling uint32 = 0xffffffff
|
||||
|
||||
/* This hasher uses a longer forward length, but returning a higher value here
|
||||
will hurt compression by the main hasher when combined with a composite
|
||||
hasher. The hasher tests for forward itself instead. */
|
||||
func (*hashRolling) HashTypeLength() uint {
|
||||
return 4
|
||||
}
|
||||
|
||||
func (*hashRolling) StoreLookahead() uint {
|
||||
return 4
|
||||
}
|
||||
|
||||
/* Computes a code from a single byte. A lookup table of 256 values could be
|
||||
used, but simply adding 1 works about as good. */
|
||||
func (*hashRolling) HashByte(b byte) uint32 {
|
||||
return uint32(b) + 1
|
||||
}
|
||||
|
||||
func (h *hashRolling) HashRollingFunctionInitial(state uint32, add byte, factor uint32) uint32 {
|
||||
return uint32(factor*state + h.HashByte(add))
|
||||
}
|
||||
|
||||
func (h *hashRolling) HashRollingFunction(state uint32, add byte, rem byte, factor uint32, factor_remove uint32) uint32 {
|
||||
return uint32(factor*state + h.HashByte(add) - factor_remove*h.HashByte(rem))
|
||||
}
|
||||
|
||||
/* Rolling hash for long distance long string matches. Stores one position
|
||||
per bucket, bucket key is computed over a long region. */
|
||||
type hashRolling struct {
|
||||
hasherCommon
|
||||
|
||||
jump int
|
||||
|
||||
state uint32
|
||||
table []uint32
|
||||
next_ix uint
|
||||
factor uint32
|
||||
factor_remove uint32
|
||||
}
|
||||
|
||||
func (h *hashRolling) Initialize(params *encoderParams) {
|
||||
h.state = 0
|
||||
h.next_ix = 0
|
||||
|
||||
h.factor = kRollingHashMul32
|
||||
|
||||
/* Compute the factor of the oldest byte to remove: factor**steps modulo
|
||||
0xffffffff (the multiplications rely on 32-bit overflow) */
|
||||
h.factor_remove = 1
|
||||
|
||||
for i := 0; i < 32; i += h.jump {
|
||||
h.factor_remove *= h.factor
|
||||
}
|
||||
|
||||
h.table = make([]uint32, 16777216)
|
||||
for i := 0; i < 16777216; i++ {
|
||||
h.table[i] = kInvalidPosHashRolling
|
||||
}
|
||||
}
|
||||
|
||||
func (h *hashRolling) Prepare(one_shot bool, input_size uint, data []byte) {
|
||||
/* Too small size, cannot use this hasher. */
|
||||
if input_size < 32 {
|
||||
return
|
||||
}
|
||||
h.state = 0
|
||||
for i := 0; i < 32; i += h.jump {
|
||||
h.state = h.HashRollingFunctionInitial(h.state, data[i], h.factor)
|
||||
}
|
||||
}
|
||||
|
||||
func (*hashRolling) Store(data []byte, mask uint, ix uint) {
|
||||
}
|
||||
|
||||
func (*hashRolling) StoreRange(data []byte, mask uint, ix_start uint, ix_end uint) {
|
||||
}
|
||||
|
||||
func (h *hashRolling) StitchToPreviousBlock(num_bytes uint, position uint, ringbuffer []byte, ring_buffer_mask uint) {
|
||||
var position_masked uint
|
||||
/* In this case we must re-initialize the hasher from scratch from the
|
||||
current position. */
|
||||
|
||||
var available uint = num_bytes
|
||||
if position&uint(h.jump-1) != 0 {
|
||||
var diff uint = uint(h.jump) - (position & uint(h.jump-1))
|
||||
if diff > available {
|
||||
available = 0
|
||||
} else {
|
||||
available = available - diff
|
||||
}
|
||||
position += diff
|
||||
}
|
||||
|
||||
position_masked = position & ring_buffer_mask
|
||||
|
||||
/* wrapping around ringbuffer not handled. */
|
||||
if available > ring_buffer_mask-position_masked {
|
||||
available = ring_buffer_mask - position_masked
|
||||
}
|
||||
|
||||
h.Prepare(false, available, ringbuffer[position&ring_buffer_mask:])
|
||||
h.next_ix = position
|
||||
}
|
||||
|
||||
func (*hashRolling) PrepareDistanceCache(distance_cache []int) {
|
||||
}
|
||||
|
||||
func (h *hashRolling) FindLongestMatch(dictionary *encoderDictionary, data []byte, ring_buffer_mask uint, distance_cache []int, cur_ix uint, max_length uint, max_backward uint, gap uint, max_distance uint, out *hasherSearchResult) {
|
||||
var cur_ix_masked uint = cur_ix & ring_buffer_mask
|
||||
var pos uint = h.next_ix
|
||||
|
||||
if cur_ix&uint(h.jump-1) != 0 {
|
||||
return
|
||||
}
|
||||
|
||||
/* Not enough lookahead */
|
||||
if max_length < 32 {
|
||||
return
|
||||
}
|
||||
|
||||
for pos = h.next_ix; pos <= cur_ix; pos += uint(h.jump) {
|
||||
var code uint32 = h.state & ((16777216 * 64) - 1)
|
||||
var rem byte = data[pos&ring_buffer_mask]
|
||||
var add byte = data[(pos+32)&ring_buffer_mask]
|
||||
var found_ix uint = uint(kInvalidPosHashRolling)
|
||||
|
||||
h.state = h.HashRollingFunction(h.state, add, rem, h.factor, h.factor_remove)
|
||||
|
||||
if code < 16777216 {
|
||||
found_ix = uint(h.table[code])
|
||||
h.table[code] = uint32(pos)
|
||||
if pos == cur_ix && uint32(found_ix) != kInvalidPosHashRolling {
|
||||
/* The cast to 32-bit makes backward distances up to 4GB work even
|
||||
if cur_ix is above 4GB, despite using 32-bit values in the table. */
|
||||
var backward uint = uint(uint32(cur_ix - found_ix))
|
||||
if backward <= max_backward {
|
||||
var found_ix_masked uint = found_ix & ring_buffer_mask
|
||||
var len uint = findMatchLengthWithLimit(data[found_ix_masked:], data[cur_ix_masked:], max_length)
|
||||
if len >= 4 && len > out.len {
|
||||
var score uint = backwardReferenceScore(uint(len), backward)
|
||||
if score > out.score {
|
||||
out.len = uint(len)
|
||||
out.distance = backward
|
||||
out.score = score
|
||||
out.len_code_delta = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
h.next_ix = cur_ix + uint(h.jump)
|
||||
}
|
||||
+226
@@ -0,0 +1,226 @@
|
||||
package brotli
|
||||
|
||||
import "math"
|
||||
|
||||
/* The distance symbols effectively used by "Large Window Brotli" (32-bit). */
|
||||
const numHistogramDistanceSymbols = 544
|
||||
|
||||
type histogramLiteral struct {
|
||||
data_ [numLiteralSymbols]uint32
|
||||
total_count_ uint
|
||||
bit_cost_ float64
|
||||
}
|
||||
|
||||
func histogramClearLiteral(self *histogramLiteral) {
|
||||
self.data_ = [numLiteralSymbols]uint32{}
|
||||
self.total_count_ = 0
|
||||
self.bit_cost_ = math.MaxFloat64
|
||||
}
|
||||
|
||||
func clearHistogramsLiteral(array []histogramLiteral, length uint) {
|
||||
var i uint
|
||||
for i = 0; i < length; i++ {
|
||||
histogramClearLiteral(&array[i:][0])
|
||||
}
|
||||
}
|
||||
|
||||
func histogramAddLiteral(self *histogramLiteral, val uint) {
|
||||
self.data_[val]++
|
||||
self.total_count_++
|
||||
}
|
||||
|
||||
func histogramAddVectorLiteral(self *histogramLiteral, p []byte, n uint) {
|
||||
self.total_count_ += n
|
||||
n += 1
|
||||
for {
|
||||
n--
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
self.data_[p[0]]++
|
||||
p = p[1:]
|
||||
}
|
||||
}
|
||||
|
||||
func histogramAddHistogramLiteral(self *histogramLiteral, v *histogramLiteral) {
|
||||
var i uint
|
||||
self.total_count_ += v.total_count_
|
||||
for i = 0; i < numLiteralSymbols; i++ {
|
||||
self.data_[i] += v.data_[i]
|
||||
}
|
||||
}
|
||||
|
||||
func histogramDataSizeLiteral() uint {
|
||||
return numLiteralSymbols
|
||||
}
|
||||
|
||||
type histogramCommand struct {
|
||||
data_ [numCommandSymbols]uint32
|
||||
total_count_ uint
|
||||
bit_cost_ float64
|
||||
}
|
||||
|
||||
func histogramClearCommand(self *histogramCommand) {
|
||||
self.data_ = [numCommandSymbols]uint32{}
|
||||
self.total_count_ = 0
|
||||
self.bit_cost_ = math.MaxFloat64
|
||||
}
|
||||
|
||||
func clearHistogramsCommand(array []histogramCommand, length uint) {
|
||||
var i uint
|
||||
for i = 0; i < length; i++ {
|
||||
histogramClearCommand(&array[i:][0])
|
||||
}
|
||||
}
|
||||
|
||||
func histogramAddCommand(self *histogramCommand, val uint) {
|
||||
self.data_[val]++
|
||||
self.total_count_++
|
||||
}
|
||||
|
||||
func histogramAddVectorCommand(self *histogramCommand, p []uint16, n uint) {
|
||||
self.total_count_ += n
|
||||
n += 1
|
||||
for {
|
||||
n--
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
self.data_[p[0]]++
|
||||
p = p[1:]
|
||||
}
|
||||
}
|
||||
|
||||
func histogramAddHistogramCommand(self *histogramCommand, v *histogramCommand) {
|
||||
var i uint
|
||||
self.total_count_ += v.total_count_
|
||||
for i = 0; i < numCommandSymbols; i++ {
|
||||
self.data_[i] += v.data_[i]
|
||||
}
|
||||
}
|
||||
|
||||
func histogramDataSizeCommand() uint {
|
||||
return numCommandSymbols
|
||||
}
|
||||
|
||||
type histogramDistance struct {
|
||||
data_ [numDistanceSymbols]uint32
|
||||
total_count_ uint
|
||||
bit_cost_ float64
|
||||
}
|
||||
|
||||
func histogramClearDistance(self *histogramDistance) {
|
||||
self.data_ = [numDistanceSymbols]uint32{}
|
||||
self.total_count_ = 0
|
||||
self.bit_cost_ = math.MaxFloat64
|
||||
}
|
||||
|
||||
func clearHistogramsDistance(array []histogramDistance, length uint) {
|
||||
var i uint
|
||||
for i = 0; i < length; i++ {
|
||||
histogramClearDistance(&array[i:][0])
|
||||
}
|
||||
}
|
||||
|
||||
func histogramAddDistance(self *histogramDistance, val uint) {
|
||||
self.data_[val]++
|
||||
self.total_count_++
|
||||
}
|
||||
|
||||
func histogramAddVectorDistance(self *histogramDistance, p []uint16, n uint) {
|
||||
self.total_count_ += n
|
||||
n += 1
|
||||
for {
|
||||
n--
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
self.data_[p[0]]++
|
||||
p = p[1:]
|
||||
}
|
||||
}
|
||||
|
||||
func histogramAddHistogramDistance(self *histogramDistance, v *histogramDistance) {
|
||||
var i uint
|
||||
self.total_count_ += v.total_count_
|
||||
for i = 0; i < numDistanceSymbols; i++ {
|
||||
self.data_[i] += v.data_[i]
|
||||
}
|
||||
}
|
||||
|
||||
func histogramDataSizeDistance() uint {
|
||||
return numDistanceSymbols
|
||||
}
|
||||
|
||||
type blockSplitIterator struct {
|
||||
split_ *blockSplit
|
||||
idx_ uint
|
||||
type_ uint
|
||||
length_ uint
|
||||
}
|
||||
|
||||
func initBlockSplitIterator(self *blockSplitIterator, split *blockSplit) {
|
||||
self.split_ = split
|
||||
self.idx_ = 0
|
||||
self.type_ = 0
|
||||
if len(split.lengths) > 0 {
|
||||
self.length_ = uint(split.lengths[0])
|
||||
} else {
|
||||
self.length_ = 0
|
||||
}
|
||||
}
|
||||
|
||||
func blockSplitIteratorNext(self *blockSplitIterator) {
|
||||
if self.length_ == 0 {
|
||||
self.idx_++
|
||||
self.type_ = uint(self.split_.types[self.idx_])
|
||||
self.length_ = uint(self.split_.lengths[self.idx_])
|
||||
}
|
||||
|
||||
self.length_--
|
||||
}
|
||||
|
||||
func buildHistogramsWithContext(cmds []command, literal_split *blockSplit, insert_and_copy_split *blockSplit, dist_split *blockSplit, ringbuffer []byte, start_pos uint, mask uint, prev_byte byte, prev_byte2 byte, context_modes []int, literal_histograms []histogramLiteral, insert_and_copy_histograms []histogramCommand, copy_dist_histograms []histogramDistance) {
|
||||
var pos uint = start_pos
|
||||
var literal_it blockSplitIterator
|
||||
var insert_and_copy_it blockSplitIterator
|
||||
var dist_it blockSplitIterator
|
||||
|
||||
initBlockSplitIterator(&literal_it, literal_split)
|
||||
initBlockSplitIterator(&insert_and_copy_it, insert_and_copy_split)
|
||||
initBlockSplitIterator(&dist_it, dist_split)
|
||||
for i := range cmds {
|
||||
var cmd *command = &cmds[i]
|
||||
var j uint
|
||||
blockSplitIteratorNext(&insert_and_copy_it)
|
||||
histogramAddCommand(&insert_and_copy_histograms[insert_and_copy_it.type_], uint(cmd.cmd_prefix_))
|
||||
|
||||
/* TODO: unwrap iterator blocks. */
|
||||
for j = uint(cmd.insert_len_); j != 0; j-- {
|
||||
var context uint
|
||||
blockSplitIteratorNext(&literal_it)
|
||||
context = literal_it.type_
|
||||
if context_modes != nil {
|
||||
var lut contextLUT = getContextLUT(context_modes[context])
|
||||
context = (context << literalContextBits) + uint(getContext(prev_byte, prev_byte2, lut))
|
||||
}
|
||||
|
||||
histogramAddLiteral(&literal_histograms[context], uint(ringbuffer[pos&mask]))
|
||||
prev_byte2 = prev_byte
|
||||
prev_byte = ringbuffer[pos&mask]
|
||||
pos++
|
||||
}
|
||||
|
||||
pos += uint(commandCopyLen(cmd))
|
||||
if commandCopyLen(cmd) != 0 {
|
||||
prev_byte2 = ringbuffer[(pos-2)&mask]
|
||||
prev_byte = ringbuffer[(pos-1)&mask]
|
||||
if cmd.cmd_prefix_ >= 128 {
|
||||
var context uint
|
||||
blockSplitIteratorNext(&dist_it)
|
||||
context = uint(uint32(dist_it.type_<<distanceContextBits) + commandDistanceContext(cmd))
|
||||
histogramAddDistance(©_dist_histograms[context], uint(cmd.dist_prefix_)&0x3FF)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
+184
@@ -0,0 +1,184 @@
|
||||
package brotli
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// HTTPCompressor chooses a compression method (brotli, gzip, or none) based on
|
||||
// the Accept-Encoding header, sets the Content-Encoding header, and returns a
|
||||
// WriteCloser that implements that compression. The Close method must be called
|
||||
// before the current HTTP handler returns.
|
||||
func HTTPCompressor(w http.ResponseWriter, r *http.Request) io.WriteCloser {
|
||||
if w.Header().Get("Vary") == "" {
|
||||
w.Header().Set("Vary", "Accept-Encoding")
|
||||
}
|
||||
|
||||
encoding := negotiateContentEncoding(r, []string{"br", "gzip"})
|
||||
switch encoding {
|
||||
case "br":
|
||||
w.Header().Set("Content-Encoding", "br")
|
||||
return NewWriter(w)
|
||||
case "gzip":
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
return gzip.NewWriter(w)
|
||||
}
|
||||
return nopCloser{w}
|
||||
}
|
||||
|
||||
// negotiateContentEncoding returns the best offered content encoding for the
|
||||
// request's Accept-Encoding header. If two offers match with equal weight and
|
||||
// then the offer earlier in the list is preferred. If no offers are
|
||||
// acceptable, then "" is returned.
|
||||
func negotiateContentEncoding(r *http.Request, offers []string) string {
|
||||
bestOffer := "identity"
|
||||
bestQ := -1.0
|
||||
specs := parseAccept(r.Header, "Accept-Encoding")
|
||||
for _, offer := range offers {
|
||||
for _, spec := range specs {
|
||||
if spec.Q > bestQ &&
|
||||
(spec.Value == "*" || spec.Value == offer) {
|
||||
bestQ = spec.Q
|
||||
bestOffer = offer
|
||||
}
|
||||
}
|
||||
}
|
||||
if bestQ == 0 {
|
||||
bestOffer = ""
|
||||
}
|
||||
return bestOffer
|
||||
}
|
||||
|
||||
// acceptSpec describes an Accept* header.
|
||||
type acceptSpec struct {
|
||||
Value string
|
||||
Q float64
|
||||
}
|
||||
|
||||
// parseAccept parses Accept* headers.
|
||||
func parseAccept(header http.Header, key string) (specs []acceptSpec) {
|
||||
loop:
|
||||
for _, s := range header[key] {
|
||||
for {
|
||||
var spec acceptSpec
|
||||
spec.Value, s = expectTokenSlash(s)
|
||||
if spec.Value == "" {
|
||||
continue loop
|
||||
}
|
||||
spec.Q = 1.0
|
||||
s = skipSpace(s)
|
||||
if strings.HasPrefix(s, ";") {
|
||||
s = skipSpace(s[1:])
|
||||
if !strings.HasPrefix(s, "q=") {
|
||||
continue loop
|
||||
}
|
||||
spec.Q, s = expectQuality(s[2:])
|
||||
if spec.Q < 0.0 {
|
||||
continue loop
|
||||
}
|
||||
}
|
||||
specs = append(specs, spec)
|
||||
s = skipSpace(s)
|
||||
if !strings.HasPrefix(s, ",") {
|
||||
continue loop
|
||||
}
|
||||
s = skipSpace(s[1:])
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func skipSpace(s string) (rest string) {
|
||||
i := 0
|
||||
for ; i < len(s); i++ {
|
||||
if octetTypes[s[i]]&isSpace == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return s[i:]
|
||||
}
|
||||
|
||||
func expectTokenSlash(s string) (token, rest string) {
|
||||
i := 0
|
||||
for ; i < len(s); i++ {
|
||||
b := s[i]
|
||||
if (octetTypes[b]&isToken == 0) && b != '/' {
|
||||
break
|
||||
}
|
||||
}
|
||||
return s[:i], s[i:]
|
||||
}
|
||||
|
||||
func expectQuality(s string) (q float64, rest string) {
|
||||
switch {
|
||||
case len(s) == 0:
|
||||
return -1, ""
|
||||
case s[0] == '0':
|
||||
q = 0
|
||||
case s[0] == '1':
|
||||
q = 1
|
||||
default:
|
||||
return -1, ""
|
||||
}
|
||||
s = s[1:]
|
||||
if !strings.HasPrefix(s, ".") {
|
||||
return q, s
|
||||
}
|
||||
s = s[1:]
|
||||
i := 0
|
||||
n := 0
|
||||
d := 1
|
||||
for ; i < len(s); i++ {
|
||||
b := s[i]
|
||||
if b < '0' || b > '9' {
|
||||
break
|
||||
}
|
||||
n = n*10 + int(b) - '0'
|
||||
d *= 10
|
||||
}
|
||||
return q + float64(n)/float64(d), s[i:]
|
||||
}
|
||||
|
||||
// Octet types from RFC 2616.
|
||||
var octetTypes [256]octetType
|
||||
|
||||
type octetType byte
|
||||
|
||||
const (
|
||||
isToken octetType = 1 << iota
|
||||
isSpace
|
||||
)
|
||||
|
||||
func init() {
|
||||
// OCTET = <any 8-bit sequence of data>
|
||||
// CHAR = <any US-ASCII character (octets 0 - 127)>
|
||||
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
|
||||
// CR = <US-ASCII CR, carriage return (13)>
|
||||
// LF = <US-ASCII LF, linefeed (10)>
|
||||
// SP = <US-ASCII SP, space (32)>
|
||||
// HT = <US-ASCII HT, horizontal-tab (9)>
|
||||
// <"> = <US-ASCII double-quote mark (34)>
|
||||
// CRLF = CR LF
|
||||
// LWS = [CRLF] 1*( SP | HT )
|
||||
// TEXT = <any OCTET except CTLs, but including LWS>
|
||||
// separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <">
|
||||
// | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT
|
||||
// token = 1*<any CHAR except CTLs or separators>
|
||||
// qdtext = <any TEXT except <">>
|
||||
|
||||
for c := 0; c < 256; c++ {
|
||||
var t octetType
|
||||
isCtl := c <= 31 || c == 127
|
||||
isChar := 0 <= c && c <= 127
|
||||
isSeparator := strings.ContainsRune(" \t\"(),/:;<=>?@[]\\{}", rune(c))
|
||||
if strings.ContainsRune(" \t\r\n", rune(c)) {
|
||||
t |= isSpace
|
||||
}
|
||||
if isChar && !isCtl && !isSeparator {
|
||||
t |= isToken
|
||||
}
|
||||
octetTypes[c] = t
|
||||
}
|
||||
}
|
||||
+653
@@ -0,0 +1,653 @@
|
||||
package brotli
|
||||
|
||||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Utilities for building Huffman decoding tables. */
|
||||
|
||||
const huffmanMaxCodeLength = 15
|
||||
|
||||
/* Maximum possible Huffman table size for an alphabet size of (index * 32),
|
||||
max code length 15 and root table bits 8. */
|
||||
var kMaxHuffmanTableSize = []uint16{
|
||||
256,
|
||||
402,
|
||||
436,
|
||||
468,
|
||||
500,
|
||||
534,
|
||||
566,
|
||||
598,
|
||||
630,
|
||||
662,
|
||||
694,
|
||||
726,
|
||||
758,
|
||||
790,
|
||||
822,
|
||||
854,
|
||||
886,
|
||||
920,
|
||||
952,
|
||||
984,
|
||||
1016,
|
||||
1048,
|
||||
1080,
|
||||
1112,
|
||||
1144,
|
||||
1176,
|
||||
1208,
|
||||
1240,
|
||||
1272,
|
||||
1304,
|
||||
1336,
|
||||
1368,
|
||||
1400,
|
||||
1432,
|
||||
1464,
|
||||
1496,
|
||||
1528,
|
||||
}
|
||||
|
||||
/* BROTLI_NUM_BLOCK_LEN_SYMBOLS == 26 */
|
||||
const huffmanMaxSize26 = 396
|
||||
|
||||
/* BROTLI_MAX_BLOCK_TYPE_SYMBOLS == 258 */
|
||||
const huffmanMaxSize258 = 632
|
||||
|
||||
/* BROTLI_MAX_CONTEXT_MAP_SYMBOLS == 272 */
|
||||
const huffmanMaxSize272 = 646
|
||||
|
||||
const huffmanMaxCodeLengthCodeLength = 5
|
||||
|
||||
/* Do not create this struct directly - use the ConstructHuffmanCode
|
||||
* constructor below! */
|
||||
type huffmanCode struct {
|
||||
bits byte
|
||||
value uint16
|
||||
}
|
||||
|
||||
func constructHuffmanCode(bits byte, value uint16) huffmanCode {
|
||||
var h huffmanCode
|
||||
h.bits = bits
|
||||
h.value = value
|
||||
return h
|
||||
}
|
||||
|
||||
/* Builds Huffman lookup table assuming code lengths are in symbol order. */
|
||||
|
||||
/* Builds Huffman lookup table assuming code lengths are in symbol order.
|
||||
Returns size of resulting table. */
|
||||
|
||||
/* Builds a simple Huffman table. The |num_symbols| parameter is to be
|
||||
interpreted as follows: 0 means 1 symbol, 1 means 2 symbols,
|
||||
2 means 3 symbols, 3 means 4 symbols with lengths [2, 2, 2, 2],
|
||||
4 means 4 symbols with lengths [1, 2, 3, 3]. */
|
||||
|
||||
/* Contains a collection of Huffman trees with the same alphabet size. */
|
||||
/* max_symbol is needed due to simple codes since log2(alphabet_size) could be
|
||||
greater than log2(max_symbol). */
|
||||
type huffmanTreeGroup struct {
|
||||
htrees [][]huffmanCode
|
||||
codes []huffmanCode
|
||||
alphabet_size uint16
|
||||
max_symbol uint16
|
||||
num_htrees uint16
|
||||
}
|
||||
|
||||
const reverseBitsMax = 8
|
||||
|
||||
const reverseBitsBase = 0
|
||||
|
||||
var kReverseBits = [1 << reverseBitsMax]byte{
|
||||
0x00,
|
||||
0x80,
|
||||
0x40,
|
||||
0xC0,
|
||||
0x20,
|
||||
0xA0,
|
||||
0x60,
|
||||
0xE0,
|
||||
0x10,
|
||||
0x90,
|
||||
0x50,
|
||||
0xD0,
|
||||
0x30,
|
||||
0xB0,
|
||||
0x70,
|
||||
0xF0,
|
||||
0x08,
|
||||
0x88,
|
||||
0x48,
|
||||
0xC8,
|
||||
0x28,
|
||||
0xA8,
|
||||
0x68,
|
||||
0xE8,
|
||||
0x18,
|
||||
0x98,
|
||||
0x58,
|
||||
0xD8,
|
||||
0x38,
|
||||
0xB8,
|
||||
0x78,
|
||||
0xF8,
|
||||
0x04,
|
||||
0x84,
|
||||
0x44,
|
||||
0xC4,
|
||||
0x24,
|
||||
0xA4,
|
||||
0x64,
|
||||
0xE4,
|
||||
0x14,
|
||||
0x94,
|
||||
0x54,
|
||||
0xD4,
|
||||
0x34,
|
||||
0xB4,
|
||||
0x74,
|
||||
0xF4,
|
||||
0x0C,
|
||||
0x8C,
|
||||
0x4C,
|
||||
0xCC,
|
||||
0x2C,
|
||||
0xAC,
|
||||
0x6C,
|
||||
0xEC,
|
||||
0x1C,
|
||||
0x9C,
|
||||
0x5C,
|
||||
0xDC,
|
||||
0x3C,
|
||||
0xBC,
|
||||
0x7C,
|
||||
0xFC,
|
||||
0x02,
|
||||
0x82,
|
||||
0x42,
|
||||
0xC2,
|
||||
0x22,
|
||||
0xA2,
|
||||
0x62,
|
||||
0xE2,
|
||||
0x12,
|
||||
0x92,
|
||||
0x52,
|
||||
0xD2,
|
||||
0x32,
|
||||
0xB2,
|
||||
0x72,
|
||||
0xF2,
|
||||
0x0A,
|
||||
0x8A,
|
||||
0x4A,
|
||||
0xCA,
|
||||
0x2A,
|
||||
0xAA,
|
||||
0x6A,
|
||||
0xEA,
|
||||
0x1A,
|
||||
0x9A,
|
||||
0x5A,
|
||||
0xDA,
|
||||
0x3A,
|
||||
0xBA,
|
||||
0x7A,
|
||||
0xFA,
|
||||
0x06,
|
||||
0x86,
|
||||
0x46,
|
||||
0xC6,
|
||||
0x26,
|
||||
0xA6,
|
||||
0x66,
|
||||
0xE6,
|
||||
0x16,
|
||||
0x96,
|
||||
0x56,
|
||||
0xD6,
|
||||
0x36,
|
||||
0xB6,
|
||||
0x76,
|
||||
0xF6,
|
||||
0x0E,
|
||||
0x8E,
|
||||
0x4E,
|
||||
0xCE,
|
||||
0x2E,
|
||||
0xAE,
|
||||
0x6E,
|
||||
0xEE,
|
||||
0x1E,
|
||||
0x9E,
|
||||
0x5E,
|
||||
0xDE,
|
||||
0x3E,
|
||||
0xBE,
|
||||
0x7E,
|
||||
0xFE,
|
||||
0x01,
|
||||
0x81,
|
||||
0x41,
|
||||
0xC1,
|
||||
0x21,
|
||||
0xA1,
|
||||
0x61,
|
||||
0xE1,
|
||||
0x11,
|
||||
0x91,
|
||||
0x51,
|
||||
0xD1,
|
||||
0x31,
|
||||
0xB1,
|
||||
0x71,
|
||||
0xF1,
|
||||
0x09,
|
||||
0x89,
|
||||
0x49,
|
||||
0xC9,
|
||||
0x29,
|
||||
0xA9,
|
||||
0x69,
|
||||
0xE9,
|
||||
0x19,
|
||||
0x99,
|
||||
0x59,
|
||||
0xD9,
|
||||
0x39,
|
||||
0xB9,
|
||||
0x79,
|
||||
0xF9,
|
||||
0x05,
|
||||
0x85,
|
||||
0x45,
|
||||
0xC5,
|
||||
0x25,
|
||||
0xA5,
|
||||
0x65,
|
||||
0xE5,
|
||||
0x15,
|
||||
0x95,
|
||||
0x55,
|
||||
0xD5,
|
||||
0x35,
|
||||
0xB5,
|
||||
0x75,
|
||||
0xF5,
|
||||
0x0D,
|
||||
0x8D,
|
||||
0x4D,
|
||||
0xCD,
|
||||
0x2D,
|
||||
0xAD,
|
||||
0x6D,
|
||||
0xED,
|
||||
0x1D,
|
||||
0x9D,
|
||||
0x5D,
|
||||
0xDD,
|
||||
0x3D,
|
||||
0xBD,
|
||||
0x7D,
|
||||
0xFD,
|
||||
0x03,
|
||||
0x83,
|
||||
0x43,
|
||||
0xC3,
|
||||
0x23,
|
||||
0xA3,
|
||||
0x63,
|
||||
0xE3,
|
||||
0x13,
|
||||
0x93,
|
||||
0x53,
|
||||
0xD3,
|
||||
0x33,
|
||||
0xB3,
|
||||
0x73,
|
||||
0xF3,
|
||||
0x0B,
|
||||
0x8B,
|
||||
0x4B,
|
||||
0xCB,
|
||||
0x2B,
|
||||
0xAB,
|
||||
0x6B,
|
||||
0xEB,
|
||||
0x1B,
|
||||
0x9B,
|
||||
0x5B,
|
||||
0xDB,
|
||||
0x3B,
|
||||
0xBB,
|
||||
0x7B,
|
||||
0xFB,
|
||||
0x07,
|
||||
0x87,
|
||||
0x47,
|
||||
0xC7,
|
||||
0x27,
|
||||
0xA7,
|
||||
0x67,
|
||||
0xE7,
|
||||
0x17,
|
||||
0x97,
|
||||
0x57,
|
||||
0xD7,
|
||||
0x37,
|
||||
0xB7,
|
||||
0x77,
|
||||
0xF7,
|
||||
0x0F,
|
||||
0x8F,
|
||||
0x4F,
|
||||
0xCF,
|
||||
0x2F,
|
||||
0xAF,
|
||||
0x6F,
|
||||
0xEF,
|
||||
0x1F,
|
||||
0x9F,
|
||||
0x5F,
|
||||
0xDF,
|
||||
0x3F,
|
||||
0xBF,
|
||||
0x7F,
|
||||
0xFF,
|
||||
}
|
||||
|
||||
const reverseBitsLowest = (uint64(1) << (reverseBitsMax - 1 + reverseBitsBase))
|
||||
|
||||
/* Returns reverse(num >> BROTLI_REVERSE_BITS_BASE, BROTLI_REVERSE_BITS_MAX),
|
||||
where reverse(value, len) is the bit-wise reversal of the len least
|
||||
significant bits of value. */
|
||||
func reverseBits8(num uint64) uint64 {
|
||||
return uint64(kReverseBits[num])
|
||||
}
|
||||
|
||||
/* Stores code in table[0], table[step], table[2*step], ..., table[end] */
|
||||
/* Assumes that end is an integer multiple of step */
|
||||
func replicateValue(table []huffmanCode, step int, end int, code huffmanCode) {
|
||||
for {
|
||||
end -= step
|
||||
table[end] = code
|
||||
if end <= 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Returns the table width of the next 2nd level table. |count| is the histogram
|
||||
of bit lengths for the remaining symbols, |len| is the code length of the
|
||||
next processed symbol. */
|
||||
func nextTableBitSize(count []uint16, len int, root_bits int) int {
|
||||
var left int = 1 << uint(len-root_bits)
|
||||
for len < huffmanMaxCodeLength {
|
||||
left -= int(count[len])
|
||||
if left <= 0 {
|
||||
break
|
||||
}
|
||||
len++
|
||||
left <<= 1
|
||||
}
|
||||
|
||||
return len - root_bits
|
||||
}
|
||||
|
||||
func buildCodeLengthsHuffmanTable(table []huffmanCode, code_lengths []byte, count []uint16) {
|
||||
var code huffmanCode /* current table entry */ /* symbol index in original or sorted table */ /* prefix code */ /* prefix code addend */ /* step size to replicate values in current table */ /* size of current table */ /* symbols sorted by code length */
|
||||
var symbol int
|
||||
var key uint64
|
||||
var key_step uint64
|
||||
var step int
|
||||
var table_size int
|
||||
var sorted [codeLengthCodes]int
|
||||
var offset [huffmanMaxCodeLengthCodeLength + 1]int
|
||||
var bits int
|
||||
var bits_count int
|
||||
/* offsets in sorted table for each length */
|
||||
assert(huffmanMaxCodeLengthCodeLength <= reverseBitsMax)
|
||||
|
||||
/* Generate offsets into sorted symbol table by code length. */
|
||||
symbol = -1
|
||||
|
||||
bits = 1
|
||||
var i int
|
||||
for i = 0; i < huffmanMaxCodeLengthCodeLength; i++ {
|
||||
symbol += int(count[bits])
|
||||
offset[bits] = symbol
|
||||
bits++
|
||||
}
|
||||
|
||||
/* Symbols with code length 0 are placed after all other symbols. */
|
||||
offset[0] = codeLengthCodes - 1
|
||||
|
||||
/* Sort symbols by length, by symbol order within each length. */
|
||||
symbol = codeLengthCodes
|
||||
|
||||
for {
|
||||
var i int
|
||||
for i = 0; i < 6; i++ {
|
||||
symbol--
|
||||
sorted[offset[code_lengths[symbol]]] = symbol
|
||||
offset[code_lengths[symbol]]--
|
||||
}
|
||||
if symbol == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
table_size = 1 << huffmanMaxCodeLengthCodeLength
|
||||
|
||||
/* Special case: all symbols but one have 0 code length. */
|
||||
if offset[0] == 0 {
|
||||
code = constructHuffmanCode(0, uint16(sorted[0]))
|
||||
for key = 0; key < uint64(table_size); key++ {
|
||||
table[key] = code
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
/* Fill in table. */
|
||||
key = 0
|
||||
|
||||
key_step = reverseBitsLowest
|
||||
symbol = 0
|
||||
bits = 1
|
||||
step = 2
|
||||
for {
|
||||
for bits_count = int(count[bits]); bits_count != 0; bits_count-- {
|
||||
code = constructHuffmanCode(byte(bits), uint16(sorted[symbol]))
|
||||
symbol++
|
||||
replicateValue(table[reverseBits8(key):], step, table_size, code)
|
||||
key += key_step
|
||||
}
|
||||
|
||||
step <<= 1
|
||||
key_step >>= 1
|
||||
bits++
|
||||
if bits > huffmanMaxCodeLengthCodeLength {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildHuffmanTable(root_table []huffmanCode, root_bits int, symbol_lists symbolList, count []uint16) uint32 {
|
||||
var code huffmanCode /* current table entry */ /* next available space in table */ /* current code length */ /* symbol index in original or sorted table */ /* prefix code */ /* prefix code addend */ /* 2nd level table prefix code */ /* 2nd level table prefix code addend */ /* step size to replicate values in current table */ /* key length of current table */ /* size of current table */ /* sum of root table size and 2nd level table sizes */
|
||||
var table []huffmanCode
|
||||
var len int
|
||||
var symbol int
|
||||
var key uint64
|
||||
var key_step uint64
|
||||
var sub_key uint64
|
||||
var sub_key_step uint64
|
||||
var step int
|
||||
var table_bits int
|
||||
var table_size int
|
||||
var total_size int
|
||||
var max_length int = -1
|
||||
var bits int
|
||||
var bits_count int
|
||||
|
||||
assert(root_bits <= reverseBitsMax)
|
||||
assert(huffmanMaxCodeLength-root_bits <= reverseBitsMax)
|
||||
|
||||
for symbolListGet(symbol_lists, max_length) == 0xFFFF {
|
||||
max_length--
|
||||
}
|
||||
max_length += huffmanMaxCodeLength + 1
|
||||
|
||||
table = root_table
|
||||
table_bits = root_bits
|
||||
table_size = 1 << uint(table_bits)
|
||||
total_size = table_size
|
||||
|
||||
/* Fill in the root table. Reduce the table size to if possible,
|
||||
and create the repetitions by memcpy. */
|
||||
if table_bits > max_length {
|
||||
table_bits = max_length
|
||||
table_size = 1 << uint(table_bits)
|
||||
}
|
||||
|
||||
key = 0
|
||||
key_step = reverseBitsLowest
|
||||
bits = 1
|
||||
step = 2
|
||||
for {
|
||||
symbol = bits - (huffmanMaxCodeLength + 1)
|
||||
for bits_count = int(count[bits]); bits_count != 0; bits_count-- {
|
||||
symbol = int(symbolListGet(symbol_lists, symbol))
|
||||
code = constructHuffmanCode(byte(bits), uint16(symbol))
|
||||
replicateValue(table[reverseBits8(key):], step, table_size, code)
|
||||
key += key_step
|
||||
}
|
||||
|
||||
step <<= 1
|
||||
key_step >>= 1
|
||||
bits++
|
||||
if bits > table_bits {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
/* If root_bits != table_bits then replicate to fill the remaining slots. */
|
||||
for total_size != table_size {
|
||||
copy(table[table_size:], table[:uint(table_size)])
|
||||
table_size <<= 1
|
||||
}
|
||||
|
||||
/* Fill in 2nd level tables and add pointers to root table. */
|
||||
key_step = reverseBitsLowest >> uint(root_bits-1)
|
||||
|
||||
sub_key = reverseBitsLowest << 1
|
||||
sub_key_step = reverseBitsLowest
|
||||
len = root_bits + 1
|
||||
step = 2
|
||||
for ; len <= max_length; len++ {
|
||||
symbol = len - (huffmanMaxCodeLength + 1)
|
||||
for ; count[len] != 0; count[len]-- {
|
||||
if sub_key == reverseBitsLowest<<1 {
|
||||
table = table[table_size:]
|
||||
table_bits = nextTableBitSize(count, int(len), root_bits)
|
||||
table_size = 1 << uint(table_bits)
|
||||
total_size += table_size
|
||||
sub_key = reverseBits8(key)
|
||||
key += key_step
|
||||
root_table[sub_key] = constructHuffmanCode(byte(table_bits+root_bits), uint16(uint64(uint(-cap(table)+cap(root_table)))-sub_key))
|
||||
sub_key = 0
|
||||
}
|
||||
|
||||
symbol = int(symbolListGet(symbol_lists, symbol))
|
||||
code = constructHuffmanCode(byte(len-root_bits), uint16(symbol))
|
||||
replicateValue(table[reverseBits8(sub_key):], step, table_size, code)
|
||||
sub_key += sub_key_step
|
||||
}
|
||||
|
||||
step <<= 1
|
||||
sub_key_step >>= 1
|
||||
}
|
||||
|
||||
return uint32(total_size)
|
||||
}
|
||||
|
||||
func buildSimpleHuffmanTable(table []huffmanCode, root_bits int, val []uint16, num_symbols uint32) uint32 {
|
||||
var table_size uint32 = 1
|
||||
var goal_size uint32 = 1 << uint(root_bits)
|
||||
switch num_symbols {
|
||||
case 0:
|
||||
table[0] = constructHuffmanCode(0, val[0])
|
||||
|
||||
case 1:
|
||||
if val[1] > val[0] {
|
||||
table[0] = constructHuffmanCode(1, val[0])
|
||||
table[1] = constructHuffmanCode(1, val[1])
|
||||
} else {
|
||||
table[0] = constructHuffmanCode(1, val[1])
|
||||
table[1] = constructHuffmanCode(1, val[0])
|
||||
}
|
||||
|
||||
table_size = 2
|
||||
|
||||
case 2:
|
||||
table[0] = constructHuffmanCode(1, val[0])
|
||||
table[2] = constructHuffmanCode(1, val[0])
|
||||
if val[2] > val[1] {
|
||||
table[1] = constructHuffmanCode(2, val[1])
|
||||
table[3] = constructHuffmanCode(2, val[2])
|
||||
} else {
|
||||
table[1] = constructHuffmanCode(2, val[2])
|
||||
table[3] = constructHuffmanCode(2, val[1])
|
||||
}
|
||||
|
||||
table_size = 4
|
||||
|
||||
case 3:
|
||||
var i int
|
||||
var k int
|
||||
for i = 0; i < 3; i++ {
|
||||
for k = i + 1; k < 4; k++ {
|
||||
if val[k] < val[i] {
|
||||
var t uint16 = val[k]
|
||||
val[k] = val[i]
|
||||
val[i] = t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
table[0] = constructHuffmanCode(2, val[0])
|
||||
table[2] = constructHuffmanCode(2, val[1])
|
||||
table[1] = constructHuffmanCode(2, val[2])
|
||||
table[3] = constructHuffmanCode(2, val[3])
|
||||
table_size = 4
|
||||
|
||||
case 4:
|
||||
if val[3] < val[2] {
|
||||
var t uint16 = val[3]
|
||||
val[3] = val[2]
|
||||
val[2] = t
|
||||
}
|
||||
|
||||
table[0] = constructHuffmanCode(1, val[0])
|
||||
table[1] = constructHuffmanCode(2, val[1])
|
||||
table[2] = constructHuffmanCode(1, val[0])
|
||||
table[3] = constructHuffmanCode(3, val[2])
|
||||
table[4] = constructHuffmanCode(1, val[0])
|
||||
table[5] = constructHuffmanCode(2, val[1])
|
||||
table[6] = constructHuffmanCode(1, val[0])
|
||||
table[7] = constructHuffmanCode(3, val[3])
|
||||
table_size = 8
|
||||
}
|
||||
|
||||
for table_size != goal_size {
|
||||
copy(table[table_size:], table[:uint(table_size)])
|
||||
table_size <<= 1
|
||||
}
|
||||
|
||||
return goal_size
|
||||
}
|
||||
+182
@@ -0,0 +1,182 @@
|
||||
package brotli
|
||||
|
||||
func utf8Position(last uint, c uint, clamp uint) uint {
|
||||
if c < 128 {
|
||||
return 0 /* Next one is the 'Byte 1' again. */
|
||||
} else if c >= 192 { /* Next one is the 'Byte 2' of utf-8 encoding. */
|
||||
return brotli_min_size_t(1, clamp)
|
||||
} else {
|
||||
/* Let's decide over the last byte if this ends the sequence. */
|
||||
if last < 0xE0 {
|
||||
return 0 /* Completed two or three byte coding. */ /* Next one is the 'Byte 3' of utf-8 encoding. */
|
||||
} else {
|
||||
return brotli_min_size_t(2, clamp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decideMultiByteStatsLevel(pos uint, len uint, mask uint, data []byte) uint {
|
||||
var counts = [3]uint{0} /* should be 2, but 1 compresses better. */
|
||||
var max_utf8 uint = 1
|
||||
var last_c uint = 0
|
||||
var i uint
|
||||
for i = 0; i < len; i++ {
|
||||
var c uint = uint(data[(pos+i)&mask])
|
||||
counts[utf8Position(last_c, c, 2)]++
|
||||
last_c = c
|
||||
}
|
||||
|
||||
if counts[2] < 500 {
|
||||
max_utf8 = 1
|
||||
}
|
||||
|
||||
if counts[1]+counts[2] < 25 {
|
||||
max_utf8 = 0
|
||||
}
|
||||
|
||||
return max_utf8
|
||||
}
|
||||
|
||||
func estimateBitCostsForLiteralsUTF8(pos uint, len uint, mask uint, data []byte, cost []float32) {
|
||||
var max_utf8 uint = decideMultiByteStatsLevel(pos, uint(len), mask, data)
|
||||
/* Bootstrap histograms. */
|
||||
var histogram = [3][256]uint{[256]uint{0}}
|
||||
var window_half uint = 495
|
||||
var in_window uint = brotli_min_size_t(window_half, uint(len))
|
||||
var in_window_utf8 = [3]uint{0}
|
||||
/* max_utf8 is 0 (normal ASCII single byte modeling),
|
||||
1 (for 2-byte UTF-8 modeling), or 2 (for 3-byte UTF-8 modeling). */
|
||||
|
||||
var i uint
|
||||
{
|
||||
var last_c uint = 0
|
||||
var utf8_pos uint = 0
|
||||
for i = 0; i < in_window; i++ {
|
||||
var c uint = uint(data[(pos+i)&mask])
|
||||
histogram[utf8_pos][c]++
|
||||
in_window_utf8[utf8_pos]++
|
||||
utf8_pos = utf8Position(last_c, c, max_utf8)
|
||||
last_c = c
|
||||
}
|
||||
}
|
||||
|
||||
/* Compute bit costs with sliding window. */
|
||||
for i = 0; i < len; i++ {
|
||||
if i >= window_half {
|
||||
var c uint
|
||||
var last_c uint
|
||||
if i < window_half+1 {
|
||||
c = 0
|
||||
} else {
|
||||
c = uint(data[(pos+i-window_half-1)&mask])
|
||||
}
|
||||
if i < window_half+2 {
|
||||
last_c = 0
|
||||
} else {
|
||||
last_c = uint(data[(pos+i-window_half-2)&mask])
|
||||
}
|
||||
/* Remove a byte in the past. */
|
||||
|
||||
var utf8_pos2 uint = utf8Position(last_c, c, max_utf8)
|
||||
histogram[utf8_pos2][data[(pos+i-window_half)&mask]]--
|
||||
in_window_utf8[utf8_pos2]--
|
||||
}
|
||||
|
||||
if i+window_half < len {
|
||||
var c uint = uint(data[(pos+i+window_half-1)&mask])
|
||||
var last_c uint = uint(data[(pos+i+window_half-2)&mask])
|
||||
/* Add a byte in the future. */
|
||||
|
||||
var utf8_pos2 uint = utf8Position(last_c, c, max_utf8)
|
||||
histogram[utf8_pos2][data[(pos+i+window_half)&mask]]++
|
||||
in_window_utf8[utf8_pos2]++
|
||||
}
|
||||
{
|
||||
var c uint
|
||||
var last_c uint
|
||||
if i < 1 {
|
||||
c = 0
|
||||
} else {
|
||||
c = uint(data[(pos+i-1)&mask])
|
||||
}
|
||||
if i < 2 {
|
||||
last_c = 0
|
||||
} else {
|
||||
last_c = uint(data[(pos+i-2)&mask])
|
||||
}
|
||||
var utf8_pos uint = utf8Position(last_c, c, max_utf8)
|
||||
var masked_pos uint = (pos + i) & mask
|
||||
var histo uint = histogram[utf8_pos][data[masked_pos]]
|
||||
var lit_cost float64
|
||||
if histo == 0 {
|
||||
histo = 1
|
||||
}
|
||||
|
||||
lit_cost = fastLog2(in_window_utf8[utf8_pos]) - fastLog2(histo)
|
||||
lit_cost += 0.02905
|
||||
if lit_cost < 1.0 {
|
||||
lit_cost *= 0.5
|
||||
lit_cost += 0.5
|
||||
}
|
||||
|
||||
/* Make the first bytes more expensive -- seems to help, not sure why.
|
||||
Perhaps because the entropy source is changing its properties
|
||||
rapidly in the beginning of the file, perhaps because the beginning
|
||||
of the data is a statistical "anomaly". */
|
||||
if i < 2000 {
|
||||
lit_cost += 0.7 - (float64(2000-i) / 2000.0 * 0.35)
|
||||
}
|
||||
|
||||
cost[i] = float32(lit_cost)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func estimateBitCostsForLiterals(pos uint, len uint, mask uint, data []byte, cost []float32) {
|
||||
if isMostlyUTF8(data, pos, mask, uint(len), kMinUTF8Ratio) {
|
||||
estimateBitCostsForLiteralsUTF8(pos, uint(len), mask, data, cost)
|
||||
return
|
||||
} else {
|
||||
var histogram = [256]uint{0}
|
||||
var window_half uint = 2000
|
||||
var in_window uint = brotli_min_size_t(window_half, uint(len))
|
||||
var i uint
|
||||
/* Bootstrap histogram. */
|
||||
for i = 0; i < in_window; i++ {
|
||||
histogram[data[(pos+i)&mask]]++
|
||||
}
|
||||
|
||||
/* Compute bit costs with sliding window. */
|
||||
for i = 0; i < len; i++ {
|
||||
var histo uint
|
||||
if i >= window_half {
|
||||
/* Remove a byte in the past. */
|
||||
histogram[data[(pos+i-window_half)&mask]]--
|
||||
|
||||
in_window--
|
||||
}
|
||||
|
||||
if i+window_half < len {
|
||||
/* Add a byte in the future. */
|
||||
histogram[data[(pos+i+window_half)&mask]]++
|
||||
|
||||
in_window++
|
||||
}
|
||||
|
||||
histo = histogram[data[(pos+i)&mask]]
|
||||
if histo == 0 {
|
||||
histo = 1
|
||||
}
|
||||
{
|
||||
var lit_cost float64 = fastLog2(in_window) - fastLog2(histo)
|
||||
lit_cost += 0.029
|
||||
if lit_cost < 1.0 {
|
||||
lit_cost *= 0.5
|
||||
lit_cost += 0.5
|
||||
}
|
||||
|
||||
cost[i] = float32(lit_cost)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
+66
@@ -0,0 +1,66 @@
|
||||
package brotli
|
||||
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/*
|
||||
Dynamically grows array capacity to at least the requested size
|
||||
T: data type
|
||||
A: array
|
||||
C: capacity
|
||||
R: requested size
|
||||
*/
|
||||
func brotli_ensure_capacity_uint8_t(a *[]byte, c *uint, r uint) {
|
||||
if *c < r {
|
||||
var new_size uint = *c
|
||||
if new_size == 0 {
|
||||
new_size = r
|
||||
}
|
||||
|
||||
for new_size < r {
|
||||
new_size *= 2
|
||||
}
|
||||
|
||||
if cap(*a) < int(new_size) {
|
||||
var new_array []byte = make([]byte, new_size)
|
||||
if *c != 0 {
|
||||
copy(new_array, (*a)[:*c])
|
||||
}
|
||||
|
||||
*a = new_array
|
||||
} else {
|
||||
*a = (*a)[:new_size]
|
||||
}
|
||||
|
||||
*c = new_size
|
||||
}
|
||||
}
|
||||
|
||||
func brotli_ensure_capacity_uint32_t(a *[]uint32, c *uint, r uint) {
|
||||
var new_array []uint32
|
||||
if *c < r {
|
||||
var new_size uint = *c
|
||||
if new_size == 0 {
|
||||
new_size = r
|
||||
}
|
||||
|
||||
for new_size < r {
|
||||
new_size *= 2
|
||||
}
|
||||
|
||||
if cap(*a) < int(new_size) {
|
||||
new_array = make([]uint32, new_size)
|
||||
if *c != 0 {
|
||||
copy(new_array, (*a)[:*c])
|
||||
}
|
||||
|
||||
*a = new_array
|
||||
} else {
|
||||
*a = (*a)[:new_size]
|
||||
}
|
||||
*c = new_size
|
||||
}
|
||||
}
|
||||
+574
@@ -0,0 +1,574 @@
|
||||
package brotli
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
/* Copyright 2014 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Algorithms for distributing the literals and commands of a metablock between
|
||||
block types and contexts. */
|
||||
|
||||
type metaBlockSplit struct {
|
||||
literal_split blockSplit
|
||||
command_split blockSplit
|
||||
distance_split blockSplit
|
||||
literal_context_map []uint32
|
||||
literal_context_map_size uint
|
||||
distance_context_map []uint32
|
||||
distance_context_map_size uint
|
||||
literal_histograms []histogramLiteral
|
||||
literal_histograms_size uint
|
||||
command_histograms []histogramCommand
|
||||
command_histograms_size uint
|
||||
distance_histograms []histogramDistance
|
||||
distance_histograms_size uint
|
||||
}
|
||||
|
||||
var metaBlockPool sync.Pool
|
||||
|
||||
func getMetaBlockSplit() *metaBlockSplit {
|
||||
mb, _ := metaBlockPool.Get().(*metaBlockSplit)
|
||||
|
||||
if mb == nil {
|
||||
mb = &metaBlockSplit{}
|
||||
} else {
|
||||
initBlockSplit(&mb.literal_split)
|
||||
initBlockSplit(&mb.command_split)
|
||||
initBlockSplit(&mb.distance_split)
|
||||
mb.literal_context_map = mb.literal_context_map[:0]
|
||||
mb.literal_context_map_size = 0
|
||||
mb.distance_context_map = mb.distance_context_map[:0]
|
||||
mb.distance_context_map_size = 0
|
||||
mb.literal_histograms = mb.literal_histograms[:0]
|
||||
mb.command_histograms = mb.command_histograms[:0]
|
||||
mb.distance_histograms = mb.distance_histograms[:0]
|
||||
}
|
||||
return mb
|
||||
}
|
||||
|
||||
func freeMetaBlockSplit(mb *metaBlockSplit) {
|
||||
metaBlockPool.Put(mb)
|
||||
}
|
||||
|
||||
func initDistanceParams(params *encoderParams, npostfix uint32, ndirect uint32) {
|
||||
var dist_params *distanceParams = ¶ms.dist
|
||||
var alphabet_size uint32
|
||||
var max_distance uint32
|
||||
|
||||
dist_params.distance_postfix_bits = npostfix
|
||||
dist_params.num_direct_distance_codes = ndirect
|
||||
|
||||
alphabet_size = uint32(distanceAlphabetSize(uint(npostfix), uint(ndirect), maxDistanceBits))
|
||||
max_distance = ndirect + (1 << (maxDistanceBits + npostfix + 2)) - (1 << (npostfix + 2))
|
||||
|
||||
if params.large_window {
|
||||
var bound = [maxNpostfix + 1]uint32{0, 4, 12, 28}
|
||||
var postfix uint32 = 1 << npostfix
|
||||
alphabet_size = uint32(distanceAlphabetSize(uint(npostfix), uint(ndirect), largeMaxDistanceBits))
|
||||
|
||||
/* The maximum distance is set so that no distance symbol used can encode
|
||||
a distance larger than BROTLI_MAX_ALLOWED_DISTANCE with all
|
||||
its extra bits set. */
|
||||
if ndirect < bound[npostfix] {
|
||||
max_distance = maxAllowedDistance - (bound[npostfix] - ndirect)
|
||||
} else if ndirect >= bound[npostfix]+postfix {
|
||||
max_distance = (3 << 29) - 4 + (ndirect - bound[npostfix])
|
||||
} else {
|
||||
max_distance = maxAllowedDistance
|
||||
}
|
||||
}
|
||||
|
||||
dist_params.alphabet_size = alphabet_size
|
||||
dist_params.max_distance = uint(max_distance)
|
||||
}
|
||||
|
||||
func recomputeDistancePrefixes(cmds []command, orig_params *distanceParams, new_params *distanceParams) {
|
||||
if orig_params.distance_postfix_bits == new_params.distance_postfix_bits && orig_params.num_direct_distance_codes == new_params.num_direct_distance_codes {
|
||||
return
|
||||
}
|
||||
|
||||
for i := range cmds {
|
||||
var cmd *command = &cmds[i]
|
||||
if commandCopyLen(cmd) != 0 && cmd.cmd_prefix_ >= 128 {
|
||||
prefixEncodeCopyDistance(uint(commandRestoreDistanceCode(cmd, orig_params)), uint(new_params.num_direct_distance_codes), uint(new_params.distance_postfix_bits), &cmd.dist_prefix_, &cmd.dist_extra_)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func computeDistanceCost(cmds []command, orig_params *distanceParams, new_params *distanceParams, cost *float64) bool {
|
||||
var equal_params bool = false
|
||||
var dist_prefix uint16
|
||||
var dist_extra uint32
|
||||
var extra_bits float64 = 0.0
|
||||
var histo histogramDistance
|
||||
histogramClearDistance(&histo)
|
||||
|
||||
if orig_params.distance_postfix_bits == new_params.distance_postfix_bits && orig_params.num_direct_distance_codes == new_params.num_direct_distance_codes {
|
||||
equal_params = true
|
||||
}
|
||||
|
||||
for i := range cmds {
|
||||
cmd := &cmds[i]
|
||||
if commandCopyLen(cmd) != 0 && cmd.cmd_prefix_ >= 128 {
|
||||
if equal_params {
|
||||
dist_prefix = cmd.dist_prefix_
|
||||
} else {
|
||||
var distance uint32 = commandRestoreDistanceCode(cmd, orig_params)
|
||||
if distance > uint32(new_params.max_distance) {
|
||||
return false
|
||||
}
|
||||
|
||||
prefixEncodeCopyDistance(uint(distance), uint(new_params.num_direct_distance_codes), uint(new_params.distance_postfix_bits), &dist_prefix, &dist_extra)
|
||||
}
|
||||
|
||||
histogramAddDistance(&histo, uint(dist_prefix)&0x3FF)
|
||||
extra_bits += float64(dist_prefix >> 10)
|
||||
}
|
||||
}
|
||||
|
||||
*cost = populationCostDistance(&histo) + extra_bits
|
||||
return true
|
||||
}
|
||||
|
||||
var buildMetaBlock_kMaxNumberOfHistograms uint = 256
|
||||
|
||||
func buildMetaBlock(ringbuffer []byte, pos uint, mask uint, params *encoderParams, prev_byte byte, prev_byte2 byte, cmds []command, literal_context_mode int, mb *metaBlockSplit) {
|
||||
var distance_histograms []histogramDistance
|
||||
var literal_histograms []histogramLiteral
|
||||
var literal_context_modes []int = nil
|
||||
var literal_histograms_size uint
|
||||
var distance_histograms_size uint
|
||||
var i uint
|
||||
var literal_context_multiplier uint = 1
|
||||
var npostfix uint32
|
||||
var ndirect_msb uint32 = 0
|
||||
var check_orig bool = true
|
||||
var best_dist_cost float64 = 1e99
|
||||
var orig_params encoderParams = *params
|
||||
/* Histogram ids need to fit in one byte. */
|
||||
|
||||
var new_params encoderParams = *params
|
||||
|
||||
for npostfix = 0; npostfix <= maxNpostfix; npostfix++ {
|
||||
for ; ndirect_msb < 16; ndirect_msb++ {
|
||||
var ndirect uint32 = ndirect_msb << npostfix
|
||||
var skip bool
|
||||
var dist_cost float64
|
||||
initDistanceParams(&new_params, npostfix, ndirect)
|
||||
if npostfix == orig_params.dist.distance_postfix_bits && ndirect == orig_params.dist.num_direct_distance_codes {
|
||||
check_orig = false
|
||||
}
|
||||
|
||||
skip = !computeDistanceCost(cmds, &orig_params.dist, &new_params.dist, &dist_cost)
|
||||
if skip || (dist_cost > best_dist_cost) {
|
||||
break
|
||||
}
|
||||
|
||||
best_dist_cost = dist_cost
|
||||
params.dist = new_params.dist
|
||||
}
|
||||
|
||||
if ndirect_msb > 0 {
|
||||
ndirect_msb--
|
||||
}
|
||||
ndirect_msb /= 2
|
||||
}
|
||||
|
||||
if check_orig {
|
||||
var dist_cost float64
|
||||
computeDistanceCost(cmds, &orig_params.dist, &orig_params.dist, &dist_cost)
|
||||
if dist_cost < best_dist_cost {
|
||||
/* NB: currently unused; uncomment when more param tuning is added. */
|
||||
/* best_dist_cost = dist_cost; */
|
||||
params.dist = orig_params.dist
|
||||
}
|
||||
}
|
||||
|
||||
recomputeDistancePrefixes(cmds, &orig_params.dist, ¶ms.dist)
|
||||
|
||||
splitBlock(cmds, ringbuffer, pos, mask, params, &mb.literal_split, &mb.command_split, &mb.distance_split)
|
||||
|
||||
if !params.disable_literal_context_modeling {
|
||||
literal_context_multiplier = 1 << literalContextBits
|
||||
literal_context_modes = make([]int, (mb.literal_split.num_types))
|
||||
for i = 0; i < mb.literal_split.num_types; i++ {
|
||||
literal_context_modes[i] = literal_context_mode
|
||||
}
|
||||
}
|
||||
|
||||
literal_histograms_size = mb.literal_split.num_types * literal_context_multiplier
|
||||
literal_histograms = make([]histogramLiteral, literal_histograms_size)
|
||||
clearHistogramsLiteral(literal_histograms, literal_histograms_size)
|
||||
|
||||
distance_histograms_size = mb.distance_split.num_types << distanceContextBits
|
||||
distance_histograms = make([]histogramDistance, distance_histograms_size)
|
||||
clearHistogramsDistance(distance_histograms, distance_histograms_size)
|
||||
|
||||
mb.command_histograms_size = mb.command_split.num_types
|
||||
if cap(mb.command_histograms) < int(mb.command_histograms_size) {
|
||||
mb.command_histograms = make([]histogramCommand, (mb.command_histograms_size))
|
||||
} else {
|
||||
mb.command_histograms = mb.command_histograms[:mb.command_histograms_size]
|
||||
}
|
||||
clearHistogramsCommand(mb.command_histograms, mb.command_histograms_size)
|
||||
|
||||
buildHistogramsWithContext(cmds, &mb.literal_split, &mb.command_split, &mb.distance_split, ringbuffer, pos, mask, prev_byte, prev_byte2, literal_context_modes, literal_histograms, mb.command_histograms, distance_histograms)
|
||||
literal_context_modes = nil
|
||||
|
||||
mb.literal_context_map_size = mb.literal_split.num_types << literalContextBits
|
||||
if cap(mb.literal_context_map) < int(mb.literal_context_map_size) {
|
||||
mb.literal_context_map = make([]uint32, (mb.literal_context_map_size))
|
||||
} else {
|
||||
mb.literal_context_map = mb.literal_context_map[:mb.literal_context_map_size]
|
||||
}
|
||||
|
||||
mb.literal_histograms_size = mb.literal_context_map_size
|
||||
if cap(mb.literal_histograms) < int(mb.literal_histograms_size) {
|
||||
mb.literal_histograms = make([]histogramLiteral, (mb.literal_histograms_size))
|
||||
} else {
|
||||
mb.literal_histograms = mb.literal_histograms[:mb.literal_histograms_size]
|
||||
}
|
||||
|
||||
clusterHistogramsLiteral(literal_histograms, literal_histograms_size, buildMetaBlock_kMaxNumberOfHistograms, mb.literal_histograms, &mb.literal_histograms_size, mb.literal_context_map)
|
||||
literal_histograms = nil
|
||||
|
||||
if params.disable_literal_context_modeling {
|
||||
/* Distribute assignment to all contexts. */
|
||||
for i = mb.literal_split.num_types; i != 0; {
|
||||
var j uint = 0
|
||||
i--
|
||||
for ; j < 1<<literalContextBits; j++ {
|
||||
mb.literal_context_map[(i<<literalContextBits)+j] = mb.literal_context_map[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mb.distance_context_map_size = mb.distance_split.num_types << distanceContextBits
|
||||
if cap(mb.distance_context_map) < int(mb.distance_context_map_size) {
|
||||
mb.distance_context_map = make([]uint32, (mb.distance_context_map_size))
|
||||
} else {
|
||||
mb.distance_context_map = mb.distance_context_map[:mb.distance_context_map_size]
|
||||
}
|
||||
|
||||
mb.distance_histograms_size = mb.distance_context_map_size
|
||||
if cap(mb.distance_histograms) < int(mb.distance_histograms_size) {
|
||||
mb.distance_histograms = make([]histogramDistance, (mb.distance_histograms_size))
|
||||
} else {
|
||||
mb.distance_histograms = mb.distance_histograms[:mb.distance_histograms_size]
|
||||
}
|
||||
|
||||
clusterHistogramsDistance(distance_histograms, mb.distance_context_map_size, buildMetaBlock_kMaxNumberOfHistograms, mb.distance_histograms, &mb.distance_histograms_size, mb.distance_context_map)
|
||||
distance_histograms = nil
|
||||
}
|
||||
|
||||
const maxStaticContexts = 13
|
||||
|
||||
/* Greedy block splitter for one block category (literal, command or distance).
|
||||
Gathers histograms for all context buckets. */
|
||||
type contextBlockSplitter struct {
|
||||
alphabet_size_ uint
|
||||
num_contexts_ uint
|
||||
max_block_types_ uint
|
||||
min_block_size_ uint
|
||||
split_threshold_ float64
|
||||
num_blocks_ uint
|
||||
split_ *blockSplit
|
||||
histograms_ []histogramLiteral
|
||||
histograms_size_ *uint
|
||||
target_block_size_ uint
|
||||
block_size_ uint
|
||||
curr_histogram_ix_ uint
|
||||
last_histogram_ix_ [2]uint
|
||||
last_entropy_ [2 * maxStaticContexts]float64
|
||||
merge_last_count_ uint
|
||||
}
|
||||
|
||||
func initContextBlockSplitter(self *contextBlockSplitter, alphabet_size uint, num_contexts uint, min_block_size uint, split_threshold float64, num_symbols uint, split *blockSplit, histograms *[]histogramLiteral, histograms_size *uint) {
|
||||
var max_num_blocks uint = num_symbols/min_block_size + 1
|
||||
var max_num_types uint
|
||||
assert(num_contexts <= maxStaticContexts)
|
||||
|
||||
self.alphabet_size_ = alphabet_size
|
||||
self.num_contexts_ = num_contexts
|
||||
self.max_block_types_ = maxNumberOfBlockTypes / num_contexts
|
||||
self.min_block_size_ = min_block_size
|
||||
self.split_threshold_ = split_threshold
|
||||
self.num_blocks_ = 0
|
||||
self.split_ = split
|
||||
self.histograms_size_ = histograms_size
|
||||
self.target_block_size_ = min_block_size
|
||||
self.block_size_ = 0
|
||||
self.curr_histogram_ix_ = 0
|
||||
self.merge_last_count_ = 0
|
||||
|
||||
/* We have to allocate one more histogram than the maximum number of block
|
||||
types for the current histogram when the meta-block is too big. */
|
||||
max_num_types = brotli_min_size_t(max_num_blocks, self.max_block_types_+1)
|
||||
|
||||
brotli_ensure_capacity_uint8_t(&split.types, &split.types_alloc_size, max_num_blocks)
|
||||
brotli_ensure_capacity_uint32_t(&split.lengths, &split.lengths_alloc_size, max_num_blocks)
|
||||
split.num_blocks = max_num_blocks
|
||||
*histograms_size = max_num_types * num_contexts
|
||||
if histograms == nil || cap(*histograms) < int(*histograms_size) {
|
||||
*histograms = make([]histogramLiteral, (*histograms_size))
|
||||
} else {
|
||||
*histograms = (*histograms)[:*histograms_size]
|
||||
}
|
||||
self.histograms_ = *histograms
|
||||
|
||||
/* Clear only current histogram. */
|
||||
clearHistogramsLiteral(self.histograms_[0:], num_contexts)
|
||||
|
||||
self.last_histogram_ix_[1] = 0
|
||||
self.last_histogram_ix_[0] = self.last_histogram_ix_[1]
|
||||
}
|
||||
|
||||
/* Does either of three things:
|
||||
(1) emits the current block with a new block type;
|
||||
(2) emits the current block with the type of the second last block;
|
||||
(3) merges the current block with the last block. */
|
||||
func contextBlockSplitterFinishBlock(self *contextBlockSplitter, is_final bool) {
|
||||
var split *blockSplit = self.split_
|
||||
var num_contexts uint = self.num_contexts_
|
||||
var last_entropy []float64 = self.last_entropy_[:]
|
||||
var histograms []histogramLiteral = self.histograms_
|
||||
|
||||
if self.block_size_ < self.min_block_size_ {
|
||||
self.block_size_ = self.min_block_size_
|
||||
}
|
||||
|
||||
if self.num_blocks_ == 0 {
|
||||
var i uint
|
||||
|
||||
/* Create first block. */
|
||||
split.lengths[0] = uint32(self.block_size_)
|
||||
|
||||
split.types[0] = 0
|
||||
|
||||
for i = 0; i < num_contexts; i++ {
|
||||
last_entropy[i] = bitsEntropy(histograms[i].data_[:], self.alphabet_size_)
|
||||
last_entropy[num_contexts+i] = last_entropy[i]
|
||||
}
|
||||
|
||||
self.num_blocks_++
|
||||
split.num_types++
|
||||
self.curr_histogram_ix_ += num_contexts
|
||||
if self.curr_histogram_ix_ < *self.histograms_size_ {
|
||||
clearHistogramsLiteral(self.histograms_[self.curr_histogram_ix_:], self.num_contexts_)
|
||||
}
|
||||
|
||||
self.block_size_ = 0
|
||||
} else if self.block_size_ > 0 {
|
||||
var entropy [maxStaticContexts]float64
|
||||
var combined_histo []histogramLiteral = make([]histogramLiteral, (2 * num_contexts))
|
||||
var combined_entropy [2 * maxStaticContexts]float64
|
||||
var diff = [2]float64{0.0}
|
||||
/* Try merging the set of histograms for the current block type with the
|
||||
respective set of histograms for the last and second last block types.
|
||||
Decide over the split based on the total reduction of entropy across
|
||||
all contexts. */
|
||||
|
||||
var i uint
|
||||
for i = 0; i < num_contexts; i++ {
|
||||
var curr_histo_ix uint = self.curr_histogram_ix_ + i
|
||||
var j uint
|
||||
entropy[i] = bitsEntropy(histograms[curr_histo_ix].data_[:], self.alphabet_size_)
|
||||
for j = 0; j < 2; j++ {
|
||||
var jx uint = j*num_contexts + i
|
||||
var last_histogram_ix uint = self.last_histogram_ix_[j] + i
|
||||
combined_histo[jx] = histograms[curr_histo_ix]
|
||||
histogramAddHistogramLiteral(&combined_histo[jx], &histograms[last_histogram_ix])
|
||||
combined_entropy[jx] = bitsEntropy(combined_histo[jx].data_[0:], self.alphabet_size_)
|
||||
diff[j] += combined_entropy[jx] - entropy[i] - last_entropy[jx]
|
||||
}
|
||||
}
|
||||
|
||||
if split.num_types < self.max_block_types_ && diff[0] > self.split_threshold_ && diff[1] > self.split_threshold_ {
|
||||
/* Create new block. */
|
||||
split.lengths[self.num_blocks_] = uint32(self.block_size_)
|
||||
|
||||
split.types[self.num_blocks_] = byte(split.num_types)
|
||||
self.last_histogram_ix_[1] = self.last_histogram_ix_[0]
|
||||
self.last_histogram_ix_[0] = split.num_types * num_contexts
|
||||
for i = 0; i < num_contexts; i++ {
|
||||
last_entropy[num_contexts+i] = last_entropy[i]
|
||||
last_entropy[i] = entropy[i]
|
||||
}
|
||||
|
||||
self.num_blocks_++
|
||||
split.num_types++
|
||||
self.curr_histogram_ix_ += num_contexts
|
||||
if self.curr_histogram_ix_ < *self.histograms_size_ {
|
||||
clearHistogramsLiteral(self.histograms_[self.curr_histogram_ix_:], self.num_contexts_)
|
||||
}
|
||||
|
||||
self.block_size_ = 0
|
||||
self.merge_last_count_ = 0
|
||||
self.target_block_size_ = self.min_block_size_
|
||||
} else if diff[1] < diff[0]-20.0 {
|
||||
split.lengths[self.num_blocks_] = uint32(self.block_size_)
|
||||
split.types[self.num_blocks_] = split.types[self.num_blocks_-2]
|
||||
/* Combine this block with second last block. */
|
||||
|
||||
var tmp uint = self.last_histogram_ix_[0]
|
||||
self.last_histogram_ix_[0] = self.last_histogram_ix_[1]
|
||||
self.last_histogram_ix_[1] = tmp
|
||||
for i = 0; i < num_contexts; i++ {
|
||||
histograms[self.last_histogram_ix_[0]+i] = combined_histo[num_contexts+i]
|
||||
last_entropy[num_contexts+i] = last_entropy[i]
|
||||
last_entropy[i] = combined_entropy[num_contexts+i]
|
||||
histogramClearLiteral(&histograms[self.curr_histogram_ix_+i])
|
||||
}
|
||||
|
||||
self.num_blocks_++
|
||||
self.block_size_ = 0
|
||||
self.merge_last_count_ = 0
|
||||
self.target_block_size_ = self.min_block_size_
|
||||
} else {
|
||||
/* Combine this block with last block. */
|
||||
split.lengths[self.num_blocks_-1] += uint32(self.block_size_)
|
||||
|
||||
for i = 0; i < num_contexts; i++ {
|
||||
histograms[self.last_histogram_ix_[0]+i] = combined_histo[i]
|
||||
last_entropy[i] = combined_entropy[i]
|
||||
if split.num_types == 1 {
|
||||
last_entropy[num_contexts+i] = last_entropy[i]
|
||||
}
|
||||
|
||||
histogramClearLiteral(&histograms[self.curr_histogram_ix_+i])
|
||||
}
|
||||
|
||||
self.block_size_ = 0
|
||||
self.merge_last_count_++
|
||||
if self.merge_last_count_ > 1 {
|
||||
self.target_block_size_ += self.min_block_size_
|
||||
}
|
||||
}
|
||||
|
||||
combined_histo = nil
|
||||
}
|
||||
|
||||
if is_final {
|
||||
*self.histograms_size_ = split.num_types * num_contexts
|
||||
split.num_blocks = self.num_blocks_
|
||||
}
|
||||
}
|
||||
|
||||
/* Adds the next symbol to the current block type and context. When the
|
||||
current block reaches the target size, decides on merging the block. */
|
||||
func contextBlockSplitterAddSymbol(self *contextBlockSplitter, symbol uint, context uint) {
|
||||
histogramAddLiteral(&self.histograms_[self.curr_histogram_ix_+context], symbol)
|
||||
self.block_size_++
|
||||
if self.block_size_ == self.target_block_size_ {
|
||||
contextBlockSplitterFinishBlock(self, false) /* is_final = */
|
||||
}
|
||||
}
|
||||
|
||||
func mapStaticContexts(num_contexts uint, static_context_map []uint32, mb *metaBlockSplit) {
|
||||
var i uint
|
||||
mb.literal_context_map_size = mb.literal_split.num_types << literalContextBits
|
||||
if cap(mb.literal_context_map) < int(mb.literal_context_map_size) {
|
||||
mb.literal_context_map = make([]uint32, (mb.literal_context_map_size))
|
||||
} else {
|
||||
mb.literal_context_map = mb.literal_context_map[:mb.literal_context_map_size]
|
||||
}
|
||||
|
||||
for i = 0; i < mb.literal_split.num_types; i++ {
|
||||
var offset uint32 = uint32(i * num_contexts)
|
||||
var j uint
|
||||
for j = 0; j < 1<<literalContextBits; j++ {
|
||||
mb.literal_context_map[(i<<literalContextBits)+j] = offset + static_context_map[j]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildMetaBlockGreedyInternal(ringbuffer []byte, pos uint, mask uint, prev_byte byte, prev_byte2 byte, literal_context_lut contextLUT, num_contexts uint, static_context_map []uint32, commands []command, mb *metaBlockSplit) {
|
||||
var lit_blocks struct {
|
||||
plain blockSplitterLiteral
|
||||
ctx contextBlockSplitter
|
||||
}
|
||||
var cmd_blocks blockSplitterCommand
|
||||
var dist_blocks blockSplitterDistance
|
||||
var num_literals uint = 0
|
||||
for i := range commands {
|
||||
num_literals += uint(commands[i].insert_len_)
|
||||
}
|
||||
|
||||
if num_contexts == 1 {
|
||||
initBlockSplitterLiteral(&lit_blocks.plain, 256, 512, 400.0, num_literals, &mb.literal_split, &mb.literal_histograms, &mb.literal_histograms_size)
|
||||
} else {
|
||||
initContextBlockSplitter(&lit_blocks.ctx, 256, num_contexts, 512, 400.0, num_literals, &mb.literal_split, &mb.literal_histograms, &mb.literal_histograms_size)
|
||||
}
|
||||
|
||||
initBlockSplitterCommand(&cmd_blocks, numCommandSymbols, 1024, 500.0, uint(len(commands)), &mb.command_split, &mb.command_histograms, &mb.command_histograms_size)
|
||||
initBlockSplitterDistance(&dist_blocks, 64, 512, 100.0, uint(len(commands)), &mb.distance_split, &mb.distance_histograms, &mb.distance_histograms_size)
|
||||
|
||||
for _, cmd := range commands {
|
||||
var j uint
|
||||
blockSplitterAddSymbolCommand(&cmd_blocks, uint(cmd.cmd_prefix_))
|
||||
for j = uint(cmd.insert_len_); j != 0; j-- {
|
||||
var literal byte = ringbuffer[pos&mask]
|
||||
if num_contexts == 1 {
|
||||
blockSplitterAddSymbolLiteral(&lit_blocks.plain, uint(literal))
|
||||
} else {
|
||||
var context uint = uint(getContext(prev_byte, prev_byte2, literal_context_lut))
|
||||
contextBlockSplitterAddSymbol(&lit_blocks.ctx, uint(literal), uint(static_context_map[context]))
|
||||
}
|
||||
|
||||
prev_byte2 = prev_byte
|
||||
prev_byte = literal
|
||||
pos++
|
||||
}
|
||||
|
||||
pos += uint(commandCopyLen(&cmd))
|
||||
if commandCopyLen(&cmd) != 0 {
|
||||
prev_byte2 = ringbuffer[(pos-2)&mask]
|
||||
prev_byte = ringbuffer[(pos-1)&mask]
|
||||
if cmd.cmd_prefix_ >= 128 {
|
||||
blockSplitterAddSymbolDistance(&dist_blocks, uint(cmd.dist_prefix_)&0x3FF)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if num_contexts == 1 {
|
||||
blockSplitterFinishBlockLiteral(&lit_blocks.plain, true) /* is_final = */
|
||||
} else {
|
||||
contextBlockSplitterFinishBlock(&lit_blocks.ctx, true) /* is_final = */
|
||||
}
|
||||
|
||||
blockSplitterFinishBlockCommand(&cmd_blocks, true) /* is_final = */
|
||||
blockSplitterFinishBlockDistance(&dist_blocks, true) /* is_final = */
|
||||
|
||||
if num_contexts > 1 {
|
||||
mapStaticContexts(num_contexts, static_context_map, mb)
|
||||
}
|
||||
}
|
||||
|
||||
func buildMetaBlockGreedy(ringbuffer []byte, pos uint, mask uint, prev_byte byte, prev_byte2 byte, literal_context_lut contextLUT, num_contexts uint, static_context_map []uint32, commands []command, mb *metaBlockSplit) {
|
||||
if num_contexts == 1 {
|
||||
buildMetaBlockGreedyInternal(ringbuffer, pos, mask, prev_byte, prev_byte2, literal_context_lut, 1, nil, commands, mb)
|
||||
} else {
|
||||
buildMetaBlockGreedyInternal(ringbuffer, pos, mask, prev_byte, prev_byte2, literal_context_lut, num_contexts, static_context_map, commands, mb)
|
||||
}
|
||||
}
|
||||
|
||||
func optimizeHistograms(num_distance_codes uint32, mb *metaBlockSplit) {
|
||||
var good_for_rle [numCommandSymbols]byte
|
||||
var i uint
|
||||
for i = 0; i < mb.literal_histograms_size; i++ {
|
||||
optimizeHuffmanCountsForRLE(256, mb.literal_histograms[i].data_[:], good_for_rle[:])
|
||||
}
|
||||
|
||||
for i = 0; i < mb.command_histograms_size; i++ {
|
||||
optimizeHuffmanCountsForRLE(numCommandSymbols, mb.command_histograms[i].data_[:], good_for_rle[:])
|
||||
}
|
||||
|
||||
for i = 0; i < mb.distance_histograms_size; i++ {
|
||||
optimizeHuffmanCountsForRLE(uint(num_distance_codes), mb.distance_histograms[i].data_[:], good_for_rle[:])
|
||||
}
|
||||
}
|
||||
+165
@@ -0,0 +1,165 @@
|
||||
package brotli
|
||||
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Greedy block splitter for one block category (literal, command or distance).
|
||||
*/
|
||||
type blockSplitterCommand struct {
|
||||
alphabet_size_ uint
|
||||
min_block_size_ uint
|
||||
split_threshold_ float64
|
||||
num_blocks_ uint
|
||||
split_ *blockSplit
|
||||
histograms_ []histogramCommand
|
||||
histograms_size_ *uint
|
||||
target_block_size_ uint
|
||||
block_size_ uint
|
||||
curr_histogram_ix_ uint
|
||||
last_histogram_ix_ [2]uint
|
||||
last_entropy_ [2]float64
|
||||
merge_last_count_ uint
|
||||
}
|
||||
|
||||
func initBlockSplitterCommand(self *blockSplitterCommand, alphabet_size uint, min_block_size uint, split_threshold float64, num_symbols uint, split *blockSplit, histograms *[]histogramCommand, histograms_size *uint) {
|
||||
var max_num_blocks uint = num_symbols/min_block_size + 1
|
||||
var max_num_types uint = brotli_min_size_t(max_num_blocks, maxNumberOfBlockTypes+1)
|
||||
/* We have to allocate one more histogram than the maximum number of block
|
||||
types for the current histogram when the meta-block is too big. */
|
||||
self.alphabet_size_ = alphabet_size
|
||||
|
||||
self.min_block_size_ = min_block_size
|
||||
self.split_threshold_ = split_threshold
|
||||
self.num_blocks_ = 0
|
||||
self.split_ = split
|
||||
self.histograms_size_ = histograms_size
|
||||
self.target_block_size_ = min_block_size
|
||||
self.block_size_ = 0
|
||||
self.curr_histogram_ix_ = 0
|
||||
self.merge_last_count_ = 0
|
||||
brotli_ensure_capacity_uint8_t(&split.types, &split.types_alloc_size, max_num_blocks)
|
||||
brotli_ensure_capacity_uint32_t(&split.lengths, &split.lengths_alloc_size, max_num_blocks)
|
||||
self.split_.num_blocks = max_num_blocks
|
||||
*histograms_size = max_num_types
|
||||
if histograms == nil || cap(*histograms) < int(*histograms_size) {
|
||||
*histograms = make([]histogramCommand, (*histograms_size))
|
||||
} else {
|
||||
*histograms = (*histograms)[:*histograms_size]
|
||||
}
|
||||
self.histograms_ = *histograms
|
||||
|
||||
/* Clear only current histogram. */
|
||||
histogramClearCommand(&self.histograms_[0])
|
||||
|
||||
self.last_histogram_ix_[1] = 0
|
||||
self.last_histogram_ix_[0] = self.last_histogram_ix_[1]
|
||||
}
|
||||
|
||||
/* Does either of three things:
|
||||
(1) emits the current block with a new block type;
|
||||
(2) emits the current block with the type of the second last block;
|
||||
(3) merges the current block with the last block. */
|
||||
func blockSplitterFinishBlockCommand(self *blockSplitterCommand, is_final bool) {
|
||||
var split *blockSplit = self.split_
|
||||
var last_entropy []float64 = self.last_entropy_[:]
|
||||
var histograms []histogramCommand = self.histograms_
|
||||
self.block_size_ = brotli_max_size_t(self.block_size_, self.min_block_size_)
|
||||
if self.num_blocks_ == 0 {
|
||||
/* Create first block. */
|
||||
split.lengths[0] = uint32(self.block_size_)
|
||||
|
||||
split.types[0] = 0
|
||||
last_entropy[0] = bitsEntropy(histograms[0].data_[:], self.alphabet_size_)
|
||||
last_entropy[1] = last_entropy[0]
|
||||
self.num_blocks_++
|
||||
split.num_types++
|
||||
self.curr_histogram_ix_++
|
||||
if self.curr_histogram_ix_ < *self.histograms_size_ {
|
||||
histogramClearCommand(&histograms[self.curr_histogram_ix_])
|
||||
}
|
||||
self.block_size_ = 0
|
||||
} else if self.block_size_ > 0 {
|
||||
var entropy float64 = bitsEntropy(histograms[self.curr_histogram_ix_].data_[:], self.alphabet_size_)
|
||||
var combined_histo [2]histogramCommand
|
||||
var combined_entropy [2]float64
|
||||
var diff [2]float64
|
||||
var j uint
|
||||
for j = 0; j < 2; j++ {
|
||||
var last_histogram_ix uint = self.last_histogram_ix_[j]
|
||||
combined_histo[j] = histograms[self.curr_histogram_ix_]
|
||||
histogramAddHistogramCommand(&combined_histo[j], &histograms[last_histogram_ix])
|
||||
combined_entropy[j] = bitsEntropy(combined_histo[j].data_[0:], self.alphabet_size_)
|
||||
diff[j] = combined_entropy[j] - entropy - last_entropy[j]
|
||||
}
|
||||
|
||||
if split.num_types < maxNumberOfBlockTypes && diff[0] > self.split_threshold_ && diff[1] > self.split_threshold_ {
|
||||
/* Create new block. */
|
||||
split.lengths[self.num_blocks_] = uint32(self.block_size_)
|
||||
|
||||
split.types[self.num_blocks_] = byte(split.num_types)
|
||||
self.last_histogram_ix_[1] = self.last_histogram_ix_[0]
|
||||
self.last_histogram_ix_[0] = uint(byte(split.num_types))
|
||||
last_entropy[1] = last_entropy[0]
|
||||
last_entropy[0] = entropy
|
||||
self.num_blocks_++
|
||||
split.num_types++
|
||||
self.curr_histogram_ix_++
|
||||
if self.curr_histogram_ix_ < *self.histograms_size_ {
|
||||
histogramClearCommand(&histograms[self.curr_histogram_ix_])
|
||||
}
|
||||
self.block_size_ = 0
|
||||
self.merge_last_count_ = 0
|
||||
self.target_block_size_ = self.min_block_size_
|
||||
} else if diff[1] < diff[0]-20.0 {
|
||||
split.lengths[self.num_blocks_] = uint32(self.block_size_)
|
||||
split.types[self.num_blocks_] = split.types[self.num_blocks_-2]
|
||||
/* Combine this block with second last block. */
|
||||
|
||||
var tmp uint = self.last_histogram_ix_[0]
|
||||
self.last_histogram_ix_[0] = self.last_histogram_ix_[1]
|
||||
self.last_histogram_ix_[1] = tmp
|
||||
histograms[self.last_histogram_ix_[0]] = combined_histo[1]
|
||||
last_entropy[1] = last_entropy[0]
|
||||
last_entropy[0] = combined_entropy[1]
|
||||
self.num_blocks_++
|
||||
self.block_size_ = 0
|
||||
histogramClearCommand(&histograms[self.curr_histogram_ix_])
|
||||
self.merge_last_count_ = 0
|
||||
self.target_block_size_ = self.min_block_size_
|
||||
} else {
|
||||
/* Combine this block with last block. */
|
||||
split.lengths[self.num_blocks_-1] += uint32(self.block_size_)
|
||||
|
||||
histograms[self.last_histogram_ix_[0]] = combined_histo[0]
|
||||
last_entropy[0] = combined_entropy[0]
|
||||
if split.num_types == 1 {
|
||||
last_entropy[1] = last_entropy[0]
|
||||
}
|
||||
|
||||
self.block_size_ = 0
|
||||
histogramClearCommand(&histograms[self.curr_histogram_ix_])
|
||||
self.merge_last_count_++
|
||||
if self.merge_last_count_ > 1 {
|
||||
self.target_block_size_ += self.min_block_size_
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if is_final {
|
||||
*self.histograms_size_ = split.num_types
|
||||
split.num_blocks = self.num_blocks_
|
||||
}
|
||||
}
|
||||
|
||||
/* Adds the next symbol to the current histogram. When the current histogram
|
||||
reaches the target size, decides on merging the block. */
|
||||
func blockSplitterAddSymbolCommand(self *blockSplitterCommand, symbol uint) {
|
||||
histogramAddCommand(&self.histograms_[self.curr_histogram_ix_], symbol)
|
||||
self.block_size_++
|
||||
if self.block_size_ == self.target_block_size_ {
|
||||
blockSplitterFinishBlockCommand(self, false) /* is_final = */
|
||||
}
|
||||
}
|
||||
+165
@@ -0,0 +1,165 @@
|
||||
package brotli
|
||||
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Greedy block splitter for one block category (literal, command or distance).
|
||||
*/
|
||||
type blockSplitterDistance struct {
|
||||
alphabet_size_ uint
|
||||
min_block_size_ uint
|
||||
split_threshold_ float64
|
||||
num_blocks_ uint
|
||||
split_ *blockSplit
|
||||
histograms_ []histogramDistance
|
||||
histograms_size_ *uint
|
||||
target_block_size_ uint
|
||||
block_size_ uint
|
||||
curr_histogram_ix_ uint
|
||||
last_histogram_ix_ [2]uint
|
||||
last_entropy_ [2]float64
|
||||
merge_last_count_ uint
|
||||
}
|
||||
|
||||
func initBlockSplitterDistance(self *blockSplitterDistance, alphabet_size uint, min_block_size uint, split_threshold float64, num_symbols uint, split *blockSplit, histograms *[]histogramDistance, histograms_size *uint) {
|
||||
var max_num_blocks uint = num_symbols/min_block_size + 1
|
||||
var max_num_types uint = brotli_min_size_t(max_num_blocks, maxNumberOfBlockTypes+1)
|
||||
/* We have to allocate one more histogram than the maximum number of block
|
||||
types for the current histogram when the meta-block is too big. */
|
||||
self.alphabet_size_ = alphabet_size
|
||||
|
||||
self.min_block_size_ = min_block_size
|
||||
self.split_threshold_ = split_threshold
|
||||
self.num_blocks_ = 0
|
||||
self.split_ = split
|
||||
self.histograms_size_ = histograms_size
|
||||
self.target_block_size_ = min_block_size
|
||||
self.block_size_ = 0
|
||||
self.curr_histogram_ix_ = 0
|
||||
self.merge_last_count_ = 0
|
||||
brotli_ensure_capacity_uint8_t(&split.types, &split.types_alloc_size, max_num_blocks)
|
||||
brotli_ensure_capacity_uint32_t(&split.lengths, &split.lengths_alloc_size, max_num_blocks)
|
||||
self.split_.num_blocks = max_num_blocks
|
||||
*histograms_size = max_num_types
|
||||
if histograms == nil || cap(*histograms) < int(*histograms_size) {
|
||||
*histograms = make([]histogramDistance, *histograms_size)
|
||||
} else {
|
||||
*histograms = (*histograms)[:*histograms_size]
|
||||
}
|
||||
self.histograms_ = *histograms
|
||||
|
||||
/* Clear only current histogram. */
|
||||
histogramClearDistance(&self.histograms_[0])
|
||||
|
||||
self.last_histogram_ix_[1] = 0
|
||||
self.last_histogram_ix_[0] = self.last_histogram_ix_[1]
|
||||
}
|
||||
|
||||
/* Does either of three things:
|
||||
(1) emits the current block with a new block type;
|
||||
(2) emits the current block with the type of the second last block;
|
||||
(3) merges the current block with the last block. */
|
||||
func blockSplitterFinishBlockDistance(self *blockSplitterDistance, is_final bool) {
|
||||
var split *blockSplit = self.split_
|
||||
var last_entropy []float64 = self.last_entropy_[:]
|
||||
var histograms []histogramDistance = self.histograms_
|
||||
self.block_size_ = brotli_max_size_t(self.block_size_, self.min_block_size_)
|
||||
if self.num_blocks_ == 0 {
|
||||
/* Create first block. */
|
||||
split.lengths[0] = uint32(self.block_size_)
|
||||
|
||||
split.types[0] = 0
|
||||
last_entropy[0] = bitsEntropy(histograms[0].data_[:], self.alphabet_size_)
|
||||
last_entropy[1] = last_entropy[0]
|
||||
self.num_blocks_++
|
||||
split.num_types++
|
||||
self.curr_histogram_ix_++
|
||||
if self.curr_histogram_ix_ < *self.histograms_size_ {
|
||||
histogramClearDistance(&histograms[self.curr_histogram_ix_])
|
||||
}
|
||||
self.block_size_ = 0
|
||||
} else if self.block_size_ > 0 {
|
||||
var entropy float64 = bitsEntropy(histograms[self.curr_histogram_ix_].data_[:], self.alphabet_size_)
|
||||
var combined_histo [2]histogramDistance
|
||||
var combined_entropy [2]float64
|
||||
var diff [2]float64
|
||||
var j uint
|
||||
for j = 0; j < 2; j++ {
|
||||
var last_histogram_ix uint = self.last_histogram_ix_[j]
|
||||
combined_histo[j] = histograms[self.curr_histogram_ix_]
|
||||
histogramAddHistogramDistance(&combined_histo[j], &histograms[last_histogram_ix])
|
||||
combined_entropy[j] = bitsEntropy(combined_histo[j].data_[0:], self.alphabet_size_)
|
||||
diff[j] = combined_entropy[j] - entropy - last_entropy[j]
|
||||
}
|
||||
|
||||
if split.num_types < maxNumberOfBlockTypes && diff[0] > self.split_threshold_ && diff[1] > self.split_threshold_ {
|
||||
/* Create new block. */
|
||||
split.lengths[self.num_blocks_] = uint32(self.block_size_)
|
||||
|
||||
split.types[self.num_blocks_] = byte(split.num_types)
|
||||
self.last_histogram_ix_[1] = self.last_histogram_ix_[0]
|
||||
self.last_histogram_ix_[0] = uint(byte(split.num_types))
|
||||
last_entropy[1] = last_entropy[0]
|
||||
last_entropy[0] = entropy
|
||||
self.num_blocks_++
|
||||
split.num_types++
|
||||
self.curr_histogram_ix_++
|
||||
if self.curr_histogram_ix_ < *self.histograms_size_ {
|
||||
histogramClearDistance(&histograms[self.curr_histogram_ix_])
|
||||
}
|
||||
self.block_size_ = 0
|
||||
self.merge_last_count_ = 0
|
||||
self.target_block_size_ = self.min_block_size_
|
||||
} else if diff[1] < diff[0]-20.0 {
|
||||
split.lengths[self.num_blocks_] = uint32(self.block_size_)
|
||||
split.types[self.num_blocks_] = split.types[self.num_blocks_-2]
|
||||
/* Combine this block with second last block. */
|
||||
|
||||
var tmp uint = self.last_histogram_ix_[0]
|
||||
self.last_histogram_ix_[0] = self.last_histogram_ix_[1]
|
||||
self.last_histogram_ix_[1] = tmp
|
||||
histograms[self.last_histogram_ix_[0]] = combined_histo[1]
|
||||
last_entropy[1] = last_entropy[0]
|
||||
last_entropy[0] = combined_entropy[1]
|
||||
self.num_blocks_++
|
||||
self.block_size_ = 0
|
||||
histogramClearDistance(&histograms[self.curr_histogram_ix_])
|
||||
self.merge_last_count_ = 0
|
||||
self.target_block_size_ = self.min_block_size_
|
||||
} else {
|
||||
/* Combine this block with last block. */
|
||||
split.lengths[self.num_blocks_-1] += uint32(self.block_size_)
|
||||
|
||||
histograms[self.last_histogram_ix_[0]] = combined_histo[0]
|
||||
last_entropy[0] = combined_entropy[0]
|
||||
if split.num_types == 1 {
|
||||
last_entropy[1] = last_entropy[0]
|
||||
}
|
||||
|
||||
self.block_size_ = 0
|
||||
histogramClearDistance(&histograms[self.curr_histogram_ix_])
|
||||
self.merge_last_count_++
|
||||
if self.merge_last_count_ > 1 {
|
||||
self.target_block_size_ += self.min_block_size_
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if is_final {
|
||||
*self.histograms_size_ = split.num_types
|
||||
split.num_blocks = self.num_blocks_
|
||||
}
|
||||
}
|
||||
|
||||
/* Adds the next symbol to the current histogram. When the current histogram
|
||||
reaches the target size, decides on merging the block. */
|
||||
func blockSplitterAddSymbolDistance(self *blockSplitterDistance, symbol uint) {
|
||||
histogramAddDistance(&self.histograms_[self.curr_histogram_ix_], symbol)
|
||||
self.block_size_++
|
||||
if self.block_size_ == self.target_block_size_ {
|
||||
blockSplitterFinishBlockDistance(self, false) /* is_final = */
|
||||
}
|
||||
}
|
||||
+165
@@ -0,0 +1,165 @@
|
||||
package brotli
|
||||
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Greedy block splitter for one block category (literal, command or distance).
|
||||
*/
|
||||
type blockSplitterLiteral struct {
|
||||
alphabet_size_ uint
|
||||
min_block_size_ uint
|
||||
split_threshold_ float64
|
||||
num_blocks_ uint
|
||||
split_ *blockSplit
|
||||
histograms_ []histogramLiteral
|
||||
histograms_size_ *uint
|
||||
target_block_size_ uint
|
||||
block_size_ uint
|
||||
curr_histogram_ix_ uint
|
||||
last_histogram_ix_ [2]uint
|
||||
last_entropy_ [2]float64
|
||||
merge_last_count_ uint
|
||||
}
|
||||
|
||||
func initBlockSplitterLiteral(self *blockSplitterLiteral, alphabet_size uint, min_block_size uint, split_threshold float64, num_symbols uint, split *blockSplit, histograms *[]histogramLiteral, histograms_size *uint) {
|
||||
var max_num_blocks uint = num_symbols/min_block_size + 1
|
||||
var max_num_types uint = brotli_min_size_t(max_num_blocks, maxNumberOfBlockTypes+1)
|
||||
/* We have to allocate one more histogram than the maximum number of block
|
||||
types for the current histogram when the meta-block is too big. */
|
||||
self.alphabet_size_ = alphabet_size
|
||||
|
||||
self.min_block_size_ = min_block_size
|
||||
self.split_threshold_ = split_threshold
|
||||
self.num_blocks_ = 0
|
||||
self.split_ = split
|
||||
self.histograms_size_ = histograms_size
|
||||
self.target_block_size_ = min_block_size
|
||||
self.block_size_ = 0
|
||||
self.curr_histogram_ix_ = 0
|
||||
self.merge_last_count_ = 0
|
||||
brotli_ensure_capacity_uint8_t(&split.types, &split.types_alloc_size, max_num_blocks)
|
||||
brotli_ensure_capacity_uint32_t(&split.lengths, &split.lengths_alloc_size, max_num_blocks)
|
||||
self.split_.num_blocks = max_num_blocks
|
||||
*histograms_size = max_num_types
|
||||
if histograms == nil || cap(*histograms) < int(*histograms_size) {
|
||||
*histograms = make([]histogramLiteral, *histograms_size)
|
||||
} else {
|
||||
*histograms = (*histograms)[:*histograms_size]
|
||||
}
|
||||
self.histograms_ = *histograms
|
||||
|
||||
/* Clear only current histogram. */
|
||||
histogramClearLiteral(&self.histograms_[0])
|
||||
|
||||
self.last_histogram_ix_[1] = 0
|
||||
self.last_histogram_ix_[0] = self.last_histogram_ix_[1]
|
||||
}
|
||||
|
||||
/* Does either of three things:
|
||||
(1) emits the current block with a new block type;
|
||||
(2) emits the current block with the type of the second last block;
|
||||
(3) merges the current block with the last block. */
|
||||
func blockSplitterFinishBlockLiteral(self *blockSplitterLiteral, is_final bool) {
|
||||
var split *blockSplit = self.split_
|
||||
var last_entropy []float64 = self.last_entropy_[:]
|
||||
var histograms []histogramLiteral = self.histograms_
|
||||
self.block_size_ = brotli_max_size_t(self.block_size_, self.min_block_size_)
|
||||
if self.num_blocks_ == 0 {
|
||||
/* Create first block. */
|
||||
split.lengths[0] = uint32(self.block_size_)
|
||||
|
||||
split.types[0] = 0
|
||||
last_entropy[0] = bitsEntropy(histograms[0].data_[:], self.alphabet_size_)
|
||||
last_entropy[1] = last_entropy[0]
|
||||
self.num_blocks_++
|
||||
split.num_types++
|
||||
self.curr_histogram_ix_++
|
||||
if self.curr_histogram_ix_ < *self.histograms_size_ {
|
||||
histogramClearLiteral(&histograms[self.curr_histogram_ix_])
|
||||
}
|
||||
self.block_size_ = 0
|
||||
} else if self.block_size_ > 0 {
|
||||
var entropy float64 = bitsEntropy(histograms[self.curr_histogram_ix_].data_[:], self.alphabet_size_)
|
||||
var combined_histo [2]histogramLiteral
|
||||
var combined_entropy [2]float64
|
||||
var diff [2]float64
|
||||
var j uint
|
||||
for j = 0; j < 2; j++ {
|
||||
var last_histogram_ix uint = self.last_histogram_ix_[j]
|
||||
combined_histo[j] = histograms[self.curr_histogram_ix_]
|
||||
histogramAddHistogramLiteral(&combined_histo[j], &histograms[last_histogram_ix])
|
||||
combined_entropy[j] = bitsEntropy(combined_histo[j].data_[0:], self.alphabet_size_)
|
||||
diff[j] = combined_entropy[j] - entropy - last_entropy[j]
|
||||
}
|
||||
|
||||
if split.num_types < maxNumberOfBlockTypes && diff[0] > self.split_threshold_ && diff[1] > self.split_threshold_ {
|
||||
/* Create new block. */
|
||||
split.lengths[self.num_blocks_] = uint32(self.block_size_)
|
||||
|
||||
split.types[self.num_blocks_] = byte(split.num_types)
|
||||
self.last_histogram_ix_[1] = self.last_histogram_ix_[0]
|
||||
self.last_histogram_ix_[0] = uint(byte(split.num_types))
|
||||
last_entropy[1] = last_entropy[0]
|
||||
last_entropy[0] = entropy
|
||||
self.num_blocks_++
|
||||
split.num_types++
|
||||
self.curr_histogram_ix_++
|
||||
if self.curr_histogram_ix_ < *self.histograms_size_ {
|
||||
histogramClearLiteral(&histograms[self.curr_histogram_ix_])
|
||||
}
|
||||
self.block_size_ = 0
|
||||
self.merge_last_count_ = 0
|
||||
self.target_block_size_ = self.min_block_size_
|
||||
} else if diff[1] < diff[0]-20.0 {
|
||||
split.lengths[self.num_blocks_] = uint32(self.block_size_)
|
||||
split.types[self.num_blocks_] = split.types[self.num_blocks_-2]
|
||||
/* Combine this block with second last block. */
|
||||
|
||||
var tmp uint = self.last_histogram_ix_[0]
|
||||
self.last_histogram_ix_[0] = self.last_histogram_ix_[1]
|
||||
self.last_histogram_ix_[1] = tmp
|
||||
histograms[self.last_histogram_ix_[0]] = combined_histo[1]
|
||||
last_entropy[1] = last_entropy[0]
|
||||
last_entropy[0] = combined_entropy[1]
|
||||
self.num_blocks_++
|
||||
self.block_size_ = 0
|
||||
histogramClearLiteral(&histograms[self.curr_histogram_ix_])
|
||||
self.merge_last_count_ = 0
|
||||
self.target_block_size_ = self.min_block_size_
|
||||
} else {
|
||||
/* Combine this block with last block. */
|
||||
split.lengths[self.num_blocks_-1] += uint32(self.block_size_)
|
||||
|
||||
histograms[self.last_histogram_ix_[0]] = combined_histo[0]
|
||||
last_entropy[0] = combined_entropy[0]
|
||||
if split.num_types == 1 {
|
||||
last_entropy[1] = last_entropy[0]
|
||||
}
|
||||
|
||||
self.block_size_ = 0
|
||||
histogramClearLiteral(&histograms[self.curr_histogram_ix_])
|
||||
self.merge_last_count_++
|
||||
if self.merge_last_count_ > 1 {
|
||||
self.target_block_size_ += self.min_block_size_
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if is_final {
|
||||
*self.histograms_size_ = split.num_types
|
||||
split.num_blocks = self.num_blocks_
|
||||
}
|
||||
}
|
||||
|
||||
/* Adds the next symbol to the current histogram. When the current histogram
|
||||
reaches the target size, decides on merging the block. */
|
||||
func blockSplitterAddSymbolLiteral(self *blockSplitterLiteral, symbol uint) {
|
||||
histogramAddLiteral(&self.histograms_[self.curr_histogram_ix_], symbol)
|
||||
self.block_size_++
|
||||
if self.block_size_ == self.target_block_size_ {
|
||||
blockSplitterFinishBlockLiteral(self, false) /* is_final = */
|
||||
}
|
||||
}
|
||||
+37
@@ -0,0 +1,37 @@
|
||||
package brotli
|
||||
|
||||
/* Copyright 2017 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Parameters for the Brotli encoder with chosen quality levels. */
|
||||
type hasherParams struct {
|
||||
type_ int
|
||||
bucket_bits int
|
||||
block_bits int
|
||||
hash_len int
|
||||
num_last_distances_to_check int
|
||||
}
|
||||
|
||||
type distanceParams struct {
|
||||
distance_postfix_bits uint32
|
||||
num_direct_distance_codes uint32
|
||||
alphabet_size uint32
|
||||
max_distance uint
|
||||
}
|
||||
|
||||
/* Encoding parameters */
|
||||
type encoderParams struct {
|
||||
mode int
|
||||
quality int
|
||||
lgwin uint
|
||||
lgblock int
|
||||
size_hint uint
|
||||
disable_literal_context_modeling bool
|
||||
large_window bool
|
||||
hasher hasherParams
|
||||
dist distanceParams
|
||||
dictionary encoderDictionary
|
||||
}
|
||||
+103
@@ -0,0 +1,103 @@
|
||||
package brotli
|
||||
|
||||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
func brotli_min_double(a float64, b float64) float64 {
|
||||
if a < b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
func brotli_max_double(a float64, b float64) float64 {
|
||||
if a > b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
func brotli_min_float(a float32, b float32) float32 {
|
||||
if a < b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
func brotli_max_float(a float32, b float32) float32 {
|
||||
if a > b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
func brotli_min_int(a int, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
func brotli_max_int(a int, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
func brotli_min_size_t(a uint, b uint) uint {
|
||||
if a < b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
func brotli_max_size_t(a uint, b uint) uint {
|
||||
if a > b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
func brotli_min_uint32_t(a uint32, b uint32) uint32 {
|
||||
if a < b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
func brotli_max_uint32_t(a uint32, b uint32) uint32 {
|
||||
if a > b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
func brotli_min_uint8_t(a byte, b byte) byte {
|
||||
if a < b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
func brotli_max_uint8_t(a byte, b byte) byte {
|
||||
if a > b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
+30
@@ -0,0 +1,30 @@
|
||||
package brotli
|
||||
|
||||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Functions for encoding of integers into prefix codes the amount of extra
|
||||
bits, and the actual values of the extra bits. */
|
||||
|
||||
/* Here distance_code is an intermediate code, i.e. one of the special codes or
|
||||
the actual distance increased by BROTLI_NUM_DISTANCE_SHORT_CODES - 1. */
|
||||
func prefixEncodeCopyDistance(distance_code uint, num_direct_codes uint, postfix_bits uint, code *uint16, extra_bits *uint32) {
|
||||
if distance_code < numDistanceShortCodes+num_direct_codes {
|
||||
*code = uint16(distance_code)
|
||||
*extra_bits = 0
|
||||
return
|
||||
} else {
|
||||
var dist uint = (uint(1) << (postfix_bits + 2)) + (distance_code - numDistanceShortCodes - num_direct_codes)
|
||||
var bucket uint = uint(log2FloorNonZero(dist) - 1)
|
||||
var postfix_mask uint = (1 << postfix_bits) - 1
|
||||
var postfix uint = dist & postfix_mask
|
||||
var prefix uint = (dist >> bucket) & 1
|
||||
var offset uint = (2 + prefix) << bucket
|
||||
var nbits uint = bucket - postfix_bits
|
||||
*code = uint16(nbits<<10 | (numDistanceShortCodes + num_direct_codes + ((2*(nbits-1) + prefix) << postfix_bits) + postfix))
|
||||
*extra_bits = uint32((dist - offset) >> postfix_bits)
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user