diff --git a/src/main/java/org/codelibs/fess/sso/aad/AzureAdAuthenticator.java b/src/main/java/org/codelibs/fess/sso/aad/AzureAdAuthenticator.java index 7f8c812a7..51c0178b2 100644 --- a/src/main/java/org/codelibs/fess/sso/aad/AzureAdAuthenticator.java +++ b/src/main/java/org/codelibs/fess/sso/aad/AzureAdAuthenticator.java @@ -26,6 +26,7 @@ import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -38,7 +39,9 @@ import javax.servlet.http.HttpSession; import org.apache.commons.lang3.StringUtils; import org.codelibs.core.lang.StringUtil; +import org.codelibs.core.misc.Pair; import org.codelibs.core.net.UuidUtil; +import org.codelibs.core.stream.StreamUtil; import org.codelibs.curl.Curl; import org.codelibs.curl.CurlResponse; import org.codelibs.elasticsearch.runner.net.EcrCurl; @@ -50,6 +53,7 @@ import org.codelibs.fess.crawler.Constants; import org.codelibs.fess.exception.SsoLoginException; import org.codelibs.fess.sso.SsoAuthenticator; import org.codelibs.fess.util.ComponentUtil; +import org.codelibs.fess.util.DocumentUtil; import org.dbflute.optional.OptionalEntity; import org.lastaflute.web.login.credential.LoginCredential; import org.lastaflute.web.response.HtmlResponse; @@ -57,6 +61,8 @@ import org.lastaflute.web.util.LaRequestUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; import com.microsoft.aad.adal4j.AuthenticationContext; import com.microsoft.aad.adal4j.AuthenticationResult; import com.microsoft.aad.adal4j.ClientCredential; @@ -100,9 +106,14 @@ public class AzureAdAuthenticator implements SsoAuthenticator { protected long acquisitionTimeout = 30 * 1000L; + protected Cache> groupCache; + + protected long groupCacheExpiry = 10 * 60L; + @PostConstruct public void init() { ComponentUtil.getSsoManager().register(this); + groupCache = CacheBuilder.newBuilder().expireAfterWrite(groupCacheExpiry, TimeUnit.SECONDS).build(); } @Override @@ -333,11 +344,22 @@ public class AzureAdAuthenticator implements SsoAuthenticator { final List roleList = new ArrayList<>(); groupList.addAll(getDefaultGroupList()); roleList.addAll(getDefaultRoleList()); + processMemberOf(user, groupList, roleList, "https://graph.microsoft.com/v1.0/me/memberOf"); + user.setGroups(groupList.stream().distinct().toArray(n -> new String[n])); + user.setRoles(roleList.stream().distinct().toArray(n -> new String[n])); + } + + protected void processMemberOf(final AzureAdUser user, final List groupList, final List roleList, final String url) { + if (logger.isDebugEnabled()) { + logger.debug("url: {}", url); + } try (CurlResponse response = - Curl.get("https://graph.microsoft.com/v1.0/me/memberOf") - .header("Authorization", "Bearer " + user.getAuthenticationResult().getAccessToken()) + Curl.get(url).header("Authorization", "Bearer " + user.getAuthenticationResult().getAccessToken()) .header("Accept", "application/json").execute()) { final Map contentMap = response.getContent(EcrCurl.jsonParser); + if (logger.isDebugEnabled()) { + logger.debug("response: {}", contentMap); + } if (contentMap.containsKey("value")) { @SuppressWarnings("unchecked") final List> memberOfList = (List>) contentMap.get("value"); @@ -363,6 +385,7 @@ public class AzureAdAuthenticator implements SsoAuthenticator { } groupList.add(id); } + processParentGroup(user, groupList, roleList, id); } else { logger.warn("id is empty: {}", memberOf); } @@ -382,15 +405,94 @@ public class AzureAdAuthenticator implements SsoAuthenticator { logger.debug("mail is empty: {}", memberOf); } } + final String nextLink = (String) contentMap.get("@odata.nextLink"); + if (StringUtil.isNotBlank(nextLink)) { + processMemberOf(user, groupList, roleList, nextLink); + } } else if (contentMap.containsKey("error")) { logger.warn("Failed to access groups/roles: {}", contentMap); } } catch (final IOException e) { logger.warn("Failed to access groups/roles in AzureAD.", e); } + } - user.setGroups(groupList.toArray(n -> new String[n])); - user.setRoles(roleList.toArray(n -> new String[n])); + protected void processParentGroup(final AzureAdUser user, final List groupList, final List roleList, final String id) { + final Pair groupsAndRoles = getParentGroup(user, id); + StreamUtil.stream(groupsAndRoles.getFirst()).of(stream -> stream.forEach(groupList::add)); + StreamUtil.stream(groupsAndRoles.getSecond()).of(stream -> stream.forEach(roleList::add)); + } + + protected Pair getParentGroup(final AzureAdUser user, final String id) { + try { + return groupCache.get( + id, + () -> { + final List groupList = new ArrayList<>(); + final List roleList = new ArrayList<>(); + final String url = "https://graph.microsoft.com/v1.0/groups/" + id + "/getMemberGroups"; + if (logger.isDebugEnabled()) { + logger.debug("url: {}", url); + } + try (CurlResponse response = + Curl.post(url).header("Authorization", "Bearer " + user.getAuthenticationResult().getAccessToken()) + .header("Accept", "application/json").header("Content-type", "application/json") + .body("{\"securityEnabledOnly\":false}").execute()) { + final Map contentMap = response.getContent(EcrCurl.jsonParser); + if (logger.isDebugEnabled()) { + logger.debug("response: {}", contentMap); + } + if (contentMap.containsKey("value")) { + final String[] values = DocumentUtil.getValue(contentMap, "value", String[].class); + if (values != null) { + for (final String value : values) { + processGroup(user, groupList, roleList, value); + if (!groupList.contains(value) && !roleList.contains(value)) { + final Pair groupsAndRoles = getParentGroup(user, value); + StreamUtil.stream(groupsAndRoles.getFirst()).of(stream1 -> stream1.forEach(groupList::add)); + StreamUtil.stream(groupsAndRoles.getSecond()).of(stream2 -> stream2.forEach(roleList::add)); + } + } + } + } else if (contentMap.containsKey("error")) { + logger.warn("Failed to access parent groups: {}", contentMap); + } + } catch (final IOException e) { + logger.warn("Failed to access groups/roles in AzureAD.", e); + } + return new Pair<>(groupList.stream().distinct().toArray(n1 -> new String[n1]), roleList.stream().distinct() + .toArray(n2 -> new String[n2])); + }); + } catch (final ExecutionException e) { + logger.warn("Failed to process a group cache.", e); + return new Pair<>(StringUtil.EMPTY_STRINGS, StringUtil.EMPTY_STRINGS); + } + } + + protected void processGroup(final AzureAdUser user, final List groupList, final List roleList, final String id) { + final String url = "https://graph.microsoft.com/v1.0/groups/" + id; + if (logger.isDebugEnabled()) { + logger.debug("url: {}", url); + } + try (CurlResponse response = + Curl.get(url).header("Authorization", "Bearer " + user.getAuthenticationResult().getAccessToken()) + .header("Accept", "application/json").execute()) { + final Map contentMap = response.getContent(EcrCurl.jsonParser); + if (logger.isDebugEnabled()) { + logger.debug("response: {}", contentMap); + } + groupList.add(id); + if (contentMap.containsKey("error")) { + logger.warn("Failed to access parent groups: {}", contentMap); + } else { + final String mail = (String) contentMap.get("mail"); + if (StringUtil.isNotBlank(mail)) { + groupList.add(mail); + } + } + } catch (final IOException e) { + logger.warn("Failed to access groups/roles in AzureAD.", e); + } } protected List getDefaultGroupList() { @@ -472,4 +574,8 @@ public class AzureAdAuthenticator implements SsoAuthenticator { public void setAcquisitionTimeout(final long acquisitionTimeout) { this.acquisitionTimeout = acquisitionTimeout; } + + public void setGroupCacheExpiry(long groupCacheExpiry) { + this.groupCacheExpiry = groupCacheExpiry; + } }