@@ -147,7 +147,7 @@ public function authenticate(
147147 }
148148
149149 try {
150- $ decoded = self :: decodeAccessToken ($ session ['access_token ' ], $ clientId , $ baseUrl );
150+ $ decoded = $ this -> decodeAccessToken ($ session ['access_token ' ], $ clientId , $ baseUrl );
151151 } catch (\Exception $ e ) {
152152 return [
153153 'authenticated ' => false ,
@@ -302,20 +302,27 @@ public function fetchJwks(string $clientId): array
302302 );
303303 }
304304
305+ /**
306+ * Algorithms permitted on the JWS header. WorkOS access tokens are signed
307+ * with RS256; no other algorithm is accepted, in particular `none` is
308+ * always rejected.
309+ */
310+ private const ALLOWED_JWS_ALGORITHMS = ['RS256 ' ];
311+
305312 /**
306313 * Decode and validate an access token JWT.
307314 *
308- * This is a basic JWT decode. For production use, fetch JWKS and validate
309- * the signature properly. This helper decodes without signature verification
310- * for extracting claims when the token has already been validated upstream .
315+ * Verifies the JWS signature against the JWKS published for `$clientId`,
316+ * enforces an algorithm allow-list, and rejects expired tokens. This is
317+ * the only path used by {@see authenticate()}; callers must not bypass it .
311318 *
312319 * @param string $accessToken The JWT access token.
313- * @param string $clientId The WorkOS client ID (unused in basic decode ).
314- * @param string $baseUrl The WorkOS API base URL (unused in basic decode) .
320+ * @param string $clientId The WorkOS client ID (used to fetch JWKS ).
321+ * @param string $baseUrl The WorkOS API base URL.
315322 * @return array The decoded JWT claims.
316- * @throws \InvalidArgumentException If the token cannot be decoded.
323+ * @throws \InvalidArgumentException If the token cannot be decoded or fails verification .
317324 */
318- private static function decodeAccessToken (
325+ private function decodeAccessToken (
319326 string $ accessToken ,
320327 string $ clientId ,
321328 string $ baseUrl ,
@@ -325,21 +332,184 @@ private static function decodeAccessToken(
325332 throw new \InvalidArgumentException ('Invalid JWT format ' );
326333 }
327334
328- $ payload = base64_decode (strtr ($ parts [1 ], '-_ ' , '+/ ' ), true );
329- if ($ payload === false ) {
330- throw new \InvalidArgumentException ('Invalid JWT payload encoding ' );
335+ [$ headerB64 , $ payloadB64 , $ signatureB64 ] = $ parts ;
336+
337+ $ headerJson = self ::base64UrlDecode ($ headerB64 );
338+ if ($ headerJson === false ) {
339+ throw new \InvalidArgumentException ('Invalid JWT header encoding ' );
340+ }
341+ $ header = json_decode ($ headerJson , true );
342+ if (!is_array ($ header )) {
343+ throw new \InvalidArgumentException ('Invalid JWT header JSON ' );
344+ }
345+
346+ $ alg = $ header ['alg ' ] ?? null ;
347+ if (!is_string ($ alg ) || !in_array ($ alg , self ::ALLOWED_JWS_ALGORITHMS , true )) {
348+ throw new \InvalidArgumentException ('Unsupported JWT algorithm ' );
331349 }
332350
333- $ decoded = json_decode ($ payload , true );
334- if ($ decoded === null ) {
351+ $ payloadJson = self ::base64UrlDecode ($ payloadB64 );
352+ if ($ payloadJson === false ) {
353+ throw new \InvalidArgumentException ('Invalid JWT payload encoding ' );
354+ }
355+ $ decoded = json_decode ($ payloadJson , true );
356+ if (!is_array ($ decoded )) {
335357 throw new \InvalidArgumentException ('Invalid JWT payload JSON ' );
336358 }
337359
338- // Check expiration
339- if (isset ($ decoded ['exp ' ]) && $ decoded ['exp ' ] < time ()) {
360+ $ signature = self ::base64UrlDecode ($ signatureB64 );
361+ if ($ signature === false || $ signature === '' ) {
362+ throw new \InvalidArgumentException ('Invalid JWT signature encoding ' );
363+ }
364+
365+ // Resolve a JWK matching the header `kid`. Without a `kid` we won't
366+ // guess — refuse rather than try every key, which would mask key
367+ // rotation bugs.
368+ $ kid = $ header ['kid ' ] ?? null ;
369+ if (!is_string ($ kid ) || $ kid === '' ) {
370+ throw new \InvalidArgumentException ('JWT header missing kid ' );
371+ }
372+
373+ $ jwks = $ this ->fetchJwks ($ clientId );
374+ $ jwk = self ::findJwkByKid ($ jwks , $ kid );
375+ if ($ jwk === null ) {
376+ throw new \InvalidArgumentException ('No JWKS key matches JWT kid ' );
377+ }
378+
379+ $ publicKeyPem = self ::jwkToRsaPublicKeyPem ($ jwk );
380+ $ signingInput = $ headerB64 . '. ' . $ payloadB64 ;
381+
382+ $ verified = openssl_verify ($ signingInput , $ signature , $ publicKeyPem , OPENSSL_ALGO_SHA256 );
383+ if ($ verified !== 1 ) {
384+ throw new \InvalidArgumentException ('JWT signature verification failed ' );
385+ }
386+
387+ // Expiration check (after signature verification).
388+ if (isset ($ decoded ['exp ' ]) && is_numeric ($ decoded ['exp ' ]) && (int ) $ decoded ['exp ' ] < time ()) {
340389 throw new \InvalidArgumentException ('JWT has expired ' );
341390 }
342391
392+ // TODO(security-fix-plan.md, finding #60): enforce documented WorkOS
393+ // `iss` and `aud` values once empirically confirmed. The other WorkOS
394+ // SDKs (Ruby, Python) currently skip `aud` verification, so the
395+ // canonical values are not authoritatively documented in this repo.
396+ // Track resolution under "Open questions / follow-ups" in the plan.
397+
343398 return $ decoded ;
344399 }
400+
401+ /**
402+ * Decode a base64url-encoded segment, tolerating missing padding.
403+ *
404+ * @return string|false The decoded bytes, or false on malformed input.
405+ */
406+ private static function base64UrlDecode (string $ segment ): string |false
407+ {
408+ $ remainder = strlen ($ segment ) % 4 ;
409+ if ($ remainder !== 0 ) {
410+ $ segment .= str_repeat ('= ' , 4 - $ remainder );
411+ }
412+
413+ return base64_decode (strtr ($ segment , '-_ ' , '+/ ' ), true );
414+ }
415+
416+ /**
417+ * Locate a JWK in the JWKS response by `kid`.
418+ *
419+ * @param array<string, mixed> $jwks
420+ * @return array<string, mixed>|null
421+ */
422+ private static function findJwkByKid (array $ jwks , string $ kid ): ?array
423+ {
424+ $ keys = $ jwks ['keys ' ] ?? null ;
425+ if (!is_array ($ keys )) {
426+ return null ;
427+ }
428+ foreach ($ keys as $ jwk ) {
429+ if (is_array ($ jwk ) && ($ jwk ['kid ' ] ?? null ) === $ kid ) {
430+ return $ jwk ;
431+ }
432+ }
433+ return null ;
434+ }
435+
436+ /**
437+ * Convert an RSA JWK (`kty=RSA`, base64url `n`/`e`) to a PEM-encoded
438+ * public key suitable for {@see openssl_verify()}.
439+ *
440+ * @param array<string, mixed> $jwk
441+ */
442+ private static function jwkToRsaPublicKeyPem (array $ jwk ): string
443+ {
444+ if (($ jwk ['kty ' ] ?? null ) !== 'RSA ' ) {
445+ throw new \InvalidArgumentException ('Unsupported JWK key type ' );
446+ }
447+ $ n = $ jwk ['n ' ] ?? null ;
448+ $ e = $ jwk ['e ' ] ?? null ;
449+ if (!is_string ($ n ) || !is_string ($ e )) {
450+ throw new \InvalidArgumentException ('Malformed RSA JWK ' );
451+ }
452+
453+ $ modulus = self ::base64UrlDecode ($ n );
454+ $ exponent = self ::base64UrlDecode ($ e );
455+ if ($ modulus === false || $ exponent === false ) {
456+ throw new \InvalidArgumentException ('Malformed RSA JWK encoding ' );
457+ }
458+
459+ // Build a DER-encoded SubjectPublicKeyInfo for an RSA public key, then
460+ // wrap it as a PEM document. Avoids a hard dependency on a JWT library.
461+ $ modulusDer = self ::derEncodeUnsignedInteger ($ modulus );
462+ $ exponentDer = self ::derEncodeUnsignedInteger ($ exponent );
463+ $ rsaPublicKey = self ::derEncodeSequence ($ modulusDer . $ exponentDer );
464+ $ bitString = self ::derEncodeBitString ($ rsaPublicKey );
465+
466+ // AlgorithmIdentifier: SEQUENCE { OID 1.2.840.113549.1.1.1, NULL }.
467+ $ rsaOid = "\x06\x09\x2a\x86\x48\x86\xf7\x0d\x01\x01\x01" ;
468+ $ algorithmIdentifier = self ::derEncodeSequence ($ rsaOid . "\x05\x00" );
469+ $ spki = self ::derEncodeSequence ($ algorithmIdentifier . $ bitString );
470+
471+ $ pem = "-----BEGIN PUBLIC KEY----- \n"
472+ . chunk_split (base64_encode ($ spki ), 64 , "\n" )
473+ . "-----END PUBLIC KEY----- \n" ;
474+
475+ return $ pem ;
476+ }
477+
478+ private static function derEncodeLength (int $ length ): string
479+ {
480+ if ($ length < 0x80 ) {
481+ return chr ($ length );
482+ }
483+ $ bytes = '' ;
484+ while ($ length > 0 ) {
485+ $ bytes = chr ($ length & 0xff ) . $ bytes ;
486+ $ length >>= 8 ;
487+ }
488+ return chr (0x80 | strlen ($ bytes )) . $ bytes ;
489+ }
490+
491+ private static function derEncodeSequence (string $ contents ): string
492+ {
493+ return "\x30" . self ::derEncodeLength (strlen ($ contents )) . $ contents ;
494+ }
495+
496+ private static function derEncodeUnsignedInteger (string $ bytes ): string
497+ {
498+ // Strip leading zero bytes, then re-prepend a single 0x00 if the
499+ // high bit of the first byte is set so the value remains positive.
500+ $ bytes = ltrim ($ bytes , "\x00" );
501+ if ($ bytes === '' ) {
502+ $ bytes = "\x00" ;
503+ } elseif ((ord ($ bytes [0 ]) & 0x80 ) !== 0 ) {
504+ $ bytes = "\x00" . $ bytes ;
505+ }
506+ return "\x02" . self ::derEncodeLength (strlen ($ bytes )) . $ bytes ;
507+ }
508+
509+ private static function derEncodeBitString (string $ bytes ): string
510+ {
511+ // 0x00 = number of unused bits in the final octet (always zero here).
512+ $ contents = "\x00" . $ bytes ;
513+ return "\x03" . self ::derEncodeLength (strlen ($ contents )) . $ contents ;
514+ }
345515}
0 commit comments