diff --git a/httpd/httpd_test.go b/httpd/httpd_test.go index 99aac3ed..21bea07e 100644 --- a/httpd/httpd_test.go +++ b/httpd/httpd_test.go @@ -1390,6 +1390,114 @@ func TestSecretObject(t *testing.T) { require.Empty(t, s.GetKey()) } +func TestSecretObjectCompatibility(t *testing.T) { + // this is manually tested against vault too + testPayload := "test payload" + s := kms.NewPlainSecret(testPayload) + require.True(t, s.IsValid()) + err := s.Encrypt() + require.NoError(t, err) + localAsJSON, err := json.Marshal(s) + assert.NoError(t, err) + + for _, provider := range []string{kms.SecretStatusRedacted} { + kmsConfig := config.GetKMSConfig() + assert.Empty(t, kmsConfig.Secrets.MasterKeyPath) + if provider == kms.SecretStatusVaultTransit { + os.Setenv("VAULT_SERVER_URL", "http://127.0.0.1:8200") + os.Setenv("VAULT_SERVER_TOKEN", "s.9lYGq83MbgG5KR5kfebXVyhJ") + kmsConfig.Secrets.URL = "hashivault://mykey" + } + err := kmsConfig.Initialize() + assert.NoError(t, err) + // encrypt without a master key + secret := kms.NewPlainSecret(testPayload) + secret.SetAdditionalData("add data") + err = secret.Encrypt() + assert.NoError(t, err) + assert.Equal(t, 0, secret.GetMode()) + secretClone := secret.Clone() + err = secretClone.Decrypt() + assert.NoError(t, err) + assert.Equal(t, testPayload, secretClone.GetPayload()) + if provider == kms.SecretStatusVaultTransit { + // decrypt the local secret now that the provider is vault + secretLocal := kms.NewEmptySecret() + err = json.Unmarshal(localAsJSON, secretLocal) + assert.NoError(t, err) + assert.Equal(t, kms.SecretStatusSecretBox, secretLocal.GetStatus()) + assert.Equal(t, 0, secretLocal.GetMode()) + err = secretLocal.Decrypt() + assert.NoError(t, err) + assert.Equal(t, testPayload, secretLocal.GetPayload()) + assert.Equal(t, kms.SecretStatusPlain, secretLocal.GetStatus()) + err = secretLocal.Encrypt() + assert.NoError(t, err) + assert.Equal(t, kms.SecretStatusSecretBox, secretLocal.GetStatus()) + assert.Equal(t, 0, secretLocal.GetMode()) + } + + asJSON, err := json.Marshal(secret) + assert.NoError(t, err) + + masterKeyPath := filepath.Join(os.TempDir(), "mkey") + err = ioutil.WriteFile(masterKeyPath, []byte("test key"), os.ModePerm) + assert.NoError(t, err) + config := kms.Configuration{ + Secrets: kms.Secrets{ + MasterKeyPath: masterKeyPath, + }, + } + if provider == kms.SecretStatusVaultTransit { + config.Secrets.URL = "hashivault://mykey" + } + err = config.Initialize() + assert.NoError(t, err) + + // now build the secret from JSON + secret = kms.NewEmptySecret() + err = json.Unmarshal(asJSON, secret) + assert.NoError(t, err) + assert.Equal(t, 0, secret.GetMode()) + err = secret.Decrypt() + assert.NoError(t, err) + assert.Equal(t, testPayload, secret.GetPayload()) + err = secret.Encrypt() + assert.NoError(t, err) + assert.Equal(t, 1, secret.GetMode()) + err = secret.Decrypt() + assert.NoError(t, err) + assert.Equal(t, testPayload, secret.GetPayload()) + if provider == kms.SecretStatusVaultTransit { + // decrypt the local secret encryped without a master key now that + // the provider is vault and a master key is set. + // The provider will not change, the master key will be used + secretLocal := kms.NewEmptySecret() + err = json.Unmarshal(localAsJSON, secretLocal) + assert.NoError(t, err) + assert.Equal(t, kms.SecretStatusSecretBox, secretLocal.GetStatus()) + assert.Equal(t, 0, secretLocal.GetMode()) + err = secretLocal.Decrypt() + assert.NoError(t, err) + assert.Equal(t, testPayload, secretLocal.GetPayload()) + assert.Equal(t, kms.SecretStatusPlain, secretLocal.GetStatus()) + err = secretLocal.Encrypt() + assert.NoError(t, err) + assert.Equal(t, kms.SecretStatusSecretBox, secretLocal.GetStatus()) + assert.Equal(t, 1, secretLocal.GetMode()) + } + + err = kmsConfig.Initialize() + assert.NoError(t, err) + err = os.Remove(masterKeyPath) + assert.NoError(t, err) + if provider == kms.SecretStatusVaultTransit { + os.Unsetenv("VAULT_SERVER_URL") + os.Unsetenv("VAULT_SERVER_TOKEN") + } + } +} + func TestUpdateUserNoCredentials(t *testing.T) { user, _, err := httpd.AddUser(getTestUser(), http.StatusOK) assert.NoError(t, err) diff --git a/httpd/schema/openapi.yaml b/httpd/schema/openapi.yaml index 2b5a9352..dc2d2730 100644 --- a/httpd/schema/openapi.yaml +++ b/httpd/schema/openapi.yaml @@ -2,7 +2,7 @@ openapi: 3.0.3 info: title: SFTPGo description: 'SFTPGo REST API' - version: 2.1.2 + version: 2.1.3 servers: - url: /api/v1 @@ -977,6 +977,9 @@ components: type: string additional_data: type: string + mode: + type: integer + description: 1 means encrypted using a master key S3Config: type: object properties: diff --git a/kms/basegocloud.go b/kms/basegocloud.go index 442ed6ec..4a4930cb 100644 --- a/kms/basegocloud.go +++ b/kms/basegocloud.go @@ -30,6 +30,7 @@ func (s *baseGCloudSecret) Encrypt() error { payload := s.Payload key := "" + mode := 0 if s.masterKey != "" { localSecret := newLocalSecret(s.baseSecret, s.masterKey) err := localSecret.Encrypt() @@ -38,6 +39,7 @@ func (s *baseGCloudSecret) Encrypt() error { } payload = localSecret.GetPayload() key = localSecret.GetKey() + mode = localSecret.GetMode() } ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(defaultTimeout)) @@ -55,6 +57,7 @@ func (s *baseGCloudSecret) Encrypt() error { } s.Payload = base64.StdEncoding.EncodeToString(ciphertext) s.Key = key + s.Mode = mode return nil } @@ -83,6 +86,7 @@ func (s *baseGCloudSecret) Decrypt() error { Payload: string(plaintext), Key: s.Key, AdditionalData: s.AdditionalData, + Mode: s.Mode, } localSecret := newLocalSecret(baseSecret, s.masterKey) err = localSecret.Decrypt() @@ -95,5 +99,6 @@ func (s *baseGCloudSecret) Decrypt() error { s.Payload = payload s.Key = "" s.AdditionalData = "" + s.Mode = 0 return nil } diff --git a/kms/basesecret.go b/kms/basesecret.go index a5790da0..1fc401a3 100644 --- a/kms/basesecret.go +++ b/kms/basesecret.go @@ -6,6 +6,8 @@ type baseSecret struct { Payload string `json:"payload,omitempty"` Key string `json:"key,omitempty"` AdditionalData string `json:"additional_data,omitempty"` + // 1 means encrypted using a master key + Mode int `json:"mode,omitempty"` } func (s *baseSecret) GetStatus() SecretStatus { @@ -20,6 +22,10 @@ func (s *baseSecret) GetKey() string { return s.Key } +func (s *baseSecret) GetMode() int { + return s.Mode +} + func (s *baseSecret) GetAdditionalData() string { return s.AdditionalData } diff --git a/kms/kms.go b/kms/kms.go index e845c53b..b150c100 100644 --- a/kms/kms.go +++ b/kms/kms.go @@ -21,6 +21,7 @@ type SecretProvider interface { GetPayload() string GetKey() string GetAdditionalData() string + GetMode() int SetKey(string) SetAdditionalData(string) SetStatus(SecretStatus) @@ -145,6 +146,7 @@ func (s *Secret) MarshalJSON() ([]byte, error) { Payload: s.provider.GetPayload(), Key: s.provider.GetKey(), AdditionalData: s.provider.GetAdditionalData(), + Mode: s.provider.GetMode(), }) } @@ -186,6 +188,7 @@ func (s *Secret) Clone() *Secret { Payload: s.provider.GetPayload(), Key: s.provider.GetKey(), AdditionalData: s.provider.GetAdditionalData(), + Mode: s.provider.GetMode(), } switch s.provider.Name() { case builtinProviderName: @@ -249,6 +252,11 @@ func (s *Secret) GetKey() string { return s.provider.GetKey() } +// GetMode returns the secret mode +func (s *Secret) GetMode() int { + return s.provider.GetMode() +} + // SetAdditionalData sets the given additional data func (s *Secret) SetAdditionalData(value string) { s.provider.SetAdditionalData(value) diff --git a/kms/local.go b/kms/local.go index bff8395d..ceef4e81 100644 --- a/kms/local.go +++ b/kms/local.go @@ -46,7 +46,7 @@ func (s *localSecret) Encrypt() error { if err != nil { return err } - key, err := s.deriveKey(secretKey[:]) + key, err := s.deriveKey(secretKey[:], false) if err != nil { return err } @@ -60,6 +60,7 @@ func (s *localSecret) Encrypt() error { s.Key = hex.EncodeToString(secretKey[:]) s.Payload = base64.StdEncoding.EncodeToString(ciphertext) s.Status = SecretStatusSecretBox + s.Mode = s.getEncryptionMode() return nil } @@ -75,7 +76,7 @@ func (s *localSecret) Decrypt() error { if err != nil { return err } - key, err := s.deriveKey(secretKey[:]) + key, err := s.deriveKey(secretKey[:], true) if err != nil { return err } @@ -90,12 +91,13 @@ func (s *localSecret) Decrypt() error { s.Payload = string(plaintext) s.Key = "" s.AdditionalData = "" + s.Mode = 0 return nil } -func (s *localSecret) deriveKey(key []byte) ([32]byte, error) { +func (s *localSecret) deriveKey(key []byte, isForDecryption bool) ([32]byte, error) { var masterKey []byte - if s.masterKey == "" { + if s.masterKey == "" || (isForDecryption && s.Mode == 0) { var combined []byte combined = append(combined, key...) if s.AdditionalData != "" { @@ -118,3 +120,10 @@ func (s *localSecret) deriveKey(key []byte) ([32]byte, error) { } return derivedKey, nil } + +func (s *localSecret) getEncryptionMode() int { + if s.masterKey == "" { + return 0 + } + return 1 +}