Files
anthropic-proxy/internal/auth/login.go
T

276 lines
6.9 KiB
Go

package auth
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
)
const (
authURL = "https://claude.com/cai/oauth/authorize"
manualRedirect = "https://platform.claude.com/oauth/code/callback"
)
func base64URLEncode(data []byte) string {
return base64.RawURLEncoding.EncodeToString(data)
}
func generateCodeVerifier() string {
buf := make([]byte, 32)
_, _ = rand.Read(buf)
return base64URLEncode(buf)
}
func generateCodeChallenge(verifier string) string {
h := sha256.Sum256([]byte(verifier))
return base64URLEncode(h[:])
}
func generateState() string {
buf := make([]byte, 32)
_, _ = rand.Read(buf)
return base64URLEncode(buf)
}
func buildAuthURL(port int, codeChallenge, state string) string {
u, _ := url.Parse(authURL)
q := u.Query()
q.Set("client_id", clientID)
q.Set("response_type", "code")
q.Set("redirect_uri", fmt.Sprintf("http://localhost:%d/callback", port))
q.Set("scope", oauthScopes)
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
q.Set("state", state)
u.RawQuery = q.Encode()
return u.String()
}
func buildManualAuthURL(codeChallenge, state string) string {
u, _ := url.Parse(authURL)
q := u.Query()
q.Set("client_id", clientID)
q.Set("response_type", "code")
q.Set("redirect_uri", manualRedirect)
q.Set("scope", oauthScopes)
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
q.Set("state", state)
u.RawQuery = q.Encode()
return u.String()
}
func startCallbackServer(expectedState string) (port int, codeChan <-chan string, cleanup func(), err error) {
ch := make(chan string, 1)
ln, err := net.Listen("tcp", "localhost:0")
if err != nil {
return 0, nil, nil, err
}
port = ln.Addr().(*net.TCPAddr).Port
srv := &http.Server{}
srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/callback" {
http.NotFound(w, r)
return
}
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
if state != expectedState {
http.Error(w, "invalid state", http.StatusBadRequest)
return
}
if code == "" {
http.Error(w, "missing code", http.StatusBadRequest)
return
}
select {
case ch <- code:
w.Header().Set("Content-Type", "text/html")
fmt.Fprintln(w, "<html><body><h2>Login successful! You can close this tab.</h2></body></html>")
default:
fmt.Fprintln(w, "<html><body><h2>Already received. You can close this tab.</h2></body></html>")
}
})
go srv.Serve(ln)
cleanup = func() {
srv.Close()
}
return port, ch, cleanup, nil
}
// DefaultCredentialPath returns the path to the Claude credentials file.
func DefaultCredentialPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".claude", ".credentials.json"), nil
}
// Login performs the full OAuth 2.0 PKCE login flow and returns a Credential.
func Login(ctx context.Context) (*Credential, error) {
verifier := generateCodeVerifier()
challenge := generateCodeChallenge(verifier)
state := generateState()
port, codeChan, cleanup, err := startCallbackServer(state)
if err != nil {
return nil, fmt.Errorf("start callback server: %w", err)
}
defer cleanup()
autoURL := buildAuthURL(port, challenge, state)
manualURL := buildManualAuthURL(challenge, state)
fmt.Printf("\nTo sign in, visit:\n %s\n\n", manualURL)
openBrowser(autoURL)
var authCode string
var isManual bool
stdinCh := make(chan string, 1)
fi, statErr := os.Stdin.Stat()
if statErr == nil && (fi.Mode()&os.ModeCharDevice) != 0 {
fmt.Print("If browser didn't open, paste the authorization code here: ")
go func() {
var line string
scanner := bufio.NewScanner(os.Stdin)
if scanner.Scan() {
line = strings.TrimSpace(scanner.Text())
}
if line != "" {
stdinCh <- line
}
}()
}
timeout := time.NewTimer(120 * time.Second)
defer timeout.Stop()
select {
case code := <-codeChan:
authCode = code
isManual = false
case code := <-stdinCh:
authCode = code
isManual = true
case <-timeout.C:
return nil, fmt.Errorf("login timed out after 120 seconds")
case <-ctx.Done():
return nil, ctx.Err()
}
credPath, err := DefaultCredentialPath()
if err != nil {
return nil, fmt.Errorf("credential path: %w", err)
}
return exchangeAuthCode(ctx, authCode, state, verifier, port, isManual, credPath)
}
type authCodeRequest struct {
GrantType string `json:"grant_type"`
Code string `json:"code"`
RedirectURI string `json:"redirect_uri"`
ClientID string `json:"client_id"`
CodeVerifier string `json:"code_verifier"`
State string `json:"state"`
}
func exchangeAuthCode(ctx context.Context, code, state, verifier string, port int, isManual bool, credPath string) (*Credential, error) {
redirectURI := fmt.Sprintf("http://localhost:%d/callback", port)
if isManual {
redirectURI = manualRedirect
}
reqBody, _ := json.Marshal(authCodeRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: redirectURI,
ClientID: clientID,
CodeVerifier: verifier,
State: state,
})
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenEndpoint, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := utlsClient.Do(req)
if err != nil {
return nil, fmt.Errorf("token exchange: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token exchange returned %d: %s", resp.StatusCode, string(body))
}
var tokenResp tokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("decode token response: %w", err)
}
cred := &Credential{
ID: "claude-native",
Email: tokenResp.Account.EmailAddress,
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ExpiresAt: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second),
FilePath: credPath,
}
if err := ensureCredentialFile(credPath); err != nil {
return nil, fmt.Errorf("ensure credential file: %w", err)
}
if err := persistCredential(cred); err != nil {
return nil, fmt.Errorf("save credential: %w", err)
}
log.Printf("login successful, credentials saved to %s", credPath)
return cred, nil
}
func ensureCredentialFile(path string) error {
if _, err := os.Stat(path); err == nil {
return nil
}
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0700); err != nil {
return err
}
return os.WriteFile(path, []byte("{}"), 0600)
}
func openBrowser(url string) {
var cmd *exec.Cmd
switch runtime.GOOS {
case "darwin":
cmd = exec.Command("open", url)
case "linux":
cmd = exec.Command("xdg-open", url)
default:
return
}
_ = cmd.Start()
}