فهرست منبع

Move the role field to provider custom params, use subject value as a role name

Roman Zabaluev 2 سال پیش
والد
کامیت
e0b910cb0a

+ 8 - 8
kafka-ui-api/src/main/java/com/provectus/kafka/ui/config/auth/OAuthSecurityConfig.java

@@ -72,13 +72,13 @@ public class OAuthSecurityConfig extends AbstractAuthSecurityConfig {
     final OidcReactiveOAuth2UserService delegate = new OidcReactiveOAuth2UserService();
     return request -> delegate.loadUser(request)
         .flatMap(user -> {
-          String providerId = request.getClientRegistration().getRegistrationId();
-          final var extractor = getExtractor(providerId, acs);
+          var provider = getProviderByProviderId(request.getClientRegistration().getRegistrationId());
+          final var extractor = getExtractor(provider, acs);
           if (extractor == null) {
             return Mono.just(user);
           }
 
-          return extractor.extract(acs, user, Map.of("request", request))
+          return extractor.extract(acs, user, Map.of("request", request, "provider", provider))
               .map(groups -> new RbacOidcUser(user, groups));
         });
   }
@@ -88,13 +88,13 @@ public class OAuthSecurityConfig extends AbstractAuthSecurityConfig {
     final DefaultReactiveOAuth2UserService delegate = new DefaultReactiveOAuth2UserService();
     return request -> delegate.loadUser(request)
         .flatMap(user -> {
-          String providerId = request.getClientRegistration().getRegistrationId();
-          final var extractor = getExtractor(providerId, acs);
+          var provider = getProviderByProviderId(request.getClientRegistration().getRegistrationId());
+          final var extractor = getExtractor(provider, acs);
           if (extractor == null) {
             return Mono.just(user);
           }
 
-          return extractor.extract(acs, user, Map.of("request", request))
+          return extractor.extract(acs, user, Map.of("request", request, "provider", provider))
               .map(groups -> new RbacOAuth2User(user, groups));
         });
   }
@@ -113,8 +113,8 @@ public class OAuthSecurityConfig extends AbstractAuthSecurityConfig {
   }
 
   @Nullable
-  private ProviderAuthorityExtractor getExtractor(final String providerId, AccessControlService acs) {
-    final var provider = getProviderByProviderId(providerId);
+  private ProviderAuthorityExtractor getExtractor(final OAuthProperties.OAuth2Provider provider,
+                                                  AccessControlService acs) {
     Optional<ProviderAuthorityExtractor> extractor = acs.getOauthExtractors()
         .stream()
         .filter(e -> e.isApplicable(provider.getProvider(), provider.getCustomParams()))

+ 32 - 11
kafka-ui-api/src/main/java/com/provectus/kafka/ui/service/rbac/extractor/OauthAuthorityExtractor.java

@@ -4,6 +4,7 @@ import static com.provectus.kafka.ui.model.rbac.provider.Provider.Name.OAUTH;
 
 import com.fasterxml.jackson.core.type.TypeReference;
 import com.fasterxml.jackson.databind.ObjectMapper;
+import com.provectus.kafka.ui.config.auth.OAuthProperties;
 import com.provectus.kafka.ui.model.rbac.Role;
 import com.provectus.kafka.ui.model.rbac.provider.Provider;
 import com.provectus.kafka.ui.service.rbac.AccessControlService;
@@ -21,8 +22,16 @@ import reactor.core.publisher.Mono;
 @Slf4j
 public class OauthAuthorityExtractor implements ProviderAuthorityExtractor {
 
+  public static final String ROLES_FIELD_PARAM_NAME = "roles-field";
+
   @Override
   public boolean isApplicable(String provider, Map<String, String> customParams) {
+    var containsRolesFieldNameParam = customParams.containsKey(ROLES_FIELD_PARAM_NAME);
+    if (!containsRolesFieldNameParam) {
+      log.debug("Provider [{}] doesn't contain a roles field param name, mapping won't be performed", provider);
+      return false;
+    }
+
     return OAUTH.equalsIgnoreCase(provider) || OAUTH.equalsIgnoreCase(customParams.get(TYPE));
   }
 
@@ -38,7 +47,10 @@ public class OauthAuthorityExtractor implements ProviderAuthorityExtractor {
       throw new RuntimeException();
     }
 
-    Set<String> groupsByUsername = acs.getRoles()
+    var provider = (OAuthProperties.OAuth2Provider) additionalParams.get("provider");
+    var rolesFieldName = provider.getCustomParams().get(ROLES_FIELD_PARAM_NAME);
+
+    Set<String> rolesByUsername = acs.getRoles()
         .stream()
         .filter(r -> r.getSubjects()
             .stream()
@@ -48,37 +60,46 @@ public class OauthAuthorityExtractor implements ProviderAuthorityExtractor {
         .map(Role::getName)
         .collect(Collectors.toSet());
 
-    Set<String> groupsByGroupField = acs.getRoles()
+    Set<String> rolesByRolesField = acs.getRoles()
         .stream()
         .filter(role -> role.getSubjects()
             .stream()
             .filter(s -> s.getProvider().equals(Provider.OAUTH))
-            .filter(s -> s.getType().equals("groupsfield"))
-            .anyMatch(subject -> convertGroups(principal.getAttribute(subject.getValue())).contains(role.getName()))
+            .filter(s -> s.getType().equals("role"))
+            .anyMatch(subject
+                -> {
+              var principalRoles = convertRoles(principal.getAttribute(rolesFieldName));
+              var roleName = subject.getValue();
+              return principalRoles.contains(roleName);
+            })
         )
-        //subject.getValue()
         .map(Role::getName)
         .collect(Collectors.toSet());
 
-    return Mono.just(Stream.concat(groupsByUsername.stream(), groupsByGroupField.stream()).collect(Collectors.toSet()));
+    return Mono.just(Stream.concat(rolesByUsername.stream(), rolesByRolesField.stream()).collect(Collectors.toSet()));
   }
 
   @SuppressWarnings("unchecked")
-  private Collection<String> convertGroups(Object groups) {
+  private Collection<String> convertRoles(Object roles) {
+    if (roles == null) {
+      log.debug("Param missing from attributes, skipping");
+      return Collections.emptySet();
+    }
+
     try {
-      if ((groups instanceof List<?>) || (groups instanceof Set<?>)) {
+      if ((roles instanceof List<?>) || (roles instanceof Set<?>)) {
         log.trace("The field is either a set or a list, returning as is");
-        return (Collection<String>) groups;
+        return (Collection<String>) roles;
       }
 
-      if (!(groups instanceof String)) {
+      if (!(roles instanceof String)) {
         log.debug("The field is not a string, skipping");
         return Collections.emptySet();
       }
 
       log.trace("Trying to deserialize the field");
       //@formatter:off
-      return new ObjectMapper().readValue((String) groups, new TypeReference<>() {});
+      return new ObjectMapper().readValue((String) roles, new TypeReference<>() {});
       //@formatter:on
     } catch (Exception e) {
       log.error("Error deserializing field", e);