feat(auth): add OAuth PKCE login flow with browser + manual fallback
This commit is contained in:
@@ -0,0 +1,275 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
authURL = "https://claude.com/cai/oauth/authorize"
|
||||
manualRedirect = "https://platform.claude.com/oauth/code/callback"
|
||||
)
|
||||
|
||||
func base64URLEncode(data []byte) string {
|
||||
return base64.RawURLEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
func generateCodeVerifier() string {
|
||||
buf := make([]byte, 32)
|
||||
_, _ = rand.Read(buf)
|
||||
return base64URLEncode(buf)
|
||||
}
|
||||
|
||||
func generateCodeChallenge(verifier string) string {
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
return base64URLEncode(h[:])
|
||||
}
|
||||
|
||||
func generateState() string {
|
||||
buf := make([]byte, 32)
|
||||
_, _ = rand.Read(buf)
|
||||
return base64URLEncode(buf)
|
||||
}
|
||||
|
||||
func buildAuthURL(port int, codeChallenge, state string) string {
|
||||
u, _ := url.Parse(authURL)
|
||||
q := u.Query()
|
||||
q.Set("client_id", clientID)
|
||||
q.Set("response_type", "code")
|
||||
q.Set("redirect_uri", fmt.Sprintf("http://localhost:%d/callback", port))
|
||||
q.Set("scope", oauthScopes)
|
||||
q.Set("code_challenge", codeChallenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
q.Set("state", state)
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func buildManualAuthURL(codeChallenge, state string) string {
|
||||
u, _ := url.Parse(authURL)
|
||||
q := u.Query()
|
||||
q.Set("client_id", clientID)
|
||||
q.Set("response_type", "code")
|
||||
q.Set("redirect_uri", manualRedirect)
|
||||
q.Set("scope", oauthScopes)
|
||||
q.Set("code_challenge", codeChallenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
q.Set("state", state)
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func startCallbackServer(expectedState string) (port int, codeChan <-chan string, cleanup func(), err error) {
|
||||
ch := make(chan string, 1)
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return 0, nil, nil, err
|
||||
}
|
||||
port = ln.Addr().(*net.TCPAddr).Port
|
||||
srv := &http.Server{}
|
||||
srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/callback" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
if state != expectedState {
|
||||
http.Error(w, "invalid state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if code == "" {
|
||||
http.Error(w, "missing code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case ch <- code:
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprintln(w, "<html><body><h2>Login successful! You can close this tab.</h2></body></html>")
|
||||
default:
|
||||
fmt.Fprintln(w, "<html><body><h2>Already received. You can close this tab.</h2></body></html>")
|
||||
}
|
||||
})
|
||||
go srv.Serve(ln)
|
||||
cleanup = func() {
|
||||
srv.Close()
|
||||
}
|
||||
return port, ch, cleanup, nil
|
||||
}
|
||||
|
||||
// DefaultCredentialPath returns the path to the Claude credentials file.
|
||||
func DefaultCredentialPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".claude", ".credentials.json"), nil
|
||||
}
|
||||
|
||||
// Login performs the full OAuth 2.0 PKCE login flow and returns a Credential.
|
||||
func Login(ctx context.Context) (*Credential, error) {
|
||||
verifier := generateCodeVerifier()
|
||||
challenge := generateCodeChallenge(verifier)
|
||||
state := generateState()
|
||||
|
||||
port, codeChan, cleanup, err := startCallbackServer(state)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start callback server: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
autoURL := buildAuthURL(port, challenge, state)
|
||||
manualURL := buildManualAuthURL(challenge, state)
|
||||
|
||||
fmt.Printf("\nTo sign in, visit:\n %s\n\n", manualURL)
|
||||
openBrowser(autoURL)
|
||||
|
||||
var authCode string
|
||||
var isManual bool
|
||||
|
||||
stdinCh := make(chan string, 1)
|
||||
fi, _ := os.Stdin.Stat()
|
||||
if (fi.Mode() & os.ModeCharDevice) != 0 {
|
||||
fmt.Print("If browser didn't open, paste the authorization code here: ")
|
||||
go func() {
|
||||
var line string
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
if scanner.Scan() {
|
||||
line = strings.TrimSpace(scanner.Text())
|
||||
}
|
||||
if line != "" {
|
||||
stdinCh <- line
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
timeout := time.NewTimer(120 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
select {
|
||||
case code := <-codeChan:
|
||||
authCode = code
|
||||
isManual = false
|
||||
case code := <-stdinCh:
|
||||
authCode = code
|
||||
isManual = true
|
||||
case <-timeout.C:
|
||||
return nil, fmt.Errorf("login timed out after 120 seconds")
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
credPath, err := DefaultCredentialPath()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("credential path: %w", err)
|
||||
}
|
||||
return exchangeAuthCode(ctx, authCode, state, verifier, port, isManual, credPath)
|
||||
}
|
||||
|
||||
type authCodeRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
Code string `json:"code"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ClientID string `json:"client_id"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
func exchangeAuthCode(ctx context.Context, code, state, verifier string, port int, isManual bool, credPath string) (*Credential, error) {
|
||||
redirectURI := fmt.Sprintf("http://localhost:%d/callback", port)
|
||||
if isManual {
|
||||
redirectURI = manualRedirect
|
||||
}
|
||||
|
||||
reqBody, _ := json.Marshal(authCodeRequest{
|
||||
GrantType: "authorization_code",
|
||||
Code: code,
|
||||
RedirectURI: redirectURI,
|
||||
ClientID: clientID,
|
||||
CodeVerifier: verifier,
|
||||
State: state,
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenEndpoint, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := utlsClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token exchange: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token exchange returned %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokenResp tokenResponse
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("decode token response: %w", err)
|
||||
}
|
||||
|
||||
cred := &Credential{
|
||||
ID: "claude-native",
|
||||
Email: tokenResp.Account.EmailAddress,
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ExpiresAt: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second),
|
||||
FilePath: credPath,
|
||||
}
|
||||
|
||||
if err := ensureCredentialFile(credPath); err != nil {
|
||||
return nil, fmt.Errorf("ensure credential file: %w", err)
|
||||
}
|
||||
|
||||
if err := persistCredential(cred); err != nil {
|
||||
return nil, fmt.Errorf("save credential: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("login successful, credentials saved to %s", credPath)
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
func ensureCredentialFile(path string) error {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return nil
|
||||
}
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, []byte("{}"), 0600)
|
||||
}
|
||||
|
||||
func openBrowser(url string) {
|
||||
var cmd *exec.Cmd
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
cmd = exec.Command("open", url)
|
||||
case "linux":
|
||||
cmd = exec.Command("xdg-open", url)
|
||||
default:
|
||||
return
|
||||
}
|
||||
_ = cmd.Start()
|
||||
}
|
||||
Reference in New Issue
Block a user