mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-25 00:50:31 +00:00
add support for inter-node communications
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
parent
a538255034
commit
76e89d07d4
25 changed files with 847 additions and 59 deletions
|
@ -253,6 +253,10 @@ The configuration file contains the following sections:
|
|||
- `create_default_admin`, boolean. Before you can use SFTPGo you need to create an admin account. If you open the admin web UI, a setup screen will guide you in creating the first admin account. You can automatically create the first admin account by enabling this setting and setting the environment variables `SFTPGO_DEFAULT_ADMIN_USERNAME` and `SFTPGO_DEFAULT_ADMIN_PASSWORD`. You can also create the first admin by loading initial data. This setting has no effect if an admin account is already found within the data provider. Default `false`.
|
||||
- `naming_rules`, integer. Naming rules for usernames, folder and group names. `0` means no rules. `1` means you can use any UTF-8 character. The names are used in URIs for REST API and Web admin. If not set only unreserved URI characters are allowed: ALPHA / DIGIT / "-" / "." / "_" / "~". `2` means names are converted to lowercase before saving/matching and so case insensitive matching is possible. `3` means trimming trailing and leading white spaces before saving/matching. Rules can be combined, for example `3` means both converting to lowercase and allowing any UTF-8 character. Enabling these options for existing installations could be backward incompatible, some users could be unable to login, for example existing users with mixed cases in their usernames. You have to ensure that all existing users respect the defined rules. Default: `1`.
|
||||
- `is_shared`, integer. If the data provider is shared across multiple SFTPGo instances, set this parameter to `1`. `MySQL`, `PostgreSQL` and `CockroachDB` can be shared, this setting is ignored for other data providers. For shared data providers, active transfers are persisted in the database and thus quota checks between ongoing transfers will work cross multiple instances. Password reset requests and OIDC tokens/states are also persisted in the database if the provider is shared. For shared data providers, scheduled event actions are only executed on a single SFTPGo instance by default, you can override this behavior on a per-action basis. The database table `shared_sessions` is used only to store temporary sessions. In performance critical installations, you might consider using a database-specific optimization, for example you might use an `UNLOGGED` table for PostgreSQL. This optimization in only required in very limited use cases. Default: `0`.
|
||||
- `node`, struct. Node-specific configurations to allow inter-node communications. If your provider is shared across multiple nodes, the nodes can exchange information to present a uniform view for node-specific data. The current implementation allows to obtain active connections from all nodes. Nodes connect to each other using the REST API.
|
||||
- `host`, string. IP address or hostname that other nodes can use to connect to this node via REST API. Empty means inter-node communications disabled. Default: empty.
|
||||
- `port`, integer. The port that other nodes can use to connect to this node via REST API. Default: `0`
|
||||
- `proto`, string. Supported values `http` or `https`. For `https` the configurations for http clients is used, so you can, for example, enable mutual TLS authentication. Default: `http`
|
||||
- `backups_path`, string. Path to the backup directory. This can be an absolute path or a path relative to the config dir. We don't allow backups in arbitrary paths for security reasons.
|
||||
- **"httpd"**, the configuration for the HTTP server used to serve REST API and to expose the built-in web interface
|
||||
- `bindings`, list of structs. Each struct has the following fields:
|
||||
|
|
2
go.mod
2
go.mod
|
@ -156,7 +156,7 @@ require (
|
|||
golang.org/x/tools v0.1.12 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/genproto v0.0.0-20220921223823-23cae91e6737 // indirect
|
||||
google.golang.org/genproto v0.0.0-20220923205249-dd2d53f1fffc // indirect
|
||||
google.golang.org/grpc v1.49.0 // indirect
|
||||
google.golang.org/protobuf v1.28.1 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
|
|
4
go.sum
4
go.sum
|
@ -1229,8 +1229,8 @@ google.golang.org/genproto v0.0.0-20220523171625-347a074981d8/go.mod h1:RAyBrSAP
|
|||
google.golang.org/genproto v0.0.0-20220608133413-ed9918b62aac/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA=
|
||||
google.golang.org/genproto v0.0.0-20220616135557-88e70c0c3a90/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA=
|
||||
google.golang.org/genproto v0.0.0-20220624142145-8cd45d7dbd1f/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA=
|
||||
google.golang.org/genproto v0.0.0-20220921223823-23cae91e6737 h1:K1zaaMdYBXRyX+cwFnxj7M6zwDyumLQMZ5xqwGvjreQ=
|
||||
google.golang.org/genproto v0.0.0-20220921223823-23cae91e6737/go.mod h1:2r/26NEF3bFmT3eC3aZreahSal0C3Shl8Gi6vyDYqOQ=
|
||||
google.golang.org/genproto v0.0.0-20220923205249-dd2d53f1fffc h1:saaNe2+SBQxandnzcD/qB1JEBQ2Pqew+KlFLLdA/XcM=
|
||||
google.golang.org/genproto v0.0.0-20220923205249-dd2d53f1fffc/go.mod h1:yEEpwVWKMZZzo81NwRgyEJnA2fQvpXAYPVisv8EgDVs=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
|
||||
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
|
||||
|
|
|
@ -1076,6 +1076,7 @@ func (conns *ActiveConnections) GetStats() []ConnectionStatus {
|
|||
defer conns.RUnlock()
|
||||
|
||||
stats := make([]ConnectionStatus, 0, len(conns.connections))
|
||||
node := dataprovider.GetNodeName()
|
||||
for _, c := range conns.connections {
|
||||
stat := ConnectionStatus{
|
||||
Username: c.GetUsername(),
|
||||
|
@ -1087,6 +1088,7 @@ func (conns *ActiveConnections) GetStats() []ConnectionStatus {
|
|||
Protocol: c.GetProtocol(),
|
||||
Command: c.GetCommand(),
|
||||
Transfers: c.GetTransfers(),
|
||||
Node: node,
|
||||
}
|
||||
stats = append(stats, stat)
|
||||
}
|
||||
|
@ -1113,6 +1115,8 @@ type ConnectionStatus struct {
|
|||
Transfers []ConnectionTransfer `json:"active_transfers,omitempty"`
|
||||
// SSH command or WebDAV method
|
||||
Command string `json:"command,omitempty"`
|
||||
// Node identifier, omitted for single node installations
|
||||
Node string `json:"node,omitempty"`
|
||||
}
|
||||
|
||||
// GetConnectionDuration returns the connection duration as string
|
||||
|
|
|
@ -363,7 +363,12 @@ func Init() {
|
|||
CreateDefaultAdmin: false,
|
||||
NamingRules: 1,
|
||||
IsShared: 0,
|
||||
BackupsPath: "backups",
|
||||
Node: dataprovider.NodeConfig{
|
||||
Host: "",
|
||||
Port: 0,
|
||||
Proto: "http",
|
||||
},
|
||||
BackupsPath: "backups",
|
||||
},
|
||||
HTTPDConfig: httpd.Conf{
|
||||
Bindings: []httpd.Binding{defaultHTTPDBinding},
|
||||
|
@ -1967,6 +1972,9 @@ func setViperDefaults() {
|
|||
viper.SetDefault("data_provider.create_default_admin", globalConf.ProviderConf.CreateDefaultAdmin)
|
||||
viper.SetDefault("data_provider.naming_rules", globalConf.ProviderConf.NamingRules)
|
||||
viper.SetDefault("data_provider.is_shared", globalConf.ProviderConf.IsShared)
|
||||
viper.SetDefault("data_provider.node.host", globalConf.ProviderConf.Node.Host)
|
||||
viper.SetDefault("data_provider.node.port", globalConf.ProviderConf.Node.Port)
|
||||
viper.SetDefault("data_provider.node.proto", globalConf.ProviderConf.Node.Proto)
|
||||
viper.SetDefault("data_provider.backups_path", globalConf.ProviderConf.BackupsPath)
|
||||
viper.SetDefault("httpd.templates_path", globalConf.HTTPDConfig.TemplatesPath)
|
||||
viper.SetDefault("httpd.static_files_path", globalConf.HTTPDConfig.StaticFilesPath)
|
||||
|
|
|
@ -35,7 +35,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
boltDatabaseVersion = 22
|
||||
boltDatabaseVersion = 23
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -2483,19 +2483,39 @@ func (p *BoltProvider) deleteEventRule(rule EventRule, softDelete bool) error {
|
|||
})
|
||||
}
|
||||
|
||||
func (p *BoltProvider) getTaskByName(name string) (Task, error) {
|
||||
func (*BoltProvider) getTaskByName(name string) (Task, error) {
|
||||
return Task{}, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *BoltProvider) addTask(name string) error {
|
||||
func (*BoltProvider) addTask(name string) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *BoltProvider) updateTask(name string, version int64) error {
|
||||
func (*BoltProvider) updateTask(name string, version int64) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *BoltProvider) updateTaskTimestamp(name string) error {
|
||||
func (*BoltProvider) updateTaskTimestamp(name string) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (*BoltProvider) addNode() error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (*BoltProvider) getNodeByName(name string) (Node, error) {
|
||||
return Node{}, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (*BoltProvider) getNodes() ([]Node, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (*BoltProvider) updateNodeTimestamp() error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (*BoltProvider) cleanupNodes() error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
|
@ -2583,10 +2603,10 @@ func (p *BoltProvider) migrateDatabase() error {
|
|||
providerLog(logger.LevelError, "%v", err)
|
||||
logger.ErrorToConsole("%v", err)
|
||||
return err
|
||||
case version == 19, version == 20, version == 21:
|
||||
logger.InfoToConsole(fmt.Sprintf("updating database schema version: %d -> 22", version))
|
||||
providerLog(logger.LevelInfo, "updating database schema version: %d -> 22", version)
|
||||
return updateBoltDatabaseVersion(p.dbHandle, 22)
|
||||
case version == 19, version == 20, version == 21, version == 22:
|
||||
logger.InfoToConsole(fmt.Sprintf("updating database schema version: %d -> 23", version))
|
||||
providerLog(logger.LevelInfo, "updating database schema version: %d -> 23", version)
|
||||
return updateBoltDatabaseVersion(p.dbHandle, 23)
|
||||
default:
|
||||
if version > boltDatabaseVersion {
|
||||
providerLog(logger.LevelError, "database schema version %v is newer than the supported one: %v", version,
|
||||
|
@ -2608,7 +2628,7 @@ func (p *BoltProvider) revertDatabase(targetVersion int) error {
|
|||
return errors.New("current version match target version, nothing to do")
|
||||
}
|
||||
switch dbVersion.Version {
|
||||
case 20, 21:
|
||||
case 20, 21, 22, 23:
|
||||
logger.InfoToConsole("downgrading database schema version: %d -> 19", dbVersion.Version)
|
||||
providerLog(logger.LevelInfo, "downgrading database schema version: %d -> 19", dbVersion.Version)
|
||||
err := p.dbHandle.Update(func(tx *bolt.Tx) error {
|
||||
|
|
|
@ -187,6 +187,7 @@ var (
|
|||
sqlTableEventsRules string
|
||||
sqlTableRulesActionsMapping string
|
||||
sqlTableTasks string
|
||||
sqlTableNodes string
|
||||
sqlTableSchemaVersion string
|
||||
argon2Params *argon2id.Params
|
||||
lastLoginMinDelay = 10 * time.Minute
|
||||
|
@ -216,6 +217,7 @@ func initSQLTables() {
|
|||
sqlTableEventsRules = "events_rules"
|
||||
sqlTableRulesActionsMapping = "rules_actions_mapping"
|
||||
sqlTableTasks = "tasks"
|
||||
sqlTableNodes = "nodes"
|
||||
sqlTableSchemaVersion = "schema_version"
|
||||
}
|
||||
|
||||
|
@ -311,7 +313,7 @@ type ProviderStatus struct {
|
|||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// Config provider configuration
|
||||
// Config defines the provider configuration
|
||||
type Config struct {
|
||||
// Driver name, must be one of the SupportedProviders
|
||||
Driver string `json:"driver" mapstructure:"driver"`
|
||||
|
@ -460,6 +462,9 @@ type Config struct {
|
|||
// For shared data providers, active transfers are persisted in the database and thus
|
||||
// quota checks between ongoing transfers will work cross multiple instances
|
||||
IsShared int `json:"is_shared" mapstructure:"is_shared"`
|
||||
// Node defines the configuration for this cluster node.
|
||||
// Ignored if the provider is not shared/shareable
|
||||
Node NodeConfig `json:"node" mapstructure:"node"`
|
||||
// Path to the backup directory. This can be an absolute path or a path relative to the config dir
|
||||
BackupsPath string `json:"backups_path" mapstructure:"backups_path"`
|
||||
}
|
||||
|
@ -778,6 +783,11 @@ type Provider interface {
|
|||
updateTaskTimestamp(name string) error
|
||||
setFirstDownloadTimestamp(username string) error
|
||||
setFirstUploadTimestamp(username string) error
|
||||
addNode() error
|
||||
getNodeByName(name string) (Node, error)
|
||||
getNodes() ([]Node, error)
|
||||
updateNodeTimestamp() error
|
||||
cleanupNodes() error
|
||||
checkAvailability() error
|
||||
close() error
|
||||
reloadConfig() error
|
||||
|
@ -801,7 +811,6 @@ func checkSharedMode() {
|
|||
// Initialize the data provider.
|
||||
// An error is returned if the configured driver is invalid or if the data provider cannot be initialized
|
||||
func Initialize(cnf Config, basePath string, checkAdmins bool) error {
|
||||
var err error
|
||||
config = cnf
|
||||
checkSharedMode()
|
||||
config.Actions.ExecuteOn = util.RemoveDuplicates(config.Actions.ExecuteOn, true)
|
||||
|
@ -812,19 +821,33 @@ func Initialize(cnf Config, basePath string, checkAdmins bool) error {
|
|||
return fmt.Errorf("required directory is invalid, backup path %#v", cnf.BackupsPath)
|
||||
}
|
||||
|
||||
if err = initializeHashingAlgo(&cnf); err != nil {
|
||||
if err := initializeHashingAlgo(&cnf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = validateHooks(); err != nil {
|
||||
if err := validateHooks(); err != nil {
|
||||
return err
|
||||
}
|
||||
err = createProvider(basePath)
|
||||
if err := createProvider(basePath); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := checkDatabase(checkAdmins); err != nil {
|
||||
return err
|
||||
}
|
||||
admins, err := provider.getAdmins(1, 0, OrderASC)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cnf.UpdateMode == 0 {
|
||||
err = provider.initializeDatabase()
|
||||
isAdminCreated.Store(len(admins) > 0)
|
||||
if err := config.Node.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
delayedQuotaUpdater.start()
|
||||
return startScheduler()
|
||||
}
|
||||
|
||||
func checkDatabase(checkAdmins bool) error {
|
||||
if config.UpdateMode == 0 {
|
||||
err := provider.initializeDatabase()
|
||||
if err != nil && err != ErrNoInitRequired {
|
||||
logger.WarnToConsole("Unable to initialize data provider: %v", err)
|
||||
providerLog(logger.LevelError, "Unable to initialize data provider: %v", err)
|
||||
|
@ -838,7 +861,7 @@ func Initialize(cnf Config, basePath string, checkAdmins bool) error {
|
|||
providerLog(logger.LevelError, "database migration error: %v", err)
|
||||
return err
|
||||
}
|
||||
if checkAdmins && cnf.CreateDefaultAdmin {
|
||||
if checkAdmins && config.CreateDefaultAdmin {
|
||||
err = checkDefaultAdmin()
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "erro checking the default admin: %v", err)
|
||||
|
@ -848,13 +871,7 @@ func Initialize(cnf Config, basePath string, checkAdmins bool) error {
|
|||
} else {
|
||||
providerLog(logger.LevelInfo, "database initialization/migration skipped, manual mode is configured")
|
||||
}
|
||||
admins, err := provider.getAdmins(1, 0, OrderASC)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
isAdminCreated.Store(len(admins) > 0)
|
||||
delayedQuotaUpdater.start()
|
||||
return startScheduler()
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateHooks() error {
|
||||
|
@ -937,15 +954,17 @@ func validateSQLTablesPrefix() error {
|
|||
sqlTableEventsRules = config.SQLTablesPrefix + sqlTableEventsRules
|
||||
sqlTableRulesActionsMapping = config.SQLTablesPrefix + sqlTableRulesActionsMapping
|
||||
sqlTableTasks = config.SQLTablesPrefix + sqlTableTasks
|
||||
sqlTableNodes = config.SQLTablesPrefix + sqlTableNodes
|
||||
sqlTableSchemaVersion = config.SQLTablesPrefix + sqlTableSchemaVersion
|
||||
providerLog(logger.LevelDebug, "sql table for users %q, folders %q users folders mapping %q admins %q "+
|
||||
"api keys %q shares %q defender hosts %q defender events %q transfers %q groups %q "+
|
||||
"users groups mapping %q admins groups mapping %q groups folders mapping %q shared sessions %q "+
|
||||
"schema version %q events actions %q events rules %q rules actions mapping %q tasks %q",
|
||||
"schema version %q events actions %q events rules %q rules actions mapping %q tasks %q nodes %q",
|
||||
sqlTableUsers, sqlTableFolders, sqlTableUsersFoldersMapping, sqlTableAdmins, sqlTableAPIKeys,
|
||||
sqlTableShares, sqlTableDefenderHosts, sqlTableDefenderEvents, sqlTableActiveTransfers, sqlTableGroups,
|
||||
sqlTableUsersGroupsMapping, sqlTableAdminsGroupsMapping, sqlTableGroupsFoldersMapping, sqlTableSharedSessions,
|
||||
sqlTableSchemaVersion, sqlTableEventsActions, sqlTableEventsRules, sqlTableRulesActionsMapping, sqlTableTasks)
|
||||
sqlTableSchemaVersion, sqlTableEventsActions, sqlTableEventsRules, sqlTableRulesActionsMapping,
|
||||
sqlTableTasks, sqlTableNodes)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -1728,6 +1747,29 @@ func UpdateTaskTimestamp(name string) error {
|
|||
return provider.updateTaskTimestamp(name)
|
||||
}
|
||||
|
||||
// GetNodes returns the other cluster nodes
|
||||
func GetNodes() ([]Node, error) {
|
||||
if currentNode == nil {
|
||||
return nil, nil
|
||||
}
|
||||
nodes, err := provider.getNodes()
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to get other cluster nodes %v", err)
|
||||
}
|
||||
return nodes, err
|
||||
}
|
||||
|
||||
// GetNodeByName returns a node, different from the current one, by name
|
||||
func GetNodeByName(name string) (Node, error) {
|
||||
if currentNode == nil {
|
||||
return Node{}, util.NewRecordNotFoundError(errNoClusterNodes.Error())
|
||||
}
|
||||
if name == currentNode.Name {
|
||||
return Node{}, util.NewValidationError(fmt.Sprintf("%s is the current node, it must refer to other nodes", name))
|
||||
}
|
||||
return provider.getNodeByName(name)
|
||||
}
|
||||
|
||||
// HasAdmin returns true if the first admin has been created
|
||||
// and so SFTPGo is ready to be used
|
||||
func HasAdmin() bool {
|
||||
|
|
|
@ -2286,19 +2286,39 @@ func (p *MemoryProvider) deleteEventRule(rule EventRule, softDelete bool) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) getTaskByName(name string) (Task, error) {
|
||||
func (*MemoryProvider) getTaskByName(name string) (Task, error) {
|
||||
return Task{}, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) addTask(name string) error {
|
||||
func (*MemoryProvider) addTask(name string) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) updateTask(name string, version int64) error {
|
||||
func (*MemoryProvider) updateTask(name string, version int64) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) updateTaskTimestamp(name string) error {
|
||||
func (*MemoryProvider) updateTaskTimestamp(name string) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (*MemoryProvider) addNode() error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (*MemoryProvider) getNodeByName(name string) (Node, error) {
|
||||
return Node{}, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (*MemoryProvider) getNodes() ([]Node, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (*MemoryProvider) updateNodeTimestamp() error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (*MemoryProvider) cleanupNodes() error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
|
|
|
@ -56,6 +56,7 @@ const (
|
|||
"DROP TABLE IF EXISTS `{{events_actions}}` CASCADE;" +
|
||||
"DROP TABLE IF EXISTS `{{events_rules}}` CASCADE;" +
|
||||
"DROP TABLE IF EXISTS `{{tasks}}` CASCADE;" +
|
||||
"DROP TABLE IF EXISTS `{{nodes}}` CASCADE;" +
|
||||
"DROP TABLE IF EXISTS `{{schema_version}}` CASCADE;"
|
||||
mysqlInitialSQL = "CREATE TABLE `{{schema_version}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `version` integer NOT NULL);" +
|
||||
"CREATE TABLE `{{admins}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " +
|
||||
|
@ -182,6 +183,10 @@ const (
|
|||
"FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE CASCADE;"
|
||||
mysqlV22DownSQL = "ALTER TABLE `{{admins_groups_mapping}}` DROP INDEX `{{prefix}}unique_admin_group_mapping`;" +
|
||||
"DROP TABLE `{{admins_groups_mapping}}` CASCADE;"
|
||||
mysqlV23SQL = "CREATE TABLE `{{nodes}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
|
||||
"`name` varchar(255) NOT NULL UNIQUE, `data` longtext NOT NULL, `created_at` bigint NOT NULL, " +
|
||||
"`updated_at` bigint NOT NULL);"
|
||||
mysqlV23DownSQL = "DROP TABLE `{{nodes}}` CASCADE;"
|
||||
)
|
||||
|
||||
// MySQLProvider defines the auth provider for MySQL/MariaDB database
|
||||
|
@ -644,6 +649,26 @@ func (p *MySQLProvider) updateTaskTimestamp(name string) error {
|
|||
return sqlCommonUpdateTaskTimestamp(name, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) addNode() error {
|
||||
return sqlCommonAddNode(p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) getNodeByName(name string) (Node, error) {
|
||||
return sqlCommonGetNodeByName(name, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) getNodes() ([]Node, error) {
|
||||
return sqlCommonGetNodes(p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) updateNodeTimestamp() error {
|
||||
return sqlCommonUpdateNodeTimestamp(p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) cleanupNodes() error {
|
||||
return sqlCommonCleanupNodes(p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) setFirstDownloadTimestamp(username string) error {
|
||||
return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle)
|
||||
}
|
||||
|
@ -697,6 +722,8 @@ func (p *MySQLProvider) migrateDatabase() error { //nolint:dupl
|
|||
return updateMySQLDatabaseFromV20(p.dbHandle)
|
||||
case version == 21:
|
||||
return updateMySQLDatabaseFromV21(p.dbHandle)
|
||||
case version == 22:
|
||||
return updateMySQLDatabaseFromV22(p.dbHandle)
|
||||
default:
|
||||
if version > sqlDatabaseVersion {
|
||||
providerLog(logger.LevelError, "database schema version %v is newer than the supported one: %v", version,
|
||||
|
@ -725,6 +752,8 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error {
|
|||
return downgradeMySQLDatabaseFromV21(p.dbHandle)
|
||||
case 22:
|
||||
return downgradeMySQLDatabaseFromV22(p.dbHandle)
|
||||
case 23:
|
||||
return downgradeMySQLDatabaseFromV23(p.dbHandle)
|
||||
default:
|
||||
return fmt.Errorf("database schema version not handled: %v", dbVersion.Version)
|
||||
}
|
||||
|
@ -750,7 +779,14 @@ func updateMySQLDatabaseFromV20(dbHandle *sql.DB) error {
|
|||
}
|
||||
|
||||
func updateMySQLDatabaseFromV21(dbHandle *sql.DB) error {
|
||||
return updateMySQLDatabaseFrom21To22(dbHandle)
|
||||
if err := updateMySQLDatabaseFrom21To22(dbHandle); err != nil {
|
||||
return err
|
||||
}
|
||||
return updateMySQLDatabaseFromV22(dbHandle)
|
||||
}
|
||||
|
||||
func updateMySQLDatabaseFromV22(dbHandle *sql.DB) error {
|
||||
return updateMySQLDatabaseFrom22To23(dbHandle)
|
||||
}
|
||||
|
||||
func downgradeMySQLDatabaseFromV20(dbHandle *sql.DB) error {
|
||||
|
@ -771,6 +807,13 @@ func downgradeMySQLDatabaseFromV22(dbHandle *sql.DB) error {
|
|||
return downgradeMySQLDatabaseFromV21(dbHandle)
|
||||
}
|
||||
|
||||
func downgradeMySQLDatabaseFromV23(dbHandle *sql.DB) error {
|
||||
if err := downgradeMySQLDatabaseFrom23To22(dbHandle); err != nil {
|
||||
return err
|
||||
}
|
||||
return downgradeMySQLDatabaseFromV22(dbHandle)
|
||||
}
|
||||
|
||||
func updateMySQLDatabaseFrom19To20(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("updating database schema version: 19 -> 20")
|
||||
providerLog(logger.LevelInfo, "updating database schema version: 19 -> 20")
|
||||
|
@ -800,6 +843,13 @@ func updateMySQLDatabaseFrom21To22(dbHandle *sql.DB) error {
|
|||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 22, true)
|
||||
}
|
||||
|
||||
func updateMySQLDatabaseFrom22To23(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("updating database schema version: 22 -> 23")
|
||||
providerLog(logger.LevelInfo, "updating database schema version: 22 -> 23")
|
||||
sql := strings.ReplaceAll(mysqlV23SQL, "{{nodes}}", sqlTableNodes)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 23, true)
|
||||
}
|
||||
|
||||
func downgradeMySQLDatabaseFrom20To19(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("downgrading database schema version: 20 -> 19")
|
||||
providerLog(logger.LevelInfo, "downgrading database schema version: 20 -> 19")
|
||||
|
@ -825,3 +875,10 @@ func downgradeMySQLDatabaseFrom22To21(dbHandle *sql.DB) error {
|
|||
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 21, false)
|
||||
}
|
||||
|
||||
func downgradeMySQLDatabaseFrom23To22(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("downgrading database schema version: 23 -> 22")
|
||||
providerLog(logger.LevelInfo, "downgrading database schema version: 23 -> 22")
|
||||
sql := strings.ReplaceAll(mysqlV23DownSQL, "{{nodes}}", sqlTableNodes)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 22, false)
|
||||
}
|
||||
|
|
240
internal/dataprovider/node.go
Normal file
240
internal/dataprovider/node.go
Normal file
|
@ -0,0 +1,240 @@
|
|||
package dataprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/httpclient"
|
||||
"github.com/drakkan/sftpgo/v2/internal/kms"
|
||||
"github.com/drakkan/sftpgo/v2/internal/logger"
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
)
|
||||
|
||||
// Supported protocols for connecting to other nodes
|
||||
const (
|
||||
NodeProtoHTTP = "http"
|
||||
NodeProtoHTTPS = "https"
|
||||
)
|
||||
|
||||
const (
|
||||
// NodeTokenHeader defines the header to use for the node auth token
|
||||
NodeTokenHeader = "X-SFTPGO-Node"
|
||||
)
|
||||
|
||||
var (
|
||||
// current node
|
||||
currentNode *Node
|
||||
errNoClusterNodes = errors.New("no cluster node defined")
|
||||
activeNodeTimeDiff = -2 * time.Minute
|
||||
nodeReqTimeout = 8 * time.Second
|
||||
)
|
||||
|
||||
// NodeConfig defines the node configuration
|
||||
type NodeConfig struct {
|
||||
Host string `json:"host" mapstructure:"host"`
|
||||
Port int `json:"port" mapstructure:"port"`
|
||||
Proto string `json:"proto" mapstructure:"proto"`
|
||||
}
|
||||
|
||||
func (n *NodeConfig) validate() error {
|
||||
currentNode = nil
|
||||
if config.IsShared != 1 {
|
||||
return nil
|
||||
}
|
||||
if n.Host == "" {
|
||||
return nil
|
||||
}
|
||||
currentNode = &Node{
|
||||
Data: NodeData{
|
||||
Host: n.Host,
|
||||
Port: n.Port,
|
||||
Proto: n.Proto,
|
||||
},
|
||||
}
|
||||
return provider.addNode()
|
||||
}
|
||||
|
||||
// NodeData defines the details to connect to a cluster node
|
||||
type NodeData struct {
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Proto string `json:"proto"`
|
||||
Key *kms.Secret `json:"api_key"`
|
||||
}
|
||||
|
||||
func (n *NodeData) validate() error {
|
||||
if n.Host == "" {
|
||||
return util.NewValidationError("node host is mandatory")
|
||||
}
|
||||
if n.Port < 0 || n.Port > 65535 {
|
||||
return util.NewValidationError(fmt.Sprintf("invalid node port: %d", n.Port))
|
||||
}
|
||||
if n.Proto != NodeProtoHTTP && n.Proto != NodeProtoHTTPS {
|
||||
return util.NewValidationError(fmt.Sprintf("invalid node proto: %s", n.Proto))
|
||||
}
|
||||
n.Key = kms.NewPlainSecret(string(util.GenerateRandomBytes(32)))
|
||||
n.Key.SetAdditionalData(n.Host)
|
||||
if err := n.Key.Encrypt(); err != nil {
|
||||
return fmt.Errorf("unable to encrypt node key: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Node defines a cluster node
|
||||
type Node struct {
|
||||
Name string `json:"name"`
|
||||
Data NodeData `json:"data"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
func (n *Node) validate() error {
|
||||
if n.Name == "" {
|
||||
n.Name = n.Data.Host
|
||||
}
|
||||
return n.Data.validate()
|
||||
}
|
||||
|
||||
func (n *Node) authenticate(token string) error {
|
||||
if err := n.Data.Key.TryDecrypt(); err != nil {
|
||||
providerLog(logger.LevelError, "unable to decrypt node key: %v", err)
|
||||
return err
|
||||
}
|
||||
if token == "" {
|
||||
return ErrInvalidCredentials
|
||||
}
|
||||
t, err := jwt.Parse([]byte(token), jwt.WithVerify(jwa.HS256, []byte(n.Data.Key.GetPayload())))
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse token: %v", err)
|
||||
}
|
||||
if err := jwt.Validate(t); err != nil {
|
||||
return fmt.Errorf("unable to validate token: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getBaseURL returns the base URL for this node
|
||||
func (n *Node) getBaseURL() string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(n.Data.Proto)
|
||||
sb.WriteString("://")
|
||||
sb.WriteString(n.Data.Host)
|
||||
if n.Data.Port > 0 {
|
||||
sb.WriteString(":")
|
||||
sb.WriteString(strconv.Itoa(n.Data.Port))
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// generateAuthToken generates a new auth token
|
||||
func (n *Node) generateAuthToken() (string, error) {
|
||||
if err := n.Data.Key.TryDecrypt(); err != nil {
|
||||
return "", fmt.Errorf("unable to decrypt node key: %w", err)
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
|
||||
t := jwt.New()
|
||||
t.Set(jwt.JwtIDKey, xid.New().String()) //nolint:errcheck
|
||||
t.Set(jwt.NotBeforeKey, now.Add(-30*time.Second)) //nolint:errcheck
|
||||
t.Set(jwt.ExpirationKey, now.Add(1*time.Minute)) //nolint:errcheck
|
||||
|
||||
payload, err := jwt.Sign(t, jwa.HS256, []byte(n.Data.Key.GetPayload()))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to sign authentication token: %w", err)
|
||||
}
|
||||
return string(payload), nil
|
||||
}
|
||||
|
||||
func (n *Node) prepareRequest(ctx context.Context, relativeURL, method string, body io.Reader) (*http.Request, error) {
|
||||
url := fmt.Sprintf("%s%s", n.getBaseURL(), relativeURL)
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token, err := n.generateAuthToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set(NodeTokenHeader, fmt.Sprintf("Bearer %s", token))
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// SendGetRequest sends an HTTP GET request to this node.
|
||||
// The responseHolder must be a pointer
|
||||
func (n *Node) SendGetRequest(relativeURL string, responseHolder any) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nodeReqTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := n.prepareRequest(ctx, relativeURL, http.MethodGet, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client := httpclient.GetHTTPClient()
|
||||
defer client.CloseIdleConnections()
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to send HTTP GET to node %s: %w", n.Name, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent {
|
||||
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(responseHolder)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to decode response as json")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendDeleteRequest sends an HTTP DELETE request to this node
|
||||
func (n *Node) SendDeleteRequest(relativeURL string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nodeReqTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := n.prepareRequest(ctx, relativeURL, http.MethodDelete, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client := httpclient.GetHTTPClient()
|
||||
defer client.CloseIdleConnections()
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to send HTTP DELETE to node %s: %w", n.Name, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent {
|
||||
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthenticateNodeToken check the validity of the provided token
|
||||
func AuthenticateNodeToken(token string) error {
|
||||
if currentNode == nil {
|
||||
return errNoClusterNodes
|
||||
}
|
||||
return currentNode.authenticate(token)
|
||||
}
|
||||
|
||||
// GetNodeName returns the node name or an empty string
|
||||
func GetNodeName() string {
|
||||
if currentNode == nil {
|
||||
return ""
|
||||
}
|
||||
return currentNode.Name
|
||||
}
|
|
@ -54,6 +54,7 @@ DROP TABLE IF EXISTS "{{rules_actions_mapping}}" CASCADE;
|
|||
DROP TABLE IF EXISTS "{{events_actions}}" CASCADE;
|
||||
DROP TABLE IF EXISTS "{{events_rules}}" CASCADE;
|
||||
DROP TABLE IF EXISTS "{{tasks}}" CASCADE;
|
||||
DROP TABLE IF EXISTS "{{nodes}}" CASCADE;
|
||||
DROP TABLE IF EXISTS "{{schema_version}}" CASCADE;
|
||||
`
|
||||
pgsqlInitial = `CREATE TABLE "{{schema_version}}" ("id" serial NOT NULL PRIMARY KEY, "version" integer NOT NULL);
|
||||
|
@ -198,6 +199,9 @@ CREATE INDEX "{{prefix}}admins_groups_mapping_group_id_idx" ON "{{admins_groups_
|
|||
pgsqlV22DownSQL = `ALTER TABLE "{{admins_groups_mapping}}" DROP CONSTRAINT "{{prefix}}unique_admin_group_mapping";
|
||||
DROP TABLE "{{admins_groups_mapping}}" CASCADE;
|
||||
`
|
||||
pgsqlV23SQL = `CREATE TABLE "{{nodes}}" ("id" serial NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE,
|
||||
"data" text NOT NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL);`
|
||||
pgsqlV23DownSQL = `DROP TABLE "{{nodes}}" CASCADE;`
|
||||
)
|
||||
|
||||
// PGSQLProvider defines the auth provider for PostgreSQL database
|
||||
|
@ -616,6 +620,26 @@ func (p *PGSQLProvider) updateTaskTimestamp(name string) error {
|
|||
return sqlCommonUpdateTaskTimestamp(name, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) addNode() error {
|
||||
return sqlCommonAddNode(p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) getNodeByName(name string) (Node, error) {
|
||||
return sqlCommonGetNodeByName(name, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) getNodes() ([]Node, error) {
|
||||
return sqlCommonGetNodes(p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) updateNodeTimestamp() error {
|
||||
return sqlCommonUpdateNodeTimestamp(p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) cleanupNodes() error {
|
||||
return sqlCommonCleanupNodes(p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) setFirstDownloadTimestamp(username string) error {
|
||||
return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle)
|
||||
}
|
||||
|
@ -669,6 +693,8 @@ func (p *PGSQLProvider) migrateDatabase() error { //nolint:dupl
|
|||
return updatePgSQLDatabaseFromV20(p.dbHandle)
|
||||
case version == 21:
|
||||
return updatePgSQLDatabaseFromV21(p.dbHandle)
|
||||
case version == 22:
|
||||
return updatePgSQLDatabaseFromV21(p.dbHandle)
|
||||
default:
|
||||
if version > sqlDatabaseVersion {
|
||||
providerLog(logger.LevelError, "database schema version %v is newer than the supported one: %v", version,
|
||||
|
@ -697,6 +723,8 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error {
|
|||
return downgradePgSQLDatabaseFromV21(p.dbHandle)
|
||||
case 22:
|
||||
return downgradePgSQLDatabaseFromV22(p.dbHandle)
|
||||
case 23:
|
||||
return downgradePgSQLDatabaseFromV23(p.dbHandle)
|
||||
default:
|
||||
return fmt.Errorf("database schema version not handled: %v", dbVersion.Version)
|
||||
}
|
||||
|
@ -722,7 +750,14 @@ func updatePgSQLDatabaseFromV20(dbHandle *sql.DB) error {
|
|||
}
|
||||
|
||||
func updatePgSQLDatabaseFromV21(dbHandle *sql.DB) error {
|
||||
return updatePgSQLDatabaseFrom21To22(dbHandle)
|
||||
if err := updatePgSQLDatabaseFrom21To22(dbHandle); err != nil {
|
||||
return err
|
||||
}
|
||||
return updatePgSQLDatabaseFromV22(dbHandle)
|
||||
}
|
||||
|
||||
func updatePgSQLDatabaseFromV22(dbHandle *sql.DB) error {
|
||||
return updatePgSQLDatabaseFrom22To23(dbHandle)
|
||||
}
|
||||
|
||||
func downgradePgSQLDatabaseFromV20(dbHandle *sql.DB) error {
|
||||
|
@ -743,6 +778,13 @@ func downgradePgSQLDatabaseFromV22(dbHandle *sql.DB) error {
|
|||
return downgradePgSQLDatabaseFromV21(dbHandle)
|
||||
}
|
||||
|
||||
func downgradePgSQLDatabaseFromV23(dbHandle *sql.DB) error {
|
||||
if err := downgradePgSQLDatabaseFrom23To22(dbHandle); err != nil {
|
||||
return err
|
||||
}
|
||||
return downgradePgSQLDatabaseFromV22(dbHandle)
|
||||
}
|
||||
|
||||
func updatePgSQLDatabaseFrom19To20(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("updating database schema version: 19 -> 20")
|
||||
providerLog(logger.LevelInfo, "updating database schema version: 19 -> 20")
|
||||
|
@ -772,6 +814,13 @@ func updatePgSQLDatabaseFrom21To22(dbHandle *sql.DB) error {
|
|||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 22, true)
|
||||
}
|
||||
|
||||
func updatePgSQLDatabaseFrom22To23(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("updating database schema version: 22 -> 23")
|
||||
providerLog(logger.LevelInfo, "updating database schema version: 22 -> 23")
|
||||
sql := strings.ReplaceAll(pgsqlV23SQL, "{{nodes}}", sqlTableNodes)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 23, true)
|
||||
}
|
||||
|
||||
func downgradePgSQLDatabaseFrom20To19(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("downgrading database schema version: 20 -> 19")
|
||||
providerLog(logger.LevelInfo, "downgrading database schema version: 20 -> 19")
|
||||
|
@ -797,3 +846,10 @@ func downgradePgSQLDatabaseFrom22To21(dbHandle *sql.DB) error {
|
|||
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 21, false)
|
||||
}
|
||||
|
||||
func downgradePgSQLDatabaseFrom23To22(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("downgrading database schema version: 23 -> 22")
|
||||
providerLog(logger.LevelInfo, "downgrading database schema version: 23 -> 22")
|
||||
sql := strings.ReplaceAll(pgsqlV23DownSQL, "{{nodes}}", sqlTableNodes)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 22, false)
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ func startScheduler() error {
|
|||
stopScheduler()
|
||||
|
||||
scheduler = cron.New(cron.WithLocation(time.UTC))
|
||||
_, err := scheduler.AddFunc("@every 60s", checkDataprovider)
|
||||
_, err := scheduler.AddFunc("@every 55s", checkDataprovider)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to schedule dataprovider availability check: %w", err)
|
||||
}
|
||||
|
@ -57,6 +57,19 @@ func startScheduler() error {
|
|||
if fnReloadRules != nil {
|
||||
fnReloadRules()
|
||||
}
|
||||
if currentNode != nil {
|
||||
_, err = scheduler.AddFunc("@every 30m", func() {
|
||||
err := provider.cleanupNodes()
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to cleanup nodes: %v", err)
|
||||
} else {
|
||||
providerLog(logger.LevelDebug, "cleanup nodes ok")
|
||||
}
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to schedule nodes cleanup: %w", err)
|
||||
}
|
||||
scheduler.Start()
|
||||
return nil
|
||||
}
|
||||
|
@ -71,6 +84,13 @@ func addScheduledCacheUpdates() error {
|
|||
}
|
||||
|
||||
func checkDataprovider() {
|
||||
if currentNode != nil {
|
||||
if err := provider.updateNodeTimestamp(); err != nil {
|
||||
providerLog(logger.LevelError, "unable to update node timestamp: %v", err)
|
||||
} else {
|
||||
providerLog(logger.LevelDebug, "node timestamp updated")
|
||||
}
|
||||
}
|
||||
err := provider.checkAvailability()
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "check availability error: %v", err)
|
||||
|
|
|
@ -34,7 +34,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
sqlDatabaseVersion = 22
|
||||
sqlDatabaseVersion = 23
|
||||
defaultSQLQueryTimeout = 10 * time.Second
|
||||
longSQLQueryTimeout = 60 * time.Second
|
||||
)
|
||||
|
@ -77,6 +77,7 @@ func sqlReplaceAll(sql string) string {
|
|||
sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules)
|
||||
sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
|
||||
sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
|
||||
sql = strings.ReplaceAll(sql, "{{nodes}}", sqlTableNodes)
|
||||
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
|
||||
return sql
|
||||
}
|
||||
|
@ -3250,6 +3251,101 @@ func sqlCommonDeleteTask(name string, dbHandle sqlQuerier) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func sqlCommonAddNode(dbHandle *sql.DB) error {
|
||||
if err := currentNode.validate(); err != nil {
|
||||
return fmt.Errorf("unable to register cluster node: %w", err)
|
||||
}
|
||||
data, err := json.Marshal(currentNode.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getAddNodeQuery()
|
||||
_, err = dbHandle.ExecContext(ctx, q, currentNode.Name, string(data), util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
util.GetTimeAsMsSinceEpoch(time.Now()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to register cluster node: %w", err)
|
||||
}
|
||||
providerLog(logger.LevelInfo, "registered as cluster node %q, port: %d, proto: %s",
|
||||
currentNode.Name, currentNode.Data.Port, currentNode.Data.Proto)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqlCommonGetNodeByName(name string, dbHandle *sql.DB) (Node, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
var data []byte
|
||||
var node Node
|
||||
|
||||
q := getNodeByNameQuery()
|
||||
row := dbHandle.QueryRowContext(ctx, q, name, util.GetTimeAsMsSinceEpoch(time.Now().Add(activeNodeTimeDiff)))
|
||||
err := row.Scan(&node.Name, &data, &node.CreatedAt, &node.UpdatedAt)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return node, util.NewRecordNotFoundError(err.Error())
|
||||
}
|
||||
return node, err
|
||||
}
|
||||
err = json.Unmarshal(data, &node.Data)
|
||||
return node, err
|
||||
}
|
||||
|
||||
func sqlCommonGetNodes(dbHandle *sql.DB) ([]Node, error) {
|
||||
var nodes []Node
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getNodesQuery()
|
||||
rows, err := dbHandle.QueryContext(ctx, q, currentNode.Name,
|
||||
util.GetTimeAsMsSinceEpoch(time.Now().Add(activeNodeTimeDiff)))
|
||||
if err != nil {
|
||||
return nodes, err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var node Node
|
||||
var data []byte
|
||||
|
||||
err = rows.Scan(&node.Name, &data, &node.CreatedAt, &node.UpdatedAt)
|
||||
if err != nil {
|
||||
return nodes, err
|
||||
}
|
||||
err = json.Unmarshal(data, &node.Data)
|
||||
if err != nil {
|
||||
return nodes, err
|
||||
}
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
return nodes, rows.Err()
|
||||
}
|
||||
|
||||
func sqlCommonUpdateNodeTimestamp(dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getUpdateNodeTimestampQuery()
|
||||
res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), currentNode.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonRequireRowAffected(res)
|
||||
}
|
||||
|
||||
func sqlCommonCleanupNodes(dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getCleanupNodesQuery()
|
||||
_, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now().Add(10*activeNodeTimeDiff)))
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonGetDatabaseVersion(dbHandle sqlQuerier, showInitWarn bool) (schemaVersion, error) {
|
||||
var result schemaVersion
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
|
|
|
@ -56,6 +56,7 @@ DROP TABLE IF EXISTS "{{rules_actions_mapping}}";
|
|||
DROP TABLE IF EXISTS "{{events_rules}}";
|
||||
DROP TABLE IF EXISTS "{{events_actions}}";
|
||||
DROP TABLE IF EXISTS "{{tasks}}";
|
||||
DROP TABLE IF EXISTS "{{nodes}}";
|
||||
DROP TABLE IF EXISTS "{{schema_version}}";
|
||||
`
|
||||
sqliteInitialSQL = `CREATE TABLE "{{schema_version}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "version" integer NOT NULL);
|
||||
|
@ -176,8 +177,11 @@ ALTER TABLE "{{users}}" DROP COLUMN "first_download";
|
|||
CREATE INDEX "{{prefix}}admins_groups_mapping_admin_id_idx" ON "{{admins_groups_mapping}}" ("admin_id");
|
||||
CREATE INDEX "{{prefix}}admins_groups_mapping_group_id_idx" ON "{{admins_groups_mapping}}" ("group_id");
|
||||
`
|
||||
sqliteV22DownSQL = `DROP TABLE "{{admins_groups_mapping}}";
|
||||
`
|
||||
sqliteV22DownSQL = `DROP TABLE "{{admins_groups_mapping}}";`
|
||||
sqliteV23SQL = `CREATE TABLE "{{nodes}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"name" varchar(255) NOT NULL UNIQUE, "data" text NOT NULL, "created_at" bigint NOT NULL,
|
||||
"updated_at" bigint NOT NULL);`
|
||||
sqliteV23DownSQL = `DROP TABLE "{{nodes}}";`
|
||||
)
|
||||
|
||||
// SQLiteProvider defines the auth provider for SQLite database
|
||||
|
@ -579,6 +583,26 @@ func (p *SQLiteProvider) updateTaskTimestamp(name string) error {
|
|||
return sqlCommonUpdateTaskTimestamp(name, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) addNode() error {
|
||||
return sqlCommonAddNode(p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) getNodeByName(name string) (Node, error) {
|
||||
return sqlCommonGetNodeByName(name, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) getNodes() ([]Node, error) {
|
||||
return sqlCommonGetNodes(p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) updateNodeTimestamp() error {
|
||||
return sqlCommonUpdateNodeTimestamp(p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) cleanupNodes() error {
|
||||
return sqlCommonCleanupNodes(p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) setFirstDownloadTimestamp(username string) error {
|
||||
return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle)
|
||||
}
|
||||
|
@ -632,6 +656,8 @@ func (p *SQLiteProvider) migrateDatabase() error { //nolint:dupl
|
|||
return updateSQLiteDatabaseFromV20(p.dbHandle)
|
||||
case version == 21:
|
||||
return updateSQLiteDatabaseFromV21(p.dbHandle)
|
||||
case version == 22:
|
||||
return updateSQLiteDatabaseFromV22(p.dbHandle)
|
||||
default:
|
||||
if version > sqlDatabaseVersion {
|
||||
providerLog(logger.LevelError, "database schema version %v is newer than the supported one: %v", version,
|
||||
|
@ -660,6 +686,8 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error {
|
|||
return downgradeSQLiteDatabaseFromV21(p.dbHandle)
|
||||
case 22:
|
||||
return downgradeSQLiteDatabaseFromV22(p.dbHandle)
|
||||
case 23:
|
||||
return downgradeSQLiteDatabaseFromV23(p.dbHandle)
|
||||
default:
|
||||
return fmt.Errorf("database schema version not handled: %v", dbVersion.Version)
|
||||
}
|
||||
|
@ -685,7 +713,14 @@ func updateSQLiteDatabaseFromV20(dbHandle *sql.DB) error {
|
|||
}
|
||||
|
||||
func updateSQLiteDatabaseFromV21(dbHandle *sql.DB) error {
|
||||
return updateSQLiteDatabaseFrom21To22(dbHandle)
|
||||
if err := updateSQLiteDatabaseFrom21To22(dbHandle); err != nil {
|
||||
return err
|
||||
}
|
||||
return updateSQLiteDatabaseFromV22(dbHandle)
|
||||
}
|
||||
|
||||
func updateSQLiteDatabaseFromV22(dbHandle *sql.DB) error {
|
||||
return updateSQLiteDatabaseFrom22To23(dbHandle)
|
||||
}
|
||||
|
||||
func downgradeSQLiteDatabaseFromV20(dbHandle *sql.DB) error {
|
||||
|
@ -706,6 +741,13 @@ func downgradeSQLiteDatabaseFromV22(dbHandle *sql.DB) error {
|
|||
return downgradeSQLiteDatabaseFromV21(dbHandle)
|
||||
}
|
||||
|
||||
func downgradeSQLiteDatabaseFromV23(dbHandle *sql.DB) error {
|
||||
if err := downgradeSQLiteDatabaseFrom23To22(dbHandle); err != nil {
|
||||
return err
|
||||
}
|
||||
return downgradeSQLiteDatabaseFromV22(dbHandle)
|
||||
}
|
||||
|
||||
func updateSQLiteDatabaseFrom19To20(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("updating database schema version: 19 -> 20")
|
||||
providerLog(logger.LevelInfo, "updating database schema version: 19 -> 20")
|
||||
|
@ -735,6 +777,13 @@ func updateSQLiteDatabaseFrom21To22(dbHandle *sql.DB) error {
|
|||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 22, true)
|
||||
}
|
||||
|
||||
func updateSQLiteDatabaseFrom22To23(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("updating database schema version: 22 -> 23")
|
||||
providerLog(logger.LevelInfo, "updating database schema version: 22 -> 23")
|
||||
sql := strings.ReplaceAll(sqliteV23SQL, "{{nodes}}", sqlTableNodes)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 23, true)
|
||||
}
|
||||
|
||||
func downgradeSQLiteDatabaseFrom20To19(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("downgrading database schema version: 20 -> 19")
|
||||
providerLog(logger.LevelInfo, "downgrading database schema version: 20 -> 19")
|
||||
|
@ -761,6 +810,13 @@ func downgradeSQLiteDatabaseFrom22To21(dbHandle *sql.DB) error {
|
|||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 21, false)
|
||||
}
|
||||
|
||||
func downgradeSQLiteDatabaseFrom23To22(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("downgrading database schema version: 23 -> 22")
|
||||
providerLog(logger.LevelInfo, "downgrading database schema version: 23 -> 22")
|
||||
sql := strings.ReplaceAll(sqliteV23DownSQL, "{{nodes}}", sqlTableNodes)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 22, false)
|
||||
}
|
||||
|
||||
/*func setPragmaFK(dbHandle *sql.DB, value string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
|
|
@ -950,6 +950,36 @@ func getDeleteTaskQuery() string {
|
|||
return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, sqlTableTasks, sqlPlaceholders[0])
|
||||
}
|
||||
|
||||
func getAddNodeQuery() string {
|
||||
if config.Driver == MySQLDataProviderName {
|
||||
return fmt.Sprintf("INSERT INTO %s (`name`,`data`,created_at,`updated_at`) VALUES (%s,%s,%s,%s) ON DUPLICATE KEY UPDATE "+
|
||||
"`data`=VALUES(`data`), `created_at`=VALUES(`created_at`), `updated_at`=VALUES(`updated_at`)",
|
||||
sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
|
||||
}
|
||||
return fmt.Sprintf(`INSERT INTO %s (name,data,created_at,updated_at) VALUES (%s,%s,%s,%s) ON CONFLICT(name)
|
||||
DO UPDATE SET data=EXCLUDED.data, created_at=EXCLUDED.created_at, updated_at=EXCLUDED.updated_at`,
|
||||
sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
|
||||
}
|
||||
|
||||
func getUpdateNodeTimestampQuery() string {
|
||||
return fmt.Sprintf(`UPDATE %s SET updated_at=%s WHERE name = %s`,
|
||||
sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1])
|
||||
}
|
||||
|
||||
func getNodeByNameQuery() string {
|
||||
return fmt.Sprintf(`SELECT name,data,created_at,updated_at FROM %s WHERE name = %s AND updated_at > %s`,
|
||||
sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1])
|
||||
}
|
||||
|
||||
func getNodesQuery() string {
|
||||
return fmt.Sprintf(`SELECT name,data,created_at,updated_at FROM %s WHERE name != %s AND updated_at > %s`,
|
||||
sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1])
|
||||
}
|
||||
|
||||
func getCleanupNodesQuery() string {
|
||||
return fmt.Sprintf(`DELETE FROM %s WHERE updated_at < %s`, sqlTableNodes, sqlPlaceholders[0])
|
||||
}
|
||||
|
||||
func getDatabaseVersionQuery() string {
|
||||
return fmt.Sprintf("SELECT version from %s LIMIT 1", sqlTableSchemaVersion)
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
@ -152,6 +153,20 @@ func getBoolQueryParam(r *http.Request, param string) bool {
|
|||
return r.URL.Query().Get(param) == "true"
|
||||
}
|
||||
|
||||
func getActiveConnections(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
claims, err := getTokenClaims(r)
|
||||
if err != nil || claims.Username == "" {
|
||||
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
stats := common.Connections.GetStats()
|
||||
if claims.NodeID == "" {
|
||||
stats = append(stats, getNodesConnections()...)
|
||||
}
|
||||
render.JSON(w, r, stats)
|
||||
}
|
||||
|
||||
func handleCloseConnection(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
connectionID := getURLParam(r, "connectionID")
|
||||
|
@ -159,11 +174,61 @@ func handleCloseConnection(w http.ResponseWriter, r *http.Request) {
|
|||
sendAPIResponse(w, r, nil, "connectionID is mandatory", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if common.Connections.Close(connectionID) {
|
||||
sendAPIResponse(w, r, nil, "Connection closed", http.StatusOK)
|
||||
} else {
|
||||
sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound)
|
||||
node := r.URL.Query().Get("node")
|
||||
if node == "" || node == dataprovider.GetNodeName() {
|
||||
if common.Connections.Close(connectionID) {
|
||||
sendAPIResponse(w, r, nil, "Connection closed", http.StatusOK)
|
||||
} else {
|
||||
sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound)
|
||||
}
|
||||
return
|
||||
}
|
||||
n, err := dataprovider.GetNodeByName(node)
|
||||
if err != nil {
|
||||
logger.Warn(logSender, "", "unable to get node with name %q: %v", node, err)
|
||||
status := getRespStatus(err)
|
||||
sendAPIResponse(w, r, nil, http.StatusText(status), status)
|
||||
return
|
||||
}
|
||||
if err := n.SendDeleteRequest(fmt.Sprintf("%s/%s", activeConnectionsPath, connectionID)); err != nil {
|
||||
logger.Warn(logSender, "", "unable to delete connection id %q from node %q: %v", connectionID, n.Name, err)
|
||||
sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
sendAPIResponse(w, r, nil, "Connection closed", http.StatusOK)
|
||||
}
|
||||
|
||||
// getNodesConnections returns the active connections from other nodes.
|
||||
// Errors are silently ignored
|
||||
func getNodesConnections() []common.ConnectionStatus {
|
||||
nodes, err := dataprovider.GetNodes()
|
||||
if err != nil || len(nodes) == 0 {
|
||||
return nil
|
||||
}
|
||||
var results []common.ConnectionStatus
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, n := range nodes {
|
||||
wg.Add(1)
|
||||
|
||||
go func(node dataprovider.Node) {
|
||||
defer wg.Done()
|
||||
|
||||
var stats []common.ConnectionStatus
|
||||
if err := node.SendGetRequest(activeConnectionsPath, &stats); err != nil {
|
||||
logger.Warn(logSender, "", "unable to get connections from node %s: %v", node.Name, err)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
results = append(results, stats...)
|
||||
mu.Unlock()
|
||||
}(n)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
func getSearchFilters(w http.ResponseWriter, r *http.Request) (int, int, string, error) {
|
||||
|
|
|
@ -52,6 +52,7 @@ const (
|
|||
claimUsernameKey = "username"
|
||||
claimPermissionsKey = "permissions"
|
||||
claimAPIKey = "api_key"
|
||||
claimNodeID = "node_id"
|
||||
claimMustSetSecondFactorKey = "2fa_required"
|
||||
claimRequiredTwoFactorProtocols = "2fa_protos"
|
||||
claimHideUserPageSection = "hus"
|
||||
|
@ -74,6 +75,7 @@ type jwtTokenClaims struct {
|
|||
Signature string
|
||||
Audience []string
|
||||
APIKeyID string
|
||||
NodeID string
|
||||
MustSetTwoFactorAuth bool
|
||||
RequiredTwoFactorProtocols []string
|
||||
HideUserPageSections int
|
||||
|
@ -97,6 +99,9 @@ func (c *jwtTokenClaims) asMap() map[string]any {
|
|||
if c.APIKeyID != "" {
|
||||
claims[claimAPIKey] = c.APIKeyID
|
||||
}
|
||||
if c.NodeID != "" {
|
||||
claims[claimNodeID] = c.NodeID
|
||||
}
|
||||
claims[jwt.SubjectKey] = c.Signature
|
||||
if c.MustSetTwoFactorAuth {
|
||||
claims[claimMustSetSecondFactorKey] = c.MustSetTwoFactorAuth
|
||||
|
@ -157,6 +162,13 @@ func (c *jwtTokenClaims) Decode(token map[string]any) {
|
|||
}
|
||||
}
|
||||
|
||||
if val, ok := token[claimNodeID]; ok {
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
c.NodeID = v
|
||||
}
|
||||
}
|
||||
|
||||
permissions := token[claimPermissionsKey]
|
||||
c.Permissions = c.decodeSliceString(permissions)
|
||||
|
||||
|
|
|
@ -10342,6 +10342,15 @@ func TestDeleteActiveConnectionMock(t *testing.T) {
|
|||
setBearerForReq(req, token)
|
||||
rr := executeRequest(req)
|
||||
checkResponseCode(t, http.StatusNotFound, rr)
|
||||
req.Header.Set(dataprovider.NodeTokenHeader, "Bearer abc")
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusUnauthorized, rr)
|
||||
assert.Contains(t, rr.Body.String(), "the provided token cannot be authenticated")
|
||||
req, err = http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID?node=node1", nil)
|
||||
assert.NoError(t, err)
|
||||
setBearerForReq(req, token)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusNotFound, rr)
|
||||
}
|
||||
|
||||
func TestNotFoundMock(t *testing.T) {
|
||||
|
|
|
@ -551,6 +551,11 @@ func TestInvalidToken(t *testing.T) {
|
|||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
assert.Contains(t, rr.Body.String(), "Invalid token claims")
|
||||
|
||||
rr = httptest.NewRecorder()
|
||||
getActiveConnections(rr, req)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
assert.Contains(t, rr.Body.String(), "Invalid token claims")
|
||||
|
||||
rr = httptest.NewRecorder()
|
||||
server.handleWebRestore(rr, req)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
|
|
|
@ -309,6 +309,41 @@ func verifyCSRFHeader(next http.Handler) http.Handler {
|
|||
})
|
||||
}
|
||||
|
||||
func checkNodeToken(tokenAuth *jwtauth.JWTAuth) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.Header.Get(dataprovider.NodeTokenHeader)
|
||||
if token == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if len(token) > 7 && strings.ToUpper(token[0:6]) == "BEARER" {
|
||||
token = token[7:]
|
||||
}
|
||||
if err := dataprovider.AuthenticateNodeToken(token); err != nil {
|
||||
logger.Debug(logSender, "", "unable to authenticate node token %q: %v", token, err)
|
||||
sendAPIResponse(w, r, fmt.Errorf("the provided token cannot be authenticated"), "", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
c := jwtTokenClaims{
|
||||
Username: fmt.Sprintf("node %s", dataprovider.GetNodeName()),
|
||||
Permissions: []string{dataprovider.PermAdminViewConnections, dataprovider.PermAdminCloseConnections},
|
||||
NodeID: dataprovider.GetNodeName(),
|
||||
}
|
||||
|
||||
resp, err := c.createTokenResponse(tokenAuth, tokenAudienceAPI, util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
if err != nil {
|
||||
sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", resp["access_token"]))
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func checkAPIKeyAuth(tokenAuth *jwtauth.JWTAuth, scope dataprovider.APIKeyScope) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -337,7 +372,7 @@ func checkAPIKeyAuth(tokenAuth *jwtauth.JWTAuth, scope dataprovider.APIKeyScope)
|
|||
return
|
||||
}
|
||||
if err := k.Authenticate(key); err != nil {
|
||||
logger.Debug(logSender, "unable to authenticate api key %#v: %v", apiKey, err)
|
||||
logger.Debug(logSender, "", "unable to authenticate api key %#v: %v", apiKey, err)
|
||||
sendAPIResponse(w, r, fmt.Errorf("the provided api key cannot be authenticated"), "", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1195,6 +1195,7 @@ func (s *httpdServer) initializeRouter() {
|
|||
s.router.Post(userPath+"/{username}/reset-password", resetUserPassword)
|
||||
|
||||
s.router.Group(func(router chi.Router) {
|
||||
router.Use(checkNodeToken(s.tokenAuth))
|
||||
router.Use(checkAPIKeyAuth(s.tokenAuth, dataprovider.APIKeyScopeAdmin))
|
||||
router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromHeader))
|
||||
router.Use(jwtAuthenticatorAPI)
|
||||
|
@ -1222,12 +1223,7 @@ func (s *httpdServer) initializeRouter() {
|
|||
render.JSON(w, r, getServicesStatus())
|
||||
})
|
||||
|
||||
router.With(s.checkPerm(dataprovider.PermAdminViewConnections)).
|
||||
Get(activeConnectionsPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
render.JSON(w, r, common.Connections.GetStats())
|
||||
})
|
||||
|
||||
router.With(s.checkPerm(dataprovider.PermAdminViewConnections)).Get(activeConnectionsPath, getActiveConnections)
|
||||
router.With(s.checkPerm(dataprovider.PermAdminCloseConnections)).
|
||||
Delete(activeConnectionsPath+"/{connectionID}", handleCloseConnection)
|
||||
router.With(s.checkPerm(dataprovider.PermAdminQuotaScans)).Get(quotasBasePath+"/users/scans", getUsersQuotaScans)
|
||||
|
|
|
@ -2845,6 +2845,7 @@ func (s *httpdServer) handleWebGetStatus(w http.ResponseWriter, r *http.Request)
|
|||
func (s *httpdServer) handleWebGetConnections(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
connectionStats := common.Connections.GetStats()
|
||||
connectionStats = append(connectionStats, getNodesConnections()...)
|
||||
data := connectionsPage{
|
||||
basePage: s.getBasePageData(pageConnectionsTitle, webConnectionsPath, r),
|
||||
Connections: connectionStats,
|
||||
|
|
|
@ -5459,6 +5459,9 @@ components:
|
|||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Transfer'
|
||||
node:
|
||||
type: string
|
||||
description: 'Node identifier, omitted for single node installations'
|
||||
FolderRetention:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
@ -236,6 +236,11 @@
|
|||
"create_default_admin": false,
|
||||
"naming_rules": 1,
|
||||
"is_shared": 0,
|
||||
"node": {
|
||||
"host": "",
|
||||
"port": 0,
|
||||
"proto": "http"
|
||||
},
|
||||
"backups_path": "backups"
|
||||
},
|
||||
"httpd": {
|
||||
|
|
|
@ -40,6 +40,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|||
<thead>
|
||||
<tr>
|
||||
<th>ID</th>
|
||||
<th>Node</th>
|
||||
<th>Username</th>
|
||||
<th>Time</th>
|
||||
<th>Info</th>
|
||||
|
@ -50,6 +51,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|||
{{range .Connections}}
|
||||
<tr>
|
||||
<td>{{.ConnectionID}}</td>
|
||||
<td>{{.Node}}</td>
|
||||
<td>{{.Username}}</td>
|
||||
<td>{{.GetConnectionDuration}}</td>
|
||||
<td>{{.GetConnectionInfo}}</td>
|
||||
|
@ -105,8 +107,10 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|||
function disconnectAction() {
|
||||
var table = $('#dataTable').DataTable();
|
||||
table.button('disconnect:name').enable(false);
|
||||
var connectionID = table.row({ selected: true }).data()[0];
|
||||
var path = '{{.ConnectionsURL}}' + "/" + connectionID;
|
||||
var selectedData = table.row({ selected: true }).data()
|
||||
var connectionID = selectedData[0];
|
||||
var nodeID = selectedData[1];
|
||||
var path = '{{.ConnectionsURL}}' + "/" + fixedEncodeURIComponent(connectionID)+"?node="+encodeURIComponent(nodeID);
|
||||
$('#disconnectModal').modal('hide');
|
||||
$.ajax({
|
||||
url: path,
|
||||
|
@ -174,13 +178,13 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|||
"lengthChange": true,
|
||||
"columnDefs": [
|
||||
{
|
||||
"targets": [0],
|
||||
"targets": [0, 1],
|
||||
"visible": false,
|
||||
"searchable": false,
|
||||
"className": "noVis"
|
||||
},
|
||||
{
|
||||
"targets": [1],
|
||||
"targets": [2],
|
||||
"className": "noVis"
|
||||
}
|
||||
],
|
||||
|
@ -190,7 +194,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|||
"language": {
|
||||
"emptyTable": "No user connected"
|
||||
},
|
||||
"order": [[1, 'asc']]
|
||||
"order": [[2, 'asc']]
|
||||
});
|
||||
|
||||
new $.fn.dataTable.FixedHeader( table );
|
||||
|
|
Loading…
Reference in a new issue