Move the role field to provider custom params, use subject value as a role name

This commit is contained in:
Roman Zabaluev 2023-05-02 16:23:10 +08:00
parent 29bd456a0b
commit e0b910cb0a
2 changed files with 40 additions and 19 deletions

View file

@ -72,13 +72,13 @@ public class OAuthSecurityConfig extends AbstractAuthSecurityConfig {
final OidcReactiveOAuth2UserService delegate = new OidcReactiveOAuth2UserService(); final OidcReactiveOAuth2UserService delegate = new OidcReactiveOAuth2UserService();
return request -> delegate.loadUser(request) return request -> delegate.loadUser(request)
.flatMap(user -> { .flatMap(user -> {
String providerId = request.getClientRegistration().getRegistrationId(); var provider = getProviderByProviderId(request.getClientRegistration().getRegistrationId());
final var extractor = getExtractor(providerId, acs); final var extractor = getExtractor(provider, acs);
if (extractor == null) { if (extractor == null) {
return Mono.just(user); return Mono.just(user);
} }
return extractor.extract(acs, user, Map.of("request", request)) return extractor.extract(acs, user, Map.of("request", request, "provider", provider))
.map(groups -> new RbacOidcUser(user, groups)); .map(groups -> new RbacOidcUser(user, groups));
}); });
} }
@ -88,13 +88,13 @@ public class OAuthSecurityConfig extends AbstractAuthSecurityConfig {
final DefaultReactiveOAuth2UserService delegate = new DefaultReactiveOAuth2UserService(); final DefaultReactiveOAuth2UserService delegate = new DefaultReactiveOAuth2UserService();
return request -> delegate.loadUser(request) return request -> delegate.loadUser(request)
.flatMap(user -> { .flatMap(user -> {
String providerId = request.getClientRegistration().getRegistrationId(); var provider = getProviderByProviderId(request.getClientRegistration().getRegistrationId());
final var extractor = getExtractor(providerId, acs); final var extractor = getExtractor(provider, acs);
if (extractor == null) { if (extractor == null) {
return Mono.just(user); return Mono.just(user);
} }
return extractor.extract(acs, user, Map.of("request", request)) return extractor.extract(acs, user, Map.of("request", request, "provider", provider))
.map(groups -> new RbacOAuth2User(user, groups)); .map(groups -> new RbacOAuth2User(user, groups));
}); });
} }
@ -113,8 +113,8 @@ public class OAuthSecurityConfig extends AbstractAuthSecurityConfig {
} }
@Nullable @Nullable
private ProviderAuthorityExtractor getExtractor(final String providerId, AccessControlService acs) { private ProviderAuthorityExtractor getExtractor(final OAuthProperties.OAuth2Provider provider,
final var provider = getProviderByProviderId(providerId); AccessControlService acs) {
Optional<ProviderAuthorityExtractor> extractor = acs.getOauthExtractors() Optional<ProviderAuthorityExtractor> extractor = acs.getOauthExtractors()
.stream() .stream()
.filter(e -> e.isApplicable(provider.getProvider(), provider.getCustomParams())) .filter(e -> e.isApplicable(provider.getProvider(), provider.getCustomParams()))

View file

@ -4,6 +4,7 @@ import static com.provectus.kafka.ui.model.rbac.provider.Provider.Name.OAUTH;
import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.provectus.kafka.ui.config.auth.OAuthProperties;
import com.provectus.kafka.ui.model.rbac.Role; import com.provectus.kafka.ui.model.rbac.Role;
import com.provectus.kafka.ui.model.rbac.provider.Provider; import com.provectus.kafka.ui.model.rbac.provider.Provider;
import com.provectus.kafka.ui.service.rbac.AccessControlService; import com.provectus.kafka.ui.service.rbac.AccessControlService;
@ -21,8 +22,16 @@ import reactor.core.publisher.Mono;
@Slf4j @Slf4j
public class OauthAuthorityExtractor implements ProviderAuthorityExtractor { public class OauthAuthorityExtractor implements ProviderAuthorityExtractor {
public static final String ROLES_FIELD_PARAM_NAME = "roles-field";
@Override @Override
public boolean isApplicable(String provider, Map<String, String> customParams) { public boolean isApplicable(String provider, Map<String, String> customParams) {
var containsRolesFieldNameParam = customParams.containsKey(ROLES_FIELD_PARAM_NAME);
if (!containsRolesFieldNameParam) {
log.debug("Provider [{}] doesn't contain a roles field param name, mapping won't be performed", provider);
return false;
}
return OAUTH.equalsIgnoreCase(provider) || OAUTH.equalsIgnoreCase(customParams.get(TYPE)); return OAUTH.equalsIgnoreCase(provider) || OAUTH.equalsIgnoreCase(customParams.get(TYPE));
} }
@ -38,7 +47,10 @@ public class OauthAuthorityExtractor implements ProviderAuthorityExtractor {
throw new RuntimeException(); throw new RuntimeException();
} }
Set<String> groupsByUsername = acs.getRoles() var provider = (OAuthProperties.OAuth2Provider) additionalParams.get("provider");
var rolesFieldName = provider.getCustomParams().get(ROLES_FIELD_PARAM_NAME);
Set<String> rolesByUsername = acs.getRoles()
.stream() .stream()
.filter(r -> r.getSubjects() .filter(r -> r.getSubjects()
.stream() .stream()
@ -48,37 +60,46 @@ public class OauthAuthorityExtractor implements ProviderAuthorityExtractor {
.map(Role::getName) .map(Role::getName)
.collect(Collectors.toSet()); .collect(Collectors.toSet());
Set<String> groupsByGroupField = acs.getRoles() Set<String> rolesByRolesField = acs.getRoles()
.stream() .stream()
.filter(role -> role.getSubjects() .filter(role -> role.getSubjects()
.stream() .stream()
.filter(s -> s.getProvider().equals(Provider.OAUTH)) .filter(s -> s.getProvider().equals(Provider.OAUTH))
.filter(s -> s.getType().equals("groupsfield")) .filter(s -> s.getType().equals("role"))
.anyMatch(subject -> convertGroups(principal.getAttribute(subject.getValue())).contains(role.getName())) .anyMatch(subject
-> {
var principalRoles = convertRoles(principal.getAttribute(rolesFieldName));
var roleName = subject.getValue();
return principalRoles.contains(roleName);
})
) )
//subject.getValue()
.map(Role::getName) .map(Role::getName)
.collect(Collectors.toSet()); .collect(Collectors.toSet());
return Mono.just(Stream.concat(groupsByUsername.stream(), groupsByGroupField.stream()).collect(Collectors.toSet())); return Mono.just(Stream.concat(rolesByUsername.stream(), rolesByRolesField.stream()).collect(Collectors.toSet()));
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private Collection<String> convertGroups(Object groups) { private Collection<String> convertRoles(Object roles) {
try { if (roles == null) {
if ((groups instanceof List<?>) || (groups instanceof Set<?>)) { log.debug("Param missing from attributes, skipping");
log.trace("The field is either a set or a list, returning as is"); return Collections.emptySet();
return (Collection<String>) groups;
} }
if (!(groups instanceof String)) { try {
if ((roles instanceof List<?>) || (roles instanceof Set<?>)) {
log.trace("The field is either a set or a list, returning as is");
return (Collection<String>) roles;
}
if (!(roles instanceof String)) {
log.debug("The field is not a string, skipping"); log.debug("The field is not a string, skipping");
return Collections.emptySet(); return Collections.emptySet();
} }
log.trace("Trying to deserialize the field"); log.trace("Trying to deserialize the field");
//@formatter:off //@formatter:off
return new ObjectMapper().readValue((String) groups, new TypeReference<>() {}); return new ObjectMapper().readValue((String) roles, new TypeReference<>() {});
//@formatter:on //@formatter:on
} catch (Exception e) { } catch (Exception e) {
log.error("Error deserializing field", e); log.error("Error deserializing field", e);