OIDC: allow to extract custom fields from sub-structs

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino 2023-03-23 18:15:07 +01:00
parent e29f6857db
commit 354fc9b3d6
No known key found for this signature in database
GPG key ID: 935D2952DEC4EECF
6 changed files with 94 additions and 51 deletions

4
go.mod
View file

@ -79,7 +79,7 @@ require (
require (
cloud.google.com/go v0.110.0 // indirect
cloud.google.com/go/compute v1.18.0 // indirect
cloud.google.com/go/compute v1.19.0 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect
cloud.google.com/go/iam v0.13.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.2.0 // indirect
@ -157,7 +157,7 @@ require (
golang.org/x/tools v0.7.0 // indirect
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20230320184635-7606e756e683 // indirect
google.golang.org/genproto v0.0.0-20230322174352-cde4c949918d // indirect
google.golang.org/grpc v1.54.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect

7
go.sum
View file

@ -123,8 +123,9 @@ cloud.google.com/go/compute v1.12.1/go.mod h1:e8yNOBcBONZU1vJKCvCoDw/4JQsA0dpM4x
cloud.google.com/go/compute v1.13.0/go.mod h1:5aPTS0cUNMIc1CE546K+Th6weJUNQErARyZtRXDJ8GE=
cloud.google.com/go/compute v1.14.0/go.mod h1:YfLtxrj9sU4Yxv+sXzZkyPjEyPBZfXHUvjxega5vAdo=
cloud.google.com/go/compute v1.15.1/go.mod h1:bjjoF/NtFUrkD/urWfdHaKuOPDR5nWIs63rR+SXhcpA=
cloud.google.com/go/compute v1.18.0 h1:FEigFqoDbys2cvFkZ9Fjq4gnHBP55anJ0yQyau2f9oY=
cloud.google.com/go/compute v1.18.0/go.mod h1:1X7yHxec2Ga+Ss6jPyjxRxpu2uu7PLgsOVXvgU0yacs=
cloud.google.com/go/compute v1.19.0 h1:+9zda3WGgW1ZSTlVppLCYFIr48Pa35q1uG2N1itbCEQ=
cloud.google.com/go/compute v1.19.0/go.mod h1:rikpw2y+UMidAe9tISo04EHNOIf42RLYF/q8Bs93scU=
cloud.google.com/go/compute/metadata v0.1.0/go.mod h1:Z1VN+bulIf6bt4P/C37K4DyZYZEXYonfTBHHFPO/4UU=
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
cloud.google.com/go/compute/metadata v0.2.1/go.mod h1:jgHgmJd2RKBGzXqF5LR2EZMGxBkeanZ9wwa75XHJgOM=
@ -2802,8 +2803,8 @@ google.golang.org/genproto v0.0.0-20230113154510-dbe35b8444a5/go.mod h1:RGgjbofJ
google.golang.org/genproto v0.0.0-20230124163310-31e0e69b6fc2/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM=
google.golang.org/genproto v0.0.0-20230125152338-dcaf20b6aeaa/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM=
google.golang.org/genproto v0.0.0-20230209215440-0dfe4f8abfcc/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM=
google.golang.org/genproto v0.0.0-20230320184635-7606e756e683 h1:khxVcsk/FhnzxMKOyD+TDGwjbEOpcPuIpmafPGFmhMA=
google.golang.org/genproto v0.0.0-20230320184635-7606e756e683/go.mod h1:NWraEVixdDnqcqQ30jipen1STv2r/n24Wb7twVTGR4s=
google.golang.org/genproto v0.0.0-20230322174352-cde4c949918d h1:OE8TncEeAei3Tehf/P/Jdt/K+8GnTUrRY6wzYpbCes4=
google.golang.org/genproto v0.0.0-20230322174352-cde4c949918d/go.mod h1:NWraEVixdDnqcqQ30jipen1STv2r/n24Wb7twVTGR4s=
google.golang.org/grpc v0.0.0-20160317175043-d3ddb4469d5a/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=

View file

@ -1945,9 +1945,7 @@ func TestEventParamsCopy(t *testing.T) {
assert.Equal(t, "val1", v)
}
assert.Equal(t, params.IDPCustomFields, paramsCopy.IDPCustomFields)
paramsCopy.addIDPCustomFields(&map[string]any{
"field2": "val2",
})
(*paramsCopy.IDPCustomFields)["field1"] = "val2"
assert.NotEqual(t, params.IDPCustomFields, paramsCopy.IDPCustomFields)
}

View file

@ -223,9 +223,14 @@ func (t *oidcToken) parseClaims(claims map[string]any, usernameField, roleField
return keys
}
username, ok := claims[usernameField].(string)
var username string
val, ok := getOIDCFieldFromClaims(claims, usernameField)
if ok {
username, ok = val.(string)
}
if !ok || username == "" {
logger.Warn(logSender, "", "username field %q not found, claims fields: %+v", usernameField, getClaimsFields())
logger.Warn(logSender, "", "username field %q not found, empty or not a string, claims fields: %+v",
usernameField, getClaimsFields())
return errors.New("no username field")
}
t.Username = username
@ -237,7 +242,7 @@ func (t *oidcToken) parseClaims(claims map[string]any, usernameField, roleField
t.CustomFields = nil
if len(customFields) > 0 {
for _, field := range customFields {
if val, ok := claims[field]; ok {
if val, ok := getOIDCFieldFromClaims(claims, field); ok {
if t.CustomFields == nil {
customFields := make(map[string]any)
t.CustomFields = &customFields
@ -257,36 +262,8 @@ func (t *oidcToken) parseClaims(claims map[string]any, usernameField, roleField
}
func (t *oidcToken) getRoleFromField(claims map[string]any, roleField string) {
if roleField != "" {
role, ok := claims[roleField]
if ok {
t.Role = role
return
}
if !strings.Contains(roleField, ".") {
return
}
getStructValue := func(outer any, field string) (any, bool) {
switch val := outer.(type) {
case map[string]any:
res, ok := val[field]
return res, ok
}
return nil, false
}
for idx, field := range strings.Split(roleField, ".") {
if idx == 0 {
role, ok = getStructValue(claims, field)
} else {
role, ok = getStructValue(role, field)
}
if !ok {
return
}
}
role, ok := getOIDCFieldFromClaims(claims, roleField)
if ok {
t.Role = role
}
}
@ -719,6 +696,7 @@ func loginOIDCUser(w http.ResponseWriter, r *http.Request, token oidcToken) {
// we don't set a cookie expiration so we can refresh the token without setting a new cookie
// the cookie will be invalidated on browser close
http.SetCookie(w, &cookie)
w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`)
if token.isAdmin() {
http.Redirect(w, r, webUsersPath, http.StatusFound)
return
@ -793,3 +771,38 @@ func isLoggedInWithOIDC(r *http.Request) bool {
_, ok := r.Context().Value(oidcTokenKey).(string)
return ok
}
func getOIDCFieldFromClaims(claims map[string]any, fieldName string) (any, bool) {
if fieldName == "" {
return nil, false
}
val, ok := claims[fieldName]
if ok {
return val, true
}
if !strings.Contains(fieldName, ".") {
return nil, false
}
getStructValue := func(outer any, field string) (any, bool) {
switch v := outer.(type) {
case map[string]any:
res, ok := v[field]
return res, ok
}
return nil, false
}
for idx, field := range strings.Split(fieldName, ".") {
if idx == 0 {
val, ok = getStructValue(claims, field)
} else {
val, ok = getStructValue(val, field)
}
if !ok {
return nil, false
}
}
return val, ok
}

View file

@ -281,7 +281,7 @@ func TestOIDCLoginLogout(t *testing.T) {
assert.Equal(t, http.StatusFound, rr.Code)
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
require.Len(t, oidcMgr.pendingAuths, 0)
// invalid id token claims (no username)
// invalid id token claims: no username
authReq = newOIDCPendingAuth(tokenAudienceWebClient)
oidcMgr.addPendingAuth(authReq)
idToken := &oidc.IDToken{
@ -300,6 +300,25 @@ func TestOIDCLoginLogout(t *testing.T) {
assert.Equal(t, http.StatusFound, rr.Code)
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
require.Len(t, oidcMgr.pendingAuths, 0)
// invalid id token clamims: username not a string
authReq = newOIDCPendingAuth(tokenAudienceWebClient)
oidcMgr.addPendingAuth(authReq)
idToken = &oidc.IDToken{
Nonce: authReq.Nonce,
Expiry: time.Now().Add(5 * time.Minute),
}
setIDTokenClaims(idToken, []byte(`{"aud": "my_client_id","preferred_username": 1}`))
server.binding.OIDC.verifier = &mockOIDCVerifier{
err: nil,
token: idToken,
}
rr = httptest.NewRecorder()
r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
assert.NoError(t, err)
server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusFound, rr.Code)
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
require.Len(t, oidcMgr.pendingAuths, 0)
// invalid audience
authReq = newOIDCPendingAuth(tokenAudienceWebClient)
oidcMgr.addPendingAuth(authReq)
@ -835,11 +854,16 @@ func TestOIDCToken(t *testing.T) {
token := oidcToken{
Username: admin.Username,
Role: "admin",
}
// role not initialized, user with the specified username does not exist
req, err := http.NewRequest(http.MethodGet, webUsersPath, nil)
assert.NoError(t, err)
err = token.getUser(req)
assert.ErrorIs(t, err, util.ErrNotFound)
token.Role = "admin"
req, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
assert.NoError(t, err)
err = token.getUser(req)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "is disabled")
}
@ -1143,10 +1167,11 @@ func TestOIDCEvMgrIntegration(t *testing.T) {
u := map[string]any{
"username": "{{Name}}",
"status": 1,
"home_dir": filepath.Join(os.TempDir(), "{{IDPFieldcustom1}}"),
"home_dir": filepath.Join(os.TempDir(), "{{IDPFieldcustom1.sub}}"),
"permissions": map[string][]string{
"/": {dataprovider.PermAny},
},
"description": "{{IDPFieldcustom2}}",
}
userTmpl, err := json.Marshal(u)
require.NoError(t, err)
@ -1196,7 +1221,7 @@ func TestOIDCEvMgrIntegration(t *testing.T) {
require.True(t, ok)
server := getTestOIDCServer()
server.binding.OIDC.ImplicitRoles = true
server.binding.OIDC.CustomFields = []string{"custom1", "custom2"}
server.binding.OIDC.CustomFields = []string{"custom1.sub", "custom2"}
err = server.binding.OIDC.initialize()
assert.NoError(t, err)
server.initializeRouter()
@ -1221,7 +1246,7 @@ func TestOIDCEvMgrIntegration(t *testing.T) {
Nonce: authReq.Nonce,
Expiry: time.Now().Add(5 * time.Minute),
}
setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`","custom1":"val1"}`))
setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`","custom1":{"sub":"val1"},"custom2":"desc"}`))
server.binding.OIDC.verifier = &mockOIDCVerifier{
err: nil,
token: idToken,
@ -1235,6 +1260,7 @@ func TestOIDCEvMgrIntegration(t *testing.T) {
user, err := dataprovider.UserExists(username, "")
assert.NoError(t, err)
assert.Equal(t, filepath.Join(os.TempDir(), "val1"), user.GetHomeDir())
assert.Equal(t, "desc", user.Description)
err = dataprovider.DeleteUser(username, "", "", "")
assert.NoError(t, err)
@ -1473,13 +1499,15 @@ func TestParseAdminRole(t *testing.T) {
type test struct {
input string
want bool
val any
}
tests := []test{
{input: "", want: false},
{input: "sftpgo_role", want: false},
{input: "params.sftpgo_role", want: true},
{input: "params.subparams.sftpgo_role", want: true},
{input: "params.subparams.inner.sftpgo_role", want: true},
{input: "params.sftpgo_role", want: true, val: "admin"},
{input: "params.subparams.sftpgo_role", want: true, val: "admin"},
{input: "params.subparams.inner.sftpgo_role", want: true, val: []any{"user", "admin"}},
{input: "email", want: false},
{input: "missing", want: false},
{input: "params.email", want: false},
@ -1492,6 +1520,9 @@ func TestParseAdminRole(t *testing.T) {
token := oidcToken{}
token.getRoleFromField(claims, tc.input)
assert.Equal(t, tc.want, token.isAdmin(), "%q should return %t", tc.input, tc.want)
if tc.want {
assert.Equal(t, tc.val, token.Role)
}
}
}

View file

@ -610,7 +610,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
if err != nil {
logger.Debug(logSender, "", "failed to accept an incoming connection: %v", err)
logger.Debug(logSender, "", "failed to accept an incoming connection from ip %q: %v", ipAddr, err)
checkAuthError(ipAddr, err)
return
}
@ -629,7 +629,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
defer user.CloseFs() //nolint:errcheck
if err = user.CheckFsRoot(connectionID); err != nil {
logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
logger.Warn(logSender, connectionID, "unable to check fs root for user %q: %v", user.Username, err)
return
}