add support for inter-node communications

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino 2022-09-25 19:48:55 +02:00
parent a538255034
commit 76e89d07d4
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
25 changed files with 847 additions and 59 deletions

View file

@ -253,6 +253,10 @@ The configuration file contains the following sections:
- `create_default_admin`, boolean. Before you can use SFTPGo you need to create an admin account. If you open the admin web UI, a setup screen will guide you in creating the first admin account. You can automatically create the first admin account by enabling this setting and setting the environment variables `SFTPGO_DEFAULT_ADMIN_USERNAME` and `SFTPGO_DEFAULT_ADMIN_PASSWORD`. You can also create the first admin by loading initial data. This setting has no effect if an admin account is already found within the data provider. Default `false`.
- `naming_rules`, integer. Naming rules for usernames, folder and group names. `0` means no rules. `1` means you can use any UTF-8 character. The names are used in URIs for REST API and Web admin. If not set only unreserved URI characters are allowed: ALPHA / DIGIT / "-" / "." / "_" / "~". `2` means names are converted to lowercase before saving/matching and so case insensitive matching is possible. `3` means trimming trailing and leading white spaces before saving/matching. Rules can be combined, for example `3` means both converting to lowercase and allowing any UTF-8 character. Enabling these options for existing installations could be backward incompatible, some users could be unable to login, for example existing users with mixed cases in their usernames. You have to ensure that all existing users respect the defined rules. Default: `1`.
- `is_shared`, integer. If the data provider is shared across multiple SFTPGo instances, set this parameter to `1`. `MySQL`, `PostgreSQL` and `CockroachDB` can be shared, this setting is ignored for other data providers. For shared data providers, active transfers are persisted in the database and thus quota checks between ongoing transfers will work cross multiple instances. Password reset requests and OIDC tokens/states are also persisted in the database if the provider is shared. For shared data providers, scheduled event actions are only executed on a single SFTPGo instance by default, you can override this behavior on a per-action basis. The database table `shared_sessions` is used only to store temporary sessions. In performance critical installations, you might consider using a database-specific optimization, for example you might use an `UNLOGGED` table for PostgreSQL. This optimization in only required in very limited use cases. Default: `0`.
- `node`, struct. Node-specific configurations to allow inter-node communications. If your provider is shared across multiple nodes, the nodes can exchange information to present a uniform view for node-specific data. The current implementation allows to obtain active connections from all nodes. Nodes connect to each other using the REST API.
- `host`, string. IP address or hostname that other nodes can use to connect to this node via REST API. Empty means inter-node communications disabled. Default: empty.
- `port`, integer. The port that other nodes can use to connect to this node via REST API. Default: `0`
- `proto`, string. Supported values `http` or `https`. For `https` the configurations for http clients is used, so you can, for example, enable mutual TLS authentication. Default: `http`
- `backups_path`, string. Path to the backup directory. This can be an absolute path or a path relative to the config dir. We don't allow backups in arbitrary paths for security reasons.
- **"httpd"**, the configuration for the HTTP server used to serve REST API and to expose the built-in web interface
- `bindings`, list of structs. Each struct has the following fields:

2
go.mod
View file

