diff --git a/kafka-ui-api/src/main/java/com/provectus/kafka/ui/config/auth/OAuthSecurityConfig.java b/kafka-ui-api/src/main/java/com/provectus/kafka/ui/config/auth/OAuthSecurityConfig.java index a0117bc2cc..d170a7338c 100644 --- a/kafka-ui-api/src/main/java/com/provectus/kafka/ui/config/auth/OAuthSecurityConfig.java +++ b/kafka-ui-api/src/main/java/com/provectus/kafka/ui/config/auth/OAuthSecurityConfig.java @@ -72,13 +72,13 @@ public class OAuthSecurityConfig extends AbstractAuthSecurityConfig { final OidcReactiveOAuth2UserService delegate = new OidcReactiveOAuth2UserService(); return request -> delegate.loadUser(request) .flatMap(user -> { - String providerId = request.getClientRegistration().getRegistrationId(); - final var extractor = getExtractor(providerId, acs); + var provider = getProviderByProviderId(request.getClientRegistration().getRegistrationId()); + final var extractor = getExtractor(provider, acs); if (extractor == null) { return Mono.just(user); } - return extractor.extract(acs, user, Map.of("request", request)) + return extractor.extract(acs, user, Map.of("request", request, "provider", provider)) .map(groups -> new RbacOidcUser(user, groups)); }); } @@ -88,13 +88,13 @@ public class OAuthSecurityConfig extends AbstractAuthSecurityConfig { final DefaultReactiveOAuth2UserService delegate = new DefaultReactiveOAuth2UserService(); return request -> delegate.loadUser(request) .flatMap(user -> { - String providerId = request.getClientRegistration().getRegistrationId(); - final var extractor = getExtractor(providerId, acs); + var provider = getProviderByProviderId(request.getClientRegistration().getRegistrationId()); + final var extractor = getExtractor(provider, acs); if (extractor == null) { return Mono.just(user); } - return extractor.extract(acs, user, Map.of("request", request)) + return extractor.extract(acs, user, Map.of("request", request, "provider", provider)) .map(groups -> new RbacOAuth2User(user, groups)); }); } @@ -113,8 +113,8 @@ public class OAuthSecurityConfig extends AbstractAuthSecurityConfig { } @Nullable - private ProviderAuthorityExtractor getExtractor(final String providerId, AccessControlService acs) { - final var provider = getProviderByProviderId(providerId); + private ProviderAuthorityExtractor getExtractor(final OAuthProperties.OAuth2Provider provider, + AccessControlService acs) { Optional extractor = acs.getOauthExtractors() .stream() .filter(e -> e.isApplicable(provider.getProvider(), provider.getCustomParams())) diff --git a/kafka-ui-api/src/main/java/com/provectus/kafka/ui/service/rbac/extractor/OauthAuthorityExtractor.java b/kafka-ui-api/src/main/java/com/provectus/kafka/ui/service/rbac/extractor/OauthAuthorityExtractor.java index f412c59d88..2d4c01a0a2 100644 --- a/kafka-ui-api/src/main/java/com/provectus/kafka/ui/service/rbac/extractor/OauthAuthorityExtractor.java +++ b/kafka-ui-api/src/main/java/com/provectus/kafka/ui/service/rbac/extractor/OauthAuthorityExtractor.java @@ -4,6 +4,7 @@ import static com.provectus.kafka.ui.model.rbac.provider.Provider.Name.OAUTH; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import com.provectus.kafka.ui.config.auth.OAuthProperties; import com.provectus.kafka.ui.model.rbac.Role; import com.provectus.kafka.ui.model.rbac.provider.Provider; import com.provectus.kafka.ui.service.rbac.AccessControlService; @@ -21,8 +22,16 @@ import reactor.core.publisher.Mono; @Slf4j public class OauthAuthorityExtractor implements ProviderAuthorityExtractor { + public static final String ROLES_FIELD_PARAM_NAME = "roles-field"; + @Override public boolean isApplicable(String provider, Map customParams) { + var containsRolesFieldNameParam = customParams.containsKey(ROLES_FIELD_PARAM_NAME); + if (!containsRolesFieldNameParam) { + log.debug("Provider [{}] doesn't contain a roles field param name, mapping won't be performed", provider); + return false; + } + return OAUTH.equalsIgnoreCase(provider) || OAUTH.equalsIgnoreCase(customParams.get(TYPE)); } @@ -38,7 +47,10 @@ public class OauthAuthorityExtractor implements ProviderAuthorityExtractor { throw new RuntimeException(); } - Set groupsByUsername = acs.getRoles() + var provider = (OAuthProperties.OAuth2Provider) additionalParams.get("provider"); + var rolesFieldName = provider.getCustomParams().get(ROLES_FIELD_PARAM_NAME); + + Set rolesByUsername = acs.getRoles() .stream() .filter(r -> r.getSubjects() .stream() @@ -48,37 +60,46 @@ public class OauthAuthorityExtractor implements ProviderAuthorityExtractor { .map(Role::getName) .collect(Collectors.toSet()); - Set groupsByGroupField = acs.getRoles() + Set rolesByRolesField = acs.getRoles() .stream() .filter(role -> role.getSubjects() .stream() .filter(s -> s.getProvider().equals(Provider.OAUTH)) - .filter(s -> s.getType().equals("groupsfield")) - .anyMatch(subject -> convertGroups(principal.getAttribute(subject.getValue())).contains(role.getName())) + .filter(s -> s.getType().equals("role")) + .anyMatch(subject + -> { + var principalRoles = convertRoles(principal.getAttribute(rolesFieldName)); + var roleName = subject.getValue(); + return principalRoles.contains(roleName); + }) ) - //subject.getValue() .map(Role::getName) .collect(Collectors.toSet()); - return Mono.just(Stream.concat(groupsByUsername.stream(), groupsByGroupField.stream()).collect(Collectors.toSet())); + return Mono.just(Stream.concat(rolesByUsername.stream(), rolesByRolesField.stream()).collect(Collectors.toSet())); } @SuppressWarnings("unchecked") - private Collection convertGroups(Object groups) { + private Collection convertRoles(Object roles) { + if (roles == null) { + log.debug("Param missing from attributes, skipping"); + return Collections.emptySet(); + } + try { - if ((groups instanceof List) || (groups instanceof Set)) { + if ((roles instanceof List) || (roles instanceof Set)) { log.trace("The field is either a set or a list, returning as is"); - return (Collection) groups; + return (Collection) roles; } - if (!(groups instanceof String)) { + if (!(roles instanceof String)) { log.debug("The field is not a string, skipping"); return Collections.emptySet(); } log.trace("Trying to deserialize the field"); //@formatter:off - return new ObjectMapper().readValue((String) groups, new TypeReference<>() {}); + return new ObjectMapper().readValue((String) roles, new TypeReference<>() {}); //@formatter:on } catch (Exception e) { log.error("Error deserializing field", e);