Browse Source

kms: remember if a secret was saved without a master key

So we will be able to decrypt secret stored without a master key if a
such key is provided later
Nicola Murino 4 năm trước cách đây
mục cha
commit
87b51a6fd5
6 tập tin đã thay đổi với 144 bổ sung5 xóa
  1. 108 0
      httpd/httpd_test.go
  2. 4 1
      httpd/schema/openapi.yaml
  3. 5 0
      kms/basegocloud.go
  4. 6 0
      kms/basesecret.go
  5. 8 0
      kms/kms.go
  6. 13 4
      kms/local.go

+ 108 - 0
httpd/httpd_test.go

@@ -1390,6 +1390,114 @@ func TestSecretObject(t *testing.T) {
 	require.Empty(t, s.GetKey())
 }
 
+func TestSecretObjectCompatibility(t *testing.T) {
+	// this is manually tested against vault too
+	testPayload := "test payload"
+	s := kms.NewPlainSecret(testPayload)
+	require.True(t, s.IsValid())
+	err := s.Encrypt()
+	require.NoError(t, err)
+	localAsJSON, err := json.Marshal(s)
+	assert.NoError(t, err)
+
+	for _, provider := range []string{kms.SecretStatusRedacted} {
+		kmsConfig := config.GetKMSConfig()
+		assert.Empty(t, kmsConfig.Secrets.MasterKeyPath)
+		if provider == kms.SecretStatusVaultTransit {
+			os.Setenv("VAULT_SERVER_URL", "http://127.0.0.1:8200")
+			os.Setenv("VAULT_SERVER_TOKEN", "s.9lYGq83MbgG5KR5kfebXVyhJ")
+			kmsConfig.Secrets.URL = "hashivault://mykey"
+		}
+		err := kmsConfig.Initialize()
+		assert.NoError(t, err)
+		// encrypt without a master key
+		secret := kms.NewPlainSecret(testPayload)
+		secret.SetAdditionalData("add data")
+		err = secret.Encrypt()
+		assert.NoError(t, err)
+		assert.Equal(t, 0, secret.GetMode())
+		secretClone := secret.Clone()
+		err = secretClone.Decrypt()
+		assert.NoError(t, err)
+		assert.Equal(t, testPayload, secretClone.GetPayload())
+		if provider == kms.SecretStatusVaultTransit {
+			// decrypt the local secret now that the provider is vault
+			secretLocal := kms.NewEmptySecret()
+			err = json.Unmarshal(localAsJSON, secretLocal)
+			assert.NoError(t, err)
+			assert.Equal(t, kms.SecretStatusSecretBox, secretLocal.GetStatus())
+			assert.Equal(t, 0, secretLocal.GetMode())
+			err = secretLocal.Decrypt()
+			assert.NoError(t, err)
+			assert.Equal(t, testPayload, secretLocal.GetPayload())
+			assert.Equal(t, kms.SecretStatusPlain, secretLocal.GetStatus())
+			err = secretLocal.Encrypt()
+			assert.NoError(t, err)
+			assert.Equal(t, kms.SecretStatusSecretBox, secretLocal.GetStatus())
+			assert.Equal(t, 0, secretLocal.GetMode())
+		}
+
+		asJSON, err := json.Marshal(secret)
+		assert.NoError(t, err)
+
+		masterKeyPath := filepath.Join(os.TempDir(), "mkey")
+		err = ioutil.WriteFile(masterKeyPath, []byte("test key"), os.ModePerm)
+		assert.NoError(t, err)
+		config := kms.Configuration{
+			Secrets: kms.Secrets{
+				MasterKeyPath: masterKeyPath,
+			},
+		}
+		if provider == kms.SecretStatusVaultTransit {
+			config.Secrets.URL = "hashivault://mykey"
+		}
+		err = config.Initialize()
+		assert.NoError(t, err)
+
+		// now build the secret from JSON
+		secret = kms.NewEmptySecret()
+		err = json.Unmarshal(asJSON, secret)
+		assert.NoError(t, err)
+		assert.Equal(t, 0, secret.GetMode())
+		err = secret.Decrypt()
+		assert.NoError(t, err)
+		assert.Equal(t, testPayload, secret.GetPayload())
+		err = secret.Encrypt()
+		assert.NoError(t, err)
+		assert.Equal(t, 1, secret.GetMode())
+		err = secret.Decrypt()
+		assert.NoError(t, err)
+		assert.Equal(t, testPayload, secret.GetPayload())
+		if provider == kms.SecretStatusVaultTransit {
+			// decrypt the local secret encryped without a master key now that
+			// the provider is vault and a master key is set.
+			// The provider will not change, the master key will be used
+			secretLocal := kms.NewEmptySecret()
+			err = json.Unmarshal(localAsJSON, secretLocal)
+			assert.NoError(t, err)
+			assert.Equal(t, kms.SecretStatusSecretBox, secretLocal.GetStatus())
+			assert.Equal(t, 0, secretLocal.GetMode())
+			err = secretLocal.Decrypt()
+			assert.NoError(t, err)
+			assert.Equal(t, testPayload, secretLocal.GetPayload())
+			assert.Equal(t, kms.SecretStatusPlain, secretLocal.GetStatus())
+			err = secretLocal.Encrypt()
+			assert.NoError(t, err)
+			assert.Equal(t, kms.SecretStatusSecretBox, secretLocal.GetStatus())
+			assert.Equal(t, 1, secretLocal.GetMode())
+		}
+
+		err = kmsConfig.Initialize()
+		assert.NoError(t, err)
+		err = os.Remove(masterKeyPath)
+		assert.NoError(t, err)
+		if provider == kms.SecretStatusVaultTransit {
+			os.Unsetenv("VAULT_SERVER_URL")
+			os.Unsetenv("VAULT_SERVER_TOKEN")
+		}
+	}
+}
+
 func TestUpdateUserNoCredentials(t *testing.T) {
 	user, _, err := httpd.AddUser(getTestUser(), http.StatusOK)
 	assert.NoError(t, err)

+ 4 - 1
httpd/schema/openapi.yaml

@@ -2,7 +2,7 @@ openapi: 3.0.3
 info:
   title: SFTPGo
   description: 'SFTPGo REST API'
-  version: 2.1.2
+  version: 2.1.3
 
 servers:
   - url: /api/v1
@@ -977,6 +977,9 @@ components:
           type: string
         additional_data:
           type: string
+        mode:
+          type: integer
+          description: 1 means encrypted using a master key
     S3Config:
       type: object
       properties:

+ 5 - 0
kms/basegocloud.go

@@ -30,6 +30,7 @@ func (s *baseGCloudSecret) Encrypt() error {
 
 	payload := s.Payload
 	key := ""
+	mode := 0
 	if s.masterKey != "" {
 		localSecret := newLocalSecret(s.baseSecret, s.masterKey)
 		err := localSecret.Encrypt()
@@ -38,6 +39,7 @@ func (s *baseGCloudSecret) Encrypt() error {
 		}
 		payload = localSecret.GetPayload()
 		key = localSecret.GetKey()
+		mode = localSecret.GetMode()
 	}
 
 	ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(defaultTimeout))
@@ -55,6 +57,7 @@ func (s *baseGCloudSecret) Encrypt() error {
 	}
 	s.Payload = base64.StdEncoding.EncodeToString(ciphertext)
 	s.Key = key
+	s.Mode = mode
 	return nil
 }
 
