276 lines
6.9 KiB
Go
276 lines
6.9 KiB
Go
package auth
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
authURL = "https://claude.com/cai/oauth/authorize"
|
|
manualRedirect = "https://platform.claude.com/oauth/code/callback"
|
|
)
|
|
|
|
func base64URLEncode(data []byte) string {
|
|
return base64.RawURLEncoding.EncodeToString(data)
|
|
}
|
|
|
|
func generateCodeVerifier() string {
|
|
buf := make([]byte, 32)
|
|
_, _ = rand.Read(buf)
|
|
return base64URLEncode(buf)
|
|
}
|
|
|
|
func generateCodeChallenge(verifier string) string {
|
|
h := sha256.Sum256([]byte(verifier))
|
|
return base64URLEncode(h[:])
|
|
}
|
|
|
|
func generateState() string {
|
|
buf := make([]byte, 32)
|
|
_, _ = rand.Read(buf)
|
|
return base64URLEncode(buf)
|
|
}
|
|
|
|
func buildAuthURL(port int, codeChallenge, state string) string {
|
|
u, _ := url.Parse(authURL)
|
|
q := u.Query()
|
|
q.Set("client_id", clientID)
|
|
q.Set("response_type", "code")
|
|
q.Set("redirect_uri", fmt.Sprintf("http://localhost:%d/callback", port))
|
|
q.Set("scope", oauthScopes)
|
|
q.Set("code_challenge", codeChallenge)
|
|
q.Set("code_challenge_method", "S256")
|
|
q.Set("state", state)
|
|
u.RawQuery = q.Encode()
|
|
return u.String()
|
|
}
|
|
|
|
func buildManualAuthURL(codeChallenge, state string) string {
|
|
u, _ := url.Parse(authURL)
|
|
q := u.Query()
|
|
q.Set("client_id", clientID)
|
|
q.Set("response_type", "code")
|
|
q.Set("redirect_uri", manualRedirect)
|
|
q.Set("scope", oauthScopes)
|
|
q.Set("code_challenge", codeChallenge)
|
|
q.Set("code_challenge_method", "S256")
|
|
q.Set("state", state)
|
|
u.RawQuery = q.Encode()
|
|
return u.String()
|
|
}
|
|
|
|
func startCallbackServer(expectedState string) (port int, codeChan <-chan string, cleanup func(), err error) {
|
|
ch := make(chan string, 1)
|
|
ln, err := net.Listen("tcp", "localhost:0")
|
|
if err != nil {
|
|
return 0, nil, nil, err
|
|
}
|
|
port = ln.Addr().(*net.TCPAddr).Port
|
|
srv := &http.Server{}
|
|
srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/callback" {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
code := r.URL.Query().Get("code")
|
|
state := r.URL.Query().Get("state")
|
|
if state != expectedState {
|
|
http.Error(w, "invalid state", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if code == "" {
|
|
http.Error(w, "missing code", http.StatusBadRequest)
|
|
return
|
|
}
|
|
select {
|
|
case ch <- code:
|
|
w.Header().Set("Content-Type", "text/html")
|
|
fmt.Fprintln(w, "<html><body><h2>Login successful! You can close this tab.</h2></body></html>")
|
|
default:
|
|
fmt.Fprintln(w, "<html><body><h2>Already received. You can close this tab.</h2></body></html>")
|
|
}
|
|
})
|
|
go srv.Serve(ln)
|
|
cleanup = func() {
|
|
srv.Close()
|
|
}
|
|
return port, ch, cleanup, nil
|
|
}
|
|
|
|
// DefaultCredentialPath returns the path to the Claude credentials file.
|
|
func DefaultCredentialPath() (string, error) {
|
|
home, err := os.UserHomeDir()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return filepath.Join(home, ".claude", ".credentials.json"), nil
|
|
}
|
|
|
|
// Login performs the full OAuth 2.0 PKCE login flow and returns a Credential.
|
|
func Login(ctx context.Context) (*Credential, error) {
|
|
verifier := generateCodeVerifier()
|
|
challenge := generateCodeChallenge(verifier)
|
|
state := generateState()
|
|
|
|
port, codeChan, cleanup, err := startCallbackServer(state)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("start callback server: %w", err)
|
|
}
|
|
defer cleanup()
|
|
|
|
autoURL := buildAuthURL(port, challenge, state)
|
|
manualURL := buildManualAuthURL(challenge, state)
|
|
|
|
fmt.Printf("\nTo sign in, visit:\n %s\n\n", manualURL)
|
|
openBrowser(autoURL)
|
|
|
|
var authCode string
|
|
var isManual bool
|
|
|
|
stdinCh := make(chan string, 1)
|
|
fi, statErr := os.Stdin.Stat()
|
|
if statErr == nil && (fi.Mode()&os.ModeCharDevice) != 0 {
|
|
fmt.Print("If browser didn't open, paste the authorization code here: ")
|
|
go func() {
|
|
var line string
|
|
scanner := bufio.NewScanner(os.Stdin)
|
|
if scanner.Scan() {
|
|
line = strings.TrimSpace(scanner.Text())
|
|
}
|
|
if line != "" {
|
|
stdinCh <- line
|
|
}
|
|
}()
|
|
}
|
|
|
|
timeout := time.NewTimer(120 * time.Second)
|
|
defer timeout.Stop()
|
|
|
|
select {
|
|
case code := <-codeChan:
|
|
authCode = code
|
|
isManual = false
|
|
case code := <-stdinCh:
|
|
authCode = code
|
|
isManual = true
|
|
case <-timeout.C:
|
|
return nil, fmt.Errorf("login timed out after 120 seconds")
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
|
|
credPath, err := DefaultCredentialPath()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("credential path: %w", err)
|
|
}
|
|
return exchangeAuthCode(ctx, authCode, state, verifier, port, isManual, credPath)
|
|
}
|
|
|
|
type authCodeRequest struct {
|
|
GrantType string `json:"grant_type"`
|
|
Code string `json:"code"`
|
|
RedirectURI string `json:"redirect_uri"`
|
|
ClientID string `json:"client_id"`
|
|
CodeVerifier string `json:"code_verifier"`
|
|
State string `json:"state"`
|
|
}
|
|
|
|
func exchangeAuthCode(ctx context.Context, code, state, verifier string, port int, isManual bool, credPath string) (*Credential, error) {
|
|
redirectURI := fmt.Sprintf("http://localhost:%d/callback", port)
|
|
if isManual {
|
|
redirectURI = manualRedirect
|
|
}
|
|
|
|
reqBody, _ := json.Marshal(authCodeRequest{
|
|
GrantType: "authorization_code",
|
|
Code: code,
|
|
RedirectURI: redirectURI,
|
|
ClientID: clientID,
|
|
CodeVerifier: verifier,
|
|
State: state,
|
|
})
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenEndpoint, bytes.NewReader(reqBody))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := utlsClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("token exchange: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("token exchange returned %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var tokenResp tokenResponse
|
|
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
|
return nil, fmt.Errorf("decode token response: %w", err)
|
|
}
|
|
|
|
cred := &Credential{
|
|
ID: "claude-native",
|
|
Email: tokenResp.Account.EmailAddress,
|
|
AccessToken: tokenResp.AccessToken,
|
|
RefreshToken: tokenResp.RefreshToken,
|
|
ExpiresAt: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second),
|
|
FilePath: credPath,
|
|
}
|
|
|
|
if err := ensureCredentialFile(credPath); err != nil {
|
|
return nil, fmt.Errorf("ensure credential file: %w", err)
|
|
}
|
|
|
|
if err := persistCredential(cred); err != nil {
|
|
return nil, fmt.Errorf("save credential: %w", err)
|
|
}
|
|
|
|
log.Printf("login successful, credentials saved to %s", credPath)
|
|
return cred, nil
|
|
}
|
|
|
|
func ensureCredentialFile(path string) error {
|
|
if _, err := os.Stat(path); err == nil {
|
|
return nil
|
|
}
|
|
dir := filepath.Dir(path)
|
|
if err := os.MkdirAll(dir, 0700); err != nil {
|
|
return err
|
|
}
|
|
return os.WriteFile(path, []byte("{}"), 0600)
|
|
}
|
|
|
|
func openBrowser(url string) {
|
|
var cmd *exec.Cmd
|
|
switch runtime.GOOS {
|
|
case "darwin":
|
|
cmd = exec.Command("open", url)
|
|
case "linux":
|
|
cmd = exec.Command("xdg-open", url)
|
|
default:
|
|
return
|
|
}
|
|
_ = cmd.Start()
|
|
}
|