diff --git a/app/server/appsmith-server/src/main/java/com/appsmith/server/configurations/SecurityConfig.java b/app/server/appsmith-server/src/main/java/com/appsmith/server/configurations/SecurityConfig.java index 59912e69ed..2b03873455 100644 --- a/app/server/appsmith-server/src/main/java/com/appsmith/server/configurations/SecurityConfig.java +++ b/app/server/appsmith-server/src/main/java/com/appsmith/server/configurations/SecurityConfig.java @@ -1,6 +1,7 @@ package com.appsmith.server.configurations; +import com.appsmith.server.filters.CustomServerOAuth2AuthorizationRequestResolver; import com.appsmith.server.services.UserService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; @@ -8,6 +9,7 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.http.HttpMethod; import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity; import org.springframework.security.config.web.server.ServerHttpSecurity; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.ServerAuthenticationEntryPoint; import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler; @@ -42,6 +44,9 @@ public class SecurityConfig { @Autowired private ServerAuthenticationEntryPoint authenticationEntryPoint; + @Autowired + private ReactiveClientRegistrationRepository reactiveClientRegistrationRepository; + /** * This routerFunction is required to map /public/** endpoints to the src/main/resources/public folder * This is to allow static resources to be served by the server. Couldn't find an easier way to do this, @@ -107,6 +112,7 @@ public class SecurityConfig { .authenticationSuccessHandler(authenticationSuccessHandler) .authenticationFailureHandler(authenticationFailureHandler) .and().oauth2Login() + .authorizationRequestResolver(new CustomServerOAuth2AuthorizationRequestResolver(reactiveClientRegistrationRepository)) .authenticationSuccessHandler(authenticationSuccessHandler) .authenticationFailureHandler(authenticationFailureHandler) .authorizedClientRepository(new ClientUserRepository(userService, commonConfig)) diff --git a/app/server/appsmith-server/src/main/java/com/appsmith/server/constants/Security.java b/app/server/appsmith-server/src/main/java/com/appsmith/server/constants/Security.java index a4beff71ac..54aa5909e9 100644 --- a/app/server/appsmith-server/src/main/java/com/appsmith/server/constants/Security.java +++ b/app/server/appsmith-server/src/main/java/com/appsmith/server/constants/Security.java @@ -2,4 +2,7 @@ package com.appsmith.server.constants; public interface Security { String USER_ROLE = "USER_ROLE"; + String QUERY_PARAMETER_STATE = "state"; + String REFERER_HEADER = "Referer"; + String STATE_PARAMETER_ORIGIN = "origin="; } diff --git a/app/server/appsmith-server/src/main/java/com/appsmith/server/filters/AuthenticationSuccessHandler.java b/app/server/appsmith-server/src/main/java/com/appsmith/server/filters/AuthenticationSuccessHandler.java index aac389131d..1c5e7d1c52 100644 --- a/app/server/appsmith-server/src/main/java/com/appsmith/server/filters/AuthenticationSuccessHandler.java +++ b/app/server/appsmith-server/src/main/java/com/appsmith/server/filters/AuthenticationSuccessHandler.java @@ -1,6 +1,7 @@ package com.appsmith.server.filters; import com.appsmith.server.constants.AclConstants; +import com.appsmith.server.constants.Security; import com.appsmith.server.domains.LoginSource; import com.appsmith.server.domains.User; import com.appsmith.server.domains.UserState; @@ -8,6 +9,7 @@ import com.appsmith.server.services.UserService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpHeaders; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.web.server.DefaultServerRedirectStrategy; @@ -46,22 +48,53 @@ public class AuthenticationSuccessHandler implements ServerAuthenticationSuccess public Mono onAuthenticationSuccess(WebFilterExchange webFilterExchange, Authentication authentication) { log.debug("Login succeeded for user: {}", authentication.getPrincipal()); - if(authentication instanceof OAuth2AuthenticationToken) { + if (authentication instanceof OAuth2AuthenticationToken) { OAuth2AuthenticationToken oauthAuthentication = (OAuth2AuthenticationToken) authentication; return checkAndCreateUser(oauthAuthentication) - .then(handleRedirect(webFilterExchange)); + .then(handleOAuth2Redirect(webFilterExchange)); } return handleRedirect(webFilterExchange); } + /** + * This function redirects the back to the client's page after a successful sign in/sign up attempt by the user + * This is to transfer control back to the client because the OAuth2 dance would have been performed by the server. + *

+ * We extract the redirect url from the `state` key present in the request exchange object. This is state variable + * contains a random generated key along with the referer header set in the + * {@link CustomServerOAuth2AuthorizationRequestResolver#generateKey(HttpHeaders)} function. + * + * @param webFilterExchange + * @return + */ + private Mono handleOAuth2Redirect(WebFilterExchange webFilterExchange) { + ServerWebExchange exchange = webFilterExchange.getExchange(); + String state = exchange.getRequest().getQueryParams().getFirst(Security.QUERY_PARAMETER_STATE); + String originHeader = "/"; + if (state != null && !state.isEmpty()) { + String line; + String[] stateArray = state.split(","); + for (int i = 0; i < stateArray.length; i++) { + String stateVar = stateArray[i]; + if (stateVar != null && stateVar.startsWith(Security.STATE_PARAMETER_ORIGIN) && stateVar.contains("=")) { + // This is the origin of the request that we want to redirect to + originHeader = stateVar.split("=")[1]; + } + } + } + + URI defaultRedirectLocation = URI.create(originHeader); + return this.redirectStrategy.sendRedirect(exchange, defaultRedirectLocation); + } + private Mono handleRedirect(WebFilterExchange webFilterExchange) { ServerWebExchange exchange = webFilterExchange.getExchange(); // On authentication success, we send a redirect to the client's home page. This ensures that the session // is set in the cookie on the browser. String originHeader = exchange.getRequest().getHeaders().getOrigin(); - if(originHeader == null || originHeader.isEmpty()) { + if (originHeader == null || originHeader.isEmpty()) { originHeader = "/"; } diff --git a/app/server/appsmith-server/src/main/java/com/appsmith/server/filters/CustomServerOAuth2AuthorizationRequestResolver.java b/app/server/appsmith-server/src/main/java/com/appsmith/server/filters/CustomServerOAuth2AuthorizationRequestResolver.java new file mode 100644 index 0000000000..309513bdd7 --- /dev/null +++ b/app/server/appsmith-server/src/main/java/com/appsmith/server/filters/CustomServerOAuth2AuthorizationRequestResolver.java @@ -0,0 +1,285 @@ +package com.appsmith.server.filters; + +import com.appsmith.server.constants.Security; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; +import org.springframework.security.crypto.keygen.StringKeyGenerator; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher; +import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ResponseStatusException; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; +import reactor.core.publisher.Mono; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; +import java.util.HashMap; +import java.util.Map; + +/** + * This class is a copy of {@link org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizationRequestResolver} + * It has been copied so as to override the creation of the `state` query parameter sent to the OAuth2 authentication server + * The only 2 functions that have been overriden from the base class are: {@link #generateKey(HttpHeaders)} and + * {@link #authorizationRequest(ServerWebExchange, ClientRegistration)}. + * We couldn't simply extend the base class because of the use of private variables and methods to invoke these functions. + * + */ +@Slf4j +public class CustomServerOAuth2AuthorizationRequestResolver implements ServerOAuth2AuthorizationRequestResolver { + + /** + * The name of the path variable that contains the {@link ClientRegistration#getRegistrationId()} + */ + public static final String DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId"; + + /** + * The default pattern used to resolve the {@link ClientRegistration#getRegistrationId()} + */ + public static final String DEFAULT_AUTHORIZATION_REQUEST_PATTERN = "/oauth2/authorization/{" + DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME + "}"; + + private static final char PATH_DELIMITER = '/'; + + private final ServerWebExchangeMatcher authorizationRequestMatcher; + + private final ReactiveClientRegistrationRepository clientRegistrationRepository; + + private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder()); + + private final StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96); + + /** + * Creates a new instance + * @param clientRegistrationRepository the repository to resolve the {@link ClientRegistration} + */ + public CustomServerOAuth2AuthorizationRequestResolver(ReactiveClientRegistrationRepository clientRegistrationRepository) { + this(clientRegistrationRepository, new PathPatternParserServerWebExchangeMatcher( + DEFAULT_AUTHORIZATION_REQUEST_PATTERN)); + } + + /** + * Creates a new instance + * @param clientRegistrationRepository the repository to resolve the {@link ClientRegistration} + * @param authorizationRequestMatcher the matcher that determines if the request is a match and extracts the + * {@link #DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME} from the path variables. + */ + public CustomServerOAuth2AuthorizationRequestResolver(ReactiveClientRegistrationRepository clientRegistrationRepository, + ServerWebExchangeMatcher authorizationRequestMatcher) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizationRequestMatcher, "authorizationRequestMatcher cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizationRequestMatcher = authorizationRequestMatcher; + } + + @Override + public Mono resolve(ServerWebExchange exchange) { + return this.authorizationRequestMatcher.matches(exchange) + .filter(matchResult -> matchResult.isMatch()) + .map(ServerWebExchangeMatcher.MatchResult::getVariables) + .map(variables -> variables.get(DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME)) + .cast(String.class) + .flatMap(clientRegistrationId -> resolve(exchange, clientRegistrationId)); + } + + @Override + public Mono resolve(ServerWebExchange exchange, + String clientRegistrationId) { + return this.findByRegistrationId(exchange, clientRegistrationId) + .map(clientRegistration -> authorizationRequest(exchange, clientRegistration)); + } + + private Mono findByRegistrationId(ServerWebExchange exchange, String clientRegistration) { + return this.clientRegistrationRepository.findByRegistrationId(clientRegistration) + .switchIfEmpty(Mono.error(() -> new ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid client registration id"))); + } + + private OAuth2AuthorizationRequest authorizationRequest(ServerWebExchange exchange, + ClientRegistration clientRegistration) { + String redirectUriStr = expandRedirectUri(exchange.getRequest(), clientRegistration); + + Map attributes = new HashMap<>(); + attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()); + + OAuth2AuthorizationRequest.Builder builder; + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { + builder = OAuth2AuthorizationRequest.authorizationCode(); + Map additionalParameters = new HashMap<>(); + if (!CollectionUtils.isEmpty(clientRegistration.getScopes()) && + clientRegistration.getScopes().contains(OidcScopes.OPENID)) { + // Section 3.1.2.1 Authentication Request - https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest + // scope + // REQUIRED. OpenID Connect requests MUST contain the "openid" scope value. + addNonceParameters(attributes, additionalParameters); + } + if (ClientAuthenticationMethod.NONE.equals(clientRegistration.getClientAuthenticationMethod())) { + addPkceParameters(attributes, additionalParameters); + } + builder.additionalParameters(additionalParameters); + } else if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizationGrantType())) { + builder = OAuth2AuthorizationRequest.implicit(); + } else { + throw new IllegalArgumentException( + "Invalid Authorization Grant Type (" + clientRegistration.getAuthorizationGrantType().getValue() + + ") for Client Registration with Id: " + clientRegistration.getRegistrationId()); + } + return builder + .clientId(clientRegistration.getClientId()) + .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) + .redirectUri(redirectUriStr).scopes(clientRegistration.getScopes()) + .state(this.generateKey(exchange.getRequest().getHeaders())) + .attributes(attributes) + .build(); + } + + /** + * This function sets the state query parameter sent to the OAuth2 resource server along with the parameter of the + * referer which initiated this OAuth2 login. On successful login, we will redirect back to the client's index page + * based on the referer so as to transfer control back to it. If the referer is not available, we default to + * redirecting to the server's index page. + * + * @param httpHeaders + * @return + */ + private String generateKey(HttpHeaders httpHeaders) { + String stateKey = this.stateGenerator.generateKey(); + String originHeader = httpHeaders.getOrigin(); + if(originHeader == null || originHeader.isBlank()) { + String refererHeader = httpHeaders.getFirst(Security.REFERER_HEADER); + if(refererHeader != null && !refererHeader.isBlank()) { + URI uri = null; + try { + uri = new URI(refererHeader); + String authority = uri.getAuthority(); + String scheme = uri.getScheme(); + originHeader = scheme + "://" + authority; + } catch (URISyntaxException e) { + originHeader = "/"; + } + } else { + originHeader = "/"; + } + } + stateKey = stateKey + "," + Security.STATE_PARAMETER_ORIGIN + originHeader; + return stateKey; + } + + /** + * Expands the {@link ClientRegistration#getRedirectUriTemplate()} with following provided variables:
+ * - baseUrl (e.g. https://localhost/app)
+ * - baseScheme (e.g. https)
+ * - baseHost (e.g. localhost)
+ * - basePort (e.g. :8080)
+ * - basePath (e.g. /app)
+ * - registrationId (e.g. google)
+ * - action (e.g. login)
+ *

+ * Null variables are provided as empty strings. + *

+ * Default redirectUriTemplate is: {@link org.springframework.security.config.oauth2.client}.CommonOAuth2Provider#DEFAULT_REDIRECT_URL + * + * @return expanded URI + */ + private static String expandRedirectUri(ServerHttpRequest request, ClientRegistration clientRegistration) { + Map uriVariables = new HashMap<>(); + uriVariables.put("registrationId", clientRegistration.getRegistrationId()); + + UriComponents uriComponents = UriComponentsBuilder.fromUri(request.getURI()) + .replacePath(request.getPath().contextPath().value()) + .replaceQuery(null) + .fragment(null) + .build(); + String scheme = uriComponents.getScheme(); + uriVariables.put("baseScheme", scheme == null ? "" : scheme); + String host = uriComponents.getHost(); + uriVariables.put("baseHost", host == null ? "" : host); + // following logic is based on HierarchicalUriComponents#toUriString() + int port = uriComponents.getPort(); + uriVariables.put("basePort", port == -1 ? "" : ":" + port); + String path = uriComponents.getPath(); + if (StringUtils.hasLength(path)) { + if (path.charAt(0) != PATH_DELIMITER) { + path = PATH_DELIMITER + path; + } + } + uriVariables.put("basePath", path == null ? "" : path); + uriVariables.put("baseUrl", uriComponents.toUriString()); + + String action = ""; + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { + action = "login"; + } + uriVariables.put("action", action); + + return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUriTemplate()) + .buildAndExpand(uriVariables) + .toUriString(); + } + + /** + * Creates nonce and its hash for use in OpenID Connect 1.0 Authentication Requests. + * + * @param attributes where the {@link OidcParameterNames#NONCE} is stored for the authentication request + * @param additionalParameters where the {@link OidcParameterNames#NONCE} hash is added for the authentication request + * + * @since 5.2 + * @see 3.1.2.1. Authentication Request + */ + private void addNonceParameters(Map attributes, Map additionalParameters) { + try { + String nonce = this.secureKeyGenerator.generateKey(); + String nonceHash = createHash(nonce); + attributes.put(OidcParameterNames.NONCE, nonce); + additionalParameters.put(OidcParameterNames.NONCE, nonceHash); + } catch (NoSuchAlgorithmException e) { } + } + + /** + * Creates and adds additional PKCE parameters for use in the OAuth 2.0 Authorization and Access Token Requests + * + * @param attributes where {@link PkceParameterNames#CODE_VERIFIER} is stored for the token request + * @param additionalParameters where {@link PkceParameterNames#CODE_CHALLENGE} and, usually, + * {@link PkceParameterNames#CODE_CHALLENGE_METHOD} are added to be used in the authorization request. + * + * @since 5.2 + * @see 1.1. Protocol Flow + * @see 4.1. Client Creates a Code Verifier + * @see 4.2. Client Creates the Code Challenge + */ + private void addPkceParameters(Map attributes, Map additionalParameters) { + String codeVerifier = this.secureKeyGenerator.generateKey(); + attributes.put(PkceParameterNames.CODE_VERIFIER, codeVerifier); + try { + String codeChallenge = createHash(codeVerifier); + additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, codeChallenge); + additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); + } catch (NoSuchAlgorithmException e) { + additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, codeVerifier); + } + } + + private static String createHash(String value) throws NoSuchAlgorithmException { + MessageDigest md = MessageDigest.getInstance("SHA-256"); + byte[] digest = md.digest(value.getBytes(StandardCharsets.US_ASCII)); + return Base64.getUrlEncoder().withoutPadding().encodeToString(digest); + } +}