sftpgo/internal/httpd/oauth2_test.go

136 lines
4.5 KiB
Go
Raw Normal View History

// Copyright (C) 2019-2023 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package httpd
import (
"encoding/json"
"testing"
"time"
"github.com/rs/xid"
sdkkms "github.com/sftpgo/sdk/kms"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/kms"
"github.com/drakkan/sftpgo/v2/internal/util"
)
func TestMemoryOAuth2Manager(t *testing.T) {
mgr := newOAuth2Manager(0)
m, ok := mgr.(*memoryOAuth2Manager)
require.True(t, ok)
require.Len(t, m.pendingAuths, 0)
_, err := m.getPendingAuth(xid.New().String())
require.Error(t, err)
assert.Contains(t, err.Error(), "no auth request found")
auth := newOAuth2PendingAuth(1, "https://...", "cid", kms.NewPlainSecret("mysecret"))
m.addPendingAuth(auth)
require.Len(t, m.pendingAuths, 1)
a, err := m.getPendingAuth(auth.State)
assert.NoError(t, err)
assert.Equal(t, auth.State, a.State)
assert.Equal(t, sdkkms.SecretStatusPlain, a.ClientSecret.GetStatus())
m.removePendingAuth(auth.State)
_, err = m.getPendingAuth(auth.State)
require.Error(t, err)
assert.Contains(t, err.Error(), "no auth request found")
require.Len(t, m.pendingAuths, 0)
state := xid.New().String()
auth = oauth2PendingAuth{
State: state,
Provider: 1,
IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
}
m.addPendingAuth(auth)
auth = oauth2PendingAuth{
State: xid.New().String(),
Provider: 1,
IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)),
}
m.addPendingAuth(auth)
require.Len(t, m.pendingAuths, 2)
_, err = m.getPendingAuth(auth.State)
require.Error(t, err)
assert.Contains(t, err.Error(), "auth request is too old")
m.cleanup()
require.Len(t, m.pendingAuths, 1)
m.removePendingAuth(state)
require.Len(t, m.pendingAuths, 0)
}
func TestDbOAuth2Manager(t *testing.T) {
if !isSharedProviderSupported() {
t.Skip("this test it is not available with this provider")
}
mgr := newOAuth2Manager(1)
m, ok := mgr.(*dbOAuth2Manager)
require.True(t, ok)
_, err := m.getPendingAuth(xid.New().String())
require.Error(t, err)
auth := newOAuth2PendingAuth(1, "https://...", "client_id", kms.NewPlainSecret("my db secret"))
m.addPendingAuth(auth)
a, err := m.getPendingAuth(auth.State)
assert.NoError(t, err)
assert.Equal(t, sdkkms.SecretStatusPlain, a.ClientSecret.GetStatus())
session, err := dataprovider.GetSharedSession(auth.State)
assert.NoError(t, err)
authReq := oauth2PendingAuth{}
err = json.Unmarshal(session.Data.([]byte), &authReq)
assert.NoError(t, err)
assert.Equal(t, sdkkms.SecretStatusSecretBox, authReq.ClientSecret.GetStatus())
m.cleanup()
_, err = m.getPendingAuth(auth.State)
assert.NoError(t, err)
m.removePendingAuth(auth.State)
_, err = m.getPendingAuth(auth.State)
assert.Error(t, err)
auth = oauth2PendingAuth{
State: xid.New().String(),
Provider: 1,
IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)),
ClientSecret: kms.NewPlainSecret("db secret"),
}
m.addPendingAuth(auth)
_, err = m.getPendingAuth(auth.State)
assert.Error(t, err)
_, err = dataprovider.GetSharedSession(auth.State)
assert.NoError(t, err)
m.cleanup()
_, err = dataprovider.GetSharedSession(auth.State)
assert.Error(t, err)
_, err = m.decodePendingAuthData("not a byte array")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid auth request data")
_, err = m.decodePendingAuthData([]byte("{not a json"))
require.Error(t, err)
// adding a request with a non plain secret will fail
auth = oauth2PendingAuth{
State: xid.New().String(),
Provider: 1,
IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)),
ClientSecret: kms.NewPlainSecret("db secret"),
}
auth.ClientSecret.SetStatus(sdkkms.SecretStatusSecretBox)
m.addPendingAuth(auth)
_, err = dataprovider.GetSharedSession(auth.State)
assert.Error(t, err)
asJSON, err := json.Marshal(auth)
assert.NoError(t, err)
_, err = m.decodePendingAuthData(asJSON)
assert.Error(t, err)
}