@@ -83,6 +86,7 @@ func (s *baseGCloudSecret) Decrypt() error {
 			Payload:        string(plaintext),
 			Key:            s.Key,
 			AdditionalData: s.AdditionalData,
+			Mode:           s.Mode,
 		}
 		localSecret := newLocalSecret(baseSecret, s.masterKey)
 		err = localSecret.Decrypt()
@@ -95,5 +99,6 @@ func (s *baseGCloudSecret) Decrypt() error {
 	s.Payload = payload
 	s.Key = ""
 	s.AdditionalData = ""
+	s.Mode = 0
 	return nil
 }

+ 6 - 0
kms/basesecret.go

@@ -6,6 +6,8 @@ type baseSecret struct {
 	Payload        string       `json:"payload,omitempty"`
 	Key            string       `json:"key,omitempty"`
 	AdditionalData string       `json:"additional_data,omitempty"`
+	// 1 means encrypted using a master key
+	Mode int `json:"mode,omitempty"`
 }
 
 func (s *baseSecret) GetStatus() SecretStatus {
@@ -20,6 +22,10 @@ func (s *baseSecret) GetKey() string {
 	return s.Key
 }
 
+func (s *baseSecret) GetMode() int {
+	return s.Mode
+}
+
 func (s *baseSecret) GetAdditionalData() string {
 	return s.AdditionalData
 }

+ 8 - 0
kms/kms.go

@@ -21,6 +21,7 @@ type SecretProvider interface {
 	GetPayload() string
 	GetKey() string
 	GetAdditionalData() string
+	GetMode() int
 	SetKey(string)
 	SetAdditionalData(string)
 	SetStatus(SecretStatus)
@@ -145,6 +146,7 @@ func (s *Secret) MarshalJSON() ([]byte, error) {
 		Payload:        s.provider.GetPayload(),
 		Key:            s.provider.GetKey(),
 		AdditionalData: s.provider.GetAdditionalData(),
+		Mode:           s.provider.GetMode(),
 	})
 }
 
