Browse Source

Accept a config file path instead of a config name

Config name is a Viper concept used for searching a specific file
in various paths with various extensions.

Making it configurable is usually not a useful feature
as users mostly want to define a full or relative path
to a config file.

This change replaces config name with config file.
Márk Sági-Kazár 4 years ago
parent
commit
2a9ed0abca
6 changed files with 77 additions and 78 deletions
  1. 1 1
      cmd/install_windows.go
  2. 0 1
      cmd/portable.go
  3. 8 19
      cmd/root.go
  4. 1 26
      cmd/startsubsys.go
  5. 15 8
      config/config.go
  6. 52 23
      config/config_test.go

+ 1 - 1
cmd/install_windows.go

@@ -64,7 +64,7 @@ func getCustomServeFlags() []string {
 		result = append(result, "--"+configDirFlag)
 		result = append(result, configDir)
 	}
-	if configFile != defaultConfigName {
+	if configFile != "" {
 		result = append(result, "--"+configFileFlag)
 		result = append(result, configFile)
 	}

+ 0 - 1
cmd/portable.go

@@ -124,7 +124,6 @@ Please take a look at the usage below to customize the serving parameters`,
 			}
 			service := service.Service{
 				ConfigDir:     filepath.Clean(defaultConfigDir),
-				ConfigFile:    defaultConfigName,
 				LogFilePath:   portableLogFile,
 				LogMaxSize:    defaultLogMaxSize,
 				LogMaxBackups: defaultLogMaxBackup,

+ 8 - 19
cmd/root.go

@@ -8,7 +8,6 @@ import (
 	"github.com/spf13/cobra"
 	"github.com/spf13/viper"
 
-	"github.com/drakkan/sftpgo/config"
 	"github.com/drakkan/sftpgo/version"
 )
 
@@ -16,7 +15,6 @@ const (
 	configDirFlag            = "config-dir"
 	configDirKey             = "config_dir"
 	configFileFlag           = "config-file"
-	configFileKey            = "config_file"
 	logFilePathFlag          = "log-file-path"
 	logFilePathKey           = "log_file_path"
 	logMaxSizeFlag           = "log-max-size"
@@ -40,7 +38,6 @@ const (
 	loadDataCleanFlag        = "loaddata-clean"
 	loadDataCleanKey         = "loaddata_clean"
 	defaultConfigDir         = "."
-	defaultConfigName        = config.DefaultConfigName
 	defaultLogFile           = "sftpgo.log"
 	defaultLogMaxSize        = 10
 	defaultLogMaxBackup      = 5
@@ -96,29 +93,21 @@ func addConfigFlags(cmd *cobra.Command) {
 	viper.BindEnv(configDirKey, "SFTPGO_CONFIG_DIR") //nolint:errcheck // err is not nil only if the key to bind is missing
 	cmd.Flags().StringVarP(&configDir, configDirFlag, "c", viper.GetString(configDirKey),
 		`Location for SFTPGo config dir. This directory
-should contain the "sftpgo" configuration file
-or the configured config-file and it is used as
-the base for files with a relative path (eg. the
-private keys for the SFTP server, the SQLite
+should contain the "sftpgo" configuration file.
+It is used as the base for files with a relative path
+(eg. the private keys for the SFTP server, the SQLite
 database if you use SQLite as data provider).
 This flag can be set using SFTPGO_CONFIG_DIR
 env var too.`)
 	viper.BindPFlag(configDirKey, cmd.Flags().Lookup(configDirFlag)) //nolint:errcheck
 
-	viper.SetDefault(configFileKey, defaultConfigName)
-	viper.BindEnv(configFileKey, "SFTPGO_CONFIG_FILE") //nolint:errcheck
-	cmd.Flags().StringVarP(&configFile, configFileFlag, "f", viper.GetString(configFileKey),
-		`Name for SFTPGo configuration file. It must be
-the name of a file stored in config-dir not the
-absolute path to the configuration file. The
-specified file name must have no extension we
-automatically load JSON, YAML, TOML, HCL and
-Java properties. Therefore if you set "sftpgo"
-then "sftpgo.json", "sftpgo.yaml" and so on
-are searched.
+	cmd.Flags().StringVar(&configFile, configFileFlag, os.Getenv("SFTPGO_CONFIG_FILE"),
+		`Path to SFTPGo configuration file. It must be
+an absolute path to a file or a path relative to the working directory.
+The specified file name must have a supported extension
+(JSON, YAML, TOML or Java properties).
 This flag can be set using SFTPGO_CONFIG_FILE
 env var too.`)
-	viper.BindPFlag(configFileKey, cmd.Flags().Lookup(configFileFlag)) //nolint:errcheck
 }
 
 func addServeFlags(cmd *cobra.Command) {

+ 1 - 26
cmd/startsubsys.go

@@ -145,33 +145,8 @@ $ journalctl -o verbose -f
 To see full logs.
 If not set, the logs will be sent to the standard
 error`)
-	viper.SetDefault(configDirKey, defaultConfigDir)
-	viper.BindEnv(configDirKey, "SFTPGO_CONFIG_DIR") //nolint:errcheck // err is not nil only if the key to bind is missing
-	subsystemCmd.Flags().StringVarP(&configDir, configDirFlag, "c", viper.GetString(configDirKey),
-		`Location for SFTPGo config dir. This directory
-should contain the "sftpgo" configuration file
-or the configured config-file and it is used as
-the base for files with a relative path (eg. the
-private keys for the SFTP server, the SQLite
-database if you use SQLite as data provider).
-This flag can be set using SFTPGO_CONFIG_DIR
-env var too.`)
-	viper.BindPFlag(configDirKey, subsystemCmd.Flags().Lookup(configDirFlag)) //nolint:errcheck
 
-	viper.SetDefault(configFileKey, defaultConfigName)
-	viper.BindEnv(configFileKey, "SFTPGO_CONFIG_FILE") //nolint:errcheck
-	subsystemCmd.Flags().StringVarP(&configFile, configFileFlag, "f", viper.GetString(configFileKey),
-		`Name for SFTPGo configuration file. It must be
-the name of a file stored in config-dir not the
-absolute path to the configuration file. The
-specified file name must have no extension we
-automatically load JSON, YAML, TOML, HCL and
-Java properties. Therefore if you set "sftpgo"
-then "sftpgo.json", "sftpgo.yaml" and so on
-are searched.
-This flag can be set using SFTPGO_CONFIG_FILE
-env var too.`)
-	viper.BindPFlag(configFileKey, subsystemCmd.Flags().Lookup(configFileFlag)) //nolint:errcheck
+	addConfigFlags(subsystemCmd)
 
 	viper.SetDefault(logVerboseKey, defaultLogVerbose)
 	viper.BindEnv(logVerboseKey, "SFTPGO_LOG_VERBOSE") //nolint:errcheck

+ 15 - 8
config/config.go

@@ -22,10 +22,10 @@ import (
 
 const (
 	logSender = "config"
-	// DefaultConfigName defines the name for the default config file.
+	// configName defines the name for the default config file.
 	// This is the file name without extension, we use viper and so we
 	// support all the config files format supported by viper
-	DefaultConfigName = "sftpgo"
+	configName = "sftpgo"
 	// ConfigEnvPrefix defines a prefix that ENVIRONMENT variables will use
 	configEnvPrefix = "sftpgo"
 )
@@ -48,6 +48,13 @@ type globalConfig struct {
 }
 
 func init() {
+	Init()
+}
+
+// Init initializes the global configuration.
+// It is not supposed to be called outside of this package.
+// It is exported to minimize refactoring efforts. Will eventually disappear.
+func Init() {
 	// create a default configuration to use if no config file is provided
 	globalConf = globalConfig{
 		Common: common.Configuration{
@@ -177,7 +184,7 @@ func init() {
 	viper.SetEnvPrefix(configEnvPrefix)
 	replacer := strings.NewReplacer(".", "__")
 	viper.SetEnvKeyReplacer(replacer)
-	viper.SetConfigName(DefaultConfigName)
+	viper.SetConfigName(configName)
 	setViperDefaults()
 	viper.AutomaticEnv()
 	viper.AllowEmptyEnv(true)
@@ -233,12 +240,12 @@ func SetHTTPDConfig(config httpd.Conf) {
 	globalConf.HTTPDConfig = config
 }
 
-//GetProviderConf returns the configuration for the data provider
+// GetProviderConf returns the configuration for the data provider
 func GetProviderConf() dataprovider.Config {
 	return globalConf.ProviderConf
 }
 
-//SetProviderConf sets the configuration for the data provider
+// SetProviderConf sets the configuration for the data provider
 func SetProviderConf(config dataprovider.Config) {
 	globalConf.ProviderConf = config
 }
@@ -283,13 +290,13 @@ func getRedactedGlobalConf() globalConfig {
 // configDir will be added to the configuration search paths.
 // The search path contains by default the current directory and on linux it contains
 // $HOME/.config/sftpgo and /etc/sftpgo too.
-// configName is the name of the configuration to search without extension
-func LoadConfig(configDir, configName string) error {
+// configFile is an absolute or relative path (to the working directory) to the configuration file.
+func LoadConfig(configDir, configFile string) error {
 	var err error
 	viper.AddConfigPath(configDir)
 	setViperAdditionalConfigPaths()
 	viper.AddConfigPath(".")
-	viper.SetConfigName(configName)
+	viper.SetConfigFile(configFile)
 	if err = viper.ReadInConfig(); err != nil {
 		logger.Warn(logSender, "", "error loading configuration file: %v", err)
 		logger.WarnToConsole("error loading configuration file: %v", err)

+ 52 - 23
config/config_test.go

@@ -8,6 +8,7 @@ import (
 	"strings"
 	"testing"
 
+	"github.com/spf13/viper"
 	"github.com/stretchr/testify/assert"
 
 	"github.com/drakkan/sftpgo/common"
@@ -22,12 +23,18 @@ import (
 
 const (
 	tempConfigName = "temp"
-	configName     = "sftpgo"
 )
 
+func reset() {
+	viper.Reset()
+	config.Init()
+}
+
 func TestLoadConfigTest(t *testing.T) {
+	reset()
+
 	configDir := ".."
-	err := config.LoadConfig(configDir, configName)
+	err := config.LoadConfig(configDir, "")
 	assert.NoError(t, err)
 	assert.NotEqual(t, httpd.Conf{}, config.GetHTTPConfig())
 	assert.NotEqual(t, dataprovider.Config{}, config.GetProviderConf())
@@ -35,25 +42,27 @@ func TestLoadConfigTest(t *testing.T) {
 	assert.NotEqual(t, httpclient.Config{}, config.GetHTTPConfig())
 	confName := tempConfigName + ".json"
 	configFilePath := filepath.Join(configDir, confName)
-	err = config.LoadConfig(configDir, tempConfigName)
+	err = config.LoadConfig(configDir, configFilePath)
 	assert.NoError(t, err)
 	err = ioutil.WriteFile(configFilePath, []byte("{invalid json}"), os.ModePerm)
 	assert.NoError(t, err)
-	err = config.LoadConfig(configDir, tempConfigName)
+	err = config.LoadConfig(configDir, configFilePath)
 	assert.NoError(t, err)
 	err = ioutil.WriteFile(configFilePath, []byte("{\"sftpd\": {\"bind_port\": \"a\"}}"), os.ModePerm)
 	assert.NoError(t, err)
-	err = config.LoadConfig(configDir, tempConfigName)
+	err = config.LoadConfig(configDir, configFilePath)
 	assert.NotNil(t, err)
 	err = os.Remove(configFilePath)
 	assert.NoError(t, err)
 }
 
 func TestEmptyBanner(t *testing.T) {
+	reset()
+
 	configDir := ".."
 	confName := tempConfigName + ".json"
 	configFilePath := filepath.Join(configDir, confName)
-	err := config.LoadConfig(configDir, configName)
+	err := config.LoadConfig(configDir, "")
 	assert.NoError(t, err)
 	sftpdConf := config.GetSFTPDConfig()
 	sftpdConf.Banner = " "
@@ -62,7 +71,7 @@ func TestEmptyBanner(t *testing.T) {
 	jsonConf, _ := json.Marshal(c)
 	err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
 	assert.NoError(t, err)
-	err = config.LoadConfig(configDir, tempConfigName)
+	err = config.LoadConfig(configDir, configFilePath)
 	assert.NoError(t, err)
 	sftpdConf = config.GetSFTPDConfig()
 	assert.NotEmpty(t, strings.TrimSpace(sftpdConf.Banner))
@@ -76,7 +85,7 @@ func TestEmptyBanner(t *testing.T) {
 	jsonConf, _ = json.Marshal(c1)
 	err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
 	assert.NoError(t, err)
-	err = config.LoadConfig(configDir, tempConfigName)
+	err = config.LoadConfig(configDir, configFilePath)
 	assert.NoError(t, err)
 	ftpdConf = config.GetFTPDConfig()
 	assert.NotEmpty(t, strings.TrimSpace(ftpdConf.Banner))
@@ -85,10 +94,12 @@ func TestEmptyBanner(t *testing.T) {
 }
 
 func TestInvalidUploadMode(t *testing.T) {
+	reset()
+
 	configDir := ".."
 	confName := tempConfigName + ".json"
 	configFilePath := filepath.Join(configDir, confName)
-	err := config.LoadConfig(configDir, configName)
+	err := config.LoadConfig(configDir, "")
 	assert.NoError(t, err)
 	commonConf := config.GetCommonConfig()
 	commonConf.UploadMode = 10
@@ -98,17 +109,19 @@ func TestInvalidUploadMode(t *testing.T) {
 	assert.NoError(t, err)
 	err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
 	assert.NoError(t, err)
-	err = config.LoadConfig(configDir, tempConfigName)
+	err = config.LoadConfig(configDir, configFilePath)
 	assert.NotNil(t, err)
 	err = os.Remove(configFilePath)
 	assert.NoError(t, err)
 }
 
 func TestInvalidExternalAuthScope(t *testing.T) {
+	reset()
+
 	configDir := ".."
 	confName := tempConfigName + ".json"
 	configFilePath := filepath.Join(configDir, confName)
-	err := config.LoadConfig(configDir, configName)
+	err := config.LoadConfig(configDir, "")
 	assert.NoError(t, err)
 	providerConf := config.GetProviderConf()
 	providerConf.ExternalAuthScope = 10
@@ -118,17 +131,19 @@ func TestInvalidExternalAuthScope(t *testing.T) {
 	assert.NoError(t, err)
 	err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
 	assert.NoError(t, err)
-	err = config.LoadConfig(configDir, tempConfigName)
+	err = config.LoadConfig(configDir, configFilePath)
 	assert.NotNil(t, err)
 	err = os.Remove(configFilePath)
 	assert.NoError(t, err)
 }
 
 func TestInvalidCredentialsPath(t *testing.T) {
+	reset()
+
 	configDir := ".."
 	confName := tempConfigName + ".json"
 	configFilePath := filepath.Join(configDir, confName)
-	err := config.LoadConfig(configDir, configName)
+	err := config.LoadConfig(configDir, "")
 	assert.NoError(t, err)
 	providerConf := config.GetProviderConf()
 	providerConf.CredentialsPath = ""
@@ -138,17 +153,19 @@ func TestInvalidCredentialsPath(t *testing.T) {
 	assert.NoError(t, err)
 	err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
 	assert.NoError(t, err)
-	err = config.LoadConfig(configDir, tempConfigName)
+	err = config.LoadConfig(configDir, configFilePath)
 	assert.NotNil(t, err)
 	err = os.Remove(configFilePath)
 	assert.NoError(t, err)
 }
 
 func TestInvalidProxyProtocol(t *testing.T) {
+	reset()
+
 	configDir := ".."
 	confName := tempConfigName + ".json"
 	configFilePath := filepath.Join(configDir, confName)
-	err := config.LoadConfig(configDir, configName)
+	err := config.LoadConfig(configDir, "")
 	assert.NoError(t, err)
 	commonConf := config.GetCommonConfig()
 	commonConf.ProxyProtocol = 10
@@ -158,17 +175,19 @@ func TestInvalidProxyProtocol(t *testing.T) {
 	assert.NoError(t, err)
 	err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
 	assert.NoError(t, err)
-	err = config.LoadConfig(configDir, tempConfigName)
+	err = config.LoadConfig(configDir, configFilePath)
 	assert.NotNil(t, err)
 	err = os.Remove(configFilePath)
 	assert.NoError(t, err)
 }
 
 func TestInvalidUsersBaseDir(t *testing.T) {
+	reset()
+
 	configDir := ".."
 	confName := tempConfigName + ".json"
 	configFilePath := filepath.Join(configDir, confName)
-	err := config.LoadConfig(configDir, configName)
+	err := config.LoadConfig(configDir, "")
 	assert.NoError(t, err)
 	providerConf := config.GetProviderConf()
 	providerConf.UsersBaseDir = "."
@@ -178,17 +197,19 @@ func TestInvalidUsersBaseDir(t *testing.T) {
 	assert.NoError(t, err)
 	err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
 	assert.NoError(t, err)
-	err = config.LoadConfig(configDir, tempConfigName)
+	err = config.LoadConfig(configDir, configFilePath)
 	assert.NotNil(t, err)
 	err = os.Remove(configFilePath)
 	assert.NoError(t, err)
 }
 
 func TestCommonParamsCompatibility(t *testing.T) {
+	reset()
+
 	configDir := ".."
 	confName := tempConfigName + ".json"
 	configFilePath := filepath.Join(configDir, confName)
-	err := config.LoadConfig(configDir, configName)
+	err := config.LoadConfig(configDir, "")
 	assert.NoError(t, err)
 	sftpdConf := config.GetSFTPDConfig()
 	sftpdConf.IdleTimeout = 21 //nolint:staticcheck
@@ -204,7 +225,7 @@ func TestCommonParamsCompatibility(t *testing.T) {
 	assert.NoError(t, err)
 	err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
 	assert.NoError(t, err)
-	err = config.LoadConfig(configDir, tempConfigName)
+	err = config.LoadConfig(configDir, configFilePath)
 	assert.NoError(t, err)
 	commonConf := config.GetCommonConfig()
 	assert.Equal(t, 21, commonConf.IdleTimeout)
@@ -220,10 +241,12 @@ func TestCommonParamsCompatibility(t *testing.T) {
 }
 
 func TestHostKeyCompatibility(t *testing.T) {
+	reset()
+
 	configDir := ".."
 	confName := tempConfigName + ".json"
 	configFilePath := filepath.Join(configDir, confName)
-	err := config.LoadConfig(configDir, configName)
+	err := config.LoadConfig(configDir, "")
 	assert.NoError(t, err)
 	sftpdConf := config.GetSFTPDConfig()
 	sftpdConf.Keys = []sftpd.Key{ //nolint:staticcheck
@@ -240,7 +263,7 @@ func TestHostKeyCompatibility(t *testing.T) {
 	assert.NoError(t, err)
 	err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
 	assert.NoError(t, err)
-	err = config.LoadConfig(configDir, tempConfigName)
+	err = config.LoadConfig(configDir, configFilePath)
 	assert.NoError(t, err)
 	sftpdConf = config.GetSFTPDConfig()
 	assert.Equal(t, 2, len(sftpdConf.HostKeys))
@@ -251,6 +274,8 @@ func TestHostKeyCompatibility(t *testing.T) {
 }
 
 func TestSetGetConfig(t *testing.T) {
+	reset()
+
 	sftpdConf := config.GetSFTPDConfig()
 	sftpdConf.MaxAuthTries = 10
 	config.SetSFTPDConfig(sftpdConf)
@@ -288,8 +313,10 @@ func TestSetGetConfig(t *testing.T) {
 }
 
 func TestServiceToStart(t *testing.T) {
+	reset()
+
 	configDir := ".."
-	err := config.LoadConfig(configDir, configName)
+	err := config.LoadConfig(configDir, "")
 	assert.NoError(t, err)
 	assert.True(t, config.HasServicesToStart())
 	sftpdConf := config.GetSFTPDConfig()
@@ -315,6 +342,8 @@ func TestServiceToStart(t *testing.T) {
 }
 
 func TestConfigFromEnv(t *testing.T) {
+	reset()
+
 	os.Setenv("SFTPGO_SFTPD__BIND_ADDRESS", "127.0.0.1")
 	os.Setenv("SFTPGO_DATA_PROVIDER__PASSWORD_HASHING__ARGON2_OPTIONS__ITERATIONS", "41")
 	os.Setenv("SFTPGO_DATA_PROVIDER__POOL_SIZE", "10")