Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KNOX-3073 - Token verification fallback to Knox keys behavior should be configurable #949

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.knox.gateway.hadoopauth.filter;

import static org.apache.knox.gateway.provider.federation.jwt.filter.AbstractJWTFilter.JWT_INSTANCE_KEY_FALLBACK;
import static org.easymock.EasyMock.anyString;
import static org.easymock.EasyMock.capture;
import static org.easymock.EasyMock.captureInt;
Expand Down Expand Up @@ -577,6 +578,7 @@ private HadoopAuthFilter testIfJwtSupported(String supportJwt) throws Exception
expect(filterConfig.getInitParameter("support.jwt")).andReturn(supportJwt).anyTimes();
expect(filterConfig.getInitParameter("hadoop.auth.unauthenticated.path.list")).andReturn(null).anyTimes();
expect(filterConfig.getInitParameter("clusterName")).andReturn("topology1").anyTimes();
expect(filterConfig.getInitParameter(JWT_INSTANCE_KEY_FALLBACK)).andReturn("false").anyTimes();
final boolean isJwtSupported = Boolean.parseBoolean(supportJwt);
if (isJwtSupported) {
expect(filterConfig.getInitParameter(JWTFederationFilter.KNOX_TOKEN_AUDIENCES)).andReturn(null).anyTimes();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.Arrays;
import java.util.Date;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
Expand Down Expand Up @@ -102,6 +103,9 @@ public abstract class AbstractJWTFilter implements Filter {
public static final String JWT_EXPECTED_SIGALG = "jwt.expected.sigalg";
public static final String JWT_DEFAULT_SIGALG = "RS256";

public static final String JWT_INSTANCE_KEY_FALLBACK = "jwt.instance.key.fallback";
public static final boolean JWT_INSTANCE_KEY_FALLBACK_DEFAULT = false;

static JWTMessages log = MessagesFactory.get( JWTMessages.class );

private static AuditService auditService = AuditServiceFactory.getAuditService();
Expand All @@ -116,13 +120,14 @@ public abstract class AbstractJWTFilter implements Filter {
private String expectedIssuer;
private String expectedSigAlg;
protected String expectedPrincipalClaim;
protected Set<URI> expectedJWKSUrls = new HashSet();
protected Set<URI> expectedJWKSUrls = new LinkedHashSet();
protected Set<JOSEObjectType> allowedJwsTypes;

private TokenStateService tokenStateService;
private TokenMAC tokenMAC;
protected long idleTimeoutSeconds = -1;
protected String topologyName;
protected boolean isJwtInstanceKeyFallback = JWT_INSTANCE_KEY_FALLBACK_DEFAULT;

@Override
public abstract void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
Expand Down Expand Up @@ -158,6 +163,9 @@ public void init( FilterConfig filterConfig ) throws ServletException {
// Setup the verified tokens cache
topologyName = context != null ? (String) context.getAttribute(GatewayServices.GATEWAY_CLUSTER_ATTRIBUTE) : null;
signatureVerificationCache = SignatureVerificationCache.getInstance(topologyName, filterConfig);

String fallbackConfig = filterConfig.getInitParameter(JWT_INSTANCE_KEY_FALLBACK);
isJwtInstanceKeyFallback = fallbackConfig != null ? Boolean.parseBoolean(fallbackConfig) : JWT_INSTANCE_KEY_FALLBACK_DEFAULT;
}

protected void configureExpectedParameters(FilterConfig filterConfig) {
Expand Down Expand Up @@ -512,17 +520,22 @@ protected boolean verifyTokenSignature(final JWT token) {
// If it has not yet been verified, then perform the verification now
if (!verified) {
try {
boolean attemptedPEMVerification = false;
boolean attemptedJWKSVerification = false;

if (publicKey != null) {
attemptedPEMVerification = true;
verified = authority.verifyToken(token, publicKey);
log.pemVerificationResultMessage(verified);
}

if (!verified && expectedJWKSUrls != null && !expectedJWKSUrls.isEmpty()) {
attemptedJWKSVerification = true;
verified = authority.verifyToken(token, expectedJWKSUrls, expectedSigAlg, allowedJwsTypes);
log.jwksVerificationResultMessage(verified);
}

if(!verified) {
if(!verified && ((!attemptedPEMVerification && !attemptedJWKSVerification) || isJwtInstanceKeyFallback)) {
verified = authority.verifyToken(token);
log.signingKeyVerificationResultMessage(verified);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
import java.util.UUID;
import java.util.concurrent.TimeUnit;

import static org.apache.knox.gateway.provider.federation.jwt.filter.AbstractJWTFilter.JWT_INSTANCE_KEY_FALLBACK;
import static org.apache.knox.gateway.provider.federation.jwt.filter.JWTFederationFilter.JWKS_URL;
import static org.junit.Assert.fail;

public abstract class AbstractJWTFilterTest {
Expand Down Expand Up @@ -627,6 +629,9 @@ public void testSignatureVerificationChain() throws Exception {
/* Add a failing PEM */
props.put(getVerificationPemProperty(), failingPem);

/* Turn fallback to signing key on */
props.put(JWT_INSTANCE_KEY_FALLBACK, "true");

/* This handler is setup with a publicKey, corresponding privateKey is used to sign the JWT below */
handler.init(new TestFilterConfig(props));

Expand Down Expand Up @@ -660,14 +665,15 @@ public void testSignatureVerificationChain() throws Exception {
* This will test the signature verification chain.
* Specifically the flow when provided PEM is not invalid and
* knox signing key is valid.
*
* AND JWT_INSTANCE_KEY_FALLBACK is true
* NOTE: here valid means can validate JWT.
* @throws Exception
*/
@Test
public void testSignatureVerificationChainWithPEMandSignature() throws Exception {
try {
Properties props = getProperties();
props.put(JWT_INSTANCE_KEY_FALLBACK, "true");
KeyPairGenerator kpg = KeyPairGenerator.getInstance("RSA");
kpg.initialize(2048);

Expand Down Expand Up @@ -709,6 +715,128 @@ public void testSignatureVerificationChainWithPEMandSignature() throws Exception
}
}

@Test
public void testNoPEMOrJwksWithoutFallback() throws Exception {
// Test fallback disabled, but not PEM configured.
// You can't disable key fallback without specifying an explicit verification method.
boolean verified = doTestSignatureVerificationChain(null, null, false);
Assert.assertTrue("Token should have been verified.", verified);
}

@Test
public void testNoPEMOrJwksWithFallback() throws Exception {
boolean verified = doTestSignatureVerificationChain(null, null, true);
Assert.assertTrue("Token should have been verified by falling back to keys.", verified);
}

@Test
public void testInvalidPEMNoJwksWithFallback() throws Exception {
boolean verified = doTestSignatureVerificationChain(pem, null, true);
Assert.assertTrue("Token should have been verified by falling back to keys.", verified);
}

@Test
public void testInvalidPEMNoJwksWithoutFallback() throws Exception {
String invalidPEM = generateInvalidPEM();
boolean verified = doTestSignatureVerificationChain(invalidPEM, null, false);
Assert.assertFalse("Token should NOT have been verified.", verified);
}

@Test
public void testNoPEMInvalidJwksWithoutFallback() throws Exception {
boolean verified = doTestSignatureVerificationChain(null, "https://localhost/nonesense", false);
Assert.assertFalse("Token should have NOT been verified.", verified);
}

@Test
public void testNoPEMInvalidJwksWithFallback() throws Exception {
boolean verified = doTestSignatureVerificationChain(null, "https://localhost/nonesense", true);
Assert.assertTrue("Token should have been verified by falling back to keys.", verified);
}

@Test
public void testInvalidPEMInvalidJwksWithoutFallback() throws Exception {
String invalidPEM = generateInvalidPEM();
boolean verified = doTestSignatureVerificationChain(invalidPEM, "https://localhost/nonesense", false);
Assert.assertFalse("Token should NOT have been verified.", verified);
}

@Test
public void testInvalidPEMInvalidJwksWithFallback() throws Exception {
String invalidPEM = generateInvalidPEM();
boolean verified = doTestSignatureVerificationChain(invalidPEM, "https://localhost/nonesense", true);
Assert.assertTrue("Token should have been verified by falling back to keys.", verified);
}

protected String generateInvalidPEM() throws Exception {
KeyPairGenerator kpg = KeyPairGenerator.getInstance("RSA");
kpg.initialize(2048);

KeyPair KPair = kpg.generateKeyPair();
String dn = buildDistinguishedName(InetAddress.getLocalHost().getHostName());
Certificate cert = X509CertificateUtil.generateCertificate(dn, KPair, 365, "SHA1withRSA");
byte[] data = cert.getEncoded();
Base64 encoder = new Base64( 76, "\n".getBytes( StandardCharsets.US_ASCII ) );
return new String(encoder.encodeToString( data ).getBytes( StandardCharsets.US_ASCII ), StandardCharsets.US_ASCII).trim();
}

/**
* This will test the signature verification chain in the following order
* 1. PEM - check if PEM is configured and signature is validated
* 2. JWKS - check if endpoint id configured if not skip
* 3. Knox signing key - if the above two fail try to validate using knox signing cert
* @throws Exception
*/
public boolean doTestSignatureVerificationChain(final String testPEM,
final String testJwks,
final boolean fallbackToKeys) throws Exception {
boolean isVerified = false;

try {
Properties props = getProperties();
props.put(getAudienceProperty(), "bar");

if (testPEM != null) {
// Add a test PEM
props.put(getVerificationPemProperty(), testPEM);
}

if (testJwks != null) {
// Add the test JWKS URL
props.put(JWKS_URL, testJwks);
}

// Configure fallback to signing key on
props.put(JWT_INSTANCE_KEY_FALLBACK, String.valueOf(fallbackToKeys));

// This handler is setup with a publicKey, corresponding privateKey is used to sign the JWT below
handler.init(new TestFilterConfig(props));

SignedJWT jwt = getJWT(AbstractJWTFilter.JWT_DEFAULT_ISSUER, "alice",
new Date(new Date().getTime() + TimeUnit.MINUTES.toMillis(10)), privateKey);

HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class);
setTokenOnRequest(request, jwt);

EasyMock.expect(request.getRequestURL()).andReturn(new StringBuffer(SERVICE_URL)).anyTimes();
EasyMock.expect(request.getPathInfo()).andReturn("resource").anyTimes();
EasyMock.expect(request.getQueryString()).andReturn(null);
HttpServletResponse response = EasyMock.createNiceMock(HttpServletResponse.class);
EasyMock.expect(response.encodeRedirectURL(SERVICE_URL)).andReturn(SERVICE_URL);
EasyMock.expect(response.getOutputStream()).andAnswer(DummyServletOutputStream::new).anyTimes();
EasyMock.replay(request, response);

TestFilterChain chain = new TestFilterChain();
handler.doFilter(request, response, chain);
isVerified = chain.doFilterCalled;

} catch (ServletException se) {
fail("Should NOT have thrown a ServletException.");
}

return isVerified;
}

@Test
public void testInvalidIssuer() throws Exception {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,39 @@ public void testAlternativeCaseUsername() throws Exception {
}
}

@Override
@Test
public void testNoPEMInvalidJwksWithoutFallback() throws Exception {
// No-op: This filter does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}

@Override
@Test
public void testNoPEMInvalidJwksWithFallback() throws Exception {
// No-op: This filter does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}

@Override
@Test
public void testInvalidPEMNoJwksWithoutFallback() throws Exception {
// No-op: This filter does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}

@Override
@Test
public void testInvalidPEMInvalidJwksWithoutFallback() throws Exception {
// No-op: This filter does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}

@Override
@Test
public void testInvalidPEMInvalidJwksWithFallback() throws Exception {
// No-op: This filter does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,34 @@ public void testProxiedDefaultAuthenticationProviderURLWithoutMismatchInXForward
Assert.assertEquals(loginURL, "https://remotehost/notgateway/knoxsso/api/v1/websso?originalUrl=" + "https://remotehost/resource");
}

@Override
@Test
public void testNoPEMInvalidJwksWithoutFallback() throws Exception {
// No-op: The SSOCookieProvider does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}

@Override
@Test
public void testNoPEMInvalidJwksWithFallback() throws Exception {
// No-op: The SSOCookieProvider does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}

@Override
@Test
public void testInvalidPEMInvalidJwksWithoutFallback() throws Exception {
// No-op: The SSOCookieProvider does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}

@Override
@Test
public void testInvalidPEMInvalidJwksWithFallback() throws Exception {
// No-op: The SSOCookieProvider does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}

@Test
public void testIdleTimoutExceeded() throws Exception {
final TokenStateService tokenStateService = EasyMock.createNiceMock(TokenStateService.class);
Expand Down