@@ -186,6 +188,7 @@ func (s *Secret) Clone() *Secret {
 		Payload:        s.provider.GetPayload(),
 		Key:            s.provider.GetKey(),
 		AdditionalData: s.provider.GetAdditionalData(),
+		Mode:           s.provider.GetMode(),
 	}
 	switch s.provider.Name() {
 	case builtinProviderName:
@@ -249,6 +252,11 @@ func (s *Secret) GetKey() string {
 	return s.provider.GetKey()
 }
 
+// GetMode returns the secret mode
+func (s *Secret) GetMode() int {
+	return s.provider.GetMode()
+}
+
 // SetAdditionalData sets the given additional data
 func (s *Secret) SetAdditionalData(value string) {
 	s.provider.SetAdditionalData(value)

+ 13 - 4
kms/local.go

@@ -46,7 +46,7 @@ func (s *localSecret) Encrypt() error {
 	if err != nil {
 		return err
 	}
-	key, err := s.deriveKey(secretKey[:])
+	key, err := s.deriveKey(secretKey[:], false)
 	if err != nil {
 		return err
 	}
@@ -60,6 +60,7 @@ func (s *localSecret) Encrypt() error {
 	s.Key = hex.EncodeToString(secretKey[:])
 	s.Payload = base64.StdEncoding.EncodeToString(ciphertext)
 	s.Status = SecretStatusSecretBox
+	s.Mode = s.getEncryptionMode()
 	return nil
 }
 
@@ -75,7 +76,7 @@ func (s *localSecret) Decrypt() error {
 	if err != nil {
 		return err
 	}
-	key, err := s.deriveKey(secretKey[:])
+	key, err := s.deriveKey(secretKey[:], true)
 	if err != nil {
 		return err
 	}
@@ -90,12 +91,13 @@ func (s *localSecret) Decrypt() error {
 	s.Payload = string(plaintext)
 	s.Key = ""
 	s.AdditionalData = ""
+	s.Mode = 0
 	return nil
 }
 
-func (s *localSecret) deriveKey(key []byte) ([32]byte, error) {
+func (s *localSecret) deriveKey(key []byte, isForDecryption bool) ([32]byte, error) {
 	var masterKey []byte
-	if s.masterKey == "" {
+	if s.masterKey == "" || (isForDecryption && s.Mode == 0) {
 		var combined []byte
 		combined = append(combined, key...)
 		if s.AdditionalData != "" {
@@ -118,3 +120,10 @@ func (s *localSecret) deriveKey(key []byte) ([32]byte, error) {
 	}
 	return derivedKey, nil
 }
+
+func (s *localSecret) getEncryptionMode() int {
+	if s.masterKey == "" {
+		return 0
+	}
+	return 1
+}