Skip to content

Commit 394c85f

Browse files
committed
Use PSS padding instead of RSS
1 parent 5a4f9fc commit 394c85f

3 files changed

Lines changed: 395 additions & 59 deletions

File tree

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/JwtHelper.java

Lines changed: 169 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
package com.microsoft.aad.msal4j;
55

66
import java.nio.charset.StandardCharsets;
7+
import java.security.InvalidKeyException;
78
import java.security.Signature;
9+
import java.security.spec.MGF1ParameterSpec;
10+
import java.security.spec.PSSParameterSpec;
811
import java.util.ArrayList;
912
import java.util.Base64;
1013
import java.util.HashMap;
@@ -22,62 +25,183 @@ static ClientAssertion buildJwt(String clientId, final ClientCertificate credent
2225
ParameterValidationUtils.validateNotNull("credential", clientId);
2326

2427
try {
25-
final long time = System.currentTimeMillis();
26-
27-
// Build header
28-
Map<String, Object> header = new HashMap<>();
29-
header.put("alg", "RS256");
30-
header.put("typ", "JWT");
31-
32-
if (sendX5c) {
33-
List<String> certs = new ArrayList<>(credential.getEncodedPublicKeyCertificateChain());
34-
header.put("x5c", certs);
35-
}
36-
37-
//SHA-256 is preferred, however certain flows still require SHA-1 due to what is supported server-side. If SHA-256
38-
// is not supported or the IClientCredential.publicCertificateHash256() method is not implemented, the library will default to SHA-1.
39-
String hash256 = credential.publicCertificateHash256();
40-
if (useSha1 || hash256 == null) {
41-
header.put("x5t", credential.publicCertificateHash());
42-
} else {
43-
header.put("x5t#S256", hash256);
28+
// First try with PS256 (preferred)
29+
return generatePS256Jwt(clientId, credential, jwtAudience, sendX5c, useSha1);
30+
} catch (InvalidKeyException e) {
31+
// If the key isn't compatible with PSS, fall back to RS256.
32+
// This is for backwards compatibility, as the Signature instance created with SHA256withRSA
33+
// accepted key types that weren't RSAPrivateKey but the RSASSA-PSS signature does not.
34+
try {
35+
return generateRs256Jwt(clientId, credential, jwtAudience, sendX5c, useSha1);
36+
} catch (Exception fallbackException) {
37+
throw new MsalClientException(fallbackException);
4438
}
39+
} catch (Exception e) {
40+
throw new MsalClientException(e);
41+
}
42+
}
4543

46-
// Build payload
47-
Map<String, Object> payload = new HashMap<>();
48-
payload.put("aud", jwtAudience);
49-
payload.put("iss", clientId);
50-
payload.put("jti", UUID.randomUUID().toString());
51-
payload.put("nbf", time / 1000);
52-
payload.put("exp", time / 1000 + Constants.AAD_JWT_TOKEN_LIFETIME_SECONDS);
53-
payload.put("sub", clientId);
44+
/**
45+
* Generates a JWT signed using the PS256 algorithm (RSASSA-PSS with SHA-256).
46+
*
47+
* @param clientId The client ID to use as the issuer and subject
48+
* @param credential The certificate credential used for signing
49+
* @param jwtAudience The audience claim for the JWT
50+
* @param sendX5c Whether to include the x5c header with certificate chain
51+
* @param useSha1 Whether to use SHA-1 hash for thumbprint instead of SHA-256
52+
* @return A ClientAssertion containing the signed JWT
53+
* @throws Exception If JWT creation or signing fails
54+
*/
55+
private static ClientAssertion generatePS256Jwt(String clientId, ClientCertificate credential,
56+
String jwtAudience, boolean sendX5c,
57+
boolean useSha1) throws Exception {
58+
// Build header with PS256 algorithm
59+
Map<String, Object> header = createHeader(credential, sendX5c, useSha1, "PS256");
60+
61+
// Build payload
62+
Map<String, Object> payload = createPayload(clientId, jwtAudience, System.currentTimeMillis());
63+
64+
// Encode header and payload
65+
String jsonHeader = JsonHelper.writeJsonMap(header);
66+
String jsonPayload = JsonHelper.writeJsonMap(payload);
67+
String encodedHeader = base64UrlEncode(jsonHeader.getBytes(StandardCharsets.UTF_8));
68+
String encodedPayload = base64UrlEncode(jsonPayload.getBytes(StandardCharsets.UTF_8));
69+
String dataToSign = encodedHeader + "." + encodedPayload;
70+
71+
// Sign with PS256
72+
byte[] signatureBytes = signWithPS256(credential, dataToSign);
73+
String encodedSignature = base64UrlEncode(signatureBytes);
74+
75+
// Build the JWT
76+
String jwt = dataToSign + "." + encodedSignature;
77+
return new ClientAssertion(jwt);
78+
}
5479

55-
// Concatenate header and payload
56-
String jsonHeader = JsonHelper.writeJsonMap(header);
57-
String jsonPayload = JsonHelper.writeJsonMap(payload);
80+
/**
81+
* Generates a JWT signed using the RS256 algorithm (RSASSA-PKCS1-v1_5 with SHA-256).
82+
* This is used as a fallback when PS256 is not supported by the private key.
83+
*
84+
* @param clientId The client ID to use as the issuer and subject
85+
* @param credential The certificate credential used for signing
86+
* @param jwtAudience The audience claim for the JWT
87+
* @param sendX5c Whether to include the x5c header with certificate chain
88+
* @param useSha1 Whether to use SHA-1 hash for thumbprint instead of SHA-256
89+
* @return A ClientAssertion containing the signed JWT
90+
* @throws Exception If JWT creation or signing fails
91+
*/
92+
private static ClientAssertion generateRs256Jwt(String clientId, ClientCertificate credential,
93+
String jwtAudience, boolean sendX5c,
94+
boolean useSha1) throws Exception {
95+
// Build header with RS256 algorithm
96+
Map<String, Object> header = createHeader(credential, sendX5c, useSha1, "RS256");
97+
98+
// Build payload
99+
Map<String, Object> payload = createPayload(clientId, jwtAudience, System.currentTimeMillis());
100+
101+
// Encode header and payload
102+
String jsonHeader = JsonHelper.writeJsonMap(header);
103+
String jsonPayload = JsonHelper.writeJsonMap(payload);
104+
String encodedHeader = base64UrlEncode(jsonHeader.getBytes(StandardCharsets.UTF_8));
105+
String encodedPayload = base64UrlEncode(jsonPayload.getBytes(StandardCharsets.UTF_8));
106+
String dataToSign = encodedHeader + "." + encodedPayload;
107+
108+
// Sign with RS256
109+
byte[] signatureBytes = signWithRS256(credential, dataToSign);
110+
String encodedSignature = base64UrlEncode(signatureBytes);
111+
112+
// Build the JWT
113+
String jwt = dataToSign + "." + encodedSignature;
114+
return new ClientAssertion(jwt);
115+
}
58116

59-
String encodedHeader = base64UrlEncode(jsonHeader.getBytes(StandardCharsets.UTF_8));
60-
String encodedPayload = base64UrlEncode(jsonPayload.getBytes(StandardCharsets.UTF_8));
117+
/**
118+
* Creates the JWT header with the specified algorithm and certificate information.
119+
*
120+
* @param credential The certificate credential containing thumbprint and chain
121+
* @param sendX5c Whether to include the x5c header with certificate chain
122+
* @param useSha1 Whether to use SHA-1 hash for thumbprint instead of SHA-256
123+
* @param algorithm The signing algorithm to specify in the header (PS256 or RS256)
124+
* @return A map containing the JWT header claims
125+
* @throws Exception If certificate operations fail
126+
*/
127+
private static Map<String, Object> createHeader(ClientCertificate credential, boolean sendX5c,
128+
boolean useSha1, String algorithm) throws Exception {
129+
Map<String, Object> header = new HashMap<>();
130+
header.put("alg", algorithm);
131+
header.put("typ", "JWT");
132+
133+
if (sendX5c) {
134+
List<String> certs = new ArrayList<>(credential.getEncodedPublicKeyCertificateChain());
135+
header.put("x5c", certs);
136+
}
61137

62-
// Create signature
63-
String dataToSign = encodedHeader + "." + encodedPayload;
138+
// SHA-256 is preferred, however certain flows still require SHA-1
139+
String hash256 = credential.publicCertificateHash256();
140+
if (useSha1 || hash256 == null) {
141+
header.put("x5t", credential.publicCertificateHash());
142+
} else {
143+
header.put("x5t#S256", hash256);
144+
}
64145

65-
Signature sig = Signature.getInstance("SHA256withRSA");
66-
sig.initSign(credential.privateKey());
67-
sig.update(dataToSign.getBytes(StandardCharsets.UTF_8));
68-
byte[] signatureBytes = sig.sign();
146+
return header;
147+
}
69148

70-
String encodedSignature = base64UrlEncode(signatureBytes);
149+
/**
150+
* Creates the JWT payload with standard claims.
151+
*
152+
* @param clientId The client ID to use as the issuer and subject
153+
* @param audience The audience claim for the JWT
154+
* @param time The current time in milliseconds
155+
* @return A map containing the JWT payload claims
156+
*/
157+
private static Map<String, Object> createPayload(String clientId, String audience, long time) {
158+
Map<String, Object> payload = new HashMap<>();
159+
payload.put("aud", audience);
160+
payload.put("iss", clientId);
161+
payload.put("jti", UUID.randomUUID().toString());
162+
payload.put("nbf", time / 1000);
163+
payload.put("exp", time / 1000 + Constants.AAD_JWT_TOKEN_LIFETIME_SECONDS);
164+
payload.put("sub", clientId);
165+
return payload;
166+
}
71167

72-
// Build the JWT
73-
String jwt = dataToSign + "." + encodedSignature;
168+
/**
169+
* Signs data using the PS256 algorithm (RSASSA-PSS with SHA-256).
170+
*
171+
* @param credential The certificate credential containing the private key
172+
* @param dataToSign The data to sign
173+
* @return The signature bytes
174+
* @throws Exception If signing fails
175+
*/
176+
private static byte[] signWithPS256(ClientCertificate credential, String dataToSign) throws Exception {
177+
Signature sig = Signature.getInstance("RSASSA-PSS");
178+
sig.setParameter(new PSSParameterSpec("SHA-256", "MGF1", MGF1ParameterSpec.SHA256, 32, 1));
179+
sig.initSign(credential.privateKey());
180+
sig.update(dataToSign.getBytes(StandardCharsets.UTF_8));
181+
return sig.sign();
182+
}
74183

75-
return new ClientAssertion(jwt);
76-
} catch (final Exception e) {
77-
throw new MsalClientException(e);
78-
}
184+
/**
185+
* Signs data using the RS256 algorithm (RSASSA-PKCS1-v1_5 with SHA-256).
186+
*
187+
* @param credential The certificate credential containing the private key
188+
* @param dataToSign The data to sign
189+
* @return The signature bytes
190+
* @throws Exception If signing fails
191+
*/
192+
private static byte[] signWithRS256(ClientCertificate credential, String dataToSign) throws Exception {
193+
Signature sig = Signature.getInstance("SHA256withRSA");
194+
sig.initSign(credential.privateKey());
195+
sig.update(dataToSign.getBytes(StandardCharsets.UTF_8));
196+
return sig.sign();
79197
}
80198

