diff --git a/src/main/java/org/ohdsi/webapi/OidcConfCreator.java b/src/main/java/org/ohdsi/webapi/OidcConfCreator.java index 92aa5699d..35db51181 100644 --- a/src/main/java/org/ohdsi/webapi/OidcConfCreator.java +++ b/src/main/java/org/ohdsi/webapi/OidcConfCreator.java @@ -18,7 +18,6 @@ */ package org.ohdsi.webapi; -import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod; import org.pac4j.oidc.config.OidcConfiguration; import org.slf4j.Logger; @@ -107,7 +106,7 @@ public OidcConfiguration build() { scopes += extraScopes; } conf.setScope(scopes); - conf.setPreferredJwsAlgorithm(JWSAlgorithm.RS256); + // Use all algorithms from provider metadata (supports RS256, ES384, etc.) conf.setPkceMethod(CodeChallengeMethod.S256); try { diff --git a/src/main/java/org/ohdsi/webapi/shiro/filters/OidcJwtAuthFilter.java b/src/main/java/org/ohdsi/webapi/shiro/filters/OidcJwtAuthFilter.java new file mode 100644 index 000000000..764eb3a0c --- /dev/null +++ b/src/main/java/org/ohdsi/webapi/shiro/filters/OidcJwtAuthFilter.java @@ -0,0 +1,229 @@ +package org.ohdsi.webapi.shiro.filters; + +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWSHeader; +import com.nimbusds.jose.JWSVerifier; +import com.nimbusds.jose.crypto.ECDSAVerifier; +import com.nimbusds.jose.crypto.RSASSAVerifier; +import com.nimbusds.jose.jwk.ECKey; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.SignedJWT; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletRequest; +import org.apache.shiro.authc.AuthenticationException; +import org.ohdsi.webapi.shiro.PermissionManager; +import org.ohdsi.webapi.shiro.ServletBridge; +import org.ohdsi.webapi.shiro.tokens.JwtAuthToken; +import org.pac4j.oidc.config.OidcConfiguration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.URI; +import java.security.interfaces.ECPublicKey; +import java.security.interfaces.RSAPublicKey; +import java.text.ParseException; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Validates OIDC JWT bearer tokens using the provider's JWKS. + * Used for token exchange: external OIDC token -> WebAPI JWT. + */ +public class OidcJwtAuthFilter extends AtlasAuthFilter { + + private static final Logger logger = LoggerFactory.getLogger(OidcJwtAuthFilter.class); + private static final String AUTHORIZATION_HEADER = "Authorization"; + private static final String BEARER_PREFIX = "Bearer "; + private static final long JWKS_CACHE_DURATION_MS = 300_000; + + public static final String OIDC_EXTERNAL_TOKEN = "oidc_external_token"; + + private final OidcConfiguration oidcConfiguration; + private final PermissionManager authorizer; + private final Set defaultRoles; + private final Map keyCache = new ConcurrentHashMap<>(); + private volatile long lastJwksFetch = 0; + + public OidcJwtAuthFilter(OidcConfiguration oidcConfiguration, + PermissionManager authorizer, + Set defaultRoles, + int tokenExpirationIntervalInSeconds) { + this.oidcConfiguration = oidcConfiguration; + this.authorizer = authorizer; + this.defaultRoles = defaultRoles; + } + + @Override + protected JwtAuthToken createToken(ServletRequest request, ServletResponse response) throws Exception { + String bearerToken = extractBearerToken(request); + if (bearerToken == null) { + throw new AuthenticationException("No bearer token found"); + } + return new JwtAuthToken(verifyAndExtractSubject(bearerToken)); + } + + @Override + protected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws Exception { + String bearerToken = extractBearerToken(request); + if (bearerToken == null) { + return true; + } + + try { + String subject = verifyAndExtractSubject(bearerToken); + String name = extractName(bearerToken, subject); + authorizer.registerUser(subject, name, defaultRoles); + request.setAttribute(OIDC_EXTERNAL_TOKEN, true); + return executeLogin(request, response); + } catch (AuthenticationException e) { + logger.warn("OIDC JWT authentication failed for request from {}: {}", + request.getRemoteAddr(), e.getMessage()); + return true; + } + } + + private String extractBearerToken(ServletRequest request) { + HttpServletRequest httpRequest = ServletBridge.toHttp(request); + String authHeader = httpRequest.getHeader(AUTHORIZATION_HEADER); + if (authHeader != null && authHeader.startsWith(BEARER_PREFIX)) { + return authHeader.substring(BEARER_PREFIX.length()); + } + return null; + } + + private String verifyAndExtractSubject(String jwtToken) throws AuthenticationException { + try { + SignedJWT signedJwt = SignedJWT.parse(jwtToken); + JWSHeader header = signedJwt.getHeader(); + JWTClaimsSet claims = signedJwt.getJWTClaimsSet(); + + Date now = new Date(); + Date expiration = claims.getExpirationTime(); + if (expiration != null && expiration.before(now)) { + throw new AuthenticationException("Token expired"); + } + + Date notBefore = claims.getNotBeforeTime(); + if (notBefore != null && notBefore.after(now)) { + throw new AuthenticationException("Token not yet valid"); + } + + String expectedIssuer = getExpectedIssuer(); + if (expectedIssuer != null && !expectedIssuer.equals(claims.getIssuer())) { + throw new AuthenticationException("Invalid token issuer"); + } + + String expectedAudience = oidcConfiguration.getClientId(); + List audiences = claims.getAudience(); + if (expectedAudience != null && (audiences == null || !audiences.contains(expectedAudience))) { + throw new AuthenticationException("Invalid token audience"); + } + + JWK jwk = getKey(header.getKeyID()); + if (jwk == null) { + throw new AuthenticationException("Signing key not found"); + } + + if (!signedJwt.verify(createVerifier(jwk))) { + throw new AuthenticationException("Invalid signature"); + } + + String email = (String) claims.getClaim("email"); + return (email != null && !email.isEmpty()) ? email : claims.getSubject(); + + } catch (ParseException | JOSEException e) { + throw new AuthenticationException("JWT validation failed: " + e.getMessage(), e); + } + } + + private String extractName(String jwtToken, String fallback) { + try { + SignedJWT signedJwt = SignedJWT.parse(jwtToken); + String name = (String) signedJwt.getJWTClaimsSet().getClaim("name"); + return (name != null && !name.isEmpty()) ? name : fallback; + } catch (ParseException e) { + return fallback; + } + } + + private String getExpectedIssuer() { + try { + var resolver = oidcConfiguration.getOpMetadataResolver(); + if (resolver != null) { + var metadata = resolver.load(); + if (metadata != null) { + return metadata.getIssuer().getValue(); + } + } + } catch (Exception e) { + logger.warn("Failed to get OIDC issuer: {}", e.getMessage()); + } + return null; + } + + private JWK getKey(String kid) { + JWK jwk = keyCache.get(kid); + if (jwk == null) { + long currentTime = System.currentTimeMillis(); + if (currentTime - lastJwksFetch > JWKS_CACHE_DURATION_MS) { + synchronized (this) { + if (currentTime - lastJwksFetch > JWKS_CACHE_DURATION_MS) { + refreshJwks(); + } + } + jwk = keyCache.get(kid); + } + } + return jwk; + } + + private void refreshJwks() { + try { + URI jwksUri = getJwksUri(); + if (jwksUri == null) { + logger.error("No JWKS URI available"); + return; + } + + JWKSet jwkSet = JWKSet.load(jwksUri.toURL()); + keyCache.clear(); + for (JWK key : jwkSet.getKeys()) { + if (key.getKeyID() != null) { + keyCache.put(key.getKeyID(), key); + } + } + } catch (Exception e) { + logger.error("Failed to fetch JWKS: {}", e.getMessage()); + } finally { + lastJwksFetch = System.currentTimeMillis(); + } + } + + private URI getJwksUri() { + try { + var resolver = oidcConfiguration.getOpMetadataResolver(); + if (resolver != null) { + var metadata = resolver.load(); + if (metadata != null) { + return metadata.getJWKSetURI(); + } + } + } catch (Exception e) { + logger.warn("Failed to get JWKS URI: {}", e.getMessage()); + } + return null; + } + + private JWSVerifier createVerifier(JWK jwk) throws JOSEException { + if (jwk instanceof ECKey) { + return new ECDSAVerifier(((ECKey) jwk).toECPublicKey()); + } else if (jwk instanceof RSAKey) { + return new RSASSAVerifier(((RSAKey) jwk).toRSAPublicKey()); + } + throw new JOSEException("Unsupported key type: " + jwk.getKeyType()); + } +} diff --git a/src/main/java/org/ohdsi/webapi/shiro/filters/UpdateAccessTokenFilter.java b/src/main/java/org/ohdsi/webapi/shiro/filters/UpdateAccessTokenFilter.java index 26773cf15..4c78d6025 100644 --- a/src/main/java/org/ohdsi/webapi/shiro/filters/UpdateAccessTokenFilter.java +++ b/src/main/java/org/ohdsi/webapi/shiro/filters/UpdateAccessTokenFilter.java @@ -134,9 +134,12 @@ protected boolean preHandle(ServletRequest request, ServletResponse response) th String sessionId = (String) request.getAttribute(Constants.SESSION_ID); if (sessionId == null) { - final String token = TokenManager.extractToken(request); - if (token != null) { - sessionId = (String) TokenManager.getBody(token).get(Constants.SESSION_ID); + Boolean isOidcToken = (Boolean) request.getAttribute(OidcJwtAuthFilter.OIDC_EXTERNAL_TOKEN); + if (!Boolean.TRUE.equals(isOidcToken)) { + final String token = TokenManager.extractToken(request); + if (token != null) { + sessionId = (String) TokenManager.getBody(token).get(Constants.SESSION_ID); + } } } diff --git a/src/main/java/org/ohdsi/webapi/shiro/management/AtlasRegularSecurity.java b/src/main/java/org/ohdsi/webapi/shiro/management/AtlasRegularSecurity.java index 5290807eb..93d0c6a39 100644 --- a/src/main/java/org/ohdsi/webapi/shiro/management/AtlasRegularSecurity.java +++ b/src/main/java/org/ohdsi/webapi/shiro/management/AtlasRegularSecurity.java @@ -15,6 +15,7 @@ import org.ohdsi.webapi.shiro.Entities.UserRepository; import org.ohdsi.webapi.shiro.PermissionManager; import org.ohdsi.webapi.shiro.filters.*; +import org.ohdsi.webapi.shiro.filters.OidcJwtAuthFilter; import org.ohdsi.webapi.shiro.filters.auth.ActiveDirectoryAuthFilter; import org.ohdsi.webapi.shiro.filters.auth.AtlasJwtAuthFilter; import org.ohdsi.webapi.shiro.filters.auth.JdbcAuthFilter; @@ -49,8 +50,6 @@ import org.pac4j.oauth.client.Google2Client; import org.pac4j.oidc.client.OidcClient; import org.pac4j.oidc.config.OidcConfiguration; -import org.pac4j.oidc.credentials.authenticator.OidcAuthenticator; -import org.pac4j.http.client.direct.DirectBearerAuthClient; import org.pac4j.saml.client.SAML2Client; import org.pac4j.saml.config.SAML2Configuration; import org.slf4j.Logger; @@ -331,17 +330,18 @@ public Map getFilters() { clients.add(githubClient); } + OidcConfiguration oidcConfiguration = null; if (this.openidAuthEnabled) { - OidcConfiguration configuration = oidcConfCreator.build(); - if (StringUtils.isNotBlank(configuration.getClientId())) { + oidcConfiguration = oidcConfCreator.build(); + if (StringUtils.isNotBlank(oidcConfiguration.getClientId())) { // https://www.pac4j.org/4.0.x/docs/clients/openid-connect.html // OidcClient allows indirect login through UI with code flow - OidcClient oidcClient = new OidcClient(configuration); + OidcClient oidcClient = new OidcClient(oidcConfiguration); oidcClient.setCallbackUrl(oauthApiCallback); oidcClient.setCallbackUrlResolver(urlResolver); // URL rewriting: discovery from internal URL, redirect to external URL - String internalUrl = configuration.getDiscoveryURI(); + String internalUrl = oidcConfiguration.getDiscoveryURI(); String externalUrl = oidcConfCreator.getExternalUrl(); if (externalUrl != null && !externalUrl.isEmpty()) { org.ohdsi.webapi.shiro.filters.ExternalUrlOidcRedirectionActionBuilder redirectBuilder = @@ -353,12 +353,6 @@ public Map getFilters() { // Configuration already initialized; pac4j handles lazy init clients.add(oidcClient); - - // Bearer token authentication for API access (pac4j 6.x) - // OidcAuthenticator requires both configuration and client - OidcAuthenticator authenticator = new OidcAuthenticator(configuration, oidcClient); - DirectBearerAuthClient bearerClient = new DirectBearerAuthClient(authenticator); - clients.add(bearerClient); } else { logger.warn("openidAuth is enabled but no client id is provided"); } @@ -405,11 +399,6 @@ public Map getFilters() { oidcFilter.setConfig(cfg); oidcFilter.setClients("OidcClient"); filters.put(OIDC_AUTH, oidcFilter); - - SecurityFilter oidcDirectFilter = new SecurityFilter(); - oidcDirectFilter.setConfig(cfg); - oidcDirectFilter.setClients("HeaderClient"); - filters.put(OIDC_DIRECT_AUTH, oidcDirectFilter); } io.buji.pac4j.filter.CallbackFilter callbackFilter = new io.buji.pac4j.filter.CallbackFilter(); @@ -427,6 +416,17 @@ public Map getFilters() { filters.put(HANDLE_UNSUCCESSFUL_OAUTH, new RedirectOnFailedOAuthFilter(this.oauthUiCallback)); } + // OIDC token exchange filter + if (this.openidAuthEnabled && oidcConfiguration != null) { + OidcJwtAuthFilter oidcJwtFilter = new OidcJwtAuthFilter( + oidcConfiguration, + this.authorizer, + this.defaultRoles, + this.tokenExpirationIntervalInSeconds + ); + filters.put(OIDC_DIRECT_AUTH, oidcJwtFilter); + } + if (this.casAuthEnabled) { this.setUpCAS(filters); } @@ -440,8 +440,12 @@ public Map getFilters() { @Override protected FilterChainBuilder getFilterChainBuilder() { - List authcFilters = googleAccessTokenEnabled ? Arrays.asList(ACCESS_AUTHC, JWT_AUTHC) : - Collections.singletonList(JWT_AUTHC); + // Build authentication filter chain: try JWT first, then OIDC if enabled + List authcFilters = new ArrayList<>(); + if (googleAccessTokenEnabled) { + authcFilters.add(ACCESS_AUTHC); + } + authcFilters.add(JWT_AUTHC); // the order does matter - first match wins FilterChainBuilder filterChainBuilder = new FilterChainBuilder() .setRestFilters(SSL, NO_SESSION_CREATION, CORS, NO_CACHE)