From 8537cac991baf495ec5d5846e51742be67abd6a0 Mon Sep 17 00:00:00 2001 From: Ray Lee Date: Mon, 7 Oct 2024 14:19:00 -0400 Subject: [PATCH] DRYD-1518: Associate a CSpace user to a SAML login using an ID from the assertion. (#427) --- .../src/main/resources/log4j2.xml | 1 + .../db/postgresql/authentication.sql | 5 + .../CSpaceAuthenticationSuccessEvent.java | 68 +++++- .../authentication/CSpaceUser.java | 18 ++ .../jackson2/CSpaceUserDeserializer.java | 3 +- ...l2AuthenticatedCSpaceUserDeserializer.java | 3 +- .../authentication/realm/CSpaceRealm.java | 29 ++- .../realm/db/CSpaceDbRealm.java | 229 ++++++++++++++---- .../spring/CSpaceUserDetailsService.java | 35 ++- .../spring/Saml2AuthenticatedCSpaceUser.java | 4 +- .../services/authorization/AuthZ.java | 2 +- ...eSaml2ResponseAuthenticationConverter.java | 145 +++++++++++ .../common/security/SecurityConfig.java | 51 +--- .../common/security/SecurityUtils.java | 85 ++++++- .../src/main/resources/service-config.xsd | 14 +- 15 files changed, 577 insertions(+), 115 deletions(-) create mode 100644 services/common/src/main/java/org/collectionspace/services/common/security/CSpaceSaml2ResponseAuthenticationConverter.java diff --git a/services/JaxRsServiceProvider/src/main/resources/log4j2.xml b/services/JaxRsServiceProvider/src/main/resources/log4j2.xml index a2e7ee8d2..a6db5100a 100644 --- a/services/JaxRsServiceProvider/src/main/resources/log4j2.xml +++ b/services/JaxRsServiceProvider/src/main/resources/log4j2.xml @@ -34,6 +34,7 @@ + diff --git a/services/authentication/pstore/src/main/resources/db/postgresql/authentication.sql b/services/authentication/pstore/src/main/resources/db/postgresql/authentication.sql index 4e4df2240..3472b5204 100644 --- a/services/authentication/pstore/src/main/resources/db/postgresql/authentication.sql +++ b/services/authentication/pstore/src/main/resources/db/postgresql/authentication.sql @@ -4,6 +4,7 @@ CREATE TABLE IF NOT EXISTS users ( lastlogin TIMESTAMP, passwd VARCHAR(128) NOT NULL, salt VARCHAR(128), + sso_id VARCHAR(512), updated_at TIMESTAMP ); @@ -25,6 +26,10 @@ SET passwd = concat( ) WHERE left(passwd, 1) <> '{'; +-- Upgrade older users tables to 8.1 + +ALTER TABLE users ADD COLUMN IF NOT EXISTS sso_id VARCHAR(512); + -- Create tokens table required in 8.0 CREATE TABLE IF NOT EXISTS tokens ( diff --git a/services/authentication/service/src/main/java/org/collectionspace/authentication/CSpaceAuthenticationSuccessEvent.java b/services/authentication/service/src/main/java/org/collectionspace/authentication/CSpaceAuthenticationSuccessEvent.java index 363719df3..078a6ef34 100644 --- a/services/authentication/service/src/main/java/org/collectionspace/authentication/CSpaceAuthenticationSuccessEvent.java +++ b/services/authentication/service/src/main/java/org/collectionspace/authentication/CSpaceAuthenticationSuccessEvent.java @@ -9,6 +9,7 @@ import javax.security.auth.login.AccountException; import javax.security.auth.login.AccountNotFoundException; import org.collectionspace.authentication.realm.db.CSpaceDbRealm; +import org.collectionspace.authentication.spring.CSpaceSaml2Authentication; import org.postgresql.util.PSQLState; import org.springframework.context.ApplicationListener; import org.springframework.security.authentication.event.AuthenticationSuccessEvent; @@ -18,7 +19,10 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtAut public class CSpaceAuthenticationSuccessEvent implements ApplicationListener { - private static final String UPDATE_USER_SQL = + private static final String UPDATE_USER_SSO_ID_SQL = + "UPDATE users SET sso_id = ? WHERE username = ?"; + + private static final String UPDATE_USER_LAST_LOGIN_SQL = "UPDATE users SET lastlogin = now() WHERE username = ?"; private static final String DELETE_EXPIRED_AUTHORIZATIONS_SQL = @@ -42,6 +46,14 @@ public class CSpaceAuthenticationSuccessEvent implements ApplicationListener tenants; private CSpaceTenant primaryTenant; private boolean requireSSO; + private String ssoId; private String salt; /** @@ -47,6 +48,7 @@ public class CSpaceUser extends User { * @param authorities the authorities that have been granted to the user */ public CSpaceUser(String username, String password, String salt, + String ssoId, boolean requireSSO, Set tenants, Set authorities) { @@ -59,6 +61,7 @@ public class CSpaceUser extends User { authorities); this.tenants = tenants; + this.ssoId = ssoId; this.requireSSO = requireSSO; this.salt = salt; @@ -93,6 +96,21 @@ public class CSpaceUser extends User { return salt != null ? salt : ""; } + /** + * Returns the ID from the user's SSO provider, if the user signed in via SSO + * @return the SSO ID + */ + public String getSsoId() { + return ssoId; + } + + /** + * Sets the ID from the user's SSO provider. + */ + public void setSsoId(String ssoId) { + this.ssoId = ssoId; + } + /** * Determines if the user is required to log in using single sign-on. * @return true if SSO is required, false otherwise diff --git a/services/authentication/service/src/main/java/org/collectionspace/authentication/jackson2/CSpaceUserDeserializer.java b/services/authentication/service/src/main/java/org/collectionspace/authentication/jackson2/CSpaceUserDeserializer.java index 70061cc39..71285bc48 100644 --- a/services/authentication/service/src/main/java/org/collectionspace/authentication/jackson2/CSpaceUserDeserializer.java +++ b/services/authentication/service/src/main/java/org/collectionspace/authentication/jackson2/CSpaceUserDeserializer.java @@ -35,10 +35,11 @@ public class CSpaceUserDeserializer extends JsonDeserializer { JsonNode passwordNode = readJsonNode(jsonNode, "password"); String username = readJsonNode(jsonNode, "username").asText(); String password = passwordNode.asText(""); + String ssoId = readJsonNode(jsonNode, "ssoId").asText(); boolean requireSSO = readJsonNode(jsonNode, "requireSSO").asBoolean(); String salt = readJsonNode(jsonNode, "salt").asText(); - CSpaceUser result = new CSpaceUser(username, password, salt, requireSSO, tenants, authorities); + CSpaceUser result = new CSpaceUser(username, password, salt, ssoId, requireSSO, tenants, authorities); if (passwordNode.asText(null) == null) { result.eraseCredentials(); diff --git a/services/authentication/service/src/main/java/org/collectionspace/authentication/jackson2/Saml2AuthenticatedCSpaceUserDeserializer.java b/services/authentication/service/src/main/java/org/collectionspace/authentication/jackson2/Saml2AuthenticatedCSpaceUserDeserializer.java index 473838b8f..6286ae7d9 100644 --- a/services/authentication/service/src/main/java/org/collectionspace/authentication/jackson2/Saml2AuthenticatedCSpaceUserDeserializer.java +++ b/services/authentication/service/src/main/java/org/collectionspace/authentication/jackson2/Saml2AuthenticatedCSpaceUserDeserializer.java @@ -37,10 +37,11 @@ public class Saml2AuthenticatedCSpaceUserDeserializer extends JsonDeserializer getTenants(String username, boolean includeDisabledTenants) throws AccountException; + /** + * Retrieves the ID from the SSO provider, if the user is associated with one. + * + * @param username + * @return the ID from the SSO provider, or null + * @throws AccountException + */ + public String getSsoId(String username) throws AccountException; + /** * Determines if the user is required to login using single sign-on. - * + * * @param username * @return true if SSO is required, false otherwise * @throws AccountException diff --git a/services/authentication/service/src/main/java/org/collectionspace/authentication/realm/db/CSpaceDbRealm.java b/services/authentication/service/src/main/java/org/collectionspace/authentication/realm/db/CSpaceDbRealm.java index 2fef70989..5b685b646 100644 --- a/services/authentication/service/src/main/java/org/collectionspace/authentication/realm/db/CSpaceDbRealm.java +++ b/services/authentication/service/src/main/java/org/collectionspace/authentication/realm/db/CSpaceDbRealm.java @@ -74,16 +74,18 @@ import org.slf4j.LoggerFactory; /** * CSpaceDbRealm provides access to user, password, role, tenant database - * @author + * @author */ public class CSpaceDbRealm implements CSpaceRealm { public static String DEFAULT_DATASOURCE_NAME = "CspaceDS"; - + private Logger logger = LoggerFactory.getLogger(CSpaceDbRealm.class); - + private String datasourceName; + private String usernameForSsoIdQuery; private String principalsQuery; private String saltQuery; + private String ssoIdQuery; private String requireSSOQuery; private String rolesQuery; private String tenantsQueryNoDisabled; @@ -97,7 +99,7 @@ public class CSpaceDbRealm implements CSpaceRealm { private long delayBetweenAttemptsMillis = DELAY_BETWEEN_ATTEMPTS_MILLISECONDS; private static final String DELAY_BETWEEN_ATTEMPTS_MILLISECONDS_STR = "delayBetweenAttemptsMillis"; private static final long DELAY_BETWEEN_ATTEMPTS_MILLISECONDS = 200; - + protected void setMaxRetrySeconds(Map options) { Object optionsObj = options.get(MAX_RETRY_SECONDS_STR); if (optionsObj != null) { @@ -110,11 +112,11 @@ public class CSpaceDbRealm implements CSpaceRealm { } } } - + protected long getMaxRetrySeconds() { return this.maxRetrySeconds; } - + protected void setDelayBetweenAttemptsMillis(Map options) { Object optionsObj = options.get(DELAY_BETWEEN_ATTEMPTS_MILLISECONDS_STR); if (optionsObj != null) { @@ -127,15 +129,15 @@ public class CSpaceDbRealm implements CSpaceRealm { } } } - + protected long getDelayBetweenAttemptsMillis() { return this.delayBetweenAttemptsMillis; } - + public CSpaceDbRealm() { datasourceName = DEFAULT_DATASOURCE_NAME; } - + /** * CSpace Database Realm * @param datasourceName datasource name @@ -145,7 +147,11 @@ public class CSpaceDbRealm implements CSpaceRealm { if (datasourceName == null) { datasourceName = DEFAULT_DATASOURCE_NAME; } - Object tmp = options.get("principalsQuery"); + Object tmp = options.get("usernameForSsoIdQuery"); + if (tmp != null) { + usernameForSsoIdQuery = tmp.toString(); + } + tmp = options.get("principalsQuery"); if (tmp != null) { principalsQuery = tmp.toString(); } @@ -153,6 +159,10 @@ public class CSpaceDbRealm implements CSpaceRealm { if (tmp != null) { saltQuery = tmp.toString(); } + tmp = options.get("ssoIdQuery"); + if (tmp != null) { + ssoIdQuery = tmp.toString(); + } tmp = options.get("requireSSOQuery"); if (tmp != null) { requireSSOQuery = tmp.toString(); @@ -173,10 +183,10 @@ public class CSpaceDbRealm implements CSpaceRealm { if (tmp != null) { suspendResume = Boolean.valueOf(tmp.toString()).booleanValue(); } - + this.setMaxRetrySeconds(options); this.setDelayBetweenAttemptsMillis(options); - + if (logger.isTraceEnabled()) { logger.trace("DatabaseServerLoginModule, dsJndiName=" + datasourceName); logger.trace("principalsQuery=" + principalsQuery); @@ -185,6 +195,66 @@ public class CSpaceDbRealm implements CSpaceRealm { } } + @Override + public String getUsernameForSsoId(String ssoId) throws AccountException { + String username = null; + Connection conn = null; + PreparedStatement ps = null; + ResultSet rs = null; + + try { + conn = getConnection(); + // Get the username + if (logger.isDebugEnabled()) { + logger.debug("Executing query: " + usernameForSsoIdQuery + ", with sso id: " + ssoId); + } + ps = conn.prepareStatement(usernameForSsoIdQuery); + ps.setString(1, ssoId); + rs = ps.executeQuery(); + if (rs.next() == false) { + if (logger.isDebugEnabled()) { + logger.debug(usernameForSsoIdQuery + " returned no matches from db"); + } + throw new AccountNotFoundException("No matching sso id found"); + } + + username = rs.getString(1); + } catch (SQLException ex) { + if (logger.isTraceEnabled() == true) { + logger.error("Could not open database to read AuthN tables.", ex); + } + AccountException ae = new AccountException("Authentication query failed: " + ex.getLocalizedMessage()); + ae.initCause(ex); + throw ae; + } catch (AccountNotFoundException ex) { + throw ex; + } catch (Exception ex) { + AccountException ae = new AccountException("Unknown Exception"); + ae.initCause(ex); + throw ae; + } finally { + if (rs != null) { + try { + rs.close(); + } catch (SQLException e) { + } + } + if (ps != null) { + try { + ps.close(); + } catch (SQLException e) { + } + } + if (conn != null) { + try { + conn.close(); + } catch (SQLException ex) { + } + } + } + return username; + } + @Override public String getPassword(String username) throws AccountException { @@ -275,14 +345,14 @@ public class CSpaceDbRealm implements CSpaceRealm { if (logger.isDebugEnabled()) { logger.debug("No roles found"); } - + return roles; } do { String roleName = rs.getString(1); roles.add(roleName); - + } while (rs.next()); } catch (SQLException ex) { AccountException ae = new AccountException("Query failed"); @@ -321,7 +391,7 @@ public class CSpaceDbRealm implements CSpaceRealm { public Set getTenants(String username) throws AccountException { return getTenants(username, false); } - + private boolean userIsTenantManager(Connection conn, String username) { String acctQuery = "SELECT csid FROM accounts_common WHERE userid=?"; PreparedStatement ps = null; @@ -361,7 +431,7 @@ public class CSpaceDbRealm implements CSpaceRealm { } return accountIsTenantManager; } - + /** * Execute the tenantsQuery against the datasourceName to obtain the tenants for * the authenticated user. @@ -371,13 +441,13 @@ public class CSpaceDbRealm implements CSpaceRealm { public Set getTenants(String username, boolean includeDisabledTenants) throws AccountException { String tenantsQuery = getTenantQuery(includeDisabledTenants); - + if (logger.isDebugEnabled()) { logger.debug("getTenants using tenantsQuery: " + tenantsQuery + ", username: " + username); } Set tenants = new LinkedHashSet(); - + Connection conn = null; PreparedStatement ps = null; ResultSet rs = null; @@ -398,7 +468,7 @@ public class CSpaceDbRealm implements CSpaceRealm { if (logger.isDebugEnabled()) { logger.debug("GetTenants called with tenantManager - synthesizing the pseudo-tenant"); } - + tenants.add(new CSpaceTenant(AuthN.TENANT_MANAGER_ACCT_ID, "PseudoTenant")); } else { if (logger.isDebugEnabled()) { @@ -408,7 +478,7 @@ public class CSpaceDbRealm implements CSpaceRealm { // empty Tenants set. // FIXME should this be allowed? } - + return tenants; } @@ -466,7 +536,7 @@ public class CSpaceDbRealm implements CSpaceRealm { if (requestAttempts > 0) { Thread.sleep(getDelayBetweenAttemptsMillis()); // Wait a little time between reattempts. } - + try { // proceed to the original request by calling doFilter() result = this.getConnection(getDataSourceName()); @@ -487,7 +557,7 @@ public class CSpaceDbRealm implements CSpaceRealm { requestAttempts++; // keep track of how many times we've tried the request } } while (System.currentTimeMillis() < quittingTime); // keep trying until we run out of time - + // // Add a warning to the logs if we encountered *any* failures on our re-attempts. Only add the warning // if we were eventually successful. @@ -503,10 +573,10 @@ public class CSpaceDbRealm implements CSpaceRealm { // If we get here, it means all of our attempts to get a successful call to chain.doFilter() have failed. throw lastException; } - + return result; } - + /* * Don't call this method directly. Instead, use the getConnection() method that take no arguments. */ @@ -514,52 +584,52 @@ public class CSpaceDbRealm implements CSpaceRealm { InitialContext ctx = null; Connection conn = null; DataSource ds = null; - + try { ctx = new InitialContext(); try { ds = (DataSource) ctx.lookup(dataSourceName); } catch (Exception e) {} - + try { Context envCtx = (Context) ctx.lookup("java:comp/env"); ds = (DataSource) envCtx.lookup(dataSourceName); } catch (Exception e) {} - + try { Context envCtx = (Context) ctx.lookup("java:comp"); ds = (DataSource) envCtx.lookup(dataSourceName); } catch (Exception e) {} - + try { Context envCtx = (Context) ctx.lookup("java:"); ds = (DataSource) envCtx.lookup(dataSourceName); } catch (Exception e) {} - + try { Context envCtx = (Context) ctx.lookup("java"); ds = (DataSource) envCtx.lookup(dataSourceName); } catch (Exception e) {} - + try { ds = (DataSource) ctx.lookup("java:/" + dataSourceName); - } catch (Exception e) {} + } catch (Exception e) {} if (ds == null) { ds = AuthN.getDataSource(); } - + if (ds == null) { throw new IllegalArgumentException("datasource not found: " + dataSourceName); } - + conn = ds.getConnection(); if (conn == null) { conn = AuthN.getDataSource().getConnection(); //FIXME:REM - This is the result of some type of JNDI mess. Should try to solve this problem and clean up this code. } - + return conn; - + } catch (NamingException ex) { AccountException ae = new AccountException("Error looking up DataSource from: " + dataSourceName); ae.initCause(ex); @@ -583,6 +653,20 @@ public class CSpaceDbRealm implements CSpaceRealm { return datasourceName; } + /** + * @return the usernameForSsoIdQuery + */ + public String getUsernameForSsoIdQuery() { + return usernameForSsoIdQuery; + } + + /** + * @param usernameForSsoIdQuery the usernameForSsoIdQuery to set + */ + public void setUsernameForSsoIdQuery(String usernameForSsoIdQuery) { + this.usernameForSsoIdQuery = usernameForSsoIdQuery; + } + /** * @return the principalQuery */ @@ -624,7 +708,7 @@ public class CSpaceDbRealm implements CSpaceRealm { this.tenantsQueryNoDisabled = tenantQuery; } */ - + /* * This method crawls the exception chain looking for network related exceptions and * returns 'true' if it finds one. @@ -638,13 +722,13 @@ public class CSpaceDbRealm implements CSpaceRealm { result = true; break; } - + cause = cause.getCause(); } return result; } - + /* * Return 'true' if the exception is in the "java.net" package. */ @@ -718,10 +802,73 @@ public class CSpaceDbRealm implements CSpaceRealm { } } } - + return salt; } + @Override + public String getSsoId(String username) throws AccountException { + String ssoId = null; + Connection conn = null; + PreparedStatement ps = null; + ResultSet rs = null; + try { + conn = getConnection(); + // Get the SSO ID + if (logger.isDebugEnabled()) { + logger.debug("Executing query: " + ssoIdQuery + ", with username: " + username); + } + ps = conn.prepareStatement(ssoIdQuery); + ps.setString(1, username); + rs = ps.executeQuery(); + if (rs.next() == false) { + if (logger.isDebugEnabled()) { + logger.debug(ssoIdQuery + " returned no matches from db"); + } + throw new AccountNotFoundException("No matching username found"); + } + + ssoId = rs.getString(1); + } catch (SQLException ex) { + // Assuming PostgreSQL + if (PSQLState.UNDEFINED_COLUMN.getState().equals(ex.getSQLState())) { + String msg = "'users' table is missing 'sso_id' column."; + logger.warn(msg); + } else { + AccountException ae = new AccountException("Authentication query failed: " + ex.getLocalizedMessage()); + ae.initCause(ex); + throw ae; + } + } catch (AccountNotFoundException ex) { + throw ex; + } catch (Exception ex) { + AccountException ae = new AccountException("Unknown Exception"); + ae.initCause(ex); + throw ae; + } finally { + if (rs != null) { + try { + rs.close(); + } catch (SQLException e) { + } + } + if (ps != null) { + try { + ps.close(); + } catch (SQLException e) { + } + } + if (conn != null) { + try { + conn.close(); + } catch (SQLException ex) { + } + } + } + + return ssoId; + } + @Override public boolean isRequireSSO(String username) throws AccountException { Boolean requireSSO = null; @@ -759,7 +906,7 @@ public class CSpaceDbRealm implements CSpaceRealm { AccountException ae = new AccountException("Authentication query failed: " + ex.getLocalizedMessage()); ae.initCause(ex); - + throw ae; } catch (AccountNotFoundException ex) { throw ex; @@ -767,7 +914,7 @@ public class CSpaceDbRealm implements CSpaceRealm { AccountException ae = new AccountException("Unknown Exception"); ae.initCause(ex); - + throw ae; } finally { if (rs != null) { diff --git a/services/authentication/service/src/main/java/org/collectionspace/authentication/spring/CSpaceUserDetailsService.java b/services/authentication/service/src/main/java/org/collectionspace/authentication/spring/CSpaceUserDetailsService.java index a0974dc77..cb9344dff 100644 --- a/services/authentication/service/src/main/java/org/collectionspace/authentication/spring/CSpaceUserDetailsService.java +++ b/services/authentication/service/src/main/java/org/collectionspace/authentication/spring/CSpaceUserDetailsService.java @@ -49,17 +49,33 @@ public class CSpaceUserDetailsService implements UserDetailsService { this.realm = realm; } + public UserDetails loadUserBySsoId(String ssoId) throws UsernameNotFoundException { + try { + String username = realm.getUsernameForSsoId(ssoId); + + return loadUserByUsername(username); + } + catch (AccountNotFoundException e) { + throw new UsernameNotFoundException(e.getMessage(), e); + } + catch (AccountException e) { + throw new AuthenticationServiceException(e.getMessage(), e); + } + } + @Override public UserDetails loadUserByUsername(String username) throws UsernameNotFoundException { String password = null; String salt = null; + String ssoId = null; Boolean requireSSO = null; Set tenants = null; Set grantedAuthorities = null; - + try { password = realm.getPassword(username); salt = realm.getSalt(username); + ssoId = realm.getSsoId(username); requireSSO = realm.isRequireSSO(username); tenants = getTenants(username); if (tenants == null || tenants.isEmpty()) { @@ -75,33 +91,34 @@ public class CSpaceUserDetailsService implements UserDetailsService { catch (AccountException e) { throw new AuthenticationServiceException(e.getMessage(), e); } - - CSpaceUser cspaceUser = + + CSpaceUser cspaceUser = new CSpaceUser( username, password, salt, + ssoId, requireSSO, tenants, grantedAuthorities); - + return cspaceUser; } - + protected Set getAuthorities(String username) throws AccountException { Set roles = realm.getRoles(username); Set authorities = new LinkedHashSet(roles.size()); - + for (String role : roles) { authorities.add(new SimpleGrantedAuthority(role)); } - + return authorities; } - + protected Set getTenants(String username) throws AccountException { Set tenants = realm.getTenants(username); - + return tenants; } } diff --git a/services/authentication/service/src/main/java/org/collectionspace/authentication/spring/Saml2AuthenticatedCSpaceUser.java b/services/authentication/service/src/main/java/org/collectionspace/authentication/spring/Saml2AuthenticatedCSpaceUser.java index bfe263508..ecabed601 100644 --- a/services/authentication/service/src/main/java/org/collectionspace/authentication/spring/Saml2AuthenticatedCSpaceUser.java +++ b/services/authentication/service/src/main/java/org/collectionspace/authentication/spring/Saml2AuthenticatedCSpaceUser.java @@ -37,6 +37,7 @@ public class Saml2AuthenticatedCSpaceUser extends CSpaceUser implements Saml2Aut user.getUsername(), user.getPassword(), user.getSalt(), + user.getSsoId(), user.isRequireSSO(), user.getTenants(), (Set) user.getAuthorities() @@ -48,11 +49,12 @@ public class Saml2AuthenticatedCSpaceUser extends CSpaceUser implements Saml2Aut String username, String password, String salt, + String ssoId, boolean requireSSO, Set tenants, Set authorities ) { - super(username, password, salt, requireSSO, tenants, authorities); + super(username, password, salt, ssoId, requireSSO, tenants, authorities); this.principal = principal; } diff --git a/services/authorization/service/src/main/java/org/collectionspace/services/authorization/AuthZ.java b/services/authorization/service/src/main/java/org/collectionspace/services/authorization/AuthZ.java index 4e15399f5..106261134 100644 --- a/services/authorization/service/src/main/java/org/collectionspace/services/authorization/AuthZ.java +++ b/services/authorization/service/src/main/java/org/collectionspace/services/authorization/AuthZ.java @@ -289,7 +289,7 @@ public class AuthZ { HashSet tenantSet = new HashSet(); tenantSet.add(tenant); - CSpaceUser principal = new CSpaceUser(user, password, null, false, tenantSet, grantedAuthorities); + CSpaceUser principal = new CSpaceUser(user, password, null, null, false, tenantSet, grantedAuthorities); Authentication authRequest = new UsernamePasswordAuthenticationToken(principal, password, grantedAuthorities); SecurityContextHolder.getContext().setAuthentication(authRequest); diff --git a/services/common/src/main/java/org/collectionspace/services/common/security/CSpaceSaml2ResponseAuthenticationConverter.java b/services/common/src/main/java/org/collectionspace/services/common/security/CSpaceSaml2ResponseAuthenticationConverter.java new file mode 100644 index 000000000..0a2aa3564 --- /dev/null +++ b/services/common/src/main/java/org/collectionspace/services/common/security/CSpaceSaml2ResponseAuthenticationConverter.java @@ -0,0 +1,145 @@ +package org.collectionspace.services.common.security; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import org.apache.commons.lang3.StringUtils; +import org.collectionspace.authentication.CSpaceUser; +import org.collectionspace.authentication.spring.CSpaceSaml2Authentication; +import org.collectionspace.authentication.spring.CSpaceUserDetailsService; +import org.collectionspace.services.common.config.ConfigUtils; +import org.collectionspace.services.config.AssertionProbesType; +import org.collectionspace.services.config.SAMLRelyingPartyType; +import org.collectionspace.services.config.ServiceConfig; +import org.collectionspace.services.common.ServiceMain; +import org.opensaml.saml.saml2.core.Assertion; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.core.userdetails.UsernameNotFoundException; +import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider; +import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.ResponseToken; +import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication; + +public class CSpaceSaml2ResponseAuthenticationConverter implements Converter { + private final Logger logger = LoggerFactory.getLogger(CSpaceSaml2ResponseAuthenticationConverter.class); + + private CSpaceUserDetailsService userDetailsService; + + public CSpaceSaml2ResponseAuthenticationConverter(CSpaceUserDetailsService userDetailsService) { + this.userDetailsService = userDetailsService; + } + + @Override + public CSpaceSaml2Authentication convert(ResponseToken responseToken) { + Saml2Authentication authentication = OpenSamlAuthenticationProvider + .createDefaultResponseAuthenticationConverter() + .convert(responseToken); + + String registrationId = responseToken.getToken().getRelyingPartyRegistration().getRegistrationId(); + ServiceConfig serviceConfig = ServiceMain.getInstance().getServiceConfig(); + SAMLRelyingPartyType relyingPartyRegistration = ConfigUtils.getSAMLRelyingPartyRegistration(serviceConfig, registrationId); + CSpaceUser user = findUser(relyingPartyRegistration, responseToken); + + if (user != null) { + return new CSpaceSaml2Authentication(user, authentication); + } + + return null; + } + + /** + * Attempt to find a CSpace user for a SAML response. + * + * @param relyingPartyRegistration + * @param responseToken + * @return + */ + private CSpaceUser findUser(SAMLRelyingPartyType relyingPartyRegistration, ResponseToken responseToken) { + AssertionProbesType assertionSsoIdProbes = ( + relyingPartyRegistration != null + ? relyingPartyRegistration.getAssertionSsoIdProbes() + : null + ); + + AssertionProbesType assertionUsernameProbes = ( + relyingPartyRegistration != null + ? relyingPartyRegistration.getAssertionUsernameProbes() + : null + ); + + List attemptedUsernames = new ArrayList<>(); + List assertions = responseToken.getResponse().getAssertions(); + + SecurityUtils.logSamlAssertions(assertions); + + for (Assertion assertion : assertions) { + CSpaceUser user = null; + String ssoId = SecurityUtils.getSamlAssertionSsoId(assertion, assertionSsoIdProbes); + + // First, look for a CSpace user whose SSO ID is the ID in the assertion. + + if (ssoId != null) { + try { + user = (CSpaceUser) userDetailsService.loadUserBySsoId(ssoId); + } + catch (UsernameNotFoundException e) { + } + } + + if (user != null) { + return user; + } + + // Next, look for a CSpace user whose username is the email address in the assertion. + + Set candidateUsernames = SecurityUtils.findSamlAssertionCandidateUsernames(assertion, assertionUsernameProbes); + + for (String candidateUsername : candidateUsernames) { + try { + user = (CSpaceUser) userDetailsService.loadUserByUsername(candidateUsername); + + if (user != null) { + String expectedSsoId = user.getSsoId(); + + if (expectedSsoId == null) { + // Store the ID from the IdP to use in future log ins. Note that this does not save + // the SSO ID to the database. That happens in CSpaceAuthenticationSuccessEvent. + + user.setSsoId(ssoId); + + // TODO: If the email address in the assertion differs from the CSpace user's email, + // update the CSpace user. + } else if (!StringUtils.equals(expectedSsoId, ssoId)) { + // If the user previously logged in via SSO, but they had a different ID from the + // IdP, something's wrong. (Did an account on the IdP get assigned an email that + // previously belonged to a different account on the IdP?) + + logger.warn("User with username {} has expected SSO ID {}, but received {} in SAML assertion", + candidateUsername, expectedSsoId, ssoId); + + user = null; + } + + if (user != null) { + return user; + } + } + } + catch(UsernameNotFoundException e) { + } + } + + attemptedUsernames.addAll(candidateUsernames); + } + + // No CSpace user was found for this SAML response. + // TODO: Auto-create a CSpace user, using the display name, email address, and ID in the response. + + String errorMessage = attemptedUsernames.size() == 0 + ? "The SAML response did not contain a CollectionSpace username." + : "No CollectionSpace account found for " + StringUtils.join(attemptedUsernames, " / ") + "."; + + throw(new UsernameNotFoundException(errorMessage)); + } +} diff --git a/services/common/src/main/java/org/collectionspace/services/common/security/SecurityConfig.java b/services/common/src/main/java/org/collectionspace/services/common/security/SecurityConfig.java index f37c9a7fe..88d6869a0 100644 --- a/services/common/src/main/java/org/collectionspace/services/common/security/SecurityConfig.java +++ b/services/common/src/main/java/org/collectionspace/services/common/security/SecurityConfig.java @@ -33,7 +33,6 @@ import javax.servlet.http.HttpServletRequest; import javax.sql.DataSource; import org.apache.commons.io.IOUtils; -import org.apache.commons.lang3.StringUtils; import org.collectionspace.authentication.CSpaceUser; import org.collectionspace.authentication.spring.CSpaceDaoAuthenticationProvider; import org.collectionspace.authentication.spring.CSpaceJwtAuthenticationToken; @@ -48,7 +47,6 @@ import org.collectionspace.services.common.ServiceMain; import org.collectionspace.services.common.config.ConfigUtils; import org.collectionspace.services.common.config.TenantBindingConfigReaderImpl; import org.collectionspace.services.config.AssertingPartyDetailsType; -import org.collectionspace.services.config.AssertionProbesType; import org.collectionspace.services.config.OAuthAuthorizationGrantTypeEnum; import org.collectionspace.services.config.OAuthClientAuthenticationMethodEnum; import org.collectionspace.services.config.OAuthClientSettingsType; @@ -63,7 +61,6 @@ import org.collectionspace.services.config.X509CertificateType; import org.collectionspace.services.config.X509CredentialType; import org.collectionspace.services.config.tenant.TenantBindingType; import org.collectionspace.authentication.realm.db.CSpaceDbRealm; -import org.opensaml.saml.saml2.core.Assertion; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.context.ApplicationEventPublisher; @@ -116,7 +113,6 @@ import org.springframework.security.saml2.core.Saml2X509Credential; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider; import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication; import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest; -import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.ResponseToken; import org.springframework.security.saml2.provider.service.metadata.OpenSamlMetadataResolver; import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; @@ -133,7 +129,6 @@ import org.springframework.security.saml2.provider.service.web.authentication.lo import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint; import org.springframework.security.web.authentication.logout.LogoutFilter; -import org.springframework.security.web.context.SecurityContextPersistenceFilter; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.OrRequestMatcher; import org.springframework.web.cors.CorsConfiguration; @@ -535,48 +530,8 @@ public class SecurityConfig { // TODO: Use OpenSaml4AuthenticationProvider (requires Java 11) instead of deprecated OpenSamlAuthenticationProvider. final OpenSamlAuthenticationProvider samlAuthenticationProvider = new OpenSamlAuthenticationProvider(); - samlAuthenticationProvider.setResponseAuthenticationConverter(new Converter() { - @Override - public CSpaceSaml2Authentication convert(ResponseToken responseToken) { - Saml2Authentication authentication = OpenSamlAuthenticationProvider - .createDefaultResponseAuthenticationConverter() - .convert(responseToken); - - String registrationId = responseToken.getToken().getRelyingPartyRegistration().getRegistrationId(); - ServiceConfig serviceConfig = ServiceMain.getInstance().getServiceConfig(); - SAMLRelyingPartyType registration = ConfigUtils.getSAMLRelyingPartyRegistration(serviceConfig, registrationId); - - AssertionProbesType assertionProbes = ( - registration != null - ? registration.getAssertionUsernameProbes() - : null - ); - - List attemptedUsernames = new ArrayList<>(); - - for (Assertion assertion : responseToken.getResponse().getAssertions()) { - Set candidateUsernames = SecurityUtils.findSamlAssertionCandidateUsernames(assertion, assertionProbes); - - for (String candidateUsername : candidateUsernames) { - try { - CSpaceUser user = (CSpaceUser) userDetailsService.loadUserByUsername(candidateUsername); - - return new CSpaceSaml2Authentication(user, authentication); - } - catch(UsernameNotFoundException e) { - } - } - - attemptedUsernames.addAll(candidateUsernames); - } - - String errorMessage = attemptedUsernames.size() == 0 - ? "The SAML assertion did not contain a CollectionSpace username." - : "No CollectionSpace account found for " + StringUtils.join(attemptedUsernames, " / ") + "."; - - throw(new UsernameNotFoundException(errorMessage)); - } - }); + samlAuthenticationProvider.setResponseAuthenticationConverter( + new CSpaceSaml2ResponseAuthenticationConverter((CSpaceUserDetailsService) userDetailsService)); http .saml2Login(new Customizer>() { @@ -956,8 +911,10 @@ public class SecurityConfig { Map options = new HashMap(); options.put("dsJndiName", "CspaceDS"); + options.put("usernameForSsoIdQuery", "select username from users where sso_id=?"); options.put("principalsQuery", "select passwd from users where username=?"); options.put("saltQuery", "select salt from users where username=?"); + options.put("ssoIdQuery", "select sso_id from users where username=?"); options.put("requireSSOQuery", "select require_sso from accounts_common where userid=?"); options.put("rolesQuery", "select r.rolename from roles as r, accounts_roles as ar where ar.user_id=? and ar.role_id=r.csid"); options.put("tenantsQueryWithDisabled", "select t.id, t.name from accounts_common as a, accounts_tenants as at, tenants as t where a.userid=? and a.csid = at.TENANTS_ACCOUNTS_COMMON_CSID and at.tenant_id = t.id order by t.id"); diff --git a/services/common/src/main/java/org/collectionspace/services/common/security/SecurityUtils.java b/services/common/src/main/java/org/collectionspace/services/common/security/SecurityUtils.java index ab89e8917..adc275388 100644 --- a/services/common/src/main/java/org/collectionspace/services/common/security/SecurityUtils.java +++ b/services/common/src/main/java/org/collectionspace/services/common/security/SecurityUtils.java @@ -73,13 +73,14 @@ public class SecurityUtils { public static final String BASE16_ENCODING = "HEX"; public static final String RFC2617_ENCODING = "RFC2617"; + private static final List DEFAULT_SAML_ASSERTION_SSO_ID_PROBES = new ArrayList<>(); private static final List DEFAULT_SAML_ASSERTION_USERNAME_PROBES = new ArrayList<>(); static { - DEFAULT_SAML_ASSERTION_USERNAME_PROBES.add(new AssertionNameIDProbeType()); + DEFAULT_SAML_ASSERTION_SSO_ID_PROBES.add(new AssertionNameIDProbeType()); String[] attributeNames = new String[]{ - "urn:oid:0.9.2342.19200300.100.1.3", + "urn:oid:0.9.2342.19200300.100.1.3", // https://www.educause.edu/fidm/attributes "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress", "email", "mail" @@ -363,7 +364,7 @@ public class SecurityUtils { if (probe instanceof AssertionNameIDProbeType) { String subjectNameID = assertion.getSubject().getNameID().getValue(); - if (subjectNameID != null && subjectNameID.contains("@")) { + if (subjectNameID != null) { candidateUsernames.add(subjectNameID); } } else if (probe instanceof AssertionAttributeProbeType) { @@ -376,7 +377,53 @@ public class SecurityUtils { } } - return candidateUsernames; + // Filter out values that don't look like an email. + + Set filteredCandidateUsernames = new LinkedHashSet<>(); + + for (String username : candidateUsernames) { + if (username.contains("@")) { + filteredCandidateUsernames.add(username); + } + } + + return filteredCandidateUsernames; + } + + /* + * Retrieve the SSO ID from a SAML assertion. + */ + public static String getSamlAssertionSsoId(Assertion assertion, AssertionProbesType assertionProbes) { + List probes = null; + + if (assertionProbes != null) { + probes = assertionProbes.getNameIdOrAttribute(); + } + + if (probes == null || probes.size() == 0) { + probes = DEFAULT_SAML_ASSERTION_SSO_ID_PROBES; + } + + for (Object probe : probes) { + String ssoId = null; + + if (probe instanceof AssertionNameIDProbeType) { + ssoId = assertion.getSubject().getNameID().getValue(); + } else if (probe instanceof AssertionAttributeProbeType) { + String attributeName = ((AssertionAttributeProbeType) probe).getName(); + List values = getSamlAssertionAttributeValues(assertion, attributeName); + + if (values != null && values.size() > 0) { + ssoId = values.get(0); + } + } + + if (ssoId != null) { + return ssoId; + } + } + + return null; } private static List getSamlAssertionAttributeValues(Assertion assertion, String attributeName) { @@ -395,7 +442,7 @@ public class SecurityUtils { XSString stringValue = (XSString) value; String candidateValue = stringValue.getValue(); - if (candidateValue != null && candidateValue.contains("@")) { + if (candidateValue != null) { values.add(candidateValue); } } @@ -407,4 +454,32 @@ public class SecurityUtils { return values; } + + public static void logSamlAssertions(List assertions) { + logger.info("Received {} SAML assertion(s)", assertions.size()); + + for (Assertion assertion : assertions) { + String nameId = assertion.getSubject().getNameID().getValue(); + + logger.info("NameID: {}", nameId); + + for (AttributeStatement statement : assertion.getAttributeStatements()) { + for (Attribute attribute : statement.getAttributes()) { + String attributeName = attribute.getName(); + List stringValues = new ArrayList<>(); + List attributeValues = attribute.getAttributeValues(); + + if (attributeValues != null) { + for (XMLObject value : attributeValues) { + if (value instanceof XSString) { + stringValues.add(((XSString) value).getValue()); + } + } + } + + logger.info("Attribute: {}={}", attributeName, stringValues); + } + } + } + } } diff --git a/services/config/src/main/resources/service-config.xsd b/services/config/src/main/resources/service-config.xsd index f67f4a9b1..62ba1c878 100644 --- a/services/config/src/main/resources/service-config.xsd +++ b/services/config/src/main/resources/service-config.xsd @@ -277,7 +277,6 @@ Configures how a SAML assertion is probed to find the CollectionSpace username. Defaults to: - @@ -286,6 +285,19 @@ + + + + + + ]]> + + + -- 2.47.3