Kaynağa Gözat

OIDC: allow to extract custom fields from sub-structs

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 2 yıl önce
ebeveyn
işleme
354fc9b3d6

+ 2 - 2
go.mod

@@ -79,7 +79,7 @@ require (
 
 require (
 	cloud.google.com/go v0.110.0 // indirect
-	cloud.google.com/go/compute v1.18.0 // indirect
+	cloud.google.com/go/compute v1.19.0 // indirect
 	cloud.google.com/go/compute/metadata v0.2.3 // indirect
 	cloud.google.com/go/iam v0.13.0 // indirect
 	github.com/Azure/azure-sdk-for-go/sdk/internal v1.2.0 // indirect
@@ -157,7 +157,7 @@ require (
 	golang.org/x/tools v0.7.0 // indirect
 	golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
 	google.golang.org/appengine v1.6.7 // indirect
-	google.golang.org/genproto v0.0.0-20230320184635-7606e756e683 // indirect
+	google.golang.org/genproto v0.0.0-20230322174352-cde4c949918d // indirect
 	google.golang.org/grpc v1.54.0 // indirect
 	google.golang.org/protobuf v1.30.0 // indirect
 	gopkg.in/ini.v1 v1.67.0 // indirect

+ 4 - 3
go.sum

@@ -123,8 +123,9 @@ cloud.google.com/go/compute v1.12.1/go.mod h1:e8yNOBcBONZU1vJKCvCoDw/4JQsA0dpM4x
 cloud.google.com/go/compute v1.13.0/go.mod h1:5aPTS0cUNMIc1CE546K+Th6weJUNQErARyZtRXDJ8GE=
 cloud.google.com/go/compute v1.14.0/go.mod h1:YfLtxrj9sU4Yxv+sXzZkyPjEyPBZfXHUvjxega5vAdo=
 cloud.google.com/go/compute v1.15.1/go.mod h1:bjjoF/NtFUrkD/urWfdHaKuOPDR5nWIs63rR+SXhcpA=
-cloud.google.com/go/compute v1.18.0 h1:FEigFqoDbys2cvFkZ9Fjq4gnHBP55anJ0yQyau2f9oY=
 cloud.google.com/go/compute v1.18.0/go.mod h1:1X7yHxec2Ga+Ss6jPyjxRxpu2uu7PLgsOVXvgU0yacs=
+cloud.google.com/go/compute v1.19.0 h1:+9zda3WGgW1ZSTlVppLCYFIr48Pa35q1uG2N1itbCEQ=
+cloud.google.com/go/compute v1.19.0/go.mod h1:rikpw2y+UMidAe9tISo04EHNOIf42RLYF/q8Bs93scU=
 cloud.google.com/go/compute/metadata v0.1.0/go.mod h1:Z1VN+bulIf6bt4P/C37K4DyZYZEXYonfTBHHFPO/4UU=
 cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
 cloud.google.com/go/compute/metadata v0.2.1/go.mod h1:jgHgmJd2RKBGzXqF5LR2EZMGxBkeanZ9wwa75XHJgOM=
@@ -2802,8 +2803,8 @@ google.golang.org/genproto v0.0.0-20230113154510-dbe35b8444a5/go.mod h1:RGgjbofJ
 google.golang.org/genproto v0.0.0-20230124163310-31e0e69b6fc2/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM=
 google.golang.org/genproto v0.0.0-20230125152338-dcaf20b6aeaa/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM=
 google.golang.org/genproto v0.0.0-20230209215440-0dfe4f8abfcc/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM=
-google.golang.org/genproto v0.0.0-20230320184635-7606e756e683 h1:khxVcsk/FhnzxMKOyD+TDGwjbEOpcPuIpmafPGFmhMA=
-google.golang.org/genproto v0.0.0-20230320184635-7606e756e683/go.mod h1:NWraEVixdDnqcqQ30jipen1STv2r/n24Wb7twVTGR4s=
+google.golang.org/genproto v0.0.0-20230322174352-cde4c949918d h1:OE8TncEeAei3Tehf/P/Jdt/K+8GnTUrRY6wzYpbCes4=
+google.golang.org/genproto v0.0.0-20230322174352-cde4c949918d/go.mod h1:NWraEVixdDnqcqQ30jipen1STv2r/n24Wb7twVTGR4s=
 google.golang.org/grpc v0.0.0-20160317175043-d3ddb4469d5a/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
 google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
 google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=

+ 1 - 3
internal/common/eventmanager_test.go

@@ -1945,9 +1945,7 @@ func TestEventParamsCopy(t *testing.T) {
 		assert.Equal(t, "val1", v)
 	}
 	assert.Equal(t, params.IDPCustomFields, paramsCopy.IDPCustomFields)
-	paramsCopy.addIDPCustomFields(&map[string]any{
-		"field2": "val2",
-	})
+	(*paramsCopy.IDPCustomFields)["field1"] = "val2"
 	assert.NotEqual(t, params.IDPCustomFields, paramsCopy.IDPCustomFields)
 }
 

+ 46 - 33
internal/httpd/oidc.go

@@ -223,9 +223,14 @@ func (t *oidcToken) parseClaims(claims map[string]any, usernameField, roleField
 		return keys
 	}
 
-	username, ok := claims[usernameField].(string)
+	var username string
+	val, ok := getOIDCFieldFromClaims(claims, usernameField)
+	if ok {
+		username, ok = val.(string)
+	}
 	if !ok || username == "" {
-		logger.Warn(logSender, "", "username field %q not found, claims fields: %+v", usernameField, getClaimsFields())
+		logger.Warn(logSender, "", "username field %q not found, empty or not a string, claims fields: %+v",
+			usernameField, getClaimsFields())
 		return errors.New("no username field")
 	}
 	t.Username = username
@@ -237,7 +242,7 @@ func (t *oidcToken) parseClaims(claims map[string]any, usernameField, roleField
 	t.CustomFields = nil
 	if len(customFields) > 0 {
 		for _, field := range customFields {
-			if val, ok := claims[field]; ok {
+			if val, ok := getOIDCFieldFromClaims(claims, field); ok {
 				if t.CustomFields == nil {
 					customFields := make(map[string]any)
 					t.CustomFields = &customFields
@@ -257,36 +262,8 @@ func (t *oidcToken) parseClaims(claims map[string]any, usernameField, roleField
 }
 
 func (t *oidcToken) getRoleFromField(claims map[string]any, roleField string) {
-	if roleField != "" {
-		role, ok := claims[roleField]
-		if ok {
-			t.Role = role
-			return
-		}
-		if !strings.Contains(roleField, ".") {
-			return
-		}
-
-		getStructValue := func(outer any, field string) (any, bool) {
-			switch val := outer.(type) {
-			case map[string]any:
-				res, ok := val[field]
-				return res, ok
-			}
-			return nil, false
-		}
-
-		for idx, field := range strings.Split(roleField, ".") {
-			if idx == 0 {
-				role, ok = getStructValue(claims, field)
-			} else {
-				role, ok = getStructValue(role, field)
-			}
-			if !ok {
-				return
-			}
-		}
-
+	role, ok := getOIDCFieldFromClaims(claims, roleField)
+	if ok {
 		t.Role = role
 	}
 }
@@ -719,6 +696,7 @@ func loginOIDCUser(w http.ResponseWriter, r *http.Request, token oidcToken) {
 	// we don't set a cookie expiration so we can refresh the token without setting a new cookie
 	// the cookie will be invalidated on browser close
 	http.SetCookie(w, &cookie)
+	w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`)
 	if token.isAdmin() {
 		http.Redirect(w, r, webUsersPath, http.StatusFound)
 		return
@@ -793,3 +771,38 @@ func isLoggedInWithOIDC(r *http.Request) bool {
 	_, ok := r.Context().Value(oidcTokenKey).(string)
 	return ok
 }
+
+func getOIDCFieldFromClaims(claims map[string]any, fieldName string) (any, bool) {
+	if fieldName == "" {
+		return nil, false
+	}
+	val, ok := claims[fieldName]
+	if ok {
+		return val, true
+	}
+	if !strings.Contains(fieldName, ".") {
+		return nil, false
+	}
+
+	getStructValue := func(outer any, field string) (any, bool) {
+		switch v := outer.(type) {
+		case map[string]any:
+			res, ok := v[field]
+			return res, ok
+		}
+		return nil, false
+	}
+
+	for idx, field := range strings.Split(fieldName, ".") {
+		if idx == 0 {
+			val, ok = getStructValue(claims, field)
+		} else {
+			val, ok = getStructValue(val, field)
+		}
+		if !ok {
+			return nil, false
+		}
+	}
+
+	return val, ok
+}

+ 39 - 8
internal/httpd/oidc_test.go

@@ -281,7 +281,7 @@ func TestOIDCLoginLogout(t *testing.T) {
 	assert.Equal(t, http.StatusFound, rr.Code)
 	assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
 	require.Len(t, oidcMgr.pendingAuths, 0)
-	// invalid id token claims (no username)
+	// invalid id token claims: no username
 	authReq = newOIDCPendingAuth(tokenAudienceWebClient)
 	oidcMgr.addPendingAuth(authReq)
 	idToken := &oidc.IDToken{
@@ -300,6 +300,25 @@ func TestOIDCLoginLogout(t *testing.T) {
 	assert.Equal(t, http.StatusFound, rr.Code)
 	assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
 	require.Len(t, oidcMgr.pendingAuths, 0)
+	// invalid id token clamims: username not a string
+	authReq = newOIDCPendingAuth(tokenAudienceWebClient)
+	oidcMgr.addPendingAuth(authReq)
+	idToken = &oidc.IDToken{
+		Nonce:  authReq.Nonce,
+		Expiry: time.Now().Add(5 * time.Minute),
+	}
+	setIDTokenClaims(idToken, []byte(`{"aud": "my_client_id","preferred_username": 1}`))
+	server.binding.OIDC.verifier = &mockOIDCVerifier{
+		err:   nil,
+		token: idToken,
+	}
+	rr = httptest.NewRecorder()
+	r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
+	assert.NoError(t, err)
+	server.router.ServeHTTP(rr, r)
+	assert.Equal(t, http.StatusFound, rr.Code)
+	assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
+	require.Len(t, oidcMgr.pendingAuths, 0)
 	// invalid audience
 	authReq = newOIDCPendingAuth(tokenAudienceWebClient)
 	oidcMgr.addPendingAuth(authReq)
@@ -835,11 +854,16 @@ func TestOIDCToken(t *testing.T) {
 
 	token := oidcToken{
 		Username: admin.Username,
-		Role:     "admin",
 	}
+	// role not initialized, user with the specified username does not exist
 	req, err := http.NewRequest(http.MethodGet, webUsersPath, nil)
 	assert.NoError(t, err)
 	err = token.getUser(req)
+	assert.ErrorIs(t, err, util.ErrNotFound)
+	token.Role = "admin"
+	req, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
+	assert.NoError(t, err)
+	err = token.getUser(req)
 	if assert.Error(t, err) {
 		assert.Contains(t, err.Error(), "is disabled")
 	}
@@ -1143,10 +1167,11 @@ func TestOIDCEvMgrIntegration(t *testing.T) {
 	u := map[string]any{
 		"username": "{{Name}}",
 		"status":   1,
-		"home_dir": filepath.Join(os.TempDir(), "{{IDPFieldcustom1}}"),
+		"home_dir": filepath.Join(os.TempDir(), "{{IDPFieldcustom1.sub}}"),
 		"permissions": map[string][]string{
 			"/": {dataprovider.PermAny},
 		},
+		"description": "{{IDPFieldcustom2}}",
 	}
 	userTmpl, err := json.Marshal(u)
 	require.NoError(t, err)
@@ -1196,7 +1221,7 @@ func TestOIDCEvMgrIntegration(t *testing.T) {
 	require.True(t, ok)
 	server := getTestOIDCServer()
 	server.binding.OIDC.ImplicitRoles = true
-	server.binding.OIDC.CustomFields = []string{"custom1", "custom2"}
+	server.binding.OIDC.CustomFields = []string{"custom1.sub", "custom2"}
 	err = server.binding.OIDC.initialize()
 	assert.NoError(t, err)
 	server.initializeRouter()
@@ -1221,7 +1246,7 @@ func TestOIDCEvMgrIntegration(t *testing.T) {
 		Nonce:  authReq.Nonce,
 		Expiry: time.Now().Add(5 * time.Minute),
 	}
-	setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`","custom1":"val1"}`))
+	setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`","custom1":{"sub":"val1"},"custom2":"desc"}`))
 	server.binding.OIDC.verifier = &mockOIDCVerifier{
 		err:   nil,
 		token: idToken,
@@ -1235,6 +1260,7 @@ func TestOIDCEvMgrIntegration(t *testing.T) {
 	user, err := dataprovider.UserExists(username, "")
 	assert.NoError(t, err)
 	assert.Equal(t, filepath.Join(os.TempDir(), "val1"), user.GetHomeDir())
+	assert.Equal(t, "desc", user.Description)
 
 	err = dataprovider.DeleteUser(username, "", "", "")
 	assert.NoError(t, err)
@@ -1473,13 +1499,15 @@ func TestParseAdminRole(t *testing.T) {
 	type test struct {
 		input string
 		want  bool
+		val   any
 	}
 
 	tests := []test{
+		{input: "", want: false},
 		{input: "sftpgo_role", want: false},
-		{input: "params.sftpgo_role", want: true},
-		{input: "params.subparams.sftpgo_role", want: true},
-		{input: "params.subparams.inner.sftpgo_role", want: true},
+		{input: "params.sftpgo_role", want: true, val: "admin"},
+		{input: "params.subparams.sftpgo_role", want: true, val: "admin"},
+		{input: "params.subparams.inner.sftpgo_role", want: true, val: []any{"user", "admin"}},
 		{input: "email", want: false},
 		{input: "missing", want: false},
 		{input: "params.email", want: false},
@@ -1492,6 +1520,9 @@ func TestParseAdminRole(t *testing.T) {
 		token := oidcToken{}
 		token.getRoleFromField(claims, tc.input)
 		assert.Equal(t, tc.want, token.isAdmin(), "%q should return %t", tc.input, tc.want)
+		if tc.want {
+			assert.Equal(t, tc.val, token.Role)
+		}
 	}
 }
 

+ 2 - 2
internal/sftpd/server.go

@@ -610,7 +610,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
 
 	sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
 	if err != nil {
-		logger.Debug(logSender, "", "failed to accept an incoming connection: %v", err)
+		logger.Debug(logSender, "", "failed to accept an incoming connection from ip %q: %v", ipAddr, err)
 		checkAuthError(ipAddr, err)
 		return
 	}
@@ -629,7 +629,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
 
 	defer user.CloseFs() //nolint:errcheck
 	if err = user.CheckFsRoot(connectionID); err != nil {
-		logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
+		logger.Warn(logSender, connectionID, "unable to check fs root for user %q: %v", user.Username, err)
 		return
 	}