Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/main/java/org/ohdsi/webapi/OidcConfCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/
package org.ohdsi.webapi;

import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod;
import org.pac4j.oidc.config.OidcConfiguration;
import org.slf4j.Logger;
Expand Down Expand Up @@ -107,7 +106,7 @@ public OidcConfiguration build() {
scopes += extraScopes;
}
conf.setScope(scopes);
conf.setPreferredJwsAlgorithm(JWSAlgorithm.RS256);
// Use all algorithms from provider metadata (supports RS256, ES384, etc.)
conf.setPkceMethod(CodeChallengeMethod.S256);

try {
Expand Down
229 changes: 229 additions & 0 deletions src/main/java/org/ohdsi/webapi/shiro/filters/OidcJwtAuthFilter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
package org.ohdsi.webapi.shiro.filters;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jose.crypto.ECDSAVerifier;
import com.nimbusds.jose.crypto.RSASSAVerifier;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import org.apache.shiro.authc.AuthenticationException;
import org.ohdsi.webapi.shiro.PermissionManager;
import org.ohdsi.webapi.shiro.ServletBridge;
import org.ohdsi.webapi.shiro.tokens.JwtAuthToken;
import org.pac4j.oidc.config.OidcConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.URI;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/**
* Validates OIDC JWT bearer tokens using the provider's JWKS.
* Used for token exchange: external OIDC token -> WebAPI JWT.
*/
public class OidcJwtAuthFilter extends AtlasAuthFilter {

private static final Logger logger = LoggerFactory.getLogger(OidcJwtAuthFilter.class);
private static final String AUTHORIZATION_HEADER = "Authorization";
private static final String BEARER_PREFIX = "Bearer ";
private static final long JWKS_CACHE_DURATION_MS = 300_000;

public static final String OIDC_EXTERNAL_TOKEN = "oidc_external_token";

private final OidcConfiguration oidcConfiguration;
private final PermissionManager authorizer;
private final Set<String> defaultRoles;
private final Map<String, JWK> keyCache = new ConcurrentHashMap<>();
private volatile long lastJwksFetch = 0;

public OidcJwtAuthFilter(OidcConfiguration oidcConfiguration,
PermissionManager authorizer,
Set<String> defaultRoles,
int tokenExpirationIntervalInSeconds) {
this.oidcConfiguration = oidcConfiguration;
this.authorizer = authorizer;
this.defaultRoles = defaultRoles;
}

@Override
protected JwtAuthToken createToken(ServletRequest request, ServletResponse response) throws Exception {
String bearerToken = extractBearerToken(request);
if (bearerToken == null) {
throw new AuthenticationException("No bearer token found");
}
return new JwtAuthToken(verifyAndExtractSubject(bearerToken));
}

@Override
protected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws Exception {
String bearerToken = extractBearerToken(request);
if (bearerToken == null) {
return true;
}

try {
String subject = verifyAndExtractSubject(bearerToken);
String name = extractName(bearerToken, subject);
authorizer.registerUser(subject, name, defaultRoles);
request.setAttribute(OIDC_EXTERNAL_TOKEN, true);
return executeLogin(request, response);
} catch (AuthenticationException e) {
logger.warn("OIDC JWT authentication failed for request from {}: {}",
request.getRemoteAddr(), e.getMessage());
return true;
}
}

private String extractBearerToken(ServletRequest request) {
HttpServletRequest httpRequest = ServletBridge.toHttp(request);
String authHeader = httpRequest.getHeader(AUTHORIZATION_HEADER);
if (authHeader != null && authHeader.startsWith(BEARER_PREFIX)) {
return authHeader.substring(BEARER_PREFIX.length());
}
return null;
}

private String verifyAndExtractSubject(String jwtToken) throws AuthenticationException {
try {
SignedJWT signedJwt = SignedJWT.parse(jwtToken);
JWSHeader header = signedJwt.getHeader();
JWTClaimsSet claims = signedJwt.getJWTClaimsSet();

Date now = new Date();
Date expiration = claims.getExpirationTime();
if (expiration != null && expiration.before(now)) {
throw new AuthenticationException("Token expired");
}

Date notBefore = claims.getNotBeforeTime();
if (notBefore != null && notBefore.after(now)) {
throw new AuthenticationException("Token not yet valid");
}

String expectedIssuer = getExpectedIssuer();
if (expectedIssuer != null && !expectedIssuer.equals(claims.getIssuer())) {
throw new AuthenticationException("Invalid token issuer");
}

String expectedAudience = oidcConfiguration.getClientId();
List<String> audiences = claims.getAudience();
if (expectedAudience != null && (audiences == null || !audiences.contains(expectedAudience))) {
throw new AuthenticationException("Invalid token audience");
}

JWK jwk = getKey(header.getKeyID());
if (jwk == null) {
throw new AuthenticationException("Signing key not found");
}

if (!signedJwt.verify(createVerifier(jwk))) {
throw new AuthenticationException("Invalid signature");
}

String email = (String) claims.getClaim("email");
return (email != null && !email.isEmpty()) ? email : claims.getSubject();

} catch (ParseException | JOSEException e) {
throw new AuthenticationException("JWT validation failed: " + e.getMessage(), e);
}
}

private String extractName(String jwtToken, String fallback) {
try {
SignedJWT signedJwt = SignedJWT.parse(jwtToken);
String name = (String) signedJwt.getJWTClaimsSet().getClaim("name");
return (name != null && !name.isEmpty()) ? name : fallback;
} catch (ParseException e) {
return fallback;
}
}

private String getExpectedIssuer() {
try {
var resolver = oidcConfiguration.getOpMetadataResolver();
if (resolver != null) {
var metadata = resolver.load();
if (metadata != null) {
return metadata.getIssuer().getValue();
}
}
} catch (Exception e) {
logger.warn("Failed to get OIDC issuer: {}", e.getMessage());
}
return null;
}

private JWK getKey(String kid) {
JWK jwk = keyCache.get(kid);
if (jwk == null) {
long currentTime = System.currentTimeMillis();
if (currentTime - lastJwksFetch > JWKS_CACHE_DURATION_MS) {
synchronized (this) {
if (currentTime - lastJwksFetch > JWKS_CACHE_DURATION_MS) {
refreshJwks();
}
}
jwk = keyCache.get(kid);
}
}
return jwk;
}

private void refreshJwks() {
try {
URI jwksUri = getJwksUri();
if (jwksUri == null) {
logger.error("No JWKS URI available");
return;
}

JWKSet jwkSet = JWKSet.load(jwksUri.toURL());
keyCache.clear();
for (JWK key : jwkSet.getKeys()) {
if (key.getKeyID() != null) {
keyCache.put(key.getKeyID(), key);
}
}
} catch (Exception e) {
logger.error("Failed to fetch JWKS: {}", e.getMessage());
} finally {
lastJwksFetch = System.currentTimeMillis();
}
}

private URI getJwksUri() {
try {
var resolver = oidcConfiguration.getOpMetadataResolver();
if (resolver != null) {
var metadata = resolver.load();
if (metadata != null) {
return metadata.getJWKSetURI();
}
}
} catch (Exception e) {
logger.warn("Failed to get JWKS URI: {}", e.getMessage());
}
return null;
}

private JWSVerifier createVerifier(JWK jwk) throws JOSEException {
if (jwk instanceof ECKey) {
return new ECDSAVerifier(((ECKey) jwk).toECPublicKey());
} else if (jwk instanceof RSAKey) {
return new RSASSAVerifier(((RSAKey) jwk).toRSAPublicKey());
}
throw new JOSEException("Unsupported key type: " + jwk.getKeyType());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,12 @@ protected boolean preHandle(ServletRequest request, ServletResponse response) th

String sessionId = (String) request.getAttribute(Constants.SESSION_ID);
if (sessionId == null) {
final String token = TokenManager.extractToken(request);
if (token != null) {
sessionId = (String) TokenManager.getBody(token).get(Constants.SESSION_ID);
Boolean isOidcToken = (Boolean) request.getAttribute(OidcJwtAuthFilter.OIDC_EXTERNAL_TOKEN);
if (!Boolean.TRUE.equals(isOidcToken)) {
final String token = TokenManager.extractToken(request);
if (token != null) {
sessionId = (String) TokenManager.getBody(token).get(Constants.SESSION_ID);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.ohdsi.webapi.shiro.Entities.UserRepository;
import org.ohdsi.webapi.shiro.PermissionManager;
import org.ohdsi.webapi.shiro.filters.*;
import org.ohdsi.webapi.shiro.filters.OidcJwtAuthFilter;
import org.ohdsi.webapi.shiro.filters.auth.ActiveDirectoryAuthFilter;
import org.ohdsi.webapi.shiro.filters.auth.AtlasJwtAuthFilter;
import org.ohdsi.webapi.shiro.filters.auth.JdbcAuthFilter;
Expand Down Expand Up @@ -49,8 +50,6 @@
import org.pac4j.oauth.client.Google2Client;
import org.pac4j.oidc.client.OidcClient;
import org.pac4j.oidc.config.OidcConfiguration;
import org.pac4j.oidc.credentials.authenticator.OidcAuthenticator;
import org.pac4j.http.client.direct.DirectBearerAuthClient;
import org.pac4j.saml.client.SAML2Client;
import org.pac4j.saml.config.SAML2Configuration;
import org.slf4j.Logger;
Expand Down Expand Up @@ -331,17 +330,18 @@ public Map<FilterTemplates, Filter> getFilters() {
clients.add(githubClient);
}

OidcConfiguration oidcConfiguration = null;
if (this.openidAuthEnabled) {
OidcConfiguration configuration = oidcConfCreator.build();
if (StringUtils.isNotBlank(configuration.getClientId())) {
oidcConfiguration = oidcConfCreator.build();
if (StringUtils.isNotBlank(oidcConfiguration.getClientId())) {
// https://www.pac4j.org/4.0.x/docs/clients/openid-connect.html
// OidcClient allows indirect login through UI with code flow
OidcClient oidcClient = new OidcClient(configuration);
OidcClient oidcClient = new OidcClient(oidcConfiguration);
oidcClient.setCallbackUrl(oauthApiCallback);
oidcClient.setCallbackUrlResolver(urlResolver);

// URL rewriting: discovery from internal URL, redirect to external URL
String internalUrl = configuration.getDiscoveryURI();
String internalUrl = oidcConfiguration.getDiscoveryURI();
String externalUrl = oidcConfCreator.getExternalUrl();
if (externalUrl != null && !externalUrl.isEmpty()) {
org.ohdsi.webapi.shiro.filters.ExternalUrlOidcRedirectionActionBuilder redirectBuilder =
Expand All @@ -353,12 +353,6 @@ public Map<FilterTemplates, Filter> getFilters() {

// Configuration already initialized; pac4j handles lazy init
clients.add(oidcClient);

// Bearer token authentication for API access (pac4j 6.x)
// OidcAuthenticator requires both configuration and client
OidcAuthenticator authenticator = new OidcAuthenticator(configuration, oidcClient);
DirectBearerAuthClient bearerClient = new DirectBearerAuthClient(authenticator);
clients.add(bearerClient);
} else {
logger.warn("openidAuth is enabled but no client id is provided");
}
Expand Down Expand Up @@ -405,11 +399,6 @@ public Map<FilterTemplates, Filter> getFilters() {
oidcFilter.setConfig(cfg);
oidcFilter.setClients("OidcClient");
filters.put(OIDC_AUTH, oidcFilter);

SecurityFilter oidcDirectFilter = new SecurityFilter();
oidcDirectFilter.setConfig(cfg);
oidcDirectFilter.setClients("HeaderClient");
filters.put(OIDC_DIRECT_AUTH, oidcDirectFilter);
}

io.buji.pac4j.filter.CallbackFilter callbackFilter = new io.buji.pac4j.filter.CallbackFilter();
Expand All @@ -427,6 +416,17 @@ public Map<FilterTemplates, Filter> getFilters() {
filters.put(HANDLE_UNSUCCESSFUL_OAUTH, new RedirectOnFailedOAuthFilter(this.oauthUiCallback));
}

// OIDC token exchange filter
if (this.openidAuthEnabled && oidcConfiguration != null) {
OidcJwtAuthFilter oidcJwtFilter = new OidcJwtAuthFilter(
oidcConfiguration,
this.authorizer,
this.defaultRoles,
this.tokenExpirationIntervalInSeconds
);
filters.put(OIDC_DIRECT_AUTH, oidcJwtFilter);
}

if (this.casAuthEnabled) {
this.setUpCAS(filters);
}
Expand All @@ -440,8 +440,12 @@ public Map<FilterTemplates, Filter> getFilters() {
@Override
protected FilterChainBuilder getFilterChainBuilder() {

List<FilterTemplates> authcFilters = googleAccessTokenEnabled ? Arrays.asList(ACCESS_AUTHC, JWT_AUTHC) :
Collections.singletonList(JWT_AUTHC);
// Build authentication filter chain: try JWT first, then OIDC if enabled
List<FilterTemplates> authcFilters = new ArrayList<>();
if (googleAccessTokenEnabled) {
authcFilters.add(ACCESS_AUTHC);
}
authcFilters.add(JWT_AUTHC);
// the order does matter - first match wins
FilterChainBuilder filterChainBuilder = new FilterChainBuilder()
.setRestFilters(SSL, NO_SESSION_CREATION, CORS, NO_CACHE)
Expand Down