@ -156,7 +156,7 @@ require (
golang.org/x/tools v0.1.12 // indirect
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20220921223823-23cae91e6737 // indirect
google.golang.org/genproto v0.0.0-20220923205249-dd2d53f1fffc // indirect
google.golang.org/grpc v1.49.0 // indirect
google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect

4
go.sum
View file

@ -1229,8 +1229,8 @@ google.golang.org/genproto v0.0.0-20220523171625-347a074981d8/go.mod h1:RAyBrSAP
google.golang.org/genproto v0.0.0-20220608133413-ed9918b62aac/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA=
google.golang.org/genproto v0.0.0-20220616135557-88e70c0c3a90/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA=
google.golang.org/genproto v0.0.0-20220624142145-8cd45d7dbd1f/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA=
google.golang.org/genproto v0.0.0-20220921223823-23cae91e6737 h1:K1zaaMdYBXRyX+cwFnxj7M6zwDyumLQMZ5xqwGvjreQ=
google.golang.org/genproto v0.0.0-20220921223823-23cae91e6737/go.mod h1:2r/26NEF3bFmT3eC3aZreahSal0C3Shl8Gi6vyDYqOQ=
google.golang.org/genproto v0.0.0-20220923205249-dd2d53f1fffc h1:saaNe2+SBQxandnzcD/qB1JEBQ2Pqew+KlFLLdA/XcM=
google.golang.org/genproto v0.0.0-20220923205249-dd2d53f1fffc/go.mod h1:yEEpwVWKMZZzo81NwRgyEJnA2fQvpXAYPVisv8EgDVs=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=

View file

@ -1076,6 +1076,7 @@ func (conns *ActiveConnections) GetStats() []ConnectionStatus {
defer conns.RUnlock()
stats := make([]ConnectionStatus, 0, len(conns.connections))
node := dataprovider.GetNodeName()
for _, c := range conns.connections {
stat := ConnectionStatus{
Username: c.GetUsername(),
@ -1087,6 +1088,7 @@ func (conns *ActiveConnections) GetStats() []ConnectionStatus {
Protocol: c.GetProtocol(),
Command: c.GetCommand(),
Transfers: c.GetTransfers(),
Node: node,
}
stats = append(stats, stat)
}
@ -1113,6 +1115,8 @@ type ConnectionStatus struct {
Transfers []ConnectionTransfer `json:"active_transfers,omitempty"`
// SSH command or WebDAV method
Command string `json:"command,omitempty"`
// Node identifier, omitted for single node installations
Node string `json:"node,omitempty"`
}
// GetConnectionDuration returns the connection duration as string

View file

@ -363,6 +363,11 @@ func Init() {
CreateDefaultAdmin: false,
NamingRules: 1,
IsShared: 0,
Node: dataprovider.NodeConfig{
Host: "",
Port: 0,
Proto: "http",
},
BackupsPath: "backups",
},
HTTPDConfig: httpd.Conf{
@ -1967,6 +1972,9 @@ func setViperDefaults() {
viper.SetDefault("data_provider.create_default_admin", globalConf.ProviderConf.CreateDefaultAdmin)
viper.SetDefault("data_provider.naming_rules", globalConf.ProviderConf.NamingRules)
viper.SetDefault("data_provider.is_shared", globalConf.ProviderConf.IsShared)
viper.SetDefault("data_provider.node.host", globalConf.ProviderConf.Node.Host)
viper.SetDefault("data_provider.node.port", globalConf.ProviderConf.Node.Port)
viper.SetDefault("data_provider.node.proto", globalConf.ProviderConf.Node.Proto)
viper.SetDefault("data_provider.backups_path", globalConf.ProviderConf.BackupsPath)
viper.SetDefault("httpd.templates_path", globalConf.HTTPDConfig.TemplatesPath)
viper.SetDefault("httpd.static_files_path", globalConf.HTTPDConfig.StaticFilesPath)

View file

@ -35,7 +35,7 @@ import (
)
const (
boltDatabaseVersion = 22
boltDatabaseVersion = 23
)
var (
@ -2483,19 +2483,39 @@ func (p *BoltProvider) deleteEventRule(rule EventRule, softDelete bool) error {
})
}
func (p *BoltProvider) getTaskByName(name string) (Task, error) {
func (*BoltProvider) getTaskByName(name string) (Task, error) {
return Task{}, ErrNotImplemented
}
func (p *BoltProvider) addTask(name string) error {
func (*BoltProvider) addTask(name string) error {
return ErrNotImplemented
}
func (p *BoltProvider) updateTask(name string, version int64) error {
func (*BoltProvider) updateTask(name string, version int64) error {
return ErrNotImplemented
}
func (p *BoltProvider) updateTaskTimestamp(name string) error {
func (*BoltProvider) updateTaskTimestamp(name string) error {
return ErrNotImplemented
}
func (*BoltProvider) addNode() error {
return ErrNotImplemented
}
func (*BoltProvider) getNodeByName(name string) (Node, error) {
return Node{}, ErrNotImplemented
}
func (*BoltProvider) getNodes() ([]Node, error) {
return nil, ErrNotImplemented
}
func (*BoltProvider) updateNodeTimestamp() error {
return ErrNotImplemented
}
func (*BoltProvider) cleanupNodes() error {
return ErrNotImplemented
}
@ -2583,10 +2603,10 @@ func (p *BoltProvider) migrateDatabase() error {
providerLog(logger.LevelError, "%v", err)
logger.ErrorToConsole("%v", err)
return err
case version == 19, version == 20, version == 21:
logger.InfoToConsole(fmt.Sprintf("updating database schema version: %d -> 22", version))
providerLog(logger.LevelInfo, "updating database schema version: %d -> 22", version)
return updateBoltDatabaseVersion(p.dbHandle, 22)
case version == 19, version == 20, version == 21, version == 22:
logger.InfoToConsole(fmt.Sprintf("updating database schema version: %d -> 23", version))
providerLog(logger.LevelInfo, "updating database schema version: %d -> 23", version)
return updateBoltDatabaseVersion(p.dbHandle, 23)
default:
if version > boltDatabaseVersion {
providerLog(logger.LevelError, "database schema version %v is newer than the supported one: %v", version,
@ -2608,7 +2628,7 @@ func (p *BoltProvider) revertDatabase(targetVersion int) error {
return errors.New("current version match target version, nothing to do")
}
switch dbVersion.Version {
case 20, 21:
case 20, 21, 22, 23:
logger.InfoToConsole("downgrading database schema version: %d -> 19", dbVersion.Version)
providerLog(logger.LevelInfo, "downgrading database schema version: %d -> 19", dbVersion.Version)
err := p.dbHandle.Update(func(tx *bolt.Tx) error {

View file

@ -187,6 +187,7 @@ var (
sqlTableEventsRules string
sqlTableRulesActionsMapping string
sqlTableTasks string
sqlTableNodes string
sqlTableSchemaVersion string
argon2Params *argon2id.Params
lastLoginMinDelay = 10 * time.Minute
@ -216,6 +217,7 @@ func initSQLTables() {
sqlTableEventsRules = "events_rules"
sqlTableRulesActionsMapping = "rules_actions_mapping"
sqlTableTasks = "tasks"
sqlTableNodes = "nodes"
sqlTableSchemaVersion = "schema_version"
}
@ -311,7 +313,7 @@ type ProviderStatus struct {
Error string `json:"error"`
}
// Config provider configuration
// Config defines the provider configuration
type Config struct {
// Driver name, must be one of the SupportedProviders
Driver string `json:"driver" mapstructure:"driver"`
@ -460,6 +462,9 @@ type Config struct {
// For shared data providers, active transfers are persisted in the database and thus
// quota checks between ongoing transfers will work cross multiple instances
IsShared int `json:"is_shared" mapstructure:"is_shared"`
// Node defines the configuration for this cluster node.
// Ignored if the provider is not shared/shareable
Node NodeConfig `json:"node" mapstructure:"node"`
// Path to the backup directory. This can be an absolute path or a path relative to the config dir
BackupsPath string `json:"backups_path" mapstructure:"backups_path"`
}
@ -778,6 +783,11 @@ type Provider interface {
updateTaskTimestamp(name string) error
setFirstDownloadTimestamp(username string) error
setFirstUploadTimestamp(username string) error
addNode() error
getNodeByName(name string) (Node, error)
getNodes() ([]Node, error)
updateNodeTimestamp() error
cleanupNodes() error
checkAvailability() error
close() error
reloadConfig() error
@ -801,7 +811,6 @@ func checkSharedMode() {
// Initialize the data provider.
// An error is returned if the configured driver is invalid or if the data provider cannot be initialized
func Initialize(cnf Config, basePath string, checkAdmins bool) error {
var err error
config = cnf
checkSharedMode()
config.Actions.ExecuteOn = util.RemoveDuplicates(config.Actions.ExecuteOn, true)
@ -812,19 +821,33 @@ func Initialize(cnf Config, basePath string, checkAdmins bool) error {
return fmt.Errorf("required directory is invalid, backup path %#v", cnf.BackupsPath)
}
if err = initializeHashingAlgo(&cnf); err != nil {
if err := initializeHashingAlgo(&cnf); err != nil {
return err
}
if err = validateHooks(); err != nil {
if err := validateHooks(); err != nil {
return err
}
err = createProvider(basePath)
if err := createProvider(basePath); err != nil {
return err
}
if err := checkDatabase(checkAdmins); err != nil {
return err
}
admins, err := provider.getAdmins(1, 0, OrderASC)
if err != nil {
return err
}
if cnf.UpdateMode == 0 {
err = provider.initializeDatabase()
isAdminCreated.Store(len(admins) > 0)
if err := config.Node.validate(); err != nil {
return err
}
delayedQuotaUpdater.start()
return startScheduler()
}
func checkDatabase(checkAdmins bool) error {
if config.UpdateMode == 0 {
err := provider.initializeDatabase()
if err != nil && err != ErrNoInitRequired {
logger.WarnToConsole("Unable to initialize data provider: %v", err)
providerLog(logger.LevelError, "Unable to initialize data provider: %v", err)
@ -838,7 +861,7 @@ func Initialize(cnf Config, basePath string, checkAdmins bool) error {
providerLog(logger.LevelError, "database migration error: %v", err)
return err
}
if checkAdmins && cnf.CreateDefaultAdmin {
if checkAdmins && config.CreateDefaultAdmin {
err = checkDefaultAdmin()
if err != nil {
providerLog(logger.LevelError, "erro checking the default admin: %v", err)
@ -848,13 +871,7 @@ func Initialize(cnf Config, basePath string, checkAdmins bool) error {
} else {
providerLog(logger.LevelInfo, "database initialization/migration skipped, manual mode is configured")
}
admins, err := provider.getAdmins(1, 0, OrderASC)
if err != nil {
return err
}
isAdminCreated.Store(len(admins) > 0)
delayedQuotaUpdater.start()
return startScheduler()
return nil
}
func validateHooks() error {
@ -937,15 +954,17 @@ func validateSQLTablesPrefix() error {
sqlTableEventsRules = config.SQLTablesPrefix + sqlTableEventsRules
sqlTableRulesActionsMapping = config.SQLTablesPrefix + sqlTableRulesActionsMapping
sqlTableTasks = config.SQLTablesPrefix + sqlTableTasks
sqlTableNodes = config.SQLTablesPrefix + sqlTableNodes
sqlTableSchemaVersion = config.SQLTablesPrefix + sqlTableSchemaVersion
providerLog(logger.LevelDebug, "sql table for users %q, folders %q users folders mapping %q admins %q "+
"api keys %q shares %q defender hosts %q defender events %q transfers %q groups %q "+
"users groups mapping %q admins groups mapping %q groups folders mapping %q shared sessions %q "+
"schema version %q events actions %q events rules %q rules actions mapping %q tasks %q",
"schema version %q events actions %q events rules %q rules actions mapping %q tasks %q nodes %q",
sqlTableUsers, sqlTableFolders, sqlTableUsersFoldersMapping, sqlTableAdmins, sqlTableAPIKeys,
sqlTableShares, sqlTableDefenderHosts, sqlTableDefenderEvents, sqlTableActiveTransfers, sqlTableGroups,
sqlTableUsersGroupsMapping, sqlTableAdminsGroupsMapping, sqlTableGroupsFoldersMapping, sqlTableSharedSessions,
sqlTableSchemaVersion, sqlTableEventsActions, sqlTableEventsRules, sqlTableRulesActionsMapping, sqlTableTasks)
sqlTableSchemaVersion, sqlTableEventsActions, sqlTableEventsRules, sqlTableRulesActionsMapping,
sqlTableTasks, sqlTableNodes)
}
return nil
}
@ -1728,6 +1747,29 @@ func UpdateTaskTimestamp(name string) error {
return provider.updateTaskTimestamp(name)
}
// GetNodes returns the other cluster nodes
func GetNodes() ([]Node, error) {
if currentNode == nil {
return nil, nil
}
nodes, err := provider.getNodes()
if err != nil {
providerLog(logger.LevelError, "unable to get other cluster nodes %v", err)
}
return nodes, err
}
// GetNodeByName returns a node, different from the current one, by name
func GetNodeByName(name string) (Node, error) {
if currentNode == nil {
return Node{}, util.NewRecordNotFoundError(errNoClusterNodes.Error())
}
if name == currentNode.Name {
return Node{}, util.NewValidationError(fmt.Sprintf("%s is the current node, it must refer to other nodes", name))
}
return provider.getNodeByName(name)
}
// HasAdmin returns true if the first admin has been created
// and so SFTPGo is ready to be used
func HasAdmin() bool {

View file

@ -2286,19 +2286,39 @@ func (p *MemoryProvider) deleteEventRule(rule EventRule, softDelete bool) error
return nil
}
func (p *MemoryProvider) getTaskByName(name string) (Task, error) {
func (*MemoryProvider) getTaskByName(name string) (Task, error) {
return Task{}, ErrNotImplemented
}
func (p *MemoryProvider) addTask(name string) error {
func (*MemoryProvider) addTask(name string) error {
return ErrNotImplemented
}
func (p *MemoryProvider) updateTask(name string, version int64) error {
func (*MemoryProvider) updateTask(name string, version int64) error {
return ErrNotImplemented
}
func (p *MemoryProvider) updateTaskTimestamp(name string) error {
func (*MemoryProvider) updateTaskTimestamp(name string) error {
return ErrNotImplemented
}
func (*MemoryProvider) addNode() error {
return ErrNotImplemented
}
func (*MemoryProvider) getNodeByName(name string) (Node, error) {
return Node{}, ErrNotImplemented
}
func (*MemoryProvider) getNodes() ([]Node, error) {
return nil, ErrNotImplemented
}
func (*MemoryProvider) updateNodeTimestamp() error {
return ErrNotImplemented
}
func (*MemoryProvider) cleanupNodes() error {
return ErrNotImplemented
}

View file

@ -56,6 +56,7 @@ const (
"DROP TABLE IF EXISTS `{{events_actions}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{events_rules}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{tasks}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{nodes}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{schema_version}}` CASCADE;"
mysqlInitialSQL = "CREATE TABLE `{{schema_version}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `version` integer NOT NULL);" +
"CREATE TABLE `{{admins}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " +
@ -182,6 +183,10 @@ const (
"FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE CASCADE;"
mysqlV22DownSQL = "ALTER TABLE `{{admins_groups_mapping}}` DROP INDEX `{{prefix}}unique_admin_group_mapping`;" +
"DROP TABLE `{{admins_groups_mapping}}` CASCADE;"
mysqlV23SQL = "CREATE TABLE `{{nodes}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`name` varchar(255) NOT NULL UNIQUE, `data` longtext NOT NULL, `created_at` bigint NOT NULL, " +
"`updated_at` bigint NOT NULL);"
mysqlV23DownSQL = "DROP TABLE `{{nodes}}` CASCADE;"
)
// MySQLProvider defines the auth provider for MySQL/MariaDB database
@ -644,6 +649,26 @@ func (p *MySQLProvider) updateTaskTimestamp(name string) error {
return sqlCommonUpdateTaskTimestamp(name, p.dbHandle)
}
func (p *MySQLProvider) addNode() error {
return sqlCommonAddNode(p.dbHandle)
}
func (p *MySQLProvider) getNodeByName(name string) (Node, error) {
return sqlCommonGetNodeByName(name, p.dbHandle)
}
func (p *MySQLProvider) getNodes() ([]Node, error) {
return sqlCommonGetNodes(p.dbHandle)
}
func (p *MySQLProvider) updateNodeTimestamp() error {
return sqlCommonUpdateNodeTimestamp(p.dbHandle)
}
func (p *MySQLProvider) cleanupNodes() error {
return sqlCommonCleanupNodes(p.dbHandle)
}
func (p *MySQLProvider) setFirstDownloadTimestamp(username string) error {
return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle)
}
@ -697,6 +722,8 @@ func (p *MySQLProvider) migrateDatabase() error { //nolint:dupl
return updateMySQLDatabaseFromV20(p.dbHandle)
case version == 21:
return updateMySQLDatabaseFromV21(p.dbHandle)
case version == 22:
return updateMySQLDatabaseFromV22(p.dbHandle)
default:
if version > sqlDatabaseVersion {
providerLog(logger.LevelError, "database schema version %v is newer than the supported one: %v", version,
@ -725,6 +752,8 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error {
return downgradeMySQLDatabaseFromV21(p.dbHandle)
case 22:
return downgradeMySQLDatabaseFromV22(p.dbHandle)
case 23:
return downgradeMySQLDatabaseFromV23(p.dbHandle)
default:
return fmt.Errorf("database schema version not handled: %v", dbVersion.Version)
}
@ -750,7 +779,14 @@ func updateMySQLDatabaseFromV20(dbHandle *sql.DB) error {
}
func updateMySQLDatabaseFromV21(dbHandle *sql.DB) error {
return updateMySQLDatabaseFrom21To22(dbHandle)
if err := updateMySQLDatabaseFrom21To22(dbHandle); err != nil {
return err
}
return updateMySQLDatabaseFromV22(dbHandle)
}
func updateMySQLDatabaseFromV22(dbHandle *sql.DB) error {
return updateMySQLDatabaseFrom22To23(dbHandle)
}
func downgradeMySQLDatabaseFromV20(dbHandle *sql.DB) error {
@ -771,6 +807,13 @@ func downgradeMySQLDatabaseFromV22(dbHandle *sql.DB) error {
return downgradeMySQLDatabaseFromV21(dbHandle)
}
func downgradeMySQLDatabaseFromV23(dbHandle *sql.DB) error {
if err := downgradeMySQLDatabaseFrom23To22(dbHandle); err != nil {
return err
}
return downgradeMySQLDatabaseFromV22(dbHandle)
}
func updateMySQLDatabaseFrom19To20(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 19 -> 20")
providerLog(logger.LevelInfo, "updating database schema version: 19 -> 20")
@ -800,6 +843,13 @@ func updateMySQLDatabaseFrom21To22(dbHandle *sql.DB) error {
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 22, true)
}
func updateMySQLDatabaseFrom22To23(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 22 -> 23")
providerLog(logger.LevelInfo, "updating database schema version: 22 -> 23")
sql := strings.ReplaceAll(mysqlV23SQL, "{{nodes}}", sqlTableNodes)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 23, true)
}
func downgradeMySQLDatabaseFrom20To19(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 20 -> 19")
providerLog(logger.LevelInfo, "downgrading database schema version: 20 -> 19")
@ -825,3 +875,10 @@ func downgradeMySQLDatabaseFrom22To21(dbHandle *sql.DB) error {
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 21, false)
}
func downgradeMySQLDatabaseFrom23To22(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 23 -> 22")
providerLog(logger.LevelInfo, "downgrading database schema version: 23 -> 22")
sql := strings.ReplaceAll(mysqlV23DownSQL, "{{nodes}}", sqlTableNodes)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 22, false)
}

View file

@ -0,0 +1,240 @@
package dataprovider
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
"github.com/rs/xid"
"github.com/drakkan/sftpgo/v2/internal/httpclient"
"github.com/drakkan/sftpgo/v2/internal/kms"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
)
// Supported protocols for connecting to other nodes
const (
NodeProtoHTTP = "http"
NodeProtoHTTPS = "https"
)
const (
// NodeTokenHeader defines the header to use for the node auth token
NodeTokenHeader = "X-SFTPGO-Node"
)
var (
// current node
currentNode *Node
errNoClusterNodes = errors.New("no cluster node defined")
activeNodeTimeDiff = -2 * time.Minute
nodeReqTimeout = 8 * time.Second
)
// NodeConfig defines the node configuration
type NodeConfig struct {
Host string `json:"host" mapstructure:"host"`
Port int `json:"port" mapstructure:"port"`
Proto string `json:"proto" mapstructure:"proto"`
}
func (n *NodeConfig) validate() error {
currentNode = nil
if config.IsShared != 1 {
return nil
}
if n.Host == "" {
return nil
}
currentNode = &Node{
Data: NodeData{
Host: n.Host,
Port: n.Port,
Proto: n.Proto,
},
}
return provider.addNode()
}
// NodeData defines the details to connect to a cluster node
type NodeData struct {
Host string `json:"host"`
Port int `json:"port"`
Proto string `json:"proto"`
Key *kms.Secret `json:"api_key"`
}
func (n *NodeData) validate() error {
if n.Host == "" {
return util.NewValidationError("node host is mandatory")
}
if n.Port < 0 || n.Port > 65535 {
return util.NewValidationError(fmt.Sprintf("invalid node port: %d", n.Port))
}
if n.Proto != NodeProtoHTTP && n.Proto != NodeProtoHTTPS {
return util.NewValidationError(fmt.Sprintf("invalid node proto: %s", n.Proto))
}
n.Key = kms.NewPlainSecret(string(util.GenerateRandomBytes(32)))
n.Key.SetAdditionalData(n.Host)
if err := n.Key.Encrypt(); err != nil {
return fmt.Errorf("unable to encrypt node key: %w", err)
}
return nil
}
// Node defines a cluster node
type Node struct {
Name string `json:"name"`
Data NodeData `json:"data"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
}
func (n *Node) validate() error {
if n.Name == "" {
n.Name = n.Data.Host
}
return n.Data.validate()
}
func (n *Node) authenticate(token string) error {
if err := n.Data.Key.TryDecrypt(); err != nil {
providerLog(logger.LevelError, "unable to decrypt node key: %v", err)
return err
}
if token == "" {
return ErrInvalidCredentials
}
t, err := jwt.Parse([]byte(token), jwt.WithVerify(jwa.HS256, []byte(n.Data.Key.GetPayload())))
if err != nil {
return fmt.Errorf("unable to parse token: %v", err)
}
if err := jwt.Validate(t); err != nil {
return fmt.Errorf("unable to validate token: %v", err)
}
return nil
}
// getBaseURL returns the base URL for this node
func (n *Node) getBaseURL() string {
var sb strings.Builder
sb.WriteString(n.Data.Proto)
sb.WriteString("://")
sb.WriteString(n.Data.Host)
if n.Data.Port > 0 {
sb.WriteString(":")
sb.WriteString(strconv.Itoa(n.Data.Port))
}
return sb.String()
}
// generateAuthToken generates a new auth token
func (n *Node) generateAuthToken() (string, error) {
if err := n.Data.Key.TryDecrypt(); err != nil {
return "", fmt.Errorf("unable to decrypt node key: %w", err)
}
now := time.Now().UTC()
t := jwt.New()
t.Set(jwt.JwtIDKey, xid.New().String()) //nolint:errcheck
t.Set(jwt.NotBeforeKey, now.Add(-30*time.Second)) //nolint:errcheck
t.Set(jwt.ExpirationKey, now.Add(1*time.Minute)) //nolint:errcheck
payload, err := jwt.Sign(t, jwa.HS256, []byte(n.Data.Key.GetPayload()))
if err != nil {
return "", fmt.Errorf("unable to sign authentication token: %w", err)
}
return string(payload), nil
}
func (n *Node) prepareRequest(ctx context.Context, relativeURL, method string, body io.Reader) (*http.Request, error) {
url := fmt.Sprintf("%s%s", n.getBaseURL(), relativeURL)
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, err
}
token, err := n.generateAuthToken()
if err != nil {
return nil, err
}
req.Header.Set(NodeTokenHeader, fmt.Sprintf("Bearer %s", token))
return req, nil
}
// SendGetRequest sends an HTTP GET request to this node.
// The responseHolder must be a pointer
func (n *Node) SendGetRequest(relativeURL string, responseHolder any) error {
ctx, cancel := context.WithTimeout(context.Background(), nodeReqTimeout)
defer cancel()
req, err := n.prepareRequest(ctx, relativeURL, http.MethodGet, nil)
if err != nil {
return err
}
client := httpclient.GetHTTPClient()
defer client.CloseIdleConnections()
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("unable to send HTTP GET to node %s: %w", n.Name, err)
}
defer resp.Body.Close()
if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
err = json.NewDecoder(resp.Body).Decode(responseHolder)
if err != nil {
return fmt.Errorf("unable to decode response as json")
}
return nil
}
// SendDeleteRequest sends an HTTP DELETE request to this node
func (n *Node) SendDeleteRequest(relativeURL string) error {
ctx, cancel := context.WithTimeout(context.Background(), nodeReqTimeout)
defer cancel()
req, err := n.prepareRequest(ctx, relativeURL, http.MethodDelete, nil)
if err != nil {
return err
}
client := httpclient.GetHTTPClient()
defer client.CloseIdleConnections()
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("unable to send HTTP DELETE to node %s: %w", n.Name, err)
}
defer resp.Body.Close()
if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
return nil
}
// AuthenticateNodeToken check the validity of the provided token
func AuthenticateNodeToken(token string) error {
if currentNode == nil {
return errNoClusterNodes
}
return currentNode.authenticate(token)
}
// GetNodeName returns the node name or an empty string
func GetNodeName() string {
if currentNode == nil {
return ""
}
return currentNode.Name
}

View file

@ -54,6 +54,7 @@ DROP TABLE IF EXISTS "{{rules_actions_mapping}}" CASCADE;
DROP TABLE IF EXISTS "{{events_actions}}" CASCADE;
DROP TABLE IF EXISTS "{{events_rules}}" CASCADE;
DROP TABLE IF EXISTS "{{tasks}}" CASCADE;
DROP TABLE IF EXISTS "{{nodes}}" CASCADE;
DROP TABLE IF EXISTS "{{schema_version}}" CASCADE;
`
pgsqlInitial = `CREATE TABLE "{{schema_version}}" ("id" serial NOT NULL PRIMARY KEY, "version" integer NOT NULL);
@ -198,6 +199,9 @@ CREATE INDEX "{{prefix}}admins_groups_mapping_group_id_idx" ON "{{admins_groups_
pgsqlV22DownSQL = `ALTER TABLE "{{admins_groups_mapping}}" DROP CONSTRAINT "{{prefix}}unique_admin_group_mapping";
DROP TABLE "{{admins_groups_mapping}}" CASCADE;
`
pgsqlV23SQL = `CREATE TABLE "{{nodes}}" ("id" serial NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE,
"data" text NOT NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL);`
pgsqlV23DownSQL = `DROP TABLE "{{nodes}}" CASCADE;`
)
// PGSQLProvider defines the auth provider for PostgreSQL database
@ -616,6 +620,26 @@ func (p *PGSQLProvider) updateTaskTimestamp(name string) error {
return sqlCommonUpdateTaskTimestamp(name, p.dbHandle)
}
func (p *PGSQLProvider) addNode() error {
return sqlCommonAddNode(p.dbHandle)
}
func (p *PGSQLProvider) getNodeByName(name string) (Node, error) {
return sqlCommonGetNodeByName(name, p.dbHandle)
}
func (p *PGSQLProvider) getNodes() ([]Node, error) {
return sqlCommonGetNodes(p.dbHandle)
}
func (p *PGSQLProvider) updateNodeTimestamp() error {
return sqlCommonUpdateNodeTimestamp(p.dbHandle)
}
func (p *PGSQLProvider) cleanupNodes() error {
return sqlCommonCleanupNodes(p.dbHandle)
}
func (p *PGSQLProvider) setFirstDownloadTimestamp(username string) error {
return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle)
}
@ -669,6 +693,8 @@ func (p *PGSQLProvider) migrateDatabase() error { //nolint:dupl
return updatePgSQLDatabaseFromV20(p.dbHandle)
case version == 21:
return updatePgSQLDatabaseFromV21(p.dbHandle)
case version == 22:
return updatePgSQLDatabaseFromV21(p.dbHandle)
default:
if version > sqlDatabaseVersion {
providerLog(logger.LevelError, "database schema version %v is newer than the supported one: %v", version,
@ -697,6 +723,8 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error {
return downgradePgSQLDatabaseFromV21(p.dbHandle)
case 22:
return downgradePgSQLDatabaseFromV22(p.dbHandle)
case 23:
return downgradePgSQLDatabaseFromV23(p.dbHandle)
default:
return fmt.Errorf("database schema version not handled: %v", dbVersion.Version)
}
@ -722,7 +750,14 @@ func updatePgSQLDatabaseFromV20(dbHandle *sql.DB) error {
}
func updatePgSQLDatabaseFromV21(dbHandle *sql.DB) error {
return updatePgSQLDatabaseFrom21To22(dbHandle)
if err := updatePgSQLDatabaseFrom21To22(dbHandle); err != nil {
return err
}
return updatePgSQLDatabaseFromV22(dbHandle)
}
func updatePgSQLDatabaseFromV22(dbHandle *sql.DB) error {
return updatePgSQLDatabaseFrom22To23(dbHandle)
}
func downgradePgSQLDatabaseFromV20(dbHandle *sql.DB) error {
@ -743,6 +778,13 @@ func downgradePgSQLDatabaseFromV22(dbHandle *sql.DB) error {
return downgradePgSQLDatabaseFromV21(dbHandle)
}
func downgradePgSQLDatabaseFromV23(dbHandle *sql.DB) error {
if err := downgradePgSQLDatabaseFrom23To22(dbHandle); err != nil {
return err
}
return downgradePgSQLDatabaseFromV22(dbHandle)
}
func updatePgSQLDatabaseFrom19To20(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 19 -> 20")
providerLog(logger.LevelInfo, "updating database schema version: 19 -> 20")
@ -772,6 +814,13 @@ func updatePgSQLDatabaseFrom21To22(dbHandle *sql.DB) error {
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 22, true)
}
func updatePgSQLDatabaseFrom22To23(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 22 -> 23")
providerLog(logger.LevelInfo, "updating database schema version: 22 -> 23")
sql := strings.ReplaceAll(pgsqlV23SQL, "{{nodes}}", sqlTableNodes)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 23, true)
}
func downgradePgSQLDatabaseFrom20To19(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 20 -> 19")
providerLog(logger.LevelInfo, "downgrading database schema version: 20 -> 19")
@ -797,3 +846,10 @@ func downgradePgSQLDatabaseFrom22To21(dbHandle *sql.DB) error {
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 21, false)
}
func downgradePgSQLDatabaseFrom23To22(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 23 -> 22")
providerLog(logger.LevelInfo, "downgrading database schema version: 23 -> 22")
sql := strings.ReplaceAll(pgsqlV23DownSQL, "{{nodes}}", sqlTableNodes)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 22, false)
}

View file

@ -46,7 +46,7 @@ func startScheduler() error {
stopScheduler()
scheduler = cron.New(cron.WithLocation(time.UTC))
_, err := scheduler.AddFunc("@every 60s", checkDataprovider)
_, err := scheduler.AddFunc("@every 55s", checkDataprovider)
if err != nil {
return fmt.Errorf("unable to schedule dataprovider availability check: %w", err)
}
@ -57,6 +57,19 @@ func startScheduler() error {
if fnReloadRules != nil {
fnReloadRules()
}
if currentNode != nil {
_, err = scheduler.AddFunc("@every 30m", func() {
err := provider.cleanupNodes()
if err != nil {
providerLog(logger.LevelError, "unable to cleanup nodes: %v", err)
} else {
providerLog(logger.LevelDebug, "cleanup nodes ok")
}
})
}
if err != nil {
return fmt.Errorf("unable to schedule nodes cleanup: %w", err)
}
scheduler.Start()
return nil
}
@ -71,6 +84,13 @@ func addScheduledCacheUpdates() error {
}
func checkDataprovider() {
if currentNode != nil {
if err := provider.updateNodeTimestamp(); err != nil {
providerLog(logger.LevelError, "unable to update node timestamp: %v", err)
} else {
providerLog(logger.LevelDebug, "node timestamp updated")
}
}
err := provider.checkAvailability()
if err != nil {
providerLog(logger.LevelError, "check availability error: %v", err)

View file

@ -34,7 +34,7 @@ import (
)
const (
sqlDatabaseVersion = 22
sqlDatabaseVersion = 23
defaultSQLQueryTimeout = 10 * time.Second
longSQLQueryTimeout = 60 * time.Second
)
@ -77,6 +77,7 @@ func sqlReplaceAll(sql string) string {
sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules)
sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
sql = strings.ReplaceAll(sql, "{{nodes}}", sqlTableNodes)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sql
}
@ -3250,6 +3251,101 @@ func sqlCommonDeleteTask(name string, dbHandle sqlQuerier) error {
return err
}
func sqlCommonAddNode(dbHandle *sql.DB) error {
if err := currentNode.validate(); err != nil {
return fmt.Errorf("unable to register cluster node: %w", err)
}
data, err := json.Marshal(currentNode.Data)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getAddNodeQuery()
_, err = dbHandle.ExecContext(ctx, q, currentNode.Name, string(data), util.GetTimeAsMsSinceEpoch(time.Now()),
util.GetTimeAsMsSinceEpoch(time.Now()))
if err != nil {
return fmt.Errorf("unable to register cluster node: %w", err)
}
providerLog(logger.LevelInfo, "registered as cluster node %q, port: %d, proto: %s",
currentNode.Name, currentNode.Data.Port, currentNode.Data.Proto)
return nil
}
func sqlCommonGetNodeByName(name string, dbHandle *sql.DB) (Node, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
var data []byte
var node Node
q := getNodeByNameQuery()
row := dbHandle.QueryRowContext(ctx, q, name, util.GetTimeAsMsSinceEpoch(time.Now().Add(activeNodeTimeDiff)))
err := row.Scan(&node.Name, &data, &node.CreatedAt, &node.UpdatedAt)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return node, util.NewRecordNotFoundError(err.Error())
}
return node, err
}
err = json.Unmarshal(data, &node.Data)
return node, err
}
func sqlCommonGetNodes(dbHandle *sql.DB) ([]Node, error) {
var nodes []Node
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getNodesQuery()
rows, err := dbHandle.QueryContext(ctx, q, currentNode.Name,
util.GetTimeAsMsSinceEpoch(time.Now().Add(activeNodeTimeDiff)))
if err != nil {
return nodes, err
}
defer rows.Close()
for rows.Next() {
var node Node
var data []byte
err = rows.Scan(&node.Name, &data, &node.CreatedAt, &node.UpdatedAt)
if err != nil {
return nodes, err
}
err = json.Unmarshal(data, &node.Data)
if err != nil {
return nodes, err
}
nodes = append(nodes, node)
}
return nodes, rows.Err()
}
func sqlCommonUpdateNodeTimestamp(dbHandle *sql.DB) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getUpdateNodeTimestampQuery()
res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), currentNode.Name)
if err != nil {
return err
}
return sqlCommonRequireRowAffected(res)
}
func sqlCommonCleanupNodes(dbHandle *sql.DB) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getCleanupNodesQuery()
_, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now().Add(10*activeNodeTimeDiff)))
return err
}
func sqlCommonGetDatabaseVersion(dbHandle sqlQuerier, showInitWarn bool) (schemaVersion, error) {
var result schemaVersion
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)

View file

@ -56,6 +56,7 @@ DROP TABLE IF EXISTS "{{rules_actions_mapping}}";
DROP TABLE IF EXISTS "{{events_rules}}";
DROP TABLE IF EXISTS "{{events_actions}}";
DROP TABLE IF EXISTS "{{tasks}}";
DROP TABLE IF EXISTS "{{nodes}}";
DROP TABLE IF EXISTS "{{schema_version}}";
`
sqliteInitialSQL = `CREATE TABLE "{{schema_version}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "version" integer NOT NULL);
@ -176,8 +177,11 @@ ALTER TABLE "{{users}}" DROP COLUMN "first_download";
CREATE INDEX "{{prefix}}admins_groups_mapping_admin_id_idx" ON "{{admins_groups_mapping}}" ("admin_id");
CREATE INDEX "{{prefix}}admins_groups_mapping_group_id_idx" ON "{{admins_groups_mapping}}" ("group_id");
`
sqliteV22DownSQL = `DROP TABLE "{{admins_groups_mapping}}";
`
sqliteV22DownSQL = `DROP TABLE "{{admins_groups_mapping}}";`
sqliteV23SQL = `CREATE TABLE "{{nodes}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"name" varchar(255) NOT NULL UNIQUE, "data" text NOT NULL, "created_at" bigint NOT NULL,
"updated_at" bigint NOT NULL);`
sqliteV23DownSQL = `DROP TABLE "{{nodes}}";`
)
// SQLiteProvider defines the auth provider for SQLite database
@ -579,6 +583,26 @@ func (p *SQLiteProvider) updateTaskTimestamp(name string) error {
return sqlCommonUpdateTaskTimestamp(name, p.dbHandle)
}
func (p *SQLiteProvider) addNode() error {
return sqlCommonAddNode(p.dbHandle)
}
func (p *SQLiteProvider) getNodeByName(name string) (Node, error) {
return sqlCommonGetNodeByName(name, p.dbHandle)
}
func (p *SQLiteProvider) getNodes() ([]Node, error) {
return sqlCommonGetNodes(p.dbHandle)
}
func (p *SQLiteProvider) updateNodeTimestamp() error {
return sqlCommonUpdateNodeTimestamp(p.dbHandle)
}
func (p *SQLiteProvider) cleanupNodes() error {
return sqlCommonCleanupNodes(p.dbHandle)
}
func (p *SQLiteProvider) setFirstDownloadTimestamp(username string) error {
return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle)
}
@ -632,6 +656,8 @@ func (p *SQLiteProvider) migrateDatabase() error { //nolint:dupl
return updateSQLiteDatabaseFromV20(p.dbHandle)
case version == 21:
return updateSQLiteDatabaseFromV21(p.dbHandle)
case version == 22:
return updateSQLiteDatabaseFromV22(p.dbHandle)
default:
if version > sqlDatabaseVersion {
providerLog(logger.LevelError, "database schema version %v is newer than the supported one: %v", version,
@ -660,6 +686,8 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error {
return downgradeSQLiteDatabaseFromV21(p.dbHandle)
case 22:
return downgradeSQLiteDatabaseFromV22(p.dbHandle)
case 23:
return downgradeSQLiteDatabaseFromV23(p.dbHandle)
default:
return fmt.Errorf("database schema version not handled: %v", dbVersion.Version)
}
@ -685,7 +713,14 @@ func updateSQLiteDatabaseFromV20(dbHandle *sql.DB) error {
}
func updateSQLiteDatabaseFromV21(dbHandle *sql.DB) error {
return updateSQLiteDatabaseFrom21To22(dbHandle)
if err := updateSQLiteDatabaseFrom21To22(dbHandle); err != nil {
return err
}
return updateSQLiteDatabaseFromV22(dbHandle)
}
func updateSQLiteDatabaseFromV22(dbHandle *sql.DB) error {
return updateSQLiteDatabaseFrom22To23(dbHandle)
}
func downgradeSQLiteDatabaseFromV20(dbHandle *sql.DB) error {
@ -706,6 +741,13 @@ func downgradeSQLiteDatabaseFromV22(dbHandle *sql.DB) error {
return downgradeSQLiteDatabaseFromV21(dbHandle)
}
func downgradeSQLiteDatabaseFromV23(dbHandle *sql.DB) error {
if err := downgradeSQLiteDatabaseFrom23To22(dbHandle); err != nil {
return err
}
return downgradeSQLiteDatabaseFromV22(dbHandle)
}
func updateSQLiteDatabaseFrom19To20(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 19 -> 20")
providerLog(logger.LevelInfo, "updating database schema version: 19 -> 20")
@ -735,6 +777,13 @@ func updateSQLiteDatabaseFrom21To22(dbHandle *sql.DB) error {
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 22, true)
}
func updateSQLiteDatabaseFrom22To23(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 22 -> 23")
providerLog(logger.LevelInfo, "updating database schema version: 22 -> 23")
sql := strings.ReplaceAll(sqliteV23SQL, "{{nodes}}", sqlTableNodes)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 23, true)
}
func downgradeSQLiteDatabaseFrom20To19(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 20 -> 19")
providerLog(logger.LevelInfo, "downgrading database schema version: 20 -> 19")
@ -761,6 +810,13 @@ func downgradeSQLiteDatabaseFrom22To21(dbHandle *sql.DB) error {
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 21, false)
}
func downgradeSQLiteDatabaseFrom23To22(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 23 -> 22")
providerLog(logger.LevelInfo, "downgrading database schema version: 23 -> 22")
sql := strings.ReplaceAll(sqliteV23DownSQL, "{{nodes}}", sqlTableNodes)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 22, false)
}
/*func setPragmaFK(dbHandle *sql.DB, value string) error {
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
defer cancel()

View file

@ -950,6 +950,36 @@ func getDeleteTaskQuery() string {
return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, sqlTableTasks, sqlPlaceholders[0])
}
func getAddNodeQuery() string {
if config.Driver == MySQLDataProviderName {
return fmt.Sprintf("INSERT INTO %s (`name`,`data`,created_at,`updated_at`) VALUES (%s,%s,%s,%s) ON DUPLICATE KEY UPDATE "+
"`data`=VALUES(`data`), `created_at`=VALUES(`created_at`), `updated_at`=VALUES(`updated_at`)",
sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
return fmt.Sprintf(`INSERT INTO %s (name,data,created_at,updated_at) VALUES (%s,%s,%s,%s) ON CONFLICT(name)
DO UPDATE SET data=EXCLUDED.data, created_at=EXCLUDED.created_at, updated_at=EXCLUDED.updated_at`,
sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
func getUpdateNodeTimestampQuery() string {
return fmt.Sprintf(`UPDATE %s SET updated_at=%s WHERE name = %s`,
sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getNodeByNameQuery() string {
return fmt.Sprintf(`SELECT name,data,created_at,updated_at FROM %s WHERE name = %s AND updated_at > %s`,
sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getNodesQuery() string {
return fmt.Sprintf(`SELECT name,data,created_at,updated_at FROM %s WHERE name != %s AND updated_at > %s`,
sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getCleanupNodesQuery() string {
return fmt.Sprintf(`DELETE FROM %s WHERE updated_at < %s`, sqlTableNodes, sqlPlaceholders[0])
}
func getDatabaseVersionQuery() string {
return fmt.Sprintf("SELECT version from %s LIMIT 1", sqlTableSchemaVersion)
}

View file

@ -28,6 +28,7 @@ import (
"path"
"strconv"
"strings"
"sync"
"time"
"github.com/go-chi/chi/v5"
@ -152,6 +153,20 @@ func getBoolQueryParam(r *http.Request, param string) bool {
return r.URL.Query().Get(param) == "true"
}
func getActiveConnections(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
}
stats := common.Connections.GetStats()
if claims.NodeID == "" {
stats = append(stats, getNodesConnections()...)
}
render.JSON(w, r, stats)
}
func handleCloseConnection(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
connectionID := getURLParam(r, "connectionID")
@ -159,11 +174,61 @@ func handleCloseConnection(w http.ResponseWriter, r *http.Request) {
sendAPIResponse(w, r, nil, "connectionID is mandatory", http.StatusBadRequest)
return
}
node := r.URL.Query().Get("node")
if node == "" || node == dataprovider.GetNodeName() {
if common.Connections.Close(connectionID) {
sendAPIResponse(w, r, nil, "Connection closed", http.StatusOK)
} else {
sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound)
}
return
}
n, err := dataprovider.GetNodeByName(node)
if err != nil {
logger.Warn(logSender, "", "unable to get node with name %q: %v", node, err)
status := getRespStatus(err)
sendAPIResponse(w, r, nil, http.StatusText(status), status)
return
}
if err := n.SendDeleteRequest(fmt.Sprintf("%s/%s", activeConnectionsPath, connectionID)); err != nil {
logger.Warn(logSender, "", "unable to delete connection id %q from node %q: %v", connectionID, n.Name, err)
sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound)
return
}
sendAPIResponse(w, r, nil, "Connection closed", http.StatusOK)
}
// getNodesConnections returns the active connections from other nodes.
// Errors are silently ignored
func getNodesConnections() []common.ConnectionStatus {
nodes, err := dataprovider.GetNodes()
if err != nil || len(nodes) == 0 {
return nil
}
var results []common.ConnectionStatus
var mu sync.Mutex
var wg sync.WaitGroup
for _, n := range nodes {
wg.Add(1)
go func(node dataprovider.Node) {
defer wg.Done()
var stats []common.ConnectionStatus
if err := node.SendGetRequest(activeConnectionsPath, &stats); err != nil {
logger.Warn(logSender, "", "unable to get connections from node %s: %v", node.Name, err)
return
}
mu.Lock()
results = append(results, stats...)
mu.Unlock()
}(n)
}
wg.Wait()
return results
}
func getSearchFilters(w http.ResponseWriter, r *http.Request) (int, int, string, error) {

View file

@ -52,6 +52,7 @@ const (
claimUsernameKey = "username"
claimPermissionsKey = "permissions"
claimAPIKey = "api_key"
claimNodeID = "node_id"
claimMustSetSecondFactorKey = "2fa_required"
claimRequiredTwoFactorProtocols = "2fa_protos"
claimHideUserPageSection = "hus"
@ -74,6 +75,7 @@ type jwtTokenClaims struct {
Signature string
Audience []string
APIKeyID string
NodeID string
MustSetTwoFactorAuth bool
RequiredTwoFactorProtocols []string
HideUserPageSections int
@ -97,6 +99,9 @@ func (c *jwtTokenClaims) asMap() map[string]any {
if c.APIKeyID != "" {
claims[claimAPIKey] = c.APIKeyID
}
if c.NodeID != "" {
claims[claimNodeID] = c.NodeID
}
claims[jwt.SubjectKey] = c.Signature
if c.MustSetTwoFactorAuth {
claims[claimMustSetSecondFactorKey] = c.MustSetTwoFactorAuth
@ -157,6 +162,13 @@ func (c *jwtTokenClaims) Decode(token map[string]any) {
}
}
if val, ok := token[claimNodeID]; ok {
switch v := val.(type) {
case string:
c.NodeID = v
}
}
permissions := token[claimPermissionsKey]
c.Permissions = c.decodeSliceString(permissions)

View file

@ -10342,6 +10342,15 @@ func TestDeleteActiveConnectionMock(t *testing.T) {
setBearerForReq(req, token)
rr := executeRequest(req)
checkResponseCode(t, http.StatusNotFound, rr)
req.Header.Set(dataprovider.NodeTokenHeader, "Bearer abc")
rr = executeRequest(req)
checkResponseCode(t, http.StatusUnauthorized, rr)
assert.Contains(t, rr.Body.String(), "the provided token cannot be authenticated")
req, err = http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID?node=node1", nil)
assert.NoError(t, err)
setBearerForReq(req, token)
rr = executeRequest(req)
checkResponseCode(t, http.StatusNotFound, rr)
}
func TestNotFoundMock(t *testing.T) {

View file

@ -551,6 +551,11 @@ func TestInvalidToken(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, rr.Code)
assert.Contains(t, rr.Body.String(), "Invalid token claims")
rr = httptest.NewRecorder()
getActiveConnections(rr, req)
assert.Equal(t, http.StatusBadRequest, rr.Code)
assert.Contains(t, rr.Body.String(), "Invalid token claims")
rr = httptest.NewRecorder()
server.handleWebRestore(rr, req)
assert.Equal(t, http.StatusBadRequest, rr.Code)

View file

@ -309,6 +309,41 @@ func verifyCSRFHeader(next http.Handler) http.Handler {
})
}
func checkNodeToken(tokenAuth *jwtauth.JWTAuth) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get(dataprovider.NodeTokenHeader)
if token == "" {
next.ServeHTTP(w, r)
return
}
if len(token) > 7 && strings.ToUpper(token[0:6]) == "BEARER" {
token = token[7:]
}
if err := dataprovider.AuthenticateNodeToken(token); err != nil {
logger.Debug(logSender, "", "unable to authenticate node token %q: %v", token, err)
sendAPIResponse(w, r, fmt.Errorf("the provided token cannot be authenticated"), "", http.StatusUnauthorized)
return
}
c := jwtTokenClaims{
Username: fmt.Sprintf("node %s", dataprovider.GetNodeName()),
Permissions: []string{dataprovider.PermAdminViewConnections, dataprovider.PermAdminCloseConnections},
NodeID: dataprovider.GetNodeName(),
}
resp, err := c.createTokenResponse(tokenAuth, tokenAudienceAPI, util.GetIPFromRemoteAddress(r.RemoteAddr))
if err != nil {
sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", resp["access_token"]))
next.ServeHTTP(w, r)
})
}
}
func checkAPIKeyAuth(tokenAuth *jwtauth.JWTAuth, scope dataprovider.APIKeyScope) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -337,7 +372,7 @@ func checkAPIKeyAuth(tokenAuth *jwtauth.JWTAuth, scope dataprovider.APIKeyScope)
return
}
if err := k.Authenticate(key); err != nil {
logger.Debug(logSender, "unable to authenticate api key %#v: %v", apiKey, err)
logger.Debug(logSender, "", "unable to authenticate api key %#v: %v", apiKey, err)
sendAPIResponse(w, r, fmt.Errorf("the provided api key cannot be authenticated"), "", http.StatusUnauthorized)
return
}

View file

@ -1195,6 +1195,7 @@ func (s *httpdServer) initializeRouter() {
s.router.Post(userPath+"/{username}/reset-password", resetUserPassword)
s.router.Group(func(router chi.Router) {
router.Use(checkNodeToken(s.tokenAuth))
router.Use(checkAPIKeyAuth(s.tokenAuth, dataprovider.APIKeyScopeAdmin))
router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromHeader))
router.Use(jwtAuthenticatorAPI)
@ -1222,12 +1223,7 @@ func (s *httpdServer) initializeRouter() {
render.JSON(w, r, getServicesStatus())
})
router.With(s.checkPerm(dataprovider.PermAdminViewConnections)).
Get(activeConnectionsPath, func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
render.JSON(w, r, common.Connections.GetStats())
})
router.With(s.checkPerm(dataprovider.PermAdminViewConnections)).Get(activeConnectionsPath, getActiveConnections)
router.With(s.checkPerm(dataprovider.PermAdminCloseConnections)).
Delete(activeConnectionsPath+"/{connectionID}", handleCloseConnection)
router.With(s.checkPerm(dataprovider.PermAdminQuotaScans)).Get(quotasBasePath+"/users/scans", getUsersQuotaScans)

View file

@ -2845,6 +2845,7 @@ func (s *httpdServer) handleWebGetStatus(w http.ResponseWriter, r *http.Request)
func (s *httpdServer) handleWebGetConnections(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
connectionStats := common.Connections.GetStats()
connectionStats = append(connectionStats, getNodesConnections()...)
data := connectionsPage{
basePage: s.getBasePageData(pageConnectionsTitle, webConnectionsPath, r),
Connections: connectionStats,

View file

@ -5459,6 +5459,9 @@ components:
type: array
items:
$ref: '#/components/schemas/Transfer'
node:
type: string
description: 'Node identifier, omitted for single node installations'
FolderRetention:
type: object
properties:

View file

@ -236,6 +236,11 @@
"create_default_admin": false,
"naming_rules": 1,
"is_shared": 0,
"node": {
"host": "",
"port": 0,
"proto": "http"
},
"backups_path": "backups"
},
"httpd": {

View file

@ -40,6 +40,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
<thead>
<tr>
<th>ID</th>
<th>Node</th>
<th>Username</th>
<th>Time</th>
<th>Info</th>
@ -50,6 +51,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
{{range .Connections}}
<tr>
<td>{{.ConnectionID}}</td>
<td>{{.Node}}</td>
<td>{{.Username}}</td>
<td>{{.GetConnectionDuration}}</td>
<td>{{.GetConnectionInfo}}</td>
@ -105,8 +107,10 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
function disconnectAction() {
var table = $('#dataTable').DataTable();
table.button('disconnect:name').enable(false);
var connectionID = table.row({ selected: true }).data()[0];
var path = '{{.ConnectionsURL}}' + "/" + connectionID;
var selectedData = table.row({ selected: true }).data()
var connectionID = selectedData[0];
var nodeID = selectedData[1];
var path = '{{.ConnectionsURL}}' + "/" + fixedEncodeURIComponent(connectionID)+"?node="+encodeURIComponent(nodeID);
$('#disconnectModal').modal('hide');
$.ajax({
url: path,
@ -174,13 +178,13 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
"lengthChange": true,
"columnDefs": [
{
"targets": [0],
"targets": [0, 1],
"visible": false,
"searchable": false,
"className": "noVis"
},
{
"targets": [1],
"targets": [2],
"className": "noVis"
}
],
@ -190,7 +194,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
"language": {
"emptyTable": "No user connected"
},
"order": [[1, 'asc']]
"order": [[2, 'asc']]
});
new $.fn.dataTable.FixedHeader( table );