199+
/**
200+
* Encodes bytes using Base64URL encoding without padding.
201+
*
202+
* @param data The data to encode
203+
* @return The Base64URL encoded string
204+
*/
81205
private static String base64UrlEncode(byte[] data) {
82206
return Base64.getUrlEncoder().withoutPadding().encodeToString(data);
83207
}

msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/HelperAndUtilityTests.java

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33

44
package com.microsoft.aad.msal4j;
55

6+
import com.nimbusds.jwt.SignedJWT;
67
import org.junit.jupiter.api.Test;
78

89
import java.nio.charset.StandardCharsets;
910
import java.security.NoSuchAlgorithmException;
11+
import java.security.PrivateKey;
1012
import java.security.cert.CertificateEncodingException;
13+
import java.security.interfaces.RSAPrivateKey;
1114
import java.util.*;
1215

1316
import static org.junit.jupiter.api.Assertions.*;
@@ -115,7 +118,7 @@ void JwtHelper_buildJwt_ValidSha1AndSha256Assertions() throws MsalClientExceptio
115118

116119
// Decode and verify headers
117120
String headerJson = new String(Base64.getUrlDecoder().decode(jwtParts[0]));
118-
assertTrue(headerJson.contains("\"alg\":\"RS256\""), "Header should specify RS256 algorithm");
121+
assertTrue(headerJson.contains("\"alg\":\"PS256\""), "Header should specify RS256 algorithm");
119122
assertTrue(headerJson.contains("\"typ\":\"JWT\""), "Header should specify JWT type");
120123
assertTrue(headerJson.contains("\"x5t#S256\":\"certificateHash256\""), "Header should contain x5t#S256");
121124
assertTrue(headerJson.contains("\"x5c\":[\"cert1\",\"cert2\"]"), "Header should contain x5c");
@@ -187,4 +190,99 @@ void JsonHelper_createIdTokenFromEncodedTokenString_InvalidJsonInToken() {
187190

188191
assertEquals(AuthenticationErrorCode.INVALID_JSON, exception.errorCode());
189192
}
193+
194+
@Test
195+
void JwtHelper_buildJwt_UsesPSS256WhenSupported() throws Exception {
196+
// Create a certificate mock with an RSAPrivateKey that supports PSS
197+
RSAPrivateKey rsaPrivateKey = (RSAPrivateKey) TestHelper.getPrivateKey();
198+
199+
ClientCertificate clientCertificateMock = mock(ClientCertificate.class);
200+
when(clientCertificateMock.privateKey()).thenReturn(rsaPrivateKey);
201+
when(clientCertificateMock.publicCertificateHash()).thenReturn("certificateHash");
202+
when(clientCertificateMock.publicCertificateHash256()).thenReturn("certificateHash256");
203+
when(clientCertificateMock.getEncodedPublicKeyCertificateChain()).thenReturn(Arrays.asList("cert1", "cert2"));
204+
205+
String clientId = "clientId";
206+
String audience = "https://login.microsoftonline.com/common/oauth2/v2.0/token";
207+
208+
// Create the JWT
209+
ClientAssertion clientAssertion = JwtHelper.buildJwt(clientId, clientCertificateMock, audience, true, false);
210+
211+
assertNotNull(clientAssertion);
212+
String jwt = clientAssertion.assertion();
213+
String[] jwtParts = jwt.split("\\.");
214+
assertEquals(3, jwtParts.length, "JWT should have three parts");
215+
216+
// Decode and verify header uses PS256
217+
String headerJson = new String(Base64.getUrlDecoder().decode(jwtParts[0]));
218+
assertTrue(headerJson.contains("\"alg\":\"PS256\""), "Header should specify PS256 algorithm");
219+
220+
// Parse the JWT to verify the algorithm is PS256
221+
SignedJWT signedJWT = SignedJWT.parse(jwt);
222+
assertEquals("PS256", signedJWT.getHeader().getAlgorithm().getName(), "JWT should use PS256 algorithm");
223+
}
224+
225+
@Test
226+
void JwtHelper_buildJwt_FallsBackToRS256WhenPSSNotSupported() throws Exception {
227+
// When loaded from the Windows-MY keystore the PrivateKey will be a sun.security.mscapi.CPrivateKey,
228+
// which for some reason works with the library's older RS256 signature but not the newer PSS signature.
229+
PrivateKey nonRsaCompatibleKey = TestHelper.getPrivateKeyFromKeystore();
230+
231+
// This key should cause the PSS code to fail with an InvalidKeyException
232+
ClientCertificate clientCertificateMock = mock(ClientCertificate.class);
233+
when(clientCertificateMock.privateKey()).thenReturn(nonRsaCompatibleKey);
234+
when(clientCertificateMock.publicCertificateHash()).thenReturn("certificateHash");
235+
when(clientCertificateMock.publicCertificateHash256()).thenReturn("certificateHash256");
236+
when(clientCertificateMock.getEncodedPublicKeyCertificateChain()).thenReturn(Arrays.asList("cert1", "cert2"));
237+
238+
String clientId = "clientId";
239+
String audience = "https://login.microsoftonline.com/common/oauth2/v2.0/token";
240+
241+
// Create the JWT - this should fallback to RS256
242+
ClientAssertion clientAssertion = JwtHelper.buildJwt(clientId, clientCertificateMock, audience, true, false);
243+
244+
assertNotNull(clientAssertion);
245+
String jwt = clientAssertion.assertion();
246+
String[] jwtParts = jwt.split("\\.");
247+
assertEquals(3, jwtParts.length, "JWT should have three parts");
248+
249+
// Decode and verify header uses RS256 as fallback
250+
String headerJson = new String(Base64.getUrlDecoder().decode(jwtParts[0]));
251+
assertTrue(headerJson.contains("\"alg\":\"RS256\""), "Header should specify RS256 algorithm as fallback");
252+
}
253+
254+
@Test
255+
void JwtHelper_buildJwt_UsesCorrectSignatureAlgorithmsBasedOnKeyType() throws Exception {
256+
// Use real keys for both RSA and non-RSA tests
257+
RSAPrivateKey rsaPrivateKey = (RSAPrivateKey) TestHelper.getPrivateKey();
258+
PrivateKey nonRsaPrivateKey = TestHelper.privateKeyFromKeystore;
259+
260+
ClientCertificate rsaCertMock = mock(ClientCertificate.class);
261+
when(rsaCertMock.privateKey()).thenReturn(rsaPrivateKey);
262+
when(rsaCertMock.publicCertificateHash256()).thenReturn("certHash256");
263+
when(rsaCertMock.getEncodedPublicKeyCertificateChain()).thenReturn(Arrays.asList("cert1", "cert2"));
264+
265+
ClientCertificate nonRsaCertMock = mock(ClientCertificate.class);
266+
when(nonRsaCertMock.privateKey()).thenReturn(nonRsaPrivateKey);
267+
when(nonRsaCertMock.publicCertificateHash256()).thenReturn("certHash256");
268+
when(nonRsaCertMock.getEncodedPublicKeyCertificateChain()).thenReturn(Arrays.asList("cert1", "cert2"));
269+
270+
String clientId = "clientId";
271+
String audience = "https://login.microsoftonline.com/common/oauth2/v2.0/token";
272+
273+
// Test RSA key -> should use PS256
274+
ClientAssertion rsaAssertion = JwtHelper.buildJwt(clientId, rsaCertMock, audience, true, false);
275+
String rsaJwt = rsaAssertion.assertion();
276+
String rsaHeader = new String(Base64.getUrlDecoder().decode(rsaJwt.split("\\.")[0]));
277+
assertTrue(rsaHeader.contains("\"alg\":\"PS256\""), "RSA key should produce PS256 algorithm");
278+
279+
// Test non-RSA key -> should fallback to RS256
280+
ClientAssertion nonRsaAssertion = JwtHelper.buildJwt(clientId, nonRsaCertMock, audience, true, false);
281+
String nonRsaJwt = nonRsaAssertion.assertion();
282+
String nonRsaHeader = new String(Base64.getUrlDecoder().decode(nonRsaJwt.split("\\.")[0]));
283+
assertTrue(nonRsaHeader.contains("\"alg\":\"RS256\""), "Non-RSA key should fallback to RS256 algorithm");
284+
285+
// Verify we're actually using different keys for the different tests
286+
assertNotEquals(rsaJwt, nonRsaJwt, "The two assertions should be different");
287+
}
190288
}

0 commit comments

Comments
 (0)