From 354fc9b3d67d851fca84e80d0f2948d160395bd3 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Thu, 23 Mar 2023 18:15:07 +0100 Subject: [PATCH] OIDC: allow to extract custom fields from sub-structs Signed-off-by: Nicola Murino --- go.mod | 4 +- go.sum | 7 +-- internal/common/eventmanager_test.go | 4 +- internal/httpd/oidc.go | 79 ++++++++++++++++------------ internal/httpd/oidc_test.go | 47 ++++++++++++++--- internal/sftpd/server.go | 4 +- 6 files changed, 94 insertions(+), 51 deletions(-) diff --git a/go.mod b/go.mod index 49723622..609faf92 100644 --- a/go.mod +++ b/go.mod @@ -79,7 +79,7 @@ require ( require ( cloud.google.com/go v0.110.0 // indirect - cloud.google.com/go/compute v1.18.0 // indirect + cloud.google.com/go/compute v1.19.0 // indirect cloud.google.com/go/compute/metadata v0.2.3 // indirect cloud.google.com/go/iam v0.13.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.2.0 // indirect @@ -157,7 +157,7 @@ require ( golang.org/x/tools v0.7.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20230320184635-7606e756e683 // indirect + google.golang.org/genproto v0.0.0-20230322174352-cde4c949918d // indirect google.golang.org/grpc v1.54.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/go.sum b/go.sum index 8cc8e483..6422a2a9 100644 --- a/go.sum +++ b/go.sum @@ -123,8 +123,9 @@ cloud.google.com/go/compute v1.12.1/go.mod h1:e8yNOBcBONZU1vJKCvCoDw/4JQsA0dpM4x cloud.google.com/go/compute v1.13.0/go.mod h1:5aPTS0cUNMIc1CE546K+Th6weJUNQErARyZtRXDJ8GE= cloud.google.com/go/compute v1.14.0/go.mod h1:YfLtxrj9sU4Yxv+sXzZkyPjEyPBZfXHUvjxega5vAdo= cloud.google.com/go/compute v1.15.1/go.mod h1:bjjoF/NtFUrkD/urWfdHaKuOPDR5nWIs63rR+SXhcpA= -cloud.google.com/go/compute v1.18.0 h1:FEigFqoDbys2cvFkZ9Fjq4gnHBP55anJ0yQyau2f9oY= cloud.google.com/go/compute v1.18.0/go.mod h1:1X7yHxec2Ga+Ss6jPyjxRxpu2uu7PLgsOVXvgU0yacs= +cloud.google.com/go/compute v1.19.0 h1:+9zda3WGgW1ZSTlVppLCYFIr48Pa35q1uG2N1itbCEQ= +cloud.google.com/go/compute v1.19.0/go.mod h1:rikpw2y+UMidAe9tISo04EHNOIf42RLYF/q8Bs93scU= cloud.google.com/go/compute/metadata v0.1.0/go.mod h1:Z1VN+bulIf6bt4P/C37K4DyZYZEXYonfTBHHFPO/4UU= cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/compute/metadata v0.2.1/go.mod h1:jgHgmJd2RKBGzXqF5LR2EZMGxBkeanZ9wwa75XHJgOM= @@ -2802,8 +2803,8 @@ google.golang.org/genproto v0.0.0-20230113154510-dbe35b8444a5/go.mod h1:RGgjbofJ google.golang.org/genproto v0.0.0-20230124163310-31e0e69b6fc2/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM= google.golang.org/genproto v0.0.0-20230125152338-dcaf20b6aeaa/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM= google.golang.org/genproto v0.0.0-20230209215440-0dfe4f8abfcc/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM= -google.golang.org/genproto v0.0.0-20230320184635-7606e756e683 h1:khxVcsk/FhnzxMKOyD+TDGwjbEOpcPuIpmafPGFmhMA= -google.golang.org/genproto v0.0.0-20230320184635-7606e756e683/go.mod h1:NWraEVixdDnqcqQ30jipen1STv2r/n24Wb7twVTGR4s= +google.golang.org/genproto v0.0.0-20230322174352-cde4c949918d h1:OE8TncEeAei3Tehf/P/Jdt/K+8GnTUrRY6wzYpbCes4= +google.golang.org/genproto v0.0.0-20230322174352-cde4c949918d/go.mod h1:NWraEVixdDnqcqQ30jipen1STv2r/n24Wb7twVTGR4s= google.golang.org/grpc v0.0.0-20160317175043-d3ddb4469d5a/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= diff --git a/internal/common/eventmanager_test.go b/internal/common/eventmanager_test.go index 0bfffd31..70b5bc76 100644 --- a/internal/common/eventmanager_test.go +++ b/internal/common/eventmanager_test.go @@ -1945,9 +1945,7 @@ func TestEventParamsCopy(t *testing.T) { assert.Equal(t, "val1", v) } assert.Equal(t, params.IDPCustomFields, paramsCopy.IDPCustomFields) - paramsCopy.addIDPCustomFields(&map[string]any{ - "field2": "val2", - }) + (*paramsCopy.IDPCustomFields)["field1"] = "val2" assert.NotEqual(t, params.IDPCustomFields, paramsCopy.IDPCustomFields) } diff --git a/internal/httpd/oidc.go b/internal/httpd/oidc.go index bfa33c2a..1110e419 100644 --- a/internal/httpd/oidc.go +++ b/internal/httpd/oidc.go @@ -223,9 +223,14 @@ func (t *oidcToken) parseClaims(claims map[string]any, usernameField, roleField return keys } - username, ok := claims[usernameField].(string) + var username string + val, ok := getOIDCFieldFromClaims(claims, usernameField) + if ok { + username, ok = val.(string) + } if !ok || username == "" { - logger.Warn(logSender, "", "username field %q not found, claims fields: %+v", usernameField, getClaimsFields()) + logger.Warn(logSender, "", "username field %q not found, empty or not a string, claims fields: %+v", + usernameField, getClaimsFields()) return errors.New("no username field") } t.Username = username @@ -237,7 +242,7 @@ func (t *oidcToken) parseClaims(claims map[string]any, usernameField, roleField t.CustomFields = nil if len(customFields) > 0 { for _, field := range customFields { - if val, ok := claims[field]; ok { + if val, ok := getOIDCFieldFromClaims(claims, field); ok { if t.CustomFields == nil { customFields := make(map[string]any) t.CustomFields = &customFields @@ -257,36 +262,8 @@ func (t *oidcToken) parseClaims(claims map[string]any, usernameField, roleField } func (t *oidcToken) getRoleFromField(claims map[string]any, roleField string) { - if roleField != "" { - role, ok := claims[roleField] - if ok { - t.Role = role - return - } - if !strings.Contains(roleField, ".") { - return - } - - getStructValue := func(outer any, field string) (any, bool) { - switch val := outer.(type) { - case map[string]any: - res, ok := val[field] - return res, ok - } - return nil, false - } - - for idx, field := range strings.Split(roleField, ".") { - if idx == 0 { - role, ok = getStructValue(claims, field) - } else { - role, ok = getStructValue(role, field) - } - if !ok { - return - } - } - + role, ok := getOIDCFieldFromClaims(claims, roleField) + if ok { t.Role = role } } @@ -719,6 +696,7 @@ func loginOIDCUser(w http.ResponseWriter, r *http.Request, token oidcToken) { // we don't set a cookie expiration so we can refresh the token without setting a new cookie // the cookie will be invalidated on browser close http.SetCookie(w, &cookie) + w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`) if token.isAdmin() { http.Redirect(w, r, webUsersPath, http.StatusFound) return @@ -793,3 +771,38 @@ func isLoggedInWithOIDC(r *http.Request) bool { _, ok := r.Context().Value(oidcTokenKey).(string) return ok } + +func getOIDCFieldFromClaims(claims map[string]any, fieldName string) (any, bool) { + if fieldName == "" { + return nil, false + } + val, ok := claims[fieldName] + if ok { + return val, true + } + if !strings.Contains(fieldName, ".") { + return nil, false + } + + getStructValue := func(outer any, field string) (any, bool) { + switch v := outer.(type) { + case map[string]any: + res, ok := v[field] + return res, ok + } + return nil, false + } + + for idx, field := range strings.Split(fieldName, ".") { + if idx == 0 { + val, ok = getStructValue(claims, field) + } else { + val, ok = getStructValue(val, field) + } + if !ok { + return nil, false + } + } + + return val, ok +} diff --git a/internal/httpd/oidc_test.go b/internal/httpd/oidc_test.go index 9ab9e94c..4d699f9d 100644 --- a/internal/httpd/oidc_test.go +++ b/internal/httpd/oidc_test.go @@ -281,7 +281,7 @@ func TestOIDCLoginLogout(t *testing.T) { assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) - // invalid id token claims (no username) + // invalid id token claims: no username authReq = newOIDCPendingAuth(tokenAudienceWebClient) oidcMgr.addPendingAuth(authReq) idToken := &oidc.IDToken{ @@ -300,6 +300,25 @@ func TestOIDCLoginLogout(t *testing.T) { assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) + // invalid id token clamims: username not a string + authReq = newOIDCPendingAuth(tokenAudienceWebClient) + oidcMgr.addPendingAuth(authReq) + idToken = &oidc.IDToken{ + Nonce: authReq.Nonce, + Expiry: time.Now().Add(5 * time.Minute), + } + setIDTokenClaims(idToken, []byte(`{"aud": "my_client_id","preferred_username": 1}`)) + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: idToken, + } + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) // invalid audience authReq = newOIDCPendingAuth(tokenAudienceWebClient) oidcMgr.addPendingAuth(authReq) @@ -835,11 +854,16 @@ func TestOIDCToken(t *testing.T) { token := oidcToken{ Username: admin.Username, - Role: "admin", } + // role not initialized, user with the specified username does not exist req, err := http.NewRequest(http.MethodGet, webUsersPath, nil) assert.NoError(t, err) err = token.getUser(req) + assert.ErrorIs(t, err, util.ErrNotFound) + token.Role = "admin" + req, err = http.NewRequest(http.MethodGet, webUsersPath, nil) + assert.NoError(t, err) + err = token.getUser(req) if assert.Error(t, err) { assert.Contains(t, err.Error(), "is disabled") } @@ -1143,10 +1167,11 @@ func TestOIDCEvMgrIntegration(t *testing.T) { u := map[string]any{ "username": "{{Name}}", "status": 1, - "home_dir": filepath.Join(os.TempDir(), "{{IDPFieldcustom1}}"), + "home_dir": filepath.Join(os.TempDir(), "{{IDPFieldcustom1.sub}}"), "permissions": map[string][]string{ "/": {dataprovider.PermAny}, }, + "description": "{{IDPFieldcustom2}}", } userTmpl, err := json.Marshal(u) require.NoError(t, err) @@ -1196,7 +1221,7 @@ func TestOIDCEvMgrIntegration(t *testing.T) { require.True(t, ok) server := getTestOIDCServer() server.binding.OIDC.ImplicitRoles = true - server.binding.OIDC.CustomFields = []string{"custom1", "custom2"} + server.binding.OIDC.CustomFields = []string{"custom1.sub", "custom2"} err = server.binding.OIDC.initialize() assert.NoError(t, err) server.initializeRouter() @@ -1221,7 +1246,7 @@ func TestOIDCEvMgrIntegration(t *testing.T) { Nonce: authReq.Nonce, Expiry: time.Now().Add(5 * time.Minute), } - setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`","custom1":"val1"}`)) + setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`","custom1":{"sub":"val1"},"custom2":"desc"}`)) server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: idToken, @@ -1235,6 +1260,7 @@ func TestOIDCEvMgrIntegration(t *testing.T) { user, err := dataprovider.UserExists(username, "") assert.NoError(t, err) assert.Equal(t, filepath.Join(os.TempDir(), "val1"), user.GetHomeDir()) + assert.Equal(t, "desc", user.Description) err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) @@ -1473,13 +1499,15 @@ func TestParseAdminRole(t *testing.T) { type test struct { input string want bool + val any } tests := []test{ + {input: "", want: false}, {input: "sftpgo_role", want: false}, - {input: "params.sftpgo_role", want: true}, - {input: "params.subparams.sftpgo_role", want: true}, - {input: "params.subparams.inner.sftpgo_role", want: true}, + {input: "params.sftpgo_role", want: true, val: "admin"}, + {input: "params.subparams.sftpgo_role", want: true, val: "admin"}, + {input: "params.subparams.inner.sftpgo_role", want: true, val: []any{"user", "admin"}}, {input: "email", want: false}, {input: "missing", want: false}, {input: "params.email", want: false}, @@ -1492,6 +1520,9 @@ func TestParseAdminRole(t *testing.T) { token := oidcToken{} token.getRoleFromField(claims, tc.input) assert.Equal(t, tc.want, token.isAdmin(), "%q should return %t", tc.input, tc.want) + if tc.want { + assert.Equal(t, tc.val, token.Role) + } } } diff --git a/internal/sftpd/server.go b/internal/sftpd/server.go index 8bdb6106..72f357dd 100644 --- a/internal/sftpd/server.go +++ b/internal/sftpd/server.go @@ -610,7 +610,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve sconn, chans, reqs, err := ssh.NewServerConn(conn, config) if err != nil { - logger.Debug(logSender, "", "failed to accept an incoming connection: %v", err) + logger.Debug(logSender, "", "failed to accept an incoming connection from ip %q: %v", ipAddr, err) checkAuthError(ipAddr, err) return } @@ -629,7 +629,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve defer user.CloseFs() //nolint:errcheck if err = user.CheckFsRoot(connectionID); err != nil { - logger.Warn(logSender, connectionID, "unable to check fs root: %v", err) + logger.Warn(logSender, connectionID, "unable to check fs root for user %q: %v", user.Username, err) return }