Преглед на файлове

fix #2122 retrieve parent groups

Shinsuke Sugaya преди 6 години
родител
ревизия
ec0eec9af2
променени са 1 файла, в които са добавени 110 реда и са изтрити 4 реда
  1. 110 4
      src/main/java/org/codelibs/fess/sso/aad/AzureAdAuthenticator.java

+ 110 - 4
src/main/java/org/codelibs/fess/sso/aad/AzureAdAuthenticator.java

@@ -26,6 +26,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.List;
 import java.util.Locale;
 import java.util.Locale;
 import java.util.Map;
 import java.util.Map;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.concurrent.Future;
@@ -38,7 +39,9 @@ import javax.servlet.http.HttpSession;
 
 
 import org.apache.commons.lang3.StringUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.codelibs.core.lang.StringUtil;
 import org.codelibs.core.lang.StringUtil;
+import org.codelibs.core.misc.Pair;
 import org.codelibs.core.net.UuidUtil;
 import org.codelibs.core.net.UuidUtil;
+import org.codelibs.core.stream.StreamUtil;
 import org.codelibs.curl.Curl;
 import org.codelibs.curl.Curl;
 import org.codelibs.curl.CurlResponse;
 import org.codelibs.curl.CurlResponse;
 import org.codelibs.elasticsearch.runner.net.EcrCurl;
 import org.codelibs.elasticsearch.runner.net.EcrCurl;
@@ -50,6 +53,7 @@ import org.codelibs.fess.crawler.Constants;
 import org.codelibs.fess.exception.SsoLoginException;
 import org.codelibs.fess.exception.SsoLoginException;
 import org.codelibs.fess.sso.SsoAuthenticator;
 import org.codelibs.fess.sso.SsoAuthenticator;
 import org.codelibs.fess.util.ComponentUtil;
 import org.codelibs.fess.util.ComponentUtil;
+import org.codelibs.fess.util.DocumentUtil;
 import org.dbflute.optional.OptionalEntity;
 import org.dbflute.optional.OptionalEntity;
 import org.lastaflute.web.login.credential.LoginCredential;
 import org.lastaflute.web.login.credential.LoginCredential;
 import org.lastaflute.web.response.HtmlResponse;
 import org.lastaflute.web.response.HtmlResponse;
@@ -57,6 +61,8 @@ import org.lastaflute.web.util.LaRequestUtil;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.slf4j.LoggerFactory;
 
 
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
 import com.microsoft.aad.adal4j.AuthenticationContext;
 import com.microsoft.aad.adal4j.AuthenticationContext;
 import com.microsoft.aad.adal4j.AuthenticationResult;
 import com.microsoft.aad.adal4j.AuthenticationResult;
 import com.microsoft.aad.adal4j.ClientCredential;
 import com.microsoft.aad.adal4j.ClientCredential;
