diff --git a/internal/auth/login.go b/internal/auth/login.go new file mode 100644 index 0000000..c86f2c3 --- /dev/null +++ b/internal/auth/login.go @@ -0,0 +1,275 @@ +package auth + +import ( + "bufio" + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log" + "net" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "time" +) + +const ( + authURL = "https://claude.com/cai/oauth/authorize" + manualRedirect = "https://platform.claude.com/oauth/code/callback" +) + +func base64URLEncode(data []byte) string { + return base64.RawURLEncoding.EncodeToString(data) +} + +func generateCodeVerifier() string { + buf := make([]byte, 32) + _, _ = rand.Read(buf) + return base64URLEncode(buf) +} + +func generateCodeChallenge(verifier string) string { + h := sha256.Sum256([]byte(verifier)) + return base64URLEncode(h[:]) +} + +func generateState() string { + buf := make([]byte, 32) + _, _ = rand.Read(buf) + return base64URLEncode(buf) +} + +func buildAuthURL(port int, codeChallenge, state string) string { + u, _ := url.Parse(authURL) + q := u.Query() + q.Set("client_id", clientID) + q.Set("response_type", "code") + q.Set("redirect_uri", fmt.Sprintf("http://localhost:%d/callback", port)) + q.Set("scope", oauthScopes) + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") + q.Set("state", state) + u.RawQuery = q.Encode() + return u.String() +} + +func buildManualAuthURL(codeChallenge, state string) string { + u, _ := url.Parse(authURL) + q := u.Query() + q.Set("client_id", clientID) + q.Set("response_type", "code") + q.Set("redirect_uri", manualRedirect) + q.Set("scope", oauthScopes) + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") + q.Set("state", state) + u.RawQuery = q.Encode() + return u.String() +} + +func startCallbackServer(expectedState string) (port int, codeChan <-chan string, cleanup func(), err error) { + ch := make(chan string, 1) + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return 0, nil, nil, err + } + port = ln.Addr().(*net.TCPAddr).Port + srv := &http.Server{} + srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/callback" { + http.NotFound(w, r) + return + } + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + if state != expectedState { + http.Error(w, "invalid state", http.StatusBadRequest) + return + } + if code == "" { + http.Error(w, "missing code", http.StatusBadRequest) + return + } + select { + case ch <- code: + w.Header().Set("Content-Type", "text/html") + fmt.Fprintln(w, "

Login successful! You can close this tab.

") + default: + fmt.Fprintln(w, "

Already received. You can close this tab.

") + } + }) + go srv.Serve(ln) + cleanup = func() { + srv.Close() + } + return port, ch, cleanup, nil +} + +// DefaultCredentialPath returns the path to the Claude credentials file. +func DefaultCredentialPath() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + return filepath.Join(home, ".claude", ".credentials.json"), nil +} + +// Login performs the full OAuth 2.0 PKCE login flow and returns a Credential. +func Login(ctx context.Context) (*Credential, error) { + verifier := generateCodeVerifier() + challenge := generateCodeChallenge(verifier) + state := generateState() + + port, codeChan, cleanup, err := startCallbackServer(state) + if err != nil { + return nil, fmt.Errorf("start callback server: %w", err) + } + defer cleanup() + + autoURL := buildAuthURL(port, challenge, state) + manualURL := buildManualAuthURL(challenge, state) + + fmt.Printf("\nTo sign in, visit:\n %s\n\n", manualURL) + openBrowser(autoURL) + + var authCode string + var isManual bool + + stdinCh := make(chan string, 1) + fi, _ := os.Stdin.Stat() + if (fi.Mode() & os.ModeCharDevice) != 0 { + fmt.Print("If browser didn't open, paste the authorization code here: ") + go func() { + var line string + scanner := bufio.NewScanner(os.Stdin) + if scanner.Scan() { + line = strings.TrimSpace(scanner.Text()) + } + if line != "" { + stdinCh <- line + } + }() + } + + timeout := time.NewTimer(120 * time.Second) + defer timeout.Stop() + + select { + case code := <-codeChan: + authCode = code + isManual = false + case code := <-stdinCh: + authCode = code + isManual = true + case <-timeout.C: + return nil, fmt.Errorf("login timed out after 120 seconds") + case <-ctx.Done(): + return nil, ctx.Err() + } + + credPath, err := DefaultCredentialPath() + if err != nil { + return nil, fmt.Errorf("credential path: %w", err) + } + return exchangeAuthCode(ctx, authCode, state, verifier, port, isManual, credPath) +} + +type authCodeRequest struct { + GrantType string `json:"grant_type"` + Code string `json:"code"` + RedirectURI string `json:"redirect_uri"` + ClientID string `json:"client_id"` + CodeVerifier string `json:"code_verifier"` + State string `json:"state"` +} + +func exchangeAuthCode(ctx context.Context, code, state, verifier string, port int, isManual bool, credPath string) (*Credential, error) { + redirectURI := fmt.Sprintf("http://localhost:%d/callback", port) + if isManual { + redirectURI = manualRedirect + } + + reqBody, _ := json.Marshal(authCodeRequest{ + GrantType: "authorization_code", + Code: code, + RedirectURI: redirectURI, + ClientID: clientID, + CodeVerifier: verifier, + State: state, + }) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenEndpoint, bytes.NewReader(reqBody)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := utlsClient.Do(req) + if err != nil { + return nil, fmt.Errorf("token exchange: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token exchange returned %d: %s", resp.StatusCode, string(body)) + } + + var tokenResp tokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("decode token response: %w", err) + } + + cred := &Credential{ + ID: "claude-native", + Email: tokenResp.Account.EmailAddress, + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ExpiresAt: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second), + FilePath: credPath, + } + + if err := ensureCredentialFile(credPath); err != nil { + return nil, fmt.Errorf("ensure credential file: %w", err) + } + + if err := persistCredential(cred); err != nil { + return nil, fmt.Errorf("save credential: %w", err) + } + + log.Printf("login successful, credentials saved to %s", credPath) + return cred, nil +} + +func ensureCredentialFile(path string) error { + if _, err := os.Stat(path); err == nil { + return nil + } + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0700); err != nil { + return err + } + return os.WriteFile(path, []byte("{}"), 0600) +} + +func openBrowser(url string) { + var cmd *exec.Cmd + switch runtime.GOOS { + case "darwin": + cmd = exec.Command("open", url) + case "linux": + cmd = exec.Command("xdg-open", url) + default: + return + } + _ = cmd.Start() +}