mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-22 07:30:25 +00:00
796ea1dde9
so we can persist password reset codes, OIDC auth sessions and tokens. These features will also work in multi-node setups without sicky sessions now Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
1161 lines
37 KiB
Go
1161 lines
37 KiB
Go
package httpd
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"runtime"
|
|
"testing"
|
|
"time"
|
|
"unsafe"
|
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"github.com/go-chi/jwtauth/v5"
|
|
"github.com/rs/xid"
|
|
"github.com/sftpgo/sdk"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/oauth2"
|
|
|
|
"github.com/drakkan/sftpgo/v2/common"
|
|
"github.com/drakkan/sftpgo/v2/dataprovider"
|
|
"github.com/drakkan/sftpgo/v2/kms"
|
|
"github.com/drakkan/sftpgo/v2/util"
|
|
"github.com/drakkan/sftpgo/v2/vfs"
|
|
)
|
|
|
|
const (
|
|
oidcMockAddr = "127.0.0.1:11111"
|
|
)
|
|
|
|
type mockTokenSource struct {
|
|
token *oauth2.Token
|
|
err error
|
|
}
|
|
|
|
func (t *mockTokenSource) Token() (*oauth2.Token, error) {
|
|
return t.token, t.err
|
|
}
|
|
|
|
type mockOAuth2Config struct {
|
|
tokenSource *mockTokenSource
|
|
authCodeURL string
|
|
token *oauth2.Token
|
|
err error
|
|
}
|
|
|
|
func (c *mockOAuth2Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
|
|
return c.authCodeURL
|
|
}
|
|
|
|
func (c *mockOAuth2Config) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
|
return c.token, c.err
|
|
}
|
|
|
|
func (c *mockOAuth2Config) TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource {
|
|
return c.tokenSource
|
|
}
|
|
|
|
type mockOIDCVerifier struct {
|
|
token *oidc.IDToken
|
|
err error
|
|
}
|
|
|
|
func (v *mockOIDCVerifier) Verify(ctx context.Context, rawIDToken string) (*oidc.IDToken, error) {
|
|
return v.token, v.err
|
|
}
|
|
|
|
// hack because the field is unexported
|
|
func setIDTokenClaims(idToken *oidc.IDToken, claims []byte) {
|
|
pointerVal := reflect.ValueOf(idToken)
|
|
val := reflect.Indirect(pointerVal)
|
|
member := val.FieldByName("claims")
|
|
ptr := unsafe.Pointer(member.UnsafeAddr())
|
|
realPtr := (*[]byte)(ptr)
|
|
*realPtr = claims
|
|
}
|
|
|
|
func TestOIDCInitialization(t *testing.T) {
|
|
config := OIDC{}
|
|
err := config.initialize()
|
|
assert.NoError(t, err)
|
|
config = OIDC{
|
|
ClientID: "sftpgo-client",
|
|
ClientSecret: "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c",
|
|
ConfigURL: fmt.Sprintf("http://%v/", oidcMockAddr),
|
|
RedirectBaseURL: "http://127.0.0.1:8081/",
|
|
UsernameField: "preferred_username",
|
|
RoleField: "sftpgo_role",
|
|
}
|
|
err = config.initialize()
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "oidc: unable to initialize provider")
|
|
}
|
|
config.ConfigURL = fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr)
|
|
err = config.initialize()
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "http://127.0.0.1:8081"+webOIDCRedirectPath, config.getRedirectURL())
|
|
}
|
|
|
|
func TestOIDCLoginLogout(t *testing.T) {
|
|
oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
|
|
require.True(t, ok)
|
|
server := getTestOIDCServer()
|
|
err := server.binding.OIDC.initialize()
|
|
assert.NoError(t, err)
|
|
server.initializeRouter()
|
|
|
|
rr := httptest.NewRecorder()
|
|
r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
|
assert.Contains(t, rr.Body.String(), "Authentication state did not match")
|
|
|
|
expiredAuthReq := oidcPendingAuth{
|
|
State: xid.New().String(),
|
|
Nonce: xid.New().String(),
|
|
Audience: tokenAudienceWebClient,
|
|
IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)),
|
|
}
|
|
oidcMgr.addPendingAuth(expiredAuthReq)
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+expiredAuthReq.State, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
|
assert.Contains(t, rr.Body.String(), "Authentication state did not match")
|
|
oidcMgr.removePendingAuth(expiredAuthReq.State)
|
|
|
|
server.binding.OIDC.oauth2Config = &mockOAuth2Config{
|
|
tokenSource: &mockTokenSource{},
|
|
authCodeURL: webOIDCRedirectPath,
|
|
err: common.ErrGenericFailure,
|
|
}
|
|
server.binding.OIDC.verifier = &mockOIDCVerifier{
|
|
err: common.ErrGenericFailure,
|
|
}
|
|
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webAdminOIDCLoginPath, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webOIDCRedirectPath, rr.Header().Get("Location"))
|
|
require.Len(t, oidcMgr.pendingAuths, 1)
|
|
var state string
|
|
for k := range oidcMgr.pendingAuths {
|
|
state = k
|
|
}
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+state, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusOK, rr.Code)
|
|
// now the same for the web client
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webClientOIDCLoginPath, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webOIDCRedirectPath, rr.Header().Get("Location"))
|
|
require.Len(t, oidcMgr.pendingAuths, 1)
|
|
for k := range oidcMgr.pendingAuths {
|
|
state = k
|
|
}
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+state, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusOK, rr.Code)
|
|
// now return an OAuth2 token without the id_token
|
|
server.binding.OIDC.oauth2Config = &mockOAuth2Config{
|
|
tokenSource: &mockTokenSource{},
|
|
authCodeURL: webOIDCRedirectPath,
|
|
token: &oauth2.Token{
|
|
AccessToken: "123",
|
|
Expiry: time.Now().Add(5 * time.Minute),
|
|
},
|
|
err: nil,
|
|
}
|
|
authReq := newOIDCPendingAuth(tokenAudienceWebClient)
|
|
oidcMgr.addPendingAuth(authReq)
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
// now fail to verify the id token
|
|
token := &oauth2.Token{
|
|
AccessToken: "123",
|
|
Expiry: time.Now().Add(5 * time.Minute),
|
|
}
|
|
token = token.WithExtra(map[string]any{
|
|
"id_token": "id_token_val",
|
|
})
|
|
server.binding.OIDC.oauth2Config = &mockOAuth2Config{
|
|
tokenSource: &mockTokenSource{},
|
|
authCodeURL: webOIDCRedirectPath,
|
|
token: token,
|
|
err: nil,
|
|
}
|
|
authReq = newOIDCPendingAuth(tokenAudienceWebClient)
|
|
oidcMgr.addPendingAuth(authReq)
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
// id token nonce does not match
|
|
server.binding.OIDC.verifier = &mockOIDCVerifier{
|
|
err: nil,
|
|
token: &oidc.IDToken{},
|
|
}
|
|
authReq = newOIDCPendingAuth(tokenAudienceWebClient)
|
|
oidcMgr.addPendingAuth(authReq)
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
// null id token claims
|
|
authReq = newOIDCPendingAuth(tokenAudienceWebClient)
|
|
oidcMgr.addPendingAuth(authReq)
|
|
server.binding.OIDC.verifier = &mockOIDCVerifier{
|
|
err: nil,
|
|
token: &oidc.IDToken{
|
|
Nonce: authReq.Nonce,
|
|
},
|
|
}
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
// invalid id token claims (no username)
|
|
authReq = newOIDCPendingAuth(tokenAudienceWebClient)
|
|
oidcMgr.addPendingAuth(authReq)
|
|
idToken := &oidc.IDToken{
|
|
Nonce: authReq.Nonce,
|
|
Expiry: time.Now().Add(5 * time.Minute),
|
|
}
|
|
setIDTokenClaims(idToken, []byte(`{"aud": "my_client_id"}`))
|
|
server.binding.OIDC.verifier = &mockOIDCVerifier{
|
|
err: nil,
|
|
token: idToken,
|
|
}
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
// invalid audience
|
|
authReq = newOIDCPendingAuth(tokenAudienceWebClient)
|
|
oidcMgr.addPendingAuth(authReq)
|
|
idToken = &oidc.IDToken{
|
|
Nonce: authReq.Nonce,
|
|
Expiry: time.Now().Add(5 * time.Minute),
|
|
}
|
|
setIDTokenClaims(idToken, []byte(`{"preferred_username":"test","sftpgo_role":"admin"}`))
|
|
server.binding.OIDC.verifier = &mockOIDCVerifier{
|
|
err: nil,
|
|
token: idToken,
|
|
}
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
// invalid audience
|
|
authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
|
|
oidcMgr.addPendingAuth(authReq)
|
|
idToken = &oidc.IDToken{
|
|
Nonce: authReq.Nonce,
|
|
Expiry: time.Now().Add(5 * time.Minute),
|
|
}
|
|
setIDTokenClaims(idToken, []byte(`{"preferred_username":"test"}`))
|
|
server.binding.OIDC.verifier = &mockOIDCVerifier{
|
|
err: nil,
|
|
token: idToken,
|
|
}
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
// mapped user not found
|
|
authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
|
|
oidcMgr.addPendingAuth(authReq)
|
|
idToken = &oidc.IDToken{
|
|
Nonce: authReq.Nonce,
|
|
Expiry: time.Now().Add(5 * time.Minute),
|
|
}
|
|
setIDTokenClaims(idToken, []byte(`{"preferred_username":"test","sftpgo_role":"admin"}`))
|
|
server.binding.OIDC.verifier = &mockOIDCVerifier{
|
|
err: nil,
|
|
token: idToken,
|
|
}
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
// admin login ok
|
|
authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
|
|
oidcMgr.addPendingAuth(authReq)
|
|
idToken = &oidc.IDToken{
|
|
Nonce: authReq.Nonce,
|
|
Expiry: time.Now().Add(5 * time.Minute),
|
|
}
|
|
setIDTokenClaims(idToken, []byte(`{"preferred_username":"admin","sftpgo_role":"admin","sid":"sid123"}`))
|
|
server.binding.OIDC.verifier = &mockOIDCVerifier{
|
|
err: nil,
|
|
token: idToken,
|
|
}
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webUsersPath, rr.Header().Get("Location"))
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
require.Len(t, oidcMgr.tokens, 1)
|
|
// admin profile is not available
|
|
var tokenCookie string
|
|
for k := range oidcMgr.tokens {
|
|
tokenCookie = k
|
|
}
|
|
oidcToken, err := oidcMgr.getToken(tokenCookie)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "sid123", oidcToken.SessionID)
|
|
assert.True(t, oidcToken.isAdmin())
|
|
assert.False(t, oidcToken.isExpired())
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webAdminProfilePath, nil)
|
|
assert.NoError(t, err)
|
|
r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusForbidden, rr.Code)
|
|
// the admin can access the allowed pages
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
|
|
assert.NoError(t, err)
|
|
r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusOK, rr.Code)
|
|
// try with an invalid cookie
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
|
|
assert.NoError(t, err)
|
|
r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, xid.New().String()))
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
|
|
// Web Client is not available with an admin token
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
|
|
assert.NoError(t, err)
|
|
r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
|
|
// logout the admin user
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil)
|
|
assert.NoError(t, err)
|
|
r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
require.Len(t, oidcMgr.tokens, 0)
|
|
// now login and logout a user
|
|
username := "test_oidc_user"
|
|
user := dataprovider.User{
|
|
BaseUser: sdk.BaseUser{
|
|
Username: username,
|
|
Password: "pwd",
|
|
HomeDir: filepath.Join(os.TempDir(), username),
|
|
Status: 1,
|
|
Permissions: map[string][]string{
|
|
"/": {dataprovider.PermAny},
|
|
},
|
|
},
|
|
Filters: dataprovider.UserFilters{
|
|
BaseUserFilters: sdk.BaseUserFilters{
|
|
WebClient: []string{sdk.WebClientSharesDisabled},
|
|
},
|
|
},
|
|
}
|
|
err = dataprovider.AddUser(&user, "", "")
|
|
assert.NoError(t, err)
|
|
|
|
authReq = newOIDCPendingAuth(tokenAudienceWebClient)
|
|
oidcMgr.addPendingAuth(authReq)
|
|
idToken = &oidc.IDToken{
|
|
Nonce: authReq.Nonce,
|
|
Expiry: time.Now().Add(5 * time.Minute),
|
|
}
|
|
setIDTokenClaims(idToken, []byte(`{"preferred_username":"test_oidc_user"}`))
|
|
server.binding.OIDC.verifier = &mockOIDCVerifier{
|
|
err: nil,
|
|
token: idToken,
|
|
}
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
require.Len(t, oidcMgr.tokens, 1)
|
|
// user profile is not available
|
|
for k := range oidcMgr.tokens {
|
|
tokenCookie = k
|
|
}
|
|
oidcToken, err = oidcMgr.getToken(tokenCookie)
|
|
assert.NoError(t, err)
|
|
assert.Empty(t, oidcToken.SessionID)
|
|
assert.False(t, oidcToken.isAdmin())
|
|
assert.False(t, oidcToken.isExpired())
|
|
if assert.Len(t, oidcToken.Permissions, 1) {
|
|
assert.Equal(t, sdk.WebClientSharesDisabled, oidcToken.Permissions[0])
|
|
}
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webClientProfilePath, nil)
|
|
assert.NoError(t, err)
|
|
r.RequestURI = webClientProfilePath
|
|
r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusForbidden, rr.Code)
|
|
// the user can access the allowed pages
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
|
|
assert.NoError(t, err)
|
|
r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusOK, rr.Code)
|
|
// try with an invalid cookie
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
|
|
assert.NoError(t, err)
|
|
r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, xid.New().String()))
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
|
|
// Web Admin is not available with a client cookie
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
|
|
assert.NoError(t, err)
|
|
r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
|
|
// logout the user
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
|
|
assert.NoError(t, err)
|
|
r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
require.Len(t, oidcMgr.tokens, 0)
|
|
|
|
err = os.RemoveAll(user.GetHomeDir())
|
|
assert.NoError(t, err)
|
|
err = dataprovider.DeleteUser(username, "", "")
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestOIDCRefreshToken(t *testing.T) {
|
|
oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
|
|
require.True(t, ok)
|
|
token := oidcToken{
|
|
Cookie: xid.New().String(),
|
|
AccessToken: xid.New().String(),
|
|
TokenType: "Bearer",
|
|
ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-1 * time.Minute)),
|
|
Nonce: xid.New().String(),
|
|
}
|
|
config := mockOAuth2Config{
|
|
tokenSource: &mockTokenSource{
|
|
err: common.ErrGenericFailure,
|
|
},
|
|
}
|
|
verifier := mockOIDCVerifier{
|
|
err: common.ErrGenericFailure,
|
|
}
|
|
err := token.refresh(&config, &verifier)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "refresh token not set")
|
|
}
|
|
token.RefreshToken = xid.New().String()
|
|
err = token.refresh(&config, &verifier)
|
|
assert.ErrorIs(t, err, common.ErrGenericFailure)
|
|
|
|
newToken := &oauth2.Token{
|
|
AccessToken: xid.New().String(),
|
|
RefreshToken: xid.New().String(),
|
|
Expiry: time.Now().Add(5 * time.Minute),
|
|
}
|
|
config = mockOAuth2Config{
|
|
tokenSource: &mockTokenSource{
|
|
token: newToken,
|
|
},
|
|
}
|
|
verifier = mockOIDCVerifier{
|
|
token: &oidc.IDToken{},
|
|
}
|
|
err = token.refresh(&config, &verifier)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "the refreshed token has no id token")
|
|
}
|
|
newToken = newToken.WithExtra(map[string]any{
|
|
"id_token": "id_token_val",
|
|
})
|
|
newToken.Expiry = time.Time{}
|
|
config = mockOAuth2Config{
|
|
tokenSource: &mockTokenSource{
|
|
token: newToken,
|
|
},
|
|
}
|
|
verifier = mockOIDCVerifier{
|
|
err: common.ErrGenericFailure,
|
|
}
|
|
err = token.refresh(&config, &verifier)
|
|
assert.ErrorIs(t, err, common.ErrGenericFailure)
|
|
|
|
newToken = newToken.WithExtra(map[string]any{
|
|
"id_token": "id_token_val",
|
|
})
|
|
newToken.Expiry = time.Now().Add(5 * time.Minute)
|
|
config = mockOAuth2Config{
|
|
tokenSource: &mockTokenSource{
|
|
token: newToken,
|
|
},
|
|
}
|
|
verifier = mockOIDCVerifier{
|
|
token: &oidc.IDToken{},
|
|
}
|
|
err = token.refresh(&config, &verifier)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "the refreshed token nonce mismatch")
|
|
}
|
|
verifier = mockOIDCVerifier{
|
|
token: &oidc.IDToken{
|
|
Nonce: token.Nonce,
|
|
},
|
|
}
|
|
err = token.refresh(&config, &verifier)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "oidc: claims not set")
|
|
}
|
|
idToken := &oidc.IDToken{
|
|
Nonce: token.Nonce,
|
|
}
|
|
setIDTokenClaims(idToken, []byte(`{"sid":"id_token_sid"}`))
|
|
verifier = mockOIDCVerifier{
|
|
token: idToken,
|
|
}
|
|
err = token.refresh(&config, &verifier)
|
|
assert.NoError(t, err)
|
|
require.Len(t, oidcMgr.tokens, 1)
|
|
oidcMgr.removeToken(token.Cookie)
|
|
require.Len(t, oidcMgr.tokens, 0)
|
|
}
|
|
|
|
func TestValidateOIDCToken(t *testing.T) {
|
|
oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
|
|
require.True(t, ok)
|
|
server := getTestOIDCServer()
|
|
err := server.binding.OIDC.initialize()
|
|
assert.NoError(t, err)
|
|
server.initializeRouter()
|
|
|
|
rr := httptest.NewRecorder()
|
|
r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
|
|
assert.NoError(t, err)
|
|
_, err = server.validateOIDCToken(rr, r, false)
|
|
assert.ErrorIs(t, err, errInvalidToken)
|
|
// expired token and refresh error
|
|
server.binding.OIDC.oauth2Config = &mockOAuth2Config{
|
|
tokenSource: &mockTokenSource{
|
|
err: common.ErrGenericFailure,
|
|
},
|
|
}
|
|
token := oidcToken{
|
|
Cookie: xid.New().String(),
|
|
AccessToken: xid.New().String(),
|
|
ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)),
|
|
}
|
|
oidcMgr.addToken(token)
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
|
|
assert.NoError(t, err)
|
|
r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
|
|
_, err = server.validateOIDCToken(rr, r, false)
|
|
assert.ErrorIs(t, err, errInvalidToken)
|
|
oidcMgr.removeToken(token.Cookie)
|
|
assert.Len(t, oidcMgr.tokens, 0)
|
|
|
|
server.tokenAuth = jwtauth.New("PS256", util.GenerateRandomBytes(32), nil)
|
|
token = oidcToken{
|
|
Cookie: xid.New().String(),
|
|
AccessToken: xid.New().String(),
|
|
}
|
|
oidcMgr.addToken(token)
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
|
|
assert.NoError(t, err)
|
|
r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
|
|
oidcMgr.removeToken(token.Cookie)
|
|
assert.Len(t, oidcMgr.tokens, 0)
|
|
|
|
token = oidcToken{
|
|
Cookie: xid.New().String(),
|
|
AccessToken: xid.New().String(),
|
|
Role: "admin",
|
|
}
|
|
oidcMgr.addToken(token)
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil)
|
|
assert.NoError(t, err)
|
|
r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
|
|
oidcMgr.removeToken(token.Cookie)
|
|
assert.Len(t, oidcMgr.tokens, 0)
|
|
}
|
|
|
|
func TestSkipOIDCAuth(t *testing.T) {
|
|
server := getTestOIDCServer()
|
|
err := server.binding.OIDC.initialize()
|
|
assert.NoError(t, err)
|
|
server.initializeRouter()
|
|
jwtTokenClaims := jwtTokenClaims{
|
|
Username: "user",
|
|
}
|
|
_, tokenString, err := jwtTokenClaims.createToken(server.tokenAuth, tokenAudienceWebClient, "")
|
|
assert.NoError(t, err)
|
|
rr := httptest.NewRecorder()
|
|
r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
|
|
assert.NoError(t, err)
|
|
r.Header.Set("Cookie", fmt.Sprintf("%v=%v", jwtCookieKey, tokenString))
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
|
|
}
|
|
|
|
func TestOIDCLogoutErrors(t *testing.T) {
|
|
server := getTestOIDCServer()
|
|
assert.Empty(t, server.binding.OIDC.providerLogoutURL)
|
|
server.logoutFromOIDCOP("")
|
|
server.binding.OIDC.providerLogoutURL = "http://foo\x7f.com/"
|
|
server.doOIDCFromLogout("")
|
|
server.binding.OIDC.providerLogoutURL = "http://127.0.0.1:11234"
|
|
server.doOIDCFromLogout("")
|
|
}
|
|
|
|
func TestOIDCToken(t *testing.T) {
|
|
admin := dataprovider.Admin{
|
|
Username: "test_oidc_admin",
|
|
Password: "p",
|
|
Permissions: []string{dataprovider.PermAdminAny},
|
|
Status: 0,
|
|
}
|
|
err := dataprovider.AddAdmin(&admin, "", "")
|
|
assert.NoError(t, err)
|
|
|
|
token := oidcToken{
|
|
Username: admin.Username,
|
|
Role: "admin",
|
|
}
|
|
req, err := http.NewRequest(http.MethodGet, webUsersPath, nil)
|
|
assert.NoError(t, err)
|
|
err = token.getUser(req)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "is disabled")
|
|
}
|
|
err = dataprovider.DeleteAdmin(admin.Username, "", "")
|
|
assert.NoError(t, err)
|
|
|
|
username := "test_oidc_user"
|
|
token.Username = username
|
|
token.Role = ""
|
|
err = token.getUser(req)
|
|
if assert.Error(t, err) {
|
|
_, ok := err.(*util.RecordNotFoundError)
|
|
assert.True(t, ok)
|
|
}
|
|
|
|
user := dataprovider.User{
|
|
BaseUser: sdk.BaseUser{
|
|
Username: username,
|
|
Password: "p",
|
|
HomeDir: filepath.Join(os.TempDir(), username),
|
|
Status: 0,
|
|
Permissions: map[string][]string{
|
|
"/": {dataprovider.PermAny},
|
|
},
|
|
},
|
|
Filters: dataprovider.UserFilters{
|
|
BaseUserFilters: sdk.BaseUserFilters{
|
|
DeniedProtocols: []string{common.ProtocolHTTP},
|
|
},
|
|
},
|
|
}
|
|
err = dataprovider.AddUser(&user, "", "")
|
|
assert.NoError(t, err)
|
|
err = token.getUser(req)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "is disabled")
|
|
}
|
|
user, err = dataprovider.UserExists(username)
|
|
assert.NoError(t, err)
|
|
user.Status = 1
|
|
user.Password = "np"
|
|
err = dataprovider.UpdateUser(&user, "", "")
|
|
assert.NoError(t, err)
|
|
|
|
err = token.getUser(req)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "protocol HTTP is not allowed")
|
|
}
|
|
|
|
user.Filters.DeniedProtocols = nil
|
|
user.FsConfig.Provider = sdk.SFTPFilesystemProvider
|
|
user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{
|
|
BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{
|
|
Endpoint: "127.0.0.1:8022",
|
|
Username: username,
|
|
},
|
|
Password: kms.NewPlainSecret("np"),
|
|
}
|
|
err = dataprovider.UpdateUser(&user, "", "")
|
|
assert.NoError(t, err)
|
|
err = token.getUser(req)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "SFTP loop")
|
|
}
|
|
|
|
common.Config.PostConnectHook = fmt.Sprintf("http://%v/404", oidcMockAddr)
|
|
|
|
err = token.getUser(req)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "access denied by post connect hook")
|
|
}
|
|
|
|
common.Config.PostConnectHook = ""
|
|
|
|
err = os.RemoveAll(user.GetHomeDir())
|
|
assert.NoError(t, err)
|
|
err = dataprovider.DeleteUser(username, "", "")
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestMemoryOIDCManager(t *testing.T) {
|
|
oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
|
|
require.True(t, ok)
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
authReq := newOIDCPendingAuth(tokenAudienceWebAdmin)
|
|
oidcMgr.addPendingAuth(authReq)
|
|
require.Len(t, oidcMgr.pendingAuths, 1)
|
|
_, err := oidcMgr.getPendingAuth(authReq.State)
|
|
assert.NoError(t, err)
|
|
oidcMgr.removePendingAuth(authReq.State)
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
authReq.IssuedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-61 * time.Second))
|
|
oidcMgr.addPendingAuth(authReq)
|
|
require.Len(t, oidcMgr.pendingAuths, 1)
|
|
_, err = oidcMgr.getPendingAuth(authReq.State)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "too old")
|
|
}
|
|
oidcMgr.cleanup()
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
|
|
token := oidcToken{
|
|
AccessToken: xid.New().String(),
|
|
Nonce: xid.New().String(),
|
|
SessionID: xid.New().String(),
|
|
Cookie: xid.New().String(),
|
|
Username: xid.New().String(),
|
|
Role: "admin",
|
|
Permissions: []string{dataprovider.PermAdminAny},
|
|
}
|
|
require.Len(t, oidcMgr.tokens, 0)
|
|
oidcMgr.addToken(token)
|
|
require.Len(t, oidcMgr.tokens, 1)
|
|
_, err = oidcMgr.getToken(xid.New().String())
|
|
assert.Error(t, err)
|
|
storedToken, err := oidcMgr.getToken(token.Cookie)
|
|
assert.NoError(t, err)
|
|
token.UsedAt = 0 // ensure we don't modify the stored token
|
|
assert.Greater(t, storedToken.UsedAt, int64(0))
|
|
token.UsedAt = storedToken.UsedAt
|
|
assert.Equal(t, token, storedToken)
|
|
// the usage will not be updated, it is recent
|
|
oidcMgr.updateTokenUsage(storedToken)
|
|
storedToken, err = oidcMgr.getToken(token.Cookie)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, token, storedToken)
|
|
usedAt := util.GetTimeAsMsSinceEpoch(time.Now().Add(-5 * time.Minute))
|
|
storedToken.UsedAt = usedAt
|
|
oidcMgr.tokens[token.Cookie] = storedToken
|
|
storedToken, err = oidcMgr.getToken(token.Cookie)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, usedAt, storedToken.UsedAt)
|
|
token.UsedAt = storedToken.UsedAt
|
|
assert.Equal(t, token, storedToken)
|
|
oidcMgr.updateTokenUsage(storedToken)
|
|
storedToken, err = oidcMgr.getToken(token.Cookie)
|
|
assert.NoError(t, err)
|
|
assert.Greater(t, storedToken.UsedAt, usedAt)
|
|
token.UsedAt = storedToken.UsedAt
|
|
assert.Equal(t, token, storedToken)
|
|
storedToken.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now()) - tokenDeleteInterval - 1
|
|
oidcMgr.tokens[token.Cookie] = storedToken
|
|
storedToken, err = oidcMgr.getToken(token.Cookie)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "token is too old")
|
|
}
|
|
oidcMgr.removeToken(xid.New().String())
|
|
require.Len(t, oidcMgr.tokens, 1)
|
|
oidcMgr.removeToken(token.Cookie)
|
|
require.Len(t, oidcMgr.tokens, 0)
|
|
oidcMgr.addToken(token)
|
|
usedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-6 * time.Hour))
|
|
token.UsedAt = usedAt
|
|
oidcMgr.tokens[token.Cookie] = token
|
|
newToken := oidcToken{
|
|
Cookie: xid.New().String(),
|
|
}
|
|
oidcMgr.addToken(newToken)
|
|
oidcMgr.cleanup()
|
|
require.Len(t, oidcMgr.tokens, 1)
|
|
_, err = oidcMgr.getToken(token.Cookie)
|
|
assert.Error(t, err)
|
|
_, err = oidcMgr.getToken(newToken.Cookie)
|
|
assert.NoError(t, err)
|
|
oidcMgr.removeToken(newToken.Cookie)
|
|
require.Len(t, oidcMgr.tokens, 0)
|
|
}
|
|
|
|
func TestOIDCPreLoginHook(t *testing.T) {
|
|
if runtime.GOOS == osWindows {
|
|
t.Skip("this test is not available on Windows")
|
|
}
|
|
oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
|
|
require.True(t, ok)
|
|
username := "test_oidc_user_prelogin"
|
|
u := dataprovider.User{
|
|
BaseUser: sdk.BaseUser{
|
|
Username: username,
|
|
HomeDir: filepath.Join(os.TempDir(), username),
|
|
Status: 1,
|
|
Permissions: map[string][]string{
|
|
"/": {dataprovider.PermAny},
|
|
},
|
|
},
|
|
}
|
|
preLoginPath := filepath.Join(os.TempDir(), "prelogin.sh")
|
|
providerConf := dataprovider.GetProviderConfig()
|
|
err := dataprovider.Close()
|
|
assert.NoError(t, err)
|
|
err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm)
|
|
assert.NoError(t, err)
|
|
newProviderConf := providerConf
|
|
newProviderConf.PreLoginHook = preLoginPath
|
|
err = dataprovider.Initialize(newProviderConf, "..", true)
|
|
assert.NoError(t, err)
|
|
server := getTestOIDCServer()
|
|
server.binding.OIDC.CustomFields = []string{"field1", "field2"}
|
|
err = server.binding.OIDC.initialize()
|
|
assert.NoError(t, err)
|
|
server.initializeRouter()
|
|
|
|
_, err = dataprovider.UserExists(username)
|
|
_, ok = err.(*util.RecordNotFoundError)
|
|
assert.True(t, ok)
|
|
// now login with OIDC
|
|
authReq := newOIDCPendingAuth(tokenAudienceWebClient)
|
|
oidcMgr.addPendingAuth(authReq)
|
|
token := &oauth2.Token{
|
|
AccessToken: "1234",
|
|
Expiry: time.Now().Add(5 * time.Minute),
|
|
}
|
|
token = token.WithExtra(map[string]any{
|
|
"id_token": "id_token_val",
|
|
})
|
|
server.binding.OIDC.oauth2Config = &mockOAuth2Config{
|
|
tokenSource: &mockTokenSource{},
|
|
authCodeURL: webOIDCRedirectPath,
|
|
token: token,
|
|
}
|
|
idToken := &oidc.IDToken{
|
|
Nonce: authReq.Nonce,
|
|
Expiry: time.Now().Add(5 * time.Minute),
|
|
}
|
|
setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+username+`"}`))
|
|
server.binding.OIDC.verifier = &mockOIDCVerifier{
|
|
err: nil,
|
|
token: idToken,
|
|
}
|
|
rr := httptest.NewRecorder()
|
|
r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
|
|
_, err = dataprovider.UserExists(username)
|
|
assert.NoError(t, err)
|
|
|
|
err = dataprovider.DeleteUser(username, "", "")
|
|
assert.NoError(t, err)
|
|
|
|
err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, true), os.ModePerm)
|
|
assert.NoError(t, err)
|
|
|
|
authReq = newOIDCPendingAuth(tokenAudienceWebClient)
|
|
oidcMgr.addPendingAuth(authReq)
|
|
idToken = &oidc.IDToken{
|
|
Nonce: authReq.Nonce,
|
|
Expiry: time.Now().Add(5 * time.Minute),
|
|
}
|
|
setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+username+`","field1":"value1","field2":"value2","field3":"value3"}`))
|
|
server.binding.OIDC.verifier = &mockOIDCVerifier{
|
|
err: nil,
|
|
token: idToken,
|
|
}
|
|
rr = httptest.NewRecorder()
|
|
r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
|
|
assert.NoError(t, err)
|
|
server.router.ServeHTTP(rr, r)
|
|
assert.Equal(t, http.StatusFound, rr.Code)
|
|
assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
|
|
_, err = dataprovider.UserExists(username)
|
|
_, ok = err.(*util.RecordNotFoundError)
|
|
assert.True(t, ok)
|
|
if assert.Len(t, oidcMgr.tokens, 1) {
|
|
for k := range oidcMgr.tokens {
|
|
oidcMgr.removeToken(k)
|
|
}
|
|
}
|
|
require.Len(t, oidcMgr.pendingAuths, 0)
|
|
require.Len(t, oidcMgr.tokens, 0)
|
|
|
|
err = dataprovider.Close()
|
|
assert.NoError(t, err)
|
|
err = dataprovider.Initialize(providerConf, "..", true)
|
|
assert.NoError(t, err)
|
|
err = os.Remove(preLoginPath)
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestOIDCIsAdmin(t *testing.T) {
|
|
type test struct {
|
|
input any
|
|
want bool
|
|
}
|
|
|
|
emptySlice := make([]any, 0)
|
|
|
|
tests := []test{
|
|
{input: "admin", want: true},
|
|
{input: append(emptySlice, "admin"), want: true},
|
|
{input: append(emptySlice, "user", "admin"), want: true},
|
|
{input: "user", want: false},
|
|
{input: emptySlice, want: false},
|
|
{input: append(emptySlice, 1), want: false},
|
|
{input: 1, want: false},
|
|
{input: nil, want: false},
|
|
}
|
|
for _, tc := range tests {
|
|
token := oidcToken{
|
|
Role: tc.input,
|
|
}
|
|
assert.Equal(t, tc.want, token.isAdmin(), "%v should return %t", tc.input, tc.want)
|
|
}
|
|
}
|
|
|
|
func TestDbOIDCManager(t *testing.T) {
|
|
if !isSharedProviderSupported() {
|
|
t.Skip("this test it is not available with this provider")
|
|
}
|
|
mgr := newOIDCManager(1)
|
|
pendingAuth := newOIDCPendingAuth(tokenAudienceWebAdmin)
|
|
mgr.addPendingAuth(pendingAuth)
|
|
authReq, err := mgr.getPendingAuth(pendingAuth.State)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, pendingAuth, authReq)
|
|
pendingAuth.IssuedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour))
|
|
mgr.addPendingAuth(pendingAuth)
|
|
_, err = mgr.getPendingAuth(pendingAuth.State)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "auth request is too old")
|
|
}
|
|
mgr.removePendingAuth(pendingAuth.State)
|
|
_, err = mgr.getPendingAuth(pendingAuth.State)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "unable to get the auth request for the specified state")
|
|
}
|
|
mgr.addPendingAuth(pendingAuth)
|
|
_, err = mgr.getPendingAuth(pendingAuth.State)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "auth request is too old")
|
|
}
|
|
mgr.cleanup()
|
|
_, err = mgr.getPendingAuth(pendingAuth.State)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "unable to get the auth request for the specified state")
|
|
}
|
|
|
|
token := oidcToken{
|
|
Cookie: xid.New().String(),
|
|
AccessToken: xid.New().String(),
|
|
TokenType: "Bearer",
|
|
RefreshToken: xid.New().String(),
|
|
ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)),
|
|
SessionID: xid.New().String(),
|
|
IDToken: xid.New().String(),
|
|
Nonce: xid.New().String(),
|
|
Username: xid.New().String(),
|
|
Permissions: []string{dataprovider.PermAdminAny},
|
|
Role: "admin",
|
|
}
|
|
mgr.addToken(token)
|
|
tokenGet, err := mgr.getToken(token.Cookie)
|
|
assert.NoError(t, err)
|
|
assert.Greater(t, tokenGet.UsedAt, int64(0))
|
|
token.UsedAt = tokenGet.UsedAt
|
|
assert.Equal(t, token, tokenGet)
|
|
time.Sleep(100 * time.Millisecond)
|
|
mgr.updateTokenUsage(token)
|
|
// no change
|
|
tokenGet, err = mgr.getToken(token.Cookie)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, token.UsedAt, tokenGet.UsedAt)
|
|
tokenGet.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour))
|
|
tokenGet.RefreshToken = xid.New().String()
|
|
mgr.updateTokenUsage(tokenGet)
|
|
tokenGet, err = mgr.getToken(token.Cookie)
|
|
assert.NoError(t, err)
|
|
assert.NotEmpty(t, tokenGet.RefreshToken)
|
|
assert.NotEqual(t, token.RefreshToken, tokenGet.RefreshToken)
|
|
assert.Greater(t, tokenGet.UsedAt, token.UsedAt)
|
|
mgr.removeToken(token.Cookie)
|
|
tokenGet, err = mgr.getToken(token.Cookie)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "unable to get the token for the specified session")
|
|
}
|
|
// add an expired token
|
|
token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour))
|
|
session := dataprovider.Session{
|
|
Key: token.Cookie,
|
|
Data: token,
|
|
Type: dataprovider.SessionTypeOIDCToken,
|
|
Timestamp: token.UsedAt + tokenDeleteInterval,
|
|
}
|
|
err = dataprovider.AddSharedSession(session)
|
|
assert.NoError(t, err)
|
|
_, err = mgr.getToken(token.Cookie)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "token is too old")
|
|
}
|
|
mgr.cleanup()
|
|
_, err = mgr.getToken(token.Cookie)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "unable to get the token for the specified session")
|
|
}
|
|
// adding a session without a key should fail
|
|
session.Key = ""
|
|
err = dataprovider.AddSharedSession(session)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "unable to save a session with an empty key")
|
|
}
|
|
session.Key = xid.New().String()
|
|
session.Type = 1000
|
|
err = dataprovider.AddSharedSession(session)
|
|
if assert.Error(t, err) {
|
|
assert.Contains(t, err.Error(), "invalid session type")
|
|
}
|
|
|
|
dbMgr, ok := mgr.(*dbOIDCManager)
|
|
if assert.True(t, ok) {
|
|
_, err = dbMgr.decodePendingAuthData(2)
|
|
assert.Error(t, err)
|
|
_, err = dbMgr.decodeTokenData(true)
|
|
assert.Error(t, err)
|
|
}
|
|
}
|
|
|
|
func getTestOIDCServer() *httpdServer {
|
|
return &httpdServer{
|
|
binding: Binding{
|
|
OIDC: OIDC{
|
|
ClientID: "sftpgo-client",
|
|
ClientSecret: "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c",
|
|
ConfigURL: fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr),
|
|
RedirectBaseURL: "http://127.0.0.1:8081/",
|
|
UsernameField: "preferred_username",
|
|
RoleField: "sftpgo_role",
|
|
CustomFields: nil,
|
|
},
|
|
},
|
|
enableWebAdmin: true,
|
|
enableWebClient: true,
|
|
}
|
|
}
|
|
|
|
func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []byte {
|
|
content := []byte("#!/bin/sh\n\n")
|
|
if nonJSONResponse {
|
|
content = append(content, []byte("echo 'text response'\n")...)
|
|
return content
|
|
}
|
|
if len(user.Username) > 0 {
|
|
u, _ := json.Marshal(user)
|
|
content = append(content, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...)
|
|
}
|
|
return content
|
|
}
|