@@ -100,9 +106,14 @@ public class AzureAdAuthenticator implements SsoAuthenticator {
 
 
     protected long acquisitionTimeout = 30 * 1000L;
     protected long acquisitionTimeout = 30 * 1000L;
 
 
+    protected Cache<String, Pair<String[], String[]>> groupCache;
+
+    protected long groupCacheExpiry = 10 * 60L;
+
     @PostConstruct
     @PostConstruct
     public void init() {
     public void init() {
         ComponentUtil.getSsoManager().register(this);
         ComponentUtil.getSsoManager().register(this);
+        groupCache = CacheBuilder.newBuilder().expireAfterWrite(groupCacheExpiry, TimeUnit.SECONDS).build();
     }
     }
 
 
     @Override
     @Override
@@ -333,11 +344,22 @@ public class AzureAdAuthenticator implements SsoAuthenticator {
         final List<String> roleList = new ArrayList<>();
         final List<String> roleList = new ArrayList<>();
         groupList.addAll(getDefaultGroupList());
         groupList.addAll(getDefaultGroupList());
         roleList.addAll(getDefaultRoleList());
         roleList.addAll(getDefaultRoleList());
+        processMemberOf(user, groupList, roleList, "https://graph.microsoft.com/v1.0/me/memberOf");
+        user.setGroups(groupList.stream().distinct().toArray(n -> new String[n]));
+        user.setRoles(roleList.stream().distinct().toArray(n -> new String[n]));
+    }
+
+    protected void processMemberOf(final AzureAdUser user, final List<String> groupList, final List<String> roleList, final String url) {
+        if (logger.isDebugEnabled()) {
+            logger.debug("url: {}", url);
+        }
         try (CurlResponse response =
         try (CurlResponse response =
-                Curl.get("https://graph.microsoft.com/v1.0/me/memberOf")
-                        .header("Authorization", "Bearer " + user.getAuthenticationResult().getAccessToken())
+                Curl.get(url).header("Authorization", "Bearer " + user.getAuthenticationResult().getAccessToken())
                         .header("Accept", "application/json").execute()) {
                         .header("Accept", "application/json").execute()) {
             final Map<String, Object> contentMap = response.getContent(EcrCurl.jsonParser);
             final Map<String, Object> contentMap = response.getContent(EcrCurl.jsonParser);
+            if (logger.isDebugEnabled()) {
+                logger.debug("response: {}", contentMap);
+            }
             if (contentMap.containsKey("value")) {
             if (contentMap.containsKey("value")) {
                 @SuppressWarnings("unchecked")
                 @SuppressWarnings("unchecked")
                 final List<Map<String, Object>> memberOfList = (List<Map<String, Object>>) contentMap.get("value");
                 final List<Map<String, Object>> memberOfList = (List<Map<String, Object>>) contentMap.get("value");
@@ -363,6 +385,7 @@ public class AzureAdAuthenticator implements SsoAuthenticator {
                             }
                             }
                             groupList.add(id);
                             groupList.add(id);
                         }
                         }
+                        processParentGroup(user, groupList, roleList, id);
                     } else {
                     } else {
                         logger.warn("id is empty: {}", memberOf);
                         logger.warn("id is empty: {}", memberOf);
                     }
                     }
@@ -382,15 +405,94 @@ public class AzureAdAuthenticator implements SsoAuthenticator {
                         logger.debug("mail is empty: {}", memberOf);
                         logger.debug("mail is empty: {}", memberOf);
                     }
                     }
                 }
                 }
+                final String nextLink = (String) contentMap.get("@odata.nextLink");
+                if (StringUtil.isNotBlank(nextLink)) {
+                    processMemberOf(user, groupList, roleList, nextLink);
+                }
             } else if (contentMap.containsKey("error")) {
             } else if (contentMap.containsKey("error")) {
                 logger.warn("Failed to access groups/roles: {}", contentMap);
                 logger.warn("Failed to access groups/roles: {}", contentMap);
             }
             }
         } catch (final IOException e) {
         } catch (final IOException e) {
             logger.warn("Failed to access groups/roles in AzureAD.", e);
             logger.warn("Failed to access groups/roles in AzureAD.", e);
         }
         }
