diff --git a/plugins/auth-oauth2/package.json b/plugins/auth-oauth2/package.json index 43641dd6..c9d07a0f 100644 --- a/plugins/auth-oauth2/package.json +++ b/plugins/auth-oauth2/package.json @@ -12,6 +12,7 @@ "scripts": { "build": "yaakcli build", "dev": "yaakcli dev", - "lint":"tsc --noEmit && eslint . --ext .ts,.tsx" + "lint":"tsc --noEmit && eslint . --ext .ts,.tsx", + "test": "vitest --run tests" } } diff --git a/plugins/auth-oauth2/src/grants/authorizationCode.ts b/plugins/auth-oauth2/src/grants/authorizationCode.ts index 7ba00b7c..60373337 100644 --- a/plugins/auth-oauth2/src/grants/authorizationCode.ts +++ b/plugins/auth-oauth2/src/grants/authorizationCode.ts @@ -4,6 +4,7 @@ import { fetchAccessToken } from '../fetchAccessToken'; import { getOrRefreshAccessToken } from '../getOrRefreshAccessToken'; import type { AccessToken, TokenStoreArgs } from '../store'; import { getDataDirKey, storeToken } from '../store'; +import { extractCode } from '../util'; export const PKCE_SHA256 = 'S256'; export const PKCE_PLAIN = 'plain'; @@ -79,7 +80,6 @@ export async function getAuthorizationCode( authorizationUrl.searchParams.set('code_challenge_method', pkce.challengeMethod); } - const logsEnabled = (await ctx.store.get('enable_logs')) ?? false; const dataDirKey = await getDataDirKey(ctx, contextId); const authorizationUrlStr = authorizationUrl.toString(); console.log('[oauth2] Authorizing', authorizationUrlStr); @@ -97,18 +97,17 @@ export async function getAuthorizationCode( } }, async onNavigate({ url: urlStr }) { - const url = new URL(urlStr); - if (logsEnabled) console.log('[oauth2] Navigated to', urlStr); - - if (url.searchParams.has('error')) { + let code; + try { + code = extractCode(urlStr, redirectUri); + } catch (err) { + reject(err); close(); - return reject(new Error(`Failed to authorize: ${url.searchParams.get('error')}`)); + return; } - const code = url.searchParams.get('code'); if (!code) { - console.log('[oauth2] Code not found'); - return; // Could be one of many redirects in a chain, so skip it + return; } // Close the window here, because we don't need it anymore! diff --git a/plugins/auth-oauth2/src/index.ts b/plugins/auth-oauth2/src/index.ts index 28e16e86..b2834d77 100644 --- a/plugins/auth-oauth2/src/index.ts +++ b/plugins/auth-oauth2/src/index.ts @@ -6,8 +6,8 @@ import type { PluginDefinition, } from '@yaakapp/api'; import { - genPkceCodeVerifier, DEFAULT_PKCE_METHOD, + genPkceCodeVerifier, getAuthorizationCode, PKCE_PLAIN, PKCE_SHA256, @@ -125,17 +125,6 @@ export const plugin: PluginDefinition = { await resetDataDirKey(ctx, contextId); }, }, - { - label: 'Toggle Debug Logs', - async onSelect(ctx) { - const enableLogs = !(await ctx.store.get('enable_logs')); - await ctx.store.set('enable_logs', enableLogs); - await ctx.toast.show({ - message: `Debug logs ${enableLogs ? 'enabled' : 'disabled'}`, - color: 'info', - }); - }, - }, ], args: [ { diff --git a/plugins/auth-oauth2/src/util.ts b/plugins/auth-oauth2/src/util.ts index 42a5b05b..1269854f 100644 --- a/plugins/auth-oauth2/src/util.ts +++ b/plugins/auth-oauth2/src/util.ts @@ -3,3 +3,83 @@ import type { AccessToken } from './store'; export function isTokenExpired(token: AccessToken) { return token.expiresAt && Date.now() > token.expiresAt; } + +export function extractCode(urlStr: string, redirectUri: string | null): string | null { + const url = new URL(urlStr); + + if (!urlMatchesRedirect(url, redirectUri)) { + console.log('[oauth2] URL does not match redirect origin/path; skipping.'); + return null; + } + + // Prefer query param; fall back to fragment if query lacks it + + const query = url.searchParams; + const queryError = query.get('error'); + const queryDesc = query.get('error_description'); + const queryUri = query.get('error_uri'); + + let hashParams: URLSearchParams | null = null; + if (url.hash && url.hash.length > 1) { + hashParams = new URLSearchParams(url.hash.slice(1)); + } + const hashError = hashParams?.get('error'); + const hashDesc = hashParams?.get('error_description'); + const hashUri = hashParams?.get('error_uri'); + + const error = queryError || hashError; + if (error) { + const desc = queryDesc || hashDesc; + const uri = queryUri || hashUri; + let message = `Failed to authorize: ${error}`; + if (desc) message += ` (${desc})`; + if (uri) message += ` [${uri}]`; + throw new Error(message); + } + + const queryCode = query.get('code'); + if (queryCode) return queryCode; + + const hashCode = hashParams?.get('code'); + if (hashCode) return hashCode; + + console.log('[oauth2] Code not found'); + return null; +} + +export function urlMatchesRedirect(url: URL, redirectUrl: string | null): boolean { + if (!redirectUrl) return true; + + let redirect; + try { + redirect = new URL(redirectUrl); + } catch { + console.log('[oauth2] Invalid redirect URI; skipping.'); + return false; + } + + const sameProtocol = url.protocol === redirect.protocol; + + const sameHost = url.hostname.toLowerCase() === redirect.hostname.toLowerCase(); + + const normalizePort = (u: URL) => + (u.protocol === 'https:' && (!u.port || u.port === '443')) || + (u.protocol === 'http:' && (!u.port || u.port === '80')) + ? '' + : u.port; + + const samePort = normalizePort(url) === normalizePort(redirect); + + const normPath = (p: string) => { + const withLeading = p.startsWith('/') ? p : `/${p}`; + // strip trailing slashes, keep root as "/" + return withLeading.replace(/\/+$/g, '') || '/'; + }; + + // Require redirect path to be a prefix of the navigated URL path + const urlPath = normPath(url.pathname); + const redirectPath = normPath(redirect.pathname); + const pathMatches = urlPath === redirectPath || urlPath.startsWith(`${redirectPath}/`); + + return sameProtocol && sameHost && samePort && pathMatches; +} diff --git a/plugins/auth-oauth2/tests/util.test.ts b/plugins/auth-oauth2/tests/util.test.ts new file mode 100644 index 00000000..c2141a89 --- /dev/null +++ b/plugins/auth-oauth2/tests/util.test.ts @@ -0,0 +1,109 @@ +import { describe, test, expect } from 'vitest'; +import { extractCode } from '../src/util'; + +describe('extractCode', () => { + test('extracts code from query when same origin + path', () => { + const url = 'https://app.example.com/cb?code=abc123&state=xyz'; + const redirect = 'https://app.example.com/cb'; + expect(extractCode(url, redirect)).toBe('abc123'); + }); + + test('extracts code from query with weird path', () => { + const url = 'https://app.example.com/cbwithextra?code=abc123&state=xyz'; + const redirect = 'https://app.example.com/cb'; + expect(extractCode(url, redirect)).toBeNull(); + }); + + test('allows trailing slash differences', () => { + expect(extractCode('https://app.example.com/cb/?code=abc', 'https://app.example.com/cb')).toBe( + 'abc', + ); + expect(extractCode('https://app.example.com/cb?code=abc', 'https://app.example.com/cb/')).toBe( + 'abc', + ); + }); + + test('treats default ports as equal (https:443, http:80)', () => { + expect( + extractCode('https://app.example.com/cb?code=abc', 'https://app.example.com:443/cb'), + ).toBe('abc'); + expect(extractCode('http://app.example.com/cb?code=abc', 'http://app.example.com:80/cb')).toBe( + 'abc', + ); + }); + + test('rejects different port', () => { + expect( + extractCode('https://app.example.com/cb?code=abc', 'https://app.example.com:8443/cb'), + ).toBeNull(); + }); + + test('rejects different hostname (including subdomain changes)', () => { + expect( + extractCode('https://evil.example.com/cb?code=abc', 'https://app.example.com/cb'), + ).toBeNull(); + }); + + test('requires path to start with redirect path (ignoring query/hash)', () => { + // same origin but wrong path -> null + expect( + extractCode('https://app.example.com/other?code=abc', 'https://app.example.com/cb'), + ).toBeNull(); + + // deeper subpath under the redirect path -> allowed (prefix match) + expect( + extractCode('https://app.example.com/cb/deep?code=abc', 'https://app.example.com/cb'), + ).toBe('abc'); + }); + + test('works with custom schemes', () => { + expect(extractCode('myapp://cb?code=abc', 'myapp://cb')).toBe('abc'); + }); + + test('prefers query over fragment when both present', () => { + const url = 'https://app.example.com/cb?code=queryCode#code=hashCode'; + const redirect = 'https://app.example.com/cb'; + expect(extractCode(url, redirect)).toBe('queryCode'); + }); + + test('extracts code from fragment when query lacks code', () => { + const url = 'https://app.example.com/cb#code=fromHash&state=xyz'; + const redirect = 'https://app.example.com/cb'; + expect(extractCode(url, redirect)).toBe('fromHash'); + }); + + test('returns null if no code present (query or fragment)', () => { + const url = 'https://app.example.com/cb?state=only'; + const redirect = 'https://app.example.com/cb'; + expect(extractCode(url, redirect)).toBeNull(); + }); + + test('returns null when provider reports an error', () => { + const url = 'https://app.example.com/cb?error=access_denied&error_description=oopsy'; + const redirect = 'https://app.example.com/cb'; + expect(() => extractCode(url, redirect)).toThrow('Failed to authorize: access_denied'); + }); + + test('when redirectUri is null, extracts code from any URL', () => { + expect(extractCode('https://random.example.com/whatever?code=abc', null)).toBe('abc'); + }); + + test('handles extra params gracefully', () => { + const url = 'https://app.example.com/cb?foo=1&bar=2&code=abc&baz=3'; + const redirect = 'https://app.example.com/cb'; + expect(extractCode(url, redirect)).toBe('abc'); + }); + + test('ignores fragment noise when code is in query', () => { + const url = 'https://app.example.com/cb?code=abc#some=thing'; + const redirect = 'https://app.example.com/cb'; + expect(extractCode(url, redirect)).toBe('abc'); + }); + + // If you decide NOT to support fragment-based codes, flip these to expect null or mark as .skip + test('supports fragment-only code for response_mode=fragment providers', () => { + const url = 'https://app.example.com/cb#state=xyz&code=abc'; + const redirect = 'https://app.example.com/cb'; + expect(extractCode(url, redirect)).toBe('abc'); + }); +});