OIDC: allow to extract custom fields from sub-structs
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
parent
e29f6857db
commit
354fc9b3d6
6 changed files with 94 additions and 51 deletions
4
go.mod
4
go.mod
|
@ -79,7 +79,7 @@ require (
|
|||
|
||||
require (
|
||||
cloud.google.com/go v0.110.0 // indirect
|
||||
cloud.google.com/go/compute v1.18.0 // indirect
|
||||
cloud.google.com/go/compute v1.19.0 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.2.3 // indirect
|
||||
cloud.google.com/go/iam v0.13.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.2.0 // indirect
|
||||
|
@ -157,7 +157,7 @@ require (
|
|||
golang.org/x/tools v0.7.0 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/genproto v0.0.0-20230320184635-7606e756e683 // indirect
|
||||
google.golang.org/genproto v0.0.0-20230322174352-cde4c949918d // indirect
|
||||
google.golang.org/grpc v1.54.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
|
|
7
go.sum
7
go.sum
|
@ -123,8 +123,9 @@ cloud.google.com/go/compute v1.12.1/go.mod h1:e8yNOBcBONZU1vJKCvCoDw/4JQsA0dpM4x
|
|||
cloud.google.com/go/compute v1.13.0/go.mod h1:5aPTS0cUNMIc1CE546K+Th6weJUNQErARyZtRXDJ8GE=
|
||||
cloud.google.com/go/compute v1.14.0/go.mod h1:YfLtxrj9sU4Yxv+sXzZkyPjEyPBZfXHUvjxega5vAdo=
|
||||
cloud.google.com/go/compute v1.15.1/go.mod h1:bjjoF/NtFUrkD/urWfdHaKuOPDR5nWIs63rR+SXhcpA=
|
||||
cloud.google.com/go/compute v1.18.0 h1:FEigFqoDbys2cvFkZ9Fjq4gnHBP55anJ0yQyau2f9oY=
|
||||
cloud.google.com/go/compute v1.18.0/go.mod h1:1X7yHxec2Ga+Ss6jPyjxRxpu2uu7PLgsOVXvgU0yacs=
|
||||
cloud.google.com/go/compute v1.19.0 h1:+9zda3WGgW1ZSTlVppLCYFIr48Pa35q1uG2N1itbCEQ=
|
||||
cloud.google.com/go/compute v1.19.0/go.mod h1:rikpw2y+UMidAe9tISo04EHNOIf42RLYF/q8Bs93scU=
|
||||
cloud.google.com/go/compute/metadata v0.1.0/go.mod h1:Z1VN+bulIf6bt4P/C37K4DyZYZEXYonfTBHHFPO/4UU=
|
||||
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||
cloud.google.com/go/compute/metadata v0.2.1/go.mod h1:jgHgmJd2RKBGzXqF5LR2EZMGxBkeanZ9wwa75XHJgOM=
|
||||
|
@ -2802,8 +2803,8 @@ google.golang.org/genproto v0.0.0-20230113154510-dbe35b8444a5/go.mod h1:RGgjbofJ
|
|||
google.golang.org/genproto v0.0.0-20230124163310-31e0e69b6fc2/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM=
|
||||
google.golang.org/genproto v0.0.0-20230125152338-dcaf20b6aeaa/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM=
|
||||
google.golang.org/genproto v0.0.0-20230209215440-0dfe4f8abfcc/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM=
|
||||
google.golang.org/genproto v0.0.0-20230320184635-7606e756e683 h1:khxVcsk/FhnzxMKOyD+TDGwjbEOpcPuIpmafPGFmhMA=
|
||||
google.golang.org/genproto v0.0.0-20230320184635-7606e756e683/go.mod h1:NWraEVixdDnqcqQ30jipen1STv2r/n24Wb7twVTGR4s=
|
||||
google.golang.org/genproto v0.0.0-20230322174352-cde4c949918d h1:OE8TncEeAei3Tehf/P/Jdt/K+8GnTUrRY6wzYpbCes4=
|
||||
google.golang.org/genproto v0.0.0-20230322174352-cde4c949918d/go.mod h1:NWraEVixdDnqcqQ30jipen1STv2r/n24Wb7twVTGR4s=
|
||||
google.golang.org/grpc v0.0.0-20160317175043-d3ddb4469d5a/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
|
||||
|
|
|
@ -1945,9 +1945,7 @@ func TestEventParamsCopy(t *testing.T) {
|
|||
assert.Equal(t, "val1", v)
|
||||
}
|
||||
assert.Equal(t, params.IDPCustomFields, paramsCopy.IDPCustomFields)
|
||||
paramsCopy.addIDPCustomFields(&map[string]any{
|
||||
"field2": "val2",
|
||||
})
|
||||
(*paramsCopy.IDPCustomFields)["field1"] = "val2"
|
||||
assert.NotEqual(t, params.IDPCustomFields, paramsCopy.IDPCustomFields)
|
||||
}
|
||||
|
||||
|
|
|
@ -223,9 +223,14 @@ func (t *oidcToken) parseClaims(claims map[string]any, usernameField, roleField
|
|||
return keys
|
||||
}
|
||||
|
||||
username, ok := claims[usernameField].(string)
|
||||
var username string
|
||||
val, ok := getOIDCFieldFromClaims(claims, usernameField)
|
||||
if ok {
|
||||
username, ok = val.(string)
|
||||
}
|
||||
if !ok || username == "" {
|
||||
logger.Warn(logSender, "", "username field %q not found, claims fields: %+v", usernameField, getClaimsFields())
|
||||
logger.Warn(logSender, "", "username field %q not found, empty or not a string, claims fields: %+v",
|
||||
usernameField, getClaimsFields())
|
||||
return errors.New("no username field")
|
||||
}
|
||||
t.Username = username
|
||||
|
@ -237,7 +242,7 @@ func (t *oidcToken) parseClaims(claims map[string]any, usernameField, roleField
|
|||
t.CustomFields = nil
|
||||
if len(customFields) > 0 {
|
||||
for _, field := range customFields {
|
||||
if val, ok := claims[field]; ok {
|
||||
if val, ok := getOIDCFieldFromClaims(claims, field); ok {
|
||||
if t.CustomFields == nil {
|
||||
customFields := make(map[string]any)
|
||||
t.CustomFields = &customFields
|
||||
|
@ -257,36 +262,8 @@ func (t *oidcToken) parseClaims(claims map[string]any, usernameField, roleField
|
|||
}
|
||||
|
||||
func (t *oidcToken) getRoleFromField(claims map[string]any, roleField string) {
|
||||
if roleField != "" {
|
||||
role, ok := claims[roleField]
|
||||
if ok {
|
||||
t.Role = role
|
||||
return
|
||||
}
|
||||
if !strings.Contains(roleField, ".") {
|
||||
return
|
||||
}
|
||||
|
||||
getStructValue := func(outer any, field string) (any, bool) {
|
||||
switch val := outer.(type) {
|
||||
case map[string]any:
|
||||
res, ok := val[field]
|
||||
return res, ok
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for idx, field := range strings.Split(roleField, ".") {
|
||||
if idx == 0 {
|
||||
role, ok = getStructValue(claims, field)
|
||||
} else {
|
||||
role, ok = getStructValue(role, field)
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
role, ok := getOIDCFieldFromClaims(claims, roleField)
|
||||
if ok {
|
||||
t.Role = role
|
||||
}
|
||||
}
|
||||
|
@ -719,6 +696,7 @@ func loginOIDCUser(w http.ResponseWriter, r *http.Request, token oidcToken) {
|
|||
// we don't set a cookie expiration so we can refresh the token without setting a new cookie
|
||||
// the cookie will be invalidated on browser close
|
||||
http.SetCookie(w, &cookie)
|
||||
w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`)
|
||||
if token.isAdmin() {
|
||||
http.Redirect(w, r, webUsersPath, http.StatusFound)
|
||||
return
|
||||
|
@ -793,3 +771,38 @@ func isLoggedInWithOIDC(r *http.Request) bool {
|
|||
_, ok := r.Context().Value(oidcTokenKey).(string)
|
||||
return ok
|
||||
}
|
||||
|
||||
func getOIDCFieldFromClaims(claims map[string]any, fieldName string) (any, bool) {
|
||||
if fieldName == "" {
|
||||
return nil, false
|
||||
}
|
||||
val, ok := claims[fieldName]
|
||||
if ok {
|
||||
return val, true
|
||||
}
|
||||
if !strings.Contains(fieldName, ".") {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
getStructValue := func(outer any, field string) (any, bool) {
|
||||
switch v := outer.(type) {
|
||||
case map[string]any:
|
||||
res, ok := v[field]
|
||||
return res, ok
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for idx, field := range strings.Split(fieldName, ".") {
|
||||
if idx == 0 {
|
||||
val, ok = getStructValue(claims, field)
|
||||
} else {
|
||||
val, ok = getStructValue(val, field)
|
||||
}
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
return val, ok
|
||||
}
|
||||
|
|
|
@ -281,7 +281,7 @@ func TestOIDCLoginLogout(t *testing.T) {
|
|||
assert.Equal(t, http.StatusFound, rr.Code)
|
||||
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
|
||||
require.Len(t, oidcMgr.pendingAuths, 0)
|
||||
// invalid id token claims (no username)
|
||||
// invalid id token claims: no username
|
||||
authReq = newOIDCPendingAuth(tokenAudienceWebClient)
|
||||
oidcMgr.addPendingAuth(authReq)
|
||||
idToken := &oidc.IDToken{
|
||||
|
@ -300,6 +300,25 @@ func TestOIDCLoginLogout(t *testing.T) {
|
|||
assert.Equal(t, http.StatusFound, rr.Code)
|
||||
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
|
||||
require.Len(t, oidcMgr.pendingAuths, 0)
|
||||
// invalid id token clamims: username not a string
|
||||
authReq = newOIDCPendingAuth(tokenAudienceWebClient)
|
||||
oidcMgr.addPendingAuth(authReq)
|
||||
idToken = &oidc.IDToken{
|
||||
Nonce: authReq.Nonce,
|
||||
Expiry: time.Now().Add(5 * time.Minute),
|
||||
}
|
||||
setIDTokenClaims(idToken, []byte(`{"aud": "my_client_id","preferred_username": 1}`))
|
||||
server.binding.OIDC.verifier = &mockOIDCVerifier{
|
||||
err: nil,
|
||||
token: idToken,
|
||||
}
|
||||
rr = httptest.NewRecorder()
|
||||
r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
|
||||
assert.NoError(t, err)
|
||||
server.router.ServeHTTP(rr, r)
|
||||
assert.Equal(t, http.StatusFound, rr.Code)
|
||||
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
|
||||
require.Len(t, oidcMgr.pendingAuths, 0)
|
||||
// invalid audience
|
||||
authReq = newOIDCPendingAuth(tokenAudienceWebClient)
|
||||
oidcMgr.addPendingAuth(authReq)
|
||||
|
@ -835,11 +854,16 @@ func TestOIDCToken(t *testing.T) {
|
|||
|
||||
token := oidcToken{
|
||||
Username: admin.Username,
|
||||
Role: "admin",
|
||||
}
|
||||
// role not initialized, user with the specified username does not exist
|
||||
req, err := http.NewRequest(http.MethodGet, webUsersPath, nil)
|
||||
assert.NoError(t, err)
|
||||
err = token.getUser(req)
|
||||
assert.ErrorIs(t, err, util.ErrNotFound)
|
||||
token.Role = "admin"
|
||||
req, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
|
||||
assert.NoError(t, err)
|
||||
err = token.getUser(req)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "is disabled")
|
||||
}
|
||||
|
@ -1143,10 +1167,11 @@ func TestOIDCEvMgrIntegration(t *testing.T) {
|
|||
u := map[string]any{
|
||||
"username": "{{Name}}",
|
||||
"status": 1,
|
||||
"home_dir": filepath.Join(os.TempDir(), "{{IDPFieldcustom1}}"),
|
||||
"home_dir": filepath.Join(os.TempDir(), "{{IDPFieldcustom1.sub}}"),
|
||||
"permissions": map[string][]string{
|
||||
"/": {dataprovider.PermAny},
|
||||
},
|
||||
"description": "{{IDPFieldcustom2}}",
|
||||
}
|
||||
userTmpl, err := json.Marshal(u)
|
||||
require.NoError(t, err)
|
||||
|
@ -1196,7 +1221,7 @@ func TestOIDCEvMgrIntegration(t *testing.T) {
|
|||
require.True(t, ok)
|
||||
server := getTestOIDCServer()
|
||||
server.binding.OIDC.ImplicitRoles = true
|
||||
server.binding.OIDC.CustomFields = []string{"custom1", "custom2"}
|
||||
server.binding.OIDC.CustomFields = []string{"custom1.sub", "custom2"}
|
||||
err = server.binding.OIDC.initialize()
|
||||
assert.NoError(t, err)
|
||||
server.initializeRouter()
|
||||
|
@ -1221,7 +1246,7 @@ func TestOIDCEvMgrIntegration(t *testing.T) {
|
|||
Nonce: authReq.Nonce,
|
||||
Expiry: time.Now().Add(5 * time.Minute),
|
||||
}
|
||||
setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`","custom1":"val1"}`))
|
||||
setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`","custom1":{"sub":"val1"},"custom2":"desc"}`))
|
||||
server.binding.OIDC.verifier = &mockOIDCVerifier{
|
||||
err: nil,
|
||||
token: idToken,
|
||||
|
@ -1235,6 +1260,7 @@ func TestOIDCEvMgrIntegration(t *testing.T) {
|
|||
user, err := dataprovider.UserExists(username, "")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, filepath.Join(os.TempDir(), "val1"), user.GetHomeDir())
|
||||
assert.Equal(t, "desc", user.Description)
|
||||
|
||||
err = dataprovider.DeleteUser(username, "", "", "")
|
||||
assert.NoError(t, err)
|
||||
|
@ -1473,13 +1499,15 @@ func TestParseAdminRole(t *testing.T) {
|
|||
type test struct {
|
||||
input string
|
||||
want bool
|
||||
val any
|
||||
}
|
||||
|
||||
tests := []test{
|
||||
{input: "", want: false},
|
||||
{input: "sftpgo_role", want: false},
|
||||
{input: "params.sftpgo_role", want: true},
|
||||
{input: "params.subparams.sftpgo_role", want: true},
|
||||
{input: "params.subparams.inner.sftpgo_role", want: true},
|
||||
{input: "params.sftpgo_role", want: true, val: "admin"},
|
||||
{input: "params.subparams.sftpgo_role", want: true, val: "admin"},
|
||||
{input: "params.subparams.inner.sftpgo_role", want: true, val: []any{"user", "admin"}},
|
||||
{input: "email", want: false},
|
||||
{input: "missing", want: false},
|
||||
{input: "params.email", want: false},
|
||||
|
@ -1492,6 +1520,9 @@ func TestParseAdminRole(t *testing.T) {
|
|||
token := oidcToken{}
|
||||
token.getRoleFromField(claims, tc.input)
|
||||
assert.Equal(t, tc.want, token.isAdmin(), "%q should return %t", tc.input, tc.want)
|
||||
if tc.want {
|
||||
assert.Equal(t, tc.val, token.Role)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -610,7 +610,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
|
|||
|
||||
sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
|
||||
if err != nil {
|
||||
logger.Debug(logSender, "", "failed to accept an incoming connection: %v", err)
|
||||
logger.Debug(logSender, "", "failed to accept an incoming connection from ip %q: %v", ipAddr, err)
|
||||
checkAuthError(ipAddr, err)
|
||||
return
|
||||
}
|
||||
|
@ -629,7 +629,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
|
|||
|
||||
defer user.CloseFs() //nolint:errcheck
|
||||
if err = user.CheckFsRoot(connectionID); err != nil {
|
||||
logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
|
||||
logger.Warn(logSender, connectionID, "unable to check fs root for user %q: %v", user.Username, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue