浏览代码

fix loading enabled_ssh_commands config key

Nicola Murino 4 年之前
父节点
当前提交
4781921336
共有 2 个文件被更改,包括 50 次插入2 次删除
  1. 2 2
      config/config.go
  2. 48 0
      config/config_test.go

+ 2 - 2
config/config.go

@@ -150,7 +150,7 @@ func Init() {
 			MACs:                    []string{},
 			MACs:                    []string{},
 			TrustedUserCAKeys:       []string{},
 			TrustedUserCAKeys:       []string{},
 			LoginBannerFile:         "",
 			LoginBannerFile:         "",
-			EnabledSSHCommands:      sftpd.GetDefaultSSHCommands(),
+			EnabledSSHCommands:      []string{},
 			KeyboardInteractiveHook: "",
 			KeyboardInteractiveHook: "",
 			PasswordAuthentication:  true,
 			PasswordAuthentication:  true,
 			FolderPrefix:            "",
 			FolderPrefix:            "",
@@ -975,7 +975,7 @@ func setViperDefaults() {
 	viper.SetDefault("sftpd.macs", globalConf.SFTPD.MACs)
 	viper.SetDefault("sftpd.macs", globalConf.SFTPD.MACs)
 	viper.SetDefault("sftpd.trusted_user_ca_keys", globalConf.SFTPD.TrustedUserCAKeys)
 	viper.SetDefault("sftpd.trusted_user_ca_keys", globalConf.SFTPD.TrustedUserCAKeys)
 	viper.SetDefault("sftpd.login_banner_file", globalConf.SFTPD.LoginBannerFile)
 	viper.SetDefault("sftpd.login_banner_file", globalConf.SFTPD.LoginBannerFile)
-	viper.SetDefault("sftpd.enabled_ssh_commands", globalConf.SFTPD.EnabledSSHCommands)
+	viper.SetDefault("sftpd.enabled_ssh_commands", sftpd.GetDefaultSSHCommands())
 	viper.SetDefault("sftpd.keyboard_interactive_auth_hook", globalConf.SFTPD.KeyboardInteractiveHook)
 	viper.SetDefault("sftpd.keyboard_interactive_auth_hook", globalConf.SFTPD.KeyboardInteractiveHook)
 	viper.SetDefault("sftpd.password_authentication", globalConf.SFTPD.PasswordAuthentication)
 	viper.SetDefault("sftpd.password_authentication", globalConf.SFTPD.PasswordAuthentication)
 	viper.SetDefault("sftpd.folder_prefix", globalConf.SFTPD.FolderPrefix)
 	viper.SetDefault("sftpd.folder_prefix", globalConf.SFTPD.FolderPrefix)

+ 48 - 0
config/config_test.go

@@ -102,6 +102,35 @@ func TestEmptyBanner(t *testing.T) {
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 }
 }
 
 
+func TestEnabledSSHCommands(t *testing.T) {
+	reset()
+
+	configDir := ".."
+	confName := tempConfigName + ".json"
+	configFilePath := filepath.Join(configDir, confName)
+	err := config.LoadConfig(configDir, "")
+	assert.NoError(t, err)
+
+	reset()
+
+	sftpdConf := config.GetSFTPDConfig()
+	sftpdConf.EnabledSSHCommands = []string{"scp"}
+	c := make(map[string]sftpd.Configuration)
+	c["sftpd"] = sftpdConf
+	jsonConf, err := json.Marshal(c)
+	assert.NoError(t, err)
+	err = os.WriteFile(configFilePath, jsonConf, os.ModePerm)
+	assert.NoError(t, err)
+	err = config.LoadConfig(configDir, confName)
+	assert.NoError(t, err)
+	sftpdConf = config.GetSFTPDConfig()
+	if assert.Len(t, sftpdConf.EnabledSSHCommands, 1) {
+		assert.Equal(t, "scp", sftpdConf.EnabledSSHCommands[0])
+	}
+	err = os.Remove(configFilePath)
+	assert.NoError(t, err)
+}
+
 func TestInvalidUploadMode(t *testing.T) {
 func TestInvalidUploadMode(t *testing.T) {
 	reset()
 	reset()
 
 
@@ -291,6 +320,25 @@ func TestServiceToStart(t *testing.T) {
 	assert.True(t, config.HasServicesToStart())
 	assert.True(t, config.HasServicesToStart())
 }
 }
 
 
+func TestSSHCommandsFromEnv(t *testing.T) {
+	reset()
+
+	os.Setenv("SFTPGO_SFTPD__ENABLED_SSH_COMMANDS", "cd,scp")
+	t.Cleanup(func() {
+		os.Unsetenv("SFTPGO_SFTPD__ENABLED_SSH_COMMANDS")
+	})
+
+	configDir := ".."
+	err := config.LoadConfig(configDir, "")
+	assert.NoError(t, err)
+
+	sftpdConf := config.GetSFTPDConfig()
+	if assert.Len(t, sftpdConf.EnabledSSHCommands, 2) {
+		assert.Equal(t, "cd", sftpdConf.EnabledSSHCommands[0])
+		assert.Equal(t, "scp", sftpdConf.EnabledSSHCommands[1])
+	}
+}
+
 func TestPluginsFromEnv(t *testing.T) {
 func TestPluginsFromEnv(t *testing.T) {
 	reset()
 	reset()