From 9851ca67272f69b5ad9ba2646df719206b90f2ea Mon Sep 17 00:00:00 2001 From: Gregory Schier Date: Sun, 22 Feb 2026 07:48:08 -0800 Subject: [PATCH] Add auth resource with styled CLI output --- Cargo.lock | 49 +++ crates-cli/yaak-cli/Cargo.toml | 10 +- crates-cli/yaak-cli/src/cli.rs | 21 + crates-cli/yaak-cli/src/commands/auth.rs | 492 +++++++++++++++++++++++ crates-cli/yaak-cli/src/commands/mod.rs | 1 + crates-cli/yaak-cli/src/main.rs | 2 + crates-cli/yaak-cli/src/ui.rs | 34 ++ 7 files changed, 608 insertions(+), 1 deletion(-) create mode 100644 crates-cli/yaak-cli/src/commands/auth.rs create mode 100644 crates-cli/yaak-cli/src/ui.rs diff --git a/Cargo.lock b/Cargo.lock index 8ae3b89b..9a4f1f63 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1057,6 +1057,19 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.59.0", +] + [[package]] name = "const-random" version = "0.1.18" @@ -1569,6 +1582,12 @@ version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ef6b89e5b37196644d8796de5268852ff179b44e96276cf4290264843743bb7" +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "encoding_rs" version = "0.8.35" @@ -7146,6 +7165,12 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + [[package]] name = "universal-hash" version = "0.5.1" @@ -7510,6 +7535,22 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webbrowser" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f00bb839c1cf1e3036066614cbdcd035ecf215206691ea646aa3c60a24f68f2" +dependencies = [ + "core-foundation 0.10.1", + "jni", + "log 0.4.29", + "ndk-context", + "objc2 0.6.1", + "objc2-foundation 0.3.1", + "url", + "web-sys", +] + [[package]] name = "webkit2gtk" version = "2.0.1" @@ -8353,17 +8394,25 @@ name = "yaak-cli" version = "0.1.0" dependencies = [ "assert_cmd", + "base64 0.22.1", "clap", + "console", "dirs", "env_logger", "futures", + "hex", + "keyring", "log 0.4.29", "predicates", + "rand 0.8.5", + "reqwest", "schemars", "serde", "serde_json", + "sha2", "tempfile", "tokio", + "webbrowser", "yaak", "yaak-crypto", "yaak-http", diff --git a/crates-cli/yaak-cli/Cargo.toml b/crates-cli/yaak-cli/Cargo.toml index 54e7385e..149da3be 100644 --- a/crates-cli/yaak-cli/Cargo.toml +++ b/crates-cli/yaak-cli/Cargo.toml @@ -9,15 +9,23 @@ name = "yaakcli" path = "src/main.rs" [dependencies] +base64 = "0.22" clap = { version = "4", features = ["derive"] } +console = "0.15" dirs = "6" env_logger = "0.11" futures = "0.3" +hex = { workspace = true } +keyring = { workspace = true, features = ["apple-native", "windows-native", "sync-secret-service"] } log = { workspace = true } +rand = "0.8" +reqwest = { workspace = true } schemars = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } -tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } +sha2 = { workspace = true } +tokio = { workspace = true, features = ["rt-multi-thread", "macros", "io-util", "net", "signal", "time"] } +webbrowser = "1" yaak = { workspace = true } yaak-crypto = { workspace = true } yaak-http = { workspace = true } diff --git a/crates-cli/yaak-cli/src/cli.rs b/crates-cli/yaak-cli/src/cli.rs index 2b74ca9b..c45150cb 100644 --- a/crates-cli/yaak-cli/src/cli.rs +++ b/crates-cli/yaak-cli/src/cli.rs @@ -23,6 +23,9 @@ pub struct Cli { #[derive(Subcommand)] pub enum Commands { + /// Authentication commands + Auth(AuthArgs), + /// Send a request, folder, or workspace by ID Send(SendArgs), @@ -305,3 +308,21 @@ pub enum EnvironmentCommands { yes: bool, }, } + +#[derive(Args)] +pub struct AuthArgs { + #[command(subcommand)] + pub command: AuthCommands, +} + +#[derive(Subcommand)] +pub enum AuthCommands { + /// Login to Yaak via web browser + Login, + + /// Sign out of the Yaak CLI + Logout, + + /// Print the current logged-in user's info + Whoami, +} diff --git a/crates-cli/yaak-cli/src/commands/auth.rs b/crates-cli/yaak-cli/src/commands/auth.rs new file mode 100644 index 00000000..b724da8b --- /dev/null +++ b/crates-cli/yaak-cli/src/commands/auth.rs @@ -0,0 +1,492 @@ +use crate::cli::{AuthArgs, AuthCommands}; +use crate::ui; +use base64::Engine as _; +use keyring::Entry; +use rand::RngCore; +use rand::rngs::OsRng; +use reqwest::Url; +use serde_json::Value; +use sha2::{Digest, Sha256}; +use std::io::{self, IsTerminal, Write}; +use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; + +const OAUTH_CLIENT_ID: &str = "a1fe44800c2d7e803cad1b4bf07a291c"; +const KEYRING_USER: &str = "yaak"; +const AUTH_TIMEOUT: Duration = Duration::from_secs(300); +const MAX_REQUEST_BYTES: usize = 16 * 1024; + +type CommandResult = std::result::Result; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum Environment { + Production, + Staging, + Development, +} + +impl Environment { + fn app_base_url(self) -> &'static str { + match self { + Environment::Production => "https://yaak.app", + Environment::Staging => "https://todo.yaak.app", + Environment::Development => "http://localhost:9444", + } + } + + fn api_base_url(self) -> &'static str { + match self { + Environment::Production => "https://api.yaak.app", + Environment::Staging => "https://todo.yaak.app", + Environment::Development => "http://localhost:9444", + } + } + + fn keyring_service(self) -> &'static str { + match self { + Environment::Production => "app.yaak.cli.Token", + Environment::Staging => "app.yaak.cli.staging.Token", + Environment::Development => "app.yaak.cli.dev.Token", + } + } +} + +struct OAuthFlow { + app_base_url: String, + auth_url: Url, + token_url: String, + redirect_url: String, + state: String, + code_verifier: String, +} + +pub async fn run(args: AuthArgs) -> i32 { + let result = match args.command { + AuthCommands::Login => login().await, + AuthCommands::Logout => logout(), + AuthCommands::Whoami => whoami().await, + }; + + match result { + Ok(()) => 0, + Err(error) => { + ui::error(&error); + 1 + } + } +} + +async fn login() -> CommandResult { + let environment = current_environment(); + delete_auth_token(environment)?; + + let listener = TcpListener::bind("127.0.0.1:0") + .await + .map_err(|e| format!("Failed to start OAuth callback server: {e}"))?; + let port = listener + .local_addr() + .map_err(|e| format!("Failed to determine callback server port: {e}"))? + .port(); + + let oauth = build_oauth_flow(environment, port)?; + + ui::info(&format!("Initiating login to {}", oauth.auth_url)); + if !confirm_open_browser()? { + ui::info("Login canceled"); + return Ok(()); + } + + if let Err(err) = webbrowser::open(oauth.auth_url.as_ref()) { + ui::warning(&format!("Failed to open browser: {err}")); + ui::info(&format!("Open this URL manually:\n{}", oauth.auth_url)); + } + ui::info("Waiting for authentication..."); + + let code = tokio::select! { + result = receive_oauth_code(listener, &oauth.state, &oauth.app_base_url) => result?, + _ = tokio::signal::ctrl_c() => { + return Err("Interrupted by user".to_string()); + } + _ = tokio::time::sleep(AUTH_TIMEOUT) => { + return Err("Timeout waiting for authentication".to_string()); + } + }; + + let token = exchange_access_token(&oauth, &code).await?; + store_auth_token(environment, &token)?; + ui::success("Authentication successful!"); + Ok(()) +} + +fn logout() -> CommandResult { + delete_auth_token(current_environment())?; + ui::success("Signed out of Yaak"); + Ok(()) +} + +async fn whoami() -> CommandResult { + let environment = current_environment(); + let token = match get_auth_token(environment)? { + Some(token) => token, + None => { + ui::warning("Not logged in"); + ui::info("Please run `yaakcli auth login`"); + return Ok(()); + } + }; + + let url = format!("{}/api/v1/whoami", environment.api_base_url()); + let response = reqwest::Client::new() + .get(url) + .header("X-Yaak-Session", token) + .header(reqwest::header::USER_AGENT, user_agent()) + .send() + .await + .map_err(|e| format!("Failed to call whoami endpoint: {e}"))?; + + let status = response.status(); + let body = + response.text().await.map_err(|e| format!("Failed to read whoami response body: {e}"))?; + + if !status.is_success() { + if status.as_u16() == 401 { + let _ = delete_auth_token(environment); + return Err( + "Unauthorized to access CLI. Run `yaakcli auth login` to refresh credentials." + .to_string(), + ); + } + return Err(parse_api_error(status.as_u16(), &body)); + } + + println!("{body}"); + Ok(()) +} + +fn current_environment() -> Environment { + let value = std::env::var("ENVIRONMENT").ok(); + parse_environment(value.as_deref()) +} + +fn parse_environment(value: Option<&str>) -> Environment { + match value { + Some("staging") => Environment::Staging, + Some("development") => Environment::Development, + _ => Environment::Production, + } +} + +fn build_oauth_flow(environment: Environment, callback_port: u16) -> CommandResult { + let code_verifier = random_hex(32); + let state = random_hex(24); + let redirect_url = format!("http://127.0.0.1:{callback_port}/oauth/callback"); + + let code_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(Sha256::digest(code_verifier.as_bytes())); + + let mut auth_url = Url::parse(&format!("{}/login/oauth/authorize", environment.app_base_url())) + .map_err(|e| format!("Failed to build OAuth authorize URL: {e}"))?; + auth_url + .query_pairs_mut() + .append_pair("response_type", "code") + .append_pair("client_id", OAUTH_CLIENT_ID) + .append_pair("redirect_uri", &redirect_url) + .append_pair("state", &state) + .append_pair("code_challenge_method", "S256") + .append_pair("code_challenge", &code_challenge); + + Ok(OAuthFlow { + app_base_url: environment.app_base_url().to_string(), + auth_url, + token_url: format!("{}/login/oauth/access_token", environment.app_base_url()), + redirect_url, + state, + code_verifier, + }) +} + +async fn receive_oauth_code( + listener: TcpListener, + expected_state: &str, + app_base_url: &str, +) -> CommandResult { + loop { + let (mut stream, _) = listener + .accept() + .await + .map_err(|e| format!("OAuth callback server accept error: {e}"))?; + + match parse_callback_request(&mut stream).await { + Ok((state, code)) => { + if state != expected_state { + let _ = write_bad_request(&mut stream, "Invalid OAuth state").await; + continue; + } + + let success_redirect = format!("{app_base_url}/login/oauth/success"); + write_redirect(&mut stream, &success_redirect) + .await + .map_err(|e| format!("Failed responding to OAuth callback: {e}"))?; + return Ok(code); + } + Err(error) => { + let _ = write_bad_request(&mut stream, &error).await; + } + } + } +} + +async fn parse_callback_request(stream: &mut TcpStream) -> CommandResult<(String, String)> { + let target = read_http_target(stream).await?; + if !target.starts_with("/oauth/callback") { + return Err("Expected /oauth/callback path".to_string()); + } + + let url = Url::parse(&format!("http://127.0.0.1{target}")) + .map_err(|e| format!("Failed to parse callback URL: {e}"))?; + let mut state: Option = None; + let mut code: Option = None; + + for (k, v) in url.query_pairs() { + if k == "state" { + state = Some(v.into_owned()); + } else if k == "code" { + code = Some(v.into_owned()); + } + } + + let state = state.ok_or_else(|| "Missing 'state' query parameter".to_string())?; + let code = code.ok_or_else(|| "Missing 'code' query parameter".to_string())?; + + if code.is_empty() { + return Err("Missing 'code' query parameter".to_string()); + } + + Ok((state, code)) +} + +async fn read_http_target(stream: &mut TcpStream) -> CommandResult { + let mut buf = vec![0_u8; MAX_REQUEST_BYTES]; + let mut total_read = 0_usize; + + loop { + let n = stream + .read(&mut buf[total_read..]) + .await + .map_err(|e| format!("Failed reading callback request: {e}"))?; + if n == 0 { + break; + } + total_read += n; + + if buf[..total_read].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + + if total_read == MAX_REQUEST_BYTES { + return Err("OAuth callback request too large".to_string()); + } + } + + let req = String::from_utf8_lossy(&buf[..total_read]); + let request_line = + req.lines().next().ok_or_else(|| "Invalid callback request line".to_string())?; + let mut parts = request_line.split_whitespace(); + let method = parts.next().unwrap_or_default(); + let target = parts.next().unwrap_or_default(); + + if method != "GET" { + return Err(format!("Expected GET callback request, got '{method}'")); + } + if target.is_empty() { + return Err("Missing callback request target".to_string()); + } + + Ok(target.to_string()) +} + +async fn write_bad_request(stream: &mut TcpStream, message: &str) -> std::io::Result<()> { + let body = format!("Failed to authenticate: {message}"); + let response = format!( + "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes()).await?; + stream.shutdown().await +} + +async fn write_redirect(stream: &mut TcpStream, location: &str) -> std::io::Result<()> { + let response = format!( + "HTTP/1.1 302 Found\r\nLocation: {location}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n" + ); + stream.write_all(response.as_bytes()).await?; + stream.shutdown().await +} + +async fn exchange_access_token(oauth: &OAuthFlow, code: &str) -> CommandResult { + let response = reqwest::Client::new() + .post(&oauth.token_url) + .header(reqwest::header::USER_AGENT, user_agent()) + .form(&[ + ("grant_type", "authorization_code"), + ("client_id", OAUTH_CLIENT_ID), + ("code", code), + ("redirect_uri", oauth.redirect_url.as_str()), + ("code_verifier", oauth.code_verifier.as_str()), + ]) + .send() + .await + .map_err(|e| format!("Failed to exchange OAuth code for access token: {e}"))?; + + let status = response.status(); + let body = + response.text().await.map_err(|e| format!("Failed to read token response body: {e}"))?; + + if !status.is_success() { + return Err(format!( + "Failed to fetch access token: status={} body={}", + status.as_u16(), + body + )); + } + + let parsed: Value = + serde_json::from_str(&body).map_err(|e| format!("Invalid token response JSON: {e}"))?; + let token = parsed + .get("access_token") + .and_then(Value::as_str) + .filter(|s| !s.is_empty()) + .ok_or_else(|| format!("Token response missing access_token: {body}"))?; + + Ok(token.to_string()) +} + +fn keyring_entry(environment: Environment) -> CommandResult { + Entry::new(environment.keyring_service(), KEYRING_USER) + .map_err(|e| format!("Failed to initialize auth keyring entry: {e}")) +} + +fn get_auth_token(environment: Environment) -> CommandResult> { + let entry = keyring_entry(environment)?; + match entry.get_password() { + Ok(token) => Ok(Some(token)), + Err(keyring::Error::NoEntry) => Ok(None), + Err(err) => Err(format!("Failed to read auth token: {err}")), + } +} + +fn store_auth_token(environment: Environment, token: &str) -> CommandResult { + let entry = keyring_entry(environment)?; + entry.set_password(token).map_err(|e| format!("Failed to store auth token: {e}")) +} + +fn delete_auth_token(environment: Environment) -> CommandResult { + let entry = keyring_entry(environment)?; + match entry.delete_credential() { + Ok(()) | Err(keyring::Error::NoEntry) => Ok(()), + Err(err) => Err(format!("Failed to delete auth token: {err}")), + } +} + +fn parse_api_error(status: u16, body: &str) -> String { + if let Ok(value) = serde_json::from_str::(body) { + if let Some(message) = value.get("message").and_then(Value::as_str) { + return message.to_string(); + } + if let Some(error) = value.get("error").and_then(Value::as_str) { + return error.to_string(); + } + } + + format!("API error {status}: {body}") +} + +fn random_hex(bytes: usize) -> String { + let mut data = vec![0_u8; bytes]; + OsRng.fill_bytes(&mut data); + hex::encode(data) +} + +fn user_agent() -> String { + format!("YaakCli/{} ({})", env!("CARGO_PKG_VERSION"), ua_platform()) +} + +fn ua_platform() -> &'static str { + match std::env::consts::OS { + "windows" => "Win", + "darwin" => "Mac", + "linux" => "Linux", + _ => "Unknown", + } +} + +fn confirm_open_browser() -> CommandResult { + if !io::stdin().is_terminal() { + return Ok(true); + } + + loop { + print!("Open default browser? [Y/n]: "); + io::stdout().flush().map_err(|e| format!("Failed to flush stdout: {e}"))?; + + let mut input = String::new(); + io::stdin().read_line(&mut input).map_err(|e| format!("Failed to read input: {e}"))?; + + match input.trim().to_ascii_lowercase().as_str() { + "" | "y" | "yes" => return Ok(true), + "n" | "no" => return Ok(false), + _ => ui::warning("Please answer y or n"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn environment_mapping() { + assert_eq!(parse_environment(Some("staging")), Environment::Staging); + assert_eq!(parse_environment(Some("development")), Environment::Development); + assert_eq!(parse_environment(Some("production")), Environment::Production); + assert_eq!(parse_environment(None), Environment::Production); + } + + #[tokio::test] + async fn parses_callback_request() { + let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind"); + let addr = listener.local_addr().expect("local addr"); + + let server = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.expect("accept"); + parse_callback_request(&mut stream).await + }); + + let mut client = TcpStream::connect(addr).await.expect("connect"); + client + .write_all( + b"GET /oauth/callback?code=abc123&state=xyz HTTP/1.1\r\nHost: localhost\r\n\r\n", + ) + .await + .expect("write"); + + let parsed = server.await.expect("join").expect("parse"); + assert_eq!(parsed.0, "xyz"); + assert_eq!(parsed.1, "abc123"); + } + + #[test] + fn builds_oauth_flow_with_pkce() { + let flow = build_oauth_flow(Environment::Development, 8080).expect("flow"); + assert!(flow.auth_url.as_str().contains("code_challenge_method=S256")); + assert!( + flow.auth_url + .as_str() + .contains("redirect_uri=http%3A%2F%2F127.0.0.1%3A8080%2Foauth%2Fcallback") + ); + assert_eq!(flow.redirect_url, "http://127.0.0.1:8080/oauth/callback"); + assert_eq!(flow.token_url, "http://localhost:9444/login/oauth/access_token"); + } +} diff --git a/crates-cli/yaak-cli/src/commands/mod.rs b/crates-cli/yaak-cli/src/commands/mod.rs index 76c1ac0a..cfd6034f 100644 --- a/crates-cli/yaak-cli/src/commands/mod.rs +++ b/crates-cli/yaak-cli/src/commands/mod.rs @@ -1,3 +1,4 @@ +pub mod auth; pub mod environment; pub mod folder; pub mod request; diff --git a/crates-cli/yaak-cli/src/main.rs b/crates-cli/yaak-cli/src/main.rs index 58b83bb1..b109fee4 100644 --- a/crates-cli/yaak-cli/src/main.rs +++ b/crates-cli/yaak-cli/src/main.rs @@ -2,6 +2,7 @@ mod cli; mod commands; mod context; mod plugin_events; +mod ui; mod utils; use clap::Parser; @@ -33,6 +34,7 @@ async fn main() { let context = CliContext::initialize(data_dir, app_id, needs_plugins).await; let exit_code = match command { + Commands::Auth(args) => commands::auth::run(args).await, Commands::Send(args) => { commands::send::run(&context, args, environment.as_deref(), verbose).await } diff --git a/crates-cli/yaak-cli/src/ui.rs b/crates-cli/yaak-cli/src/ui.rs new file mode 100644 index 00000000..18a48bb1 --- /dev/null +++ b/crates-cli/yaak-cli/src/ui.rs @@ -0,0 +1,34 @@ +use console::style; +use std::io::{self, IsTerminal}; + +pub fn info(message: &str) { + if io::stdout().is_terminal() { + println!("{:<8} {}", style("INFO").cyan().bold(), style(message).cyan()); + } else { + println!("INFO {message}"); + } +} + +pub fn warning(message: &str) { + if io::stdout().is_terminal() { + println!("{:<8} {}", style("WARNING").yellow().bold(), style(message).yellow()); + } else { + println!("WARNING {message}"); + } +} + +pub fn success(message: &str) { + if io::stdout().is_terminal() { + println!("{:<8} {}", style("SUCCESS").green().bold(), style(message).green()); + } else { + println!("SUCCESS {message}"); + } +} + +pub fn error(message: &str) { + if io::stderr().is_terminal() { + eprintln!("{:<8} {}", style("ERROR").red().bold(), style(message).red()); + } else { + eprintln!("Error: {message}"); + } +}