+    }
 
 
-        user.setGroups(groupList.toArray(n -> new String[n]));
-        user.setRoles(roleList.toArray(n -> new String[n]));
+    protected void processParentGroup(final AzureAdUser user, final List<String> groupList, final List<String> roleList, final String id) {
+        final Pair<String[], String[]> groupsAndRoles = getParentGroup(user, id);
+        StreamUtil.stream(groupsAndRoles.getFirst()).of(stream -> stream.forEach(groupList::add));
+        StreamUtil.stream(groupsAndRoles.getSecond()).of(stream -> stream.forEach(roleList::add));
+    }
+
+    protected Pair<String[], String[]> getParentGroup(final AzureAdUser user, final String id) {
+        try {
+            return groupCache.get(
+                    id,
+                    () -> {
+                        final List<String> groupList = new ArrayList<>();
+                        final List<String> roleList = new ArrayList<>();
+                        final String url = "https://graph.microsoft.com/v1.0/groups/" + id + "/getMemberGroups";
+                        if (logger.isDebugEnabled()) {
+                            logger.debug("url: {}", url);
+                        }
+                        try (CurlResponse response =
+                                Curl.post(url).header("Authorization", "Bearer " + user.getAuthenticationResult().getAccessToken())
+                                        .header("Accept", "application/json").header("Content-type", "application/json")
+                                        .body("{\"securityEnabledOnly\":false}").execute()) {
+                            final Map<String, Object> contentMap = response.getContent(EcrCurl.jsonParser);
+                            if (logger.isDebugEnabled()) {
+                                logger.debug("response: {}", contentMap);
+                            }
+                            if (contentMap.containsKey("value")) {
+                                final String[] values = DocumentUtil.getValue(contentMap, "value", String[].class);
+                                if (values != null) {
+                                    for (final String value : values) {
+                                        processGroup(user, groupList, roleList, value);
+                                        if (!groupList.contains(value) && !roleList.contains(value)) {
+                                            final Pair<String[], String[]> groupsAndRoles = getParentGroup(user, value);
+                                            StreamUtil.stream(groupsAndRoles.getFirst()).of(stream1 -> stream1.forEach(groupList::add));
+                                            StreamUtil.stream(groupsAndRoles.getSecond()).of(stream2 -> stream2.forEach(roleList::add));
+                                        }
+                                    }
+                                }
+                            } else if (contentMap.containsKey("error")) {
+                                logger.warn("Failed to access parent groups: {}", contentMap);
+                            }
+                        } catch (final IOException e) {
+                            logger.warn("Failed to access groups/roles in AzureAD.", e);
+                        }
+                        return new Pair<>(groupList.stream().distinct().toArray(n1 -> new String[n1]), roleList.stream().distinct()
+                                .toArray(n2 -> new String[n2]));
+                    });
+        } catch (final ExecutionException e) {
+            logger.warn("Failed to process a group cache.", e);
+            return new Pair<>(StringUtil.EMPTY_STRINGS, StringUtil.EMPTY_STRINGS);
+        }
+    }
+
+    protected void processGroup(final AzureAdUser user, final List<String> groupList, final List<String> roleList, final String id) {
+        final String url = "https://graph.microsoft.com/v1.0/groups/" + id;
+        if (logger.isDebugEnabled()) {
+            logger.debug("url: {}", url);
+        }
+        try (CurlResponse response =
+                Curl.get(url).header("Authorization", "Bearer " + user.getAuthenticationResult().getAccessToken())
+                        .header("Accept", "application/json").execute()) {
+            final Map<String, Object> contentMap = response.getContent(EcrCurl.jsonParser);
+            if (logger.isDebugEnabled()) {
+                logger.debug("response: {}", contentMap);
+            }
+            groupList.add(id);
+            if (contentMap.containsKey("error")) {
+                logger.warn("Failed to access parent groups: {}", contentMap);
+            } else {
+                final String mail = (String) contentMap.get("mail");
+                if (StringUtil.isNotBlank(mail)) {
+                    groupList.add(mail);
+                }
+            }
+        } catch (final IOException e) {
+            logger.warn("Failed to access groups/roles in AzureAD.", e);
+        }
     }
     }
 
 
     protected List<String> getDefaultGroupList() {
     protected List<String> getDefaultGroupList() {
@@ -472,4 +574,8 @@ public class AzureAdAuthenticator implements SsoAuthenticator {
     public void setAcquisitionTimeout(final long acquisitionTimeout) {
     public void setAcquisitionTimeout(final long acquisitionTimeout) {
         this.acquisitionTimeout = acquisitionTimeout;
         this.acquisitionTimeout = acquisitionTimeout;
     }
     }
+
+    public void setGroupCacheExpiry(long groupCacheExpiry) {
+        this.groupCacheExpiry = groupCacheExpiry;
+    }
 }
 }