From: Ray Lee Date: Sat, 23 Sep 2023 02:49:01 +0000 (-0400) Subject: Add SAML providers to CORS allowed hosts. (#367) X-Git-Url: https://git.aero2k.de/?a=commitdiff_plain;h=808629d97f327dc726cfb2b5151b69505be31a89;p=tmp%2Fjakarta-migration.git Add SAML providers to CORS allowed hosts. (#367) --- diff --git a/services/common/src/main/java/org/collectionspace/services/common/security/SecurityConfig.java b/services/common/src/main/java/org/collectionspace/services/common/security/SecurityConfig.java index 9056f5140..569f9f5eb 100644 --- a/services/common/src/main/java/org/collectionspace/services/common/security/SecurityConfig.java +++ b/services/common/src/main/java/org/collectionspace/services/common/security/SecurityConfig.java @@ -4,6 +4,7 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.io.Reader; import java.net.MalformedURLException; +import java.net.URL; import java.security.cert.X509Certificate; import java.security.KeyFactory; import java.security.KeyPair; @@ -20,6 +21,7 @@ import java.util.Base64; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -62,7 +64,6 @@ import org.collectionspace.authentication.realm.db.CSpaceDbRealm; import org.opensaml.saml.saml2.core.Assertion; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.beans.factory.annotation.Value; import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -131,6 +132,7 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.OrRequestMatcher; import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.cors.CorsConfigurationSource; +import org.springframework.web.cors.UrlBasedCorsConfigurationSource; import com.google.common.io.CharStreams; import com.nimbusds.jose.jwk.JWKSet; @@ -152,12 +154,32 @@ public class SecurityConfig { private CorsConfiguration defaultCorsConfiguration = null; private CorsConfiguration oauthServerCorsConfiguration = null; + private Map samlCorsConfigurations = null; - private void initializeCorsConfigurations() { + private void initializeCorsConfigurations(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { ServiceConfig serviceConfig = ServiceMain.getInstance().getServiceConfig(); Duration maxAge = ConfigUtils.getCorsMaxAge(serviceConfig); + + // Read explicitly configured allowed origins from service config. + List allowedOrigins = ConfigUtils.getCorsAllowedOrigins(serviceConfig); + // Automatically add UI locations as allowed origins. + + TenantBindingConfigReaderImpl tenantBindingConfigReader = ServiceMain.getInstance().getTenantBindingConfigReader(); + + for (TenantBindingType tenantBinding : tenantBindingConfigReader.getTenantBindings().values()) { + URL uiBaseUrl = null; + try { + uiBaseUrl = new URL(ConfigUtils.getUIBaseUrl(tenantBinding)); + } catch (MalformedURLException e) { + } + + if (uiBaseUrl != null) { + allowedOrigins.add(uiBaseUrl.getProtocol() + "://" + uiBaseUrl.getAuthority()); + } + } + if (this.defaultCorsConfiguration == null) { this.defaultCorsConfiguration = defaultCorsConfiguration(allowedOrigins, maxAge); } @@ -165,6 +187,12 @@ public class SecurityConfig { if (this.oauthServerCorsConfiguration == null) { this.oauthServerCorsConfiguration = oauthServerCorsConfiguration(allowedOrigins, maxAge); } + + if (relyingPartyRegistrationRepository != null && this.samlCorsConfigurations == null) { + // Automatically add SAML providers as allowed origins for SAML response endpoints. + + this.samlCorsConfigurations = samlCorsConfigurations(relyingPartyRegistrationRepository, allowedOrigins, maxAge); + } } private CorsConfiguration defaultCorsConfiguration(List allowedOrigins, Duration maxAge) { @@ -214,6 +242,64 @@ public class SecurityConfig { return configuration; } + /** + * Generate CORS configurations for SAML. For each registered SAML provider, POST requests to the + * SAML response endpoint are allowed from the provider's sign on location. + * + * @param relyingPartyRegistrationRepository + * @param allowedOrigins + * @param maxAge + * @return + */ + private Map samlCorsConfigurations( + RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, + List allowedOrigins, + Duration maxAge) + { + ServiceConfig serviceConfig = ServiceMain.getInstance().getServiceConfig(); + List relyingPartiesConfig = ConfigUtils.getSAMLRelyingPartyRegistrations(serviceConfig); + Map corsConfigurations = new LinkedHashMap<>(); + + if (relyingPartiesConfig != null) { + for (final SAMLRelyingPartyType relyingPartyConfig : relyingPartiesConfig) { + String id = relyingPartyConfig.getId(); + RelyingPartyRegistration registration = relyingPartyRegistrationRepository.findByRegistrationId(id); + + if (registration == null) { + continue; + } + + URL providerUrl = null; + + try { + providerUrl = new URL(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation()); + } catch (MalformedURLException e) { + } + + if (providerUrl != null) { + CorsConfiguration configuration = new CorsConfiguration(); + String responseUrl = "/login/saml2/sso/" + id; + String providerOrigin = providerUrl.getProtocol() + "://" + providerUrl.getAuthority(); + + configuration.setAllowedOrigins(allowedOrigins); + configuration.addAllowedOrigin(providerOrigin); + + if (maxAge != null) { + configuration.setMaxAge(maxAge); + } + + configuration.setAllowedMethods(Arrays.asList( + HttpMethod.POST.toString() + )); + + corsConfigurations.put(responseUrl, configuration); + } + } + } + + return corsConfigurations; + } + @Bean public JdbcOperations jdbcOperations(DataSource cspaceDataSource) { return new JdbcTemplate(cspaceDataSource); @@ -232,7 +318,7 @@ public class SecurityConfig { @Bean @Order(Ordered.HIGHEST_PRECEDENCE) public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception { - this.initializeCorsConfigurations(); + this.initializeCorsConfigurations(null); OAuth2AuthorizationServerConfiguration.applyDefaultSecurity(http); @@ -269,10 +355,12 @@ public class SecurityConfig { final Optional optionalRelyingPartyRegistrationRepository ) throws Exception { + final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository = optionalRelyingPartyRegistrationRepository.orElse(null); + ServiceConfig serviceConfig = ServiceMain.getInstance().getServiceConfig(); SAMLType saml = ConfigUtils.getSAML(serviceConfig); - this.initializeCorsConfigurations(); + this.initializeCorsConfigurations(relyingPartyRegistrationRepository); http .authorizeHttpRequests(new Customizer.AuthorizationManagerRequestMatcherRegistry>() { @@ -391,20 +479,24 @@ public class SecurityConfig { .cors(new Customizer>() { @Override public void customize(CorsConfigurer configurer) { - configurer.configurationSource(new CorsConfigurationSource() { - @Override - @Nullable - public CorsConfiguration getCorsConfiguration(HttpServletRequest request) { - return SecurityConfig.this.defaultCorsConfiguration; + UrlBasedCorsConfigurationSource configurationSource = new UrlBasedCorsConfigurationSource(); + Map urlMappings = new LinkedHashMap<>(); + + if (SecurityConfig.this.samlCorsConfigurations != null) { + for (Map.Entry entry : SecurityConfig.this.samlCorsConfigurations.entrySet()) { + urlMappings.put(entry.getKey(), entry.getValue()); } - }); + } + + urlMappings.put("/**", SecurityConfig.this.defaultCorsConfiguration); + + configurationSource.setCorsConfigurations(urlMappings); + configurer.configurationSource(configurationSource); } }) // Insert the username from the security context into a request attribute for logging. .addFilterBefore(new CSpaceUserAttributeFilter(), LogoutFilter.class); - RelyingPartyRegistrationRepository relyingPartyRegistrationRepository = optionalRelyingPartyRegistrationRepository.orElse(null); - if (relyingPartyRegistrationRepository != null) { RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository); diff --git a/services/config/src/main/java/org/collectionspace/services/common/config/ConfigUtils.java b/services/config/src/main/java/org/collectionspace/services/common/config/ConfigUtils.java index e3ec282b4..121b426d6 100644 --- a/services/config/src/main/java/org/collectionspace/services/common/config/ConfigUtils.java +++ b/services/config/src/main/java/org/collectionspace/services/common/config/ConfigUtils.java @@ -207,6 +207,16 @@ public class ConfigUtils { return (samlRegistrations != null && samlRegistrations.size() > 0); } + public static String getUIBaseUrl(TenantBindingType tenantBinding) { + UIConfig uiConfig = tenantBinding.getUiConfig(); + + if (uiConfig != null) { + return uiConfig.getBaseUrl(); + } + + return null; + } + public static String getUILoginSuccessUrl(TenantBindingType tenantBinding) throws MalformedURLException { UIConfig uiConfig = tenantBinding.getUiConfig();