mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-25 17:10:28 +00:00
kms: remember if a secret was saved without a master key
So we will be able to decrypt secret stored without a master key if a such key is provided later
This commit is contained in:
parent
940836b25b
commit
87b51a6fd5
6 changed files with 144 additions and 5 deletions
|
@ -1390,6 +1390,114 @@ func TestSecretObject(t *testing.T) {
|
||||||
require.Empty(t, s.GetKey())
|
require.Empty(t, s.GetKey())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSecretObjectCompatibility(t *testing.T) {
|
||||||
|
// this is manually tested against vault too
|
||||||
|
testPayload := "test payload"
|
||||||
|
s := kms.NewPlainSecret(testPayload)
|
||||||
|
require.True(t, s.IsValid())
|
||||||
|
err := s.Encrypt()
|
||||||
|
require.NoError(t, err)
|
||||||
|
localAsJSON, err := json.Marshal(s)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
for _, provider := range []string{kms.SecretStatusRedacted} {
|
||||||
|
kmsConfig := config.GetKMSConfig()
|
||||||
|
assert.Empty(t, kmsConfig.Secrets.MasterKeyPath)
|
||||||
|
if provider == kms.SecretStatusVaultTransit {
|
||||||
|
os.Setenv("VAULT_SERVER_URL", "http://127.0.0.1:8200")
|
||||||
|
os.Setenv("VAULT_SERVER_TOKEN", "s.9lYGq83MbgG5KR5kfebXVyhJ")
|
||||||
|
kmsConfig.Secrets.URL = "hashivault://mykey"
|
||||||
|
}
|
||||||
|
err := kmsConfig.Initialize()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// encrypt without a master key
|
||||||
|
secret := kms.NewPlainSecret(testPayload)
|
||||||
|
secret.SetAdditionalData("add data")
|
||||||
|
err = secret.Encrypt()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, secret.GetMode())
|
||||||
|
secretClone := secret.Clone()
|
||||||
|
err = secretClone.Decrypt()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, testPayload, secretClone.GetPayload())
|
||||||
|
if provider == kms.SecretStatusVaultTransit {
|
||||||
|
// decrypt the local secret now that the provider is vault
|
||||||
|
secretLocal := kms.NewEmptySecret()
|
||||||
|
err = json.Unmarshal(localAsJSON, secretLocal)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, kms.SecretStatusSecretBox, secretLocal.GetStatus())
|
||||||
|
assert.Equal(t, 0, secretLocal.GetMode())
|
||||||
|
err = secretLocal.Decrypt()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, testPayload, secretLocal.GetPayload())
|
||||||
|
assert.Equal(t, kms.SecretStatusPlain, secretLocal.GetStatus())
|
||||||
|
err = secretLocal.Encrypt()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, kms.SecretStatusSecretBox, secretLocal.GetStatus())
|
||||||
|
assert.Equal(t, 0, secretLocal.GetMode())
|
||||||
|
}
|
||||||
|
|
||||||
|
asJSON, err := json.Marshal(secret)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
masterKeyPath := filepath.Join(os.TempDir(), "mkey")
|
||||||
|
err = ioutil.WriteFile(masterKeyPath, []byte("test key"), os.ModePerm)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
config := kms.Configuration{
|
||||||
|
Secrets: kms.Secrets{
|
||||||
|
MasterKeyPath: masterKeyPath,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if provider == kms.SecretStatusVaultTransit {
|
||||||
|
config.Secrets.URL = "hashivault://mykey"
|
||||||
|
}
|
||||||
|
err = config.Initialize()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// now build the secret from JSON
|
||||||
|
secret = kms.NewEmptySecret()
|
||||||
|
err = json.Unmarshal(asJSON, secret)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, secret.GetMode())
|
||||||
|
err = secret.Decrypt()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, testPayload, secret.GetPayload())
|
||||||
|
err = secret.Encrypt()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, secret.GetMode())
|
||||||
|
err = secret.Decrypt()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, testPayload, secret.GetPayload())
|
||||||
|
if provider == kms.SecretStatusVaultTransit {
|
||||||
|
// decrypt the local secret encryped without a master key now that
|
||||||
|
// the provider is vault and a master key is set.
|
||||||
|
// The provider will not change, the master key will be used
|
||||||
|
secretLocal := kms.NewEmptySecret()
|
||||||
|
err = json.Unmarshal(localAsJSON, secretLocal)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, kms.SecretStatusSecretBox, secretLocal.GetStatus())
|
||||||
|
assert.Equal(t, 0, secretLocal.GetMode())
|
||||||
|
err = secretLocal.Decrypt()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, testPayload, secretLocal.GetPayload())
|
||||||
|
assert.Equal(t, kms.SecretStatusPlain, secretLocal.GetStatus())
|
||||||
|
err = secretLocal.Encrypt()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, kms.SecretStatusSecretBox, secretLocal.GetStatus())
|
||||||
|
assert.Equal(t, 1, secretLocal.GetMode())
|
||||||
|
}
|
||||||
|
|
||||||
|
err = kmsConfig.Initialize()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = os.Remove(masterKeyPath)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
if provider == kms.SecretStatusVaultTransit {
|
||||||
|
os.Unsetenv("VAULT_SERVER_URL")
|
||||||
|
os.Unsetenv("VAULT_SERVER_TOKEN")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpdateUserNoCredentials(t *testing.T) {
|
func TestUpdateUserNoCredentials(t *testing.T) {
|
||||||
user, _, err := httpd.AddUser(getTestUser(), http.StatusOK)
|
user, _, err := httpd.AddUser(getTestUser(), http.StatusOK)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
|
@ -2,7 +2,7 @@ openapi: 3.0.3
|
||||||
info:
|
info:
|
||||||
title: SFTPGo
|
title: SFTPGo
|
||||||
description: 'SFTPGo REST API'
|
description: 'SFTPGo REST API'
|
||||||
version: 2.1.2
|
version: 2.1.3
|
||||||
|
|
||||||
servers:
|
servers:
|
||||||
- url: /api/v1
|
- url: /api/v1
|
||||||
|
@ -977,6 +977,9 @@ components:
|
||||||
type: string
|
type: string
|
||||||
additional_data:
|
additional_data:
|
||||||
type: string
|
type: string
|
||||||
|
mode:
|
||||||
|
type: integer
|
||||||
|
description: 1 means encrypted using a master key
|
||||||
S3Config:
|
S3Config:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
@ -30,6 +30,7 @@ func (s *baseGCloudSecret) Encrypt() error {
|
||||||
|
|
||||||
payload := s.Payload
|
payload := s.Payload
|
||||||
key := ""
|
key := ""
|
||||||
|
mode := 0
|
||||||
if s.masterKey != "" {
|
if s.masterKey != "" {
|
||||||
localSecret := newLocalSecret(s.baseSecret, s.masterKey)
|
localSecret := newLocalSecret(s.baseSecret, s.masterKey)
|
||||||
err := localSecret.Encrypt()
|
err := localSecret.Encrypt()
|
||||||
|
@ -38,6 +39,7 @@ func (s *baseGCloudSecret) Encrypt() error {
|
||||||
}
|
}
|
||||||
payload = localSecret.GetPayload()
|
payload = localSecret.GetPayload()
|
||||||
key = localSecret.GetKey()
|
key = localSecret.GetKey()
|
||||||
|
mode = localSecret.GetMode()
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(defaultTimeout))
|
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(defaultTimeout))
|
||||||
|
@ -55,6 +57,7 @@ func (s *baseGCloudSecret) Encrypt() error {
|
||||||
}
|
}
|
||||||
s.Payload = base64.StdEncoding.EncodeToString(ciphertext)
|
s.Payload = base64.StdEncoding.EncodeToString(ciphertext)
|
||||||
s.Key = key
|
s.Key = key
|
||||||
|
s.Mode = mode
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,6 +86,7 @@ func (s *baseGCloudSecret) Decrypt() error {
|
||||||
Payload: string(plaintext),
|
Payload: string(plaintext),
|
||||||
Key: s.Key,
|
Key: s.Key,
|
||||||
AdditionalData: s.AdditionalData,
|
AdditionalData: s.AdditionalData,
|
||||||
|
Mode: s.Mode,
|
||||||
}
|
}
|
||||||
localSecret := newLocalSecret(baseSecret, s.masterKey)
|
localSecret := newLocalSecret(baseSecret, s.masterKey)
|
||||||
err = localSecret.Decrypt()
|
err = localSecret.Decrypt()
|
||||||
|
@ -95,5 +99,6 @@ func (s *baseGCloudSecret) Decrypt() error {
|
||||||
s.Payload = payload
|
s.Payload = payload
|
||||||
s.Key = ""
|
s.Key = ""
|
||||||
s.AdditionalData = ""
|
s.AdditionalData = ""
|
||||||
|
s.Mode = 0
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,8 @@ type baseSecret struct {
|
||||||
Payload string `json:"payload,omitempty"`
|
Payload string `json:"payload,omitempty"`
|
||||||
Key string `json:"key,omitempty"`
|
Key string `json:"key,omitempty"`
|
||||||
AdditionalData string `json:"additional_data,omitempty"`
|
AdditionalData string `json:"additional_data,omitempty"`
|
||||||
|
// 1 means encrypted using a master key
|
||||||
|
Mode int `json:"mode,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *baseSecret) GetStatus() SecretStatus {
|
func (s *baseSecret) GetStatus() SecretStatus {
|
||||||
|
@ -20,6 +22,10 @@ func (s *baseSecret) GetKey() string {
|
||||||
return s.Key
|
return s.Key
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *baseSecret) GetMode() int {
|
||||||
|
return s.Mode
|
||||||
|
}
|
||||||
|
|
||||||
func (s *baseSecret) GetAdditionalData() string {
|
func (s *baseSecret) GetAdditionalData() string {
|
||||||
return s.AdditionalData
|
return s.AdditionalData
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ type SecretProvider interface {
|
||||||
GetPayload() string
|
GetPayload() string
|
||||||
GetKey() string
|
GetKey() string
|
||||||
GetAdditionalData() string
|
GetAdditionalData() string
|
||||||
|
GetMode() int
|
||||||
SetKey(string)
|
SetKey(string)
|
||||||
SetAdditionalData(string)
|
SetAdditionalData(string)
|
||||||
SetStatus(SecretStatus)
|
SetStatus(SecretStatus)
|
||||||
|
@ -145,6 +146,7 @@ func (s *Secret) MarshalJSON() ([]byte, error) {
|
||||||
Payload: s.provider.GetPayload(),
|
Payload: s.provider.GetPayload(),
|
||||||
Key: s.provider.GetKey(),
|
Key: s.provider.GetKey(),
|
||||||
AdditionalData: s.provider.GetAdditionalData(),
|
AdditionalData: s.provider.GetAdditionalData(),
|
||||||
|
Mode: s.provider.GetMode(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -186,6 +188,7 @@ func (s *Secret) Clone() *Secret {
|
||||||
Payload: s.provider.GetPayload(),
|
Payload: s.provider.GetPayload(),
|
||||||
Key: s.provider.GetKey(),
|
Key: s.provider.GetKey(),
|
||||||
AdditionalData: s.provider.GetAdditionalData(),
|
AdditionalData: s.provider.GetAdditionalData(),
|
||||||
|
Mode: s.provider.GetMode(),
|
||||||
}
|
}
|
||||||
switch s.provider.Name() {
|
switch s.provider.Name() {
|
||||||
case builtinProviderName:
|
case builtinProviderName:
|
||||||
|
@ -249,6 +252,11 @@ func (s *Secret) GetKey() string {
|
||||||
return s.provider.GetKey()
|
return s.provider.GetKey()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetMode returns the secret mode
|
||||||
|
func (s *Secret) GetMode() int {
|
||||||
|
return s.provider.GetMode()
|
||||||
|
}
|
||||||
|
|
||||||
// SetAdditionalData sets the given additional data
|
// SetAdditionalData sets the given additional data
|
||||||
func (s *Secret) SetAdditionalData(value string) {
|
func (s *Secret) SetAdditionalData(value string) {
|
||||||
s.provider.SetAdditionalData(value)
|
s.provider.SetAdditionalData(value)
|
||||||
|
|
17
kms/local.go
17
kms/local.go
|
@ -46,7 +46,7 @@ func (s *localSecret) Encrypt() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
key, err := s.deriveKey(secretKey[:])
|
key, err := s.deriveKey(secretKey[:], false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -60,6 +60,7 @@ func (s *localSecret) Encrypt() error {
|
||||||
s.Key = hex.EncodeToString(secretKey[:])
|
s.Key = hex.EncodeToString(secretKey[:])
|
||||||
s.Payload = base64.StdEncoding.EncodeToString(ciphertext)
|
s.Payload = base64.StdEncoding.EncodeToString(ciphertext)
|
||||||
s.Status = SecretStatusSecretBox
|
s.Status = SecretStatusSecretBox
|
||||||
|
s.Mode = s.getEncryptionMode()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,7 +76,7 @@ func (s *localSecret) Decrypt() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
key, err := s.deriveKey(secretKey[:])
|
key, err := s.deriveKey(secretKey[:], true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -90,12 +91,13 @@ func (s *localSecret) Decrypt() error {
|
||||||
s.Payload = string(plaintext)
|
s.Payload = string(plaintext)
|
||||||
s.Key = ""
|
s.Key = ""
|
||||||
s.AdditionalData = ""
|
s.AdditionalData = ""
|
||||||
|
s.Mode = 0
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *localSecret) deriveKey(key []byte) ([32]byte, error) {
|
func (s *localSecret) deriveKey(key []byte, isForDecryption bool) ([32]byte, error) {
|
||||||
var masterKey []byte
|
var masterKey []byte
|
||||||
if s.masterKey == "" {
|
if s.masterKey == "" || (isForDecryption && s.Mode == 0) {
|
||||||
var combined []byte
|
var combined []byte
|
||||||
combined = append(combined, key...)
|
combined = append(combined, key...)
|
||||||
if s.AdditionalData != "" {
|
if s.AdditionalData != "" {
|
||||||
|
@ -118,3 +120,10 @@ func (s *localSecret) deriveKey(key []byte) ([32]byte, error) {
|
||||||
}
|
}
|
||||||
return derivedKey, nil
|
return derivedKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *localSecret) getEncryptionMode() int {
|
||||||
|
if s.masterKey == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue