Compare commits

..

68 commits
2.6.x ... main

Author SHA1 Message Date
Nicola Murino
00155eaaf6
update deps
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-07-26 19:04:08 +02:00
Nicola Murino
d94f80c8da
replace utils.Contains with slices.Contains
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-07-24 18:27:13 +02:00
Nicola Murino
bd5eb03d9c
replace hand-written slice utilities with methods from slices package
SFTPGo depends on Go 1.22 so we can use slices package

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-07-24 18:17:55 +02:00
Nicola Murino
6ba1198c47
sftpd: remove unused folder prefix from Connection struct
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-07-24 16:44:25 +02:00
Nicola Murino
b5c821795a
allow to customize name and log from the WebUI
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-07-24 09:14:27 +02:00
Nicola Murino
b2926377b7
WebUI: switch favicon from ico to png
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-07-20 16:11:21 +02:00
Nicola Murino
99f47ca4e7
sftpfs: cache and reuse parsed private keys
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-07-16 19:20:28 +02:00
Nicola Murino
fef388d8cb
don't track quota for private virtual folders
they are included within the user quota.
This is a backward incompatible change.

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-07-13 21:02:40 +02:00
Nicola Murino
92849ca473
quota: move user and folder management to a common method
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-07-13 19:30:40 +02:00
Nicola Murino
0952887157
update deps
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-07-13 17:44:13 +02:00
Nicola Murino
d010b26e1c
deb packages: update copyright year
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-07-13 14:06:11 +02:00
Nicola Murino
58de410850
nt: fix unused write warnings
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-07-03 20:42:51 +02:00
Nicola Murino
54bc3ea87d
restore: fix quota scan for users with folders associated via groups
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-07-03 20:35:12 +02:00
Nicola Murino
64a2f7aa4f
oidc refresh token: validate nonce only if set
As clarified in OpenID core spec errata 2, section 12.2

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-07-01 19:06:11 +02:00
Nicola Murino
55be9f0b9c
EventManager: allow to configure the timezone to use for the scheduler
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-30 18:52:59 +02:00
Nicola Murino
97ffa0394f
update deps
adapt smtp configuration to changes in upstream library

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-30 09:18:04 +02:00
dependabot[bot]
dc91ec2056
Bump docker/build-push-action from 5 to 6 (#1668)
Bumps [docker/build-push-action](https://github.com/docker/build-push-action) from 5 to 6.
- [Release notes](https://github.com/docker/build-push-action/releases)
- [Commits](https://github.com/docker/build-push-action/compare/v5...v6)

---
updated-dependencies:
- dependency-name: docker/build-push-action
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-06-24 09:29:56 +02:00
Nicola Murino
356795f8b0
add a test case for listing files with long names
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-22 19:23:02 +02:00
Nicola Murino
3efcd94e14
back to development
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-21 20:25:35 +02:00
Nicola Murino
34bc21b3b7
update deps
fixes a bug in chi compressor Handler

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-21 18:32:59 +02:00
Nicola Murino
37845c2936
smtp: hide commit hash in user agent
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-21 18:31:42 +02:00
Nicola Murino
47924716c1
back to development
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-19 21:04:35 +02:00
Nicola Murino
1d60505629
fix test case failure on macOS with bolt provider
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-19 10:45:14 +02:00
Nicola Murino
9daf0ba767
update swagger UI to 5.17.14
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-18 20:47:39 +02:00
Nicola Murino
bdae378569
update deps
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-18 19:13:21 +02:00
Nicola Murino
363770ab84
WebClient shares: add a logout button
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-18 19:10:32 +02:00
Nicola Murino
8bc08b25dc
sftp: limit max file list
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-17 19:24:03 +02:00
Nicola Murino
e0c1b974c9
add cgo to build constraints
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-16 09:46:17 +02:00
Nicola Murino
39cf9f6943
Web UIs: remove duplication of supported languages
Fixes #1660

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-15 19:35:23 +02:00
Nicola Murino
d650defa08
remove duplicated jwt tokens validation
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-15 16:19:37 +02:00
Nicola Murino
c5c42f072b
squash database migrations
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-15 16:02:09 +02:00
Nicola Murino
bd5b32101f
csrf: reuse the cookie in reset password
no need to generate a new cookie each time.

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-15 15:18:17 +02:00
Nicola Murino
8208ac817d
html pages: add robots meta tag
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-15 10:17:37 +02:00
Nicola Murino
a99c4879de
update dependencies
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-15 10:10:15 +02:00
Nicola Murino
01b666a78f
WebUIs: check login conditions before allowing password reset
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-14 19:34:42 +02:00
Nicola Murino
8294952474
WebUIs: refactor CSRF
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-14 18:09:32 +02:00
Nicola Murino
7fb5b1b996
reduce share token duration
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-08 12:13:38 +02:00
Nicola Murino
2749a98f26
CI: update workflow to 1.22.4
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-07 18:19:52 +02:00
Nicola Murino
08526da153
REST API: fix token invalidation after password change
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-07 18:19:05 +02:00
Nicola Murino
8269adf176
Windows: allow to override most of the "serve" flags from env files
The Windows specific code path was missing in 07710ad98

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-05 17:34:28 +02:00
Nicola Murino
0cddcba5a7
EventManager: add an action to rotate the log file
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-04 19:51:52 +02:00
Nicola Murino
3bd1eeacc1
make sure to return a fully populated user after plugin auth
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-06-04 18:14:09 +02:00
Nicola Murino
1698ec2eb3
EventManager: fix adding ObjectDataString for provider events
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-31 20:01:38 +02:00
Nicola Murino
07710ad98d
allow to override most of the "serve" flags from env files
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-31 18:49:23 +02:00
Nicola Murino
f63bf7093c
logs: redact plugin arguments
may contain sensitive data

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-30 18:10:12 +02:00
Nicola Murino
0597bf1047
Windows setup: update MinVersion
Starting from Go version 1.21, Windows 10 or Windows Server 2016 are
required

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-30 18:07:17 +02:00
Nicola Murino
5bde4b92a2
fix test cases
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-29 19:35:42 +02:00
Nicola Murino
faa994e3b3
update UI theme and dependencies
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-29 19:20:56 +02:00
Nicola Murino
68cc1a8e2c
fix proxy protocol policy
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-28 19:40:37 +02:00
Nicola Murino
9c775e2213
transfer logs: add error field
Fixes #1638

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-27 19:35:48 +02:00
Nicola Murino
6c94173ca1
WebUI branding: remove unused login_image_path from config
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-27 18:43:44 +02:00
Nicola Murino
d1e0560d28
WebAdmin status page: update the color of the labels
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-26 19:34:29 +02:00
Nicola Murino
52a94b2593
docker: build Alpine based image using golang:1.22-alpine3.20
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-25 16:53:13 +02:00
dependabot[bot]
9550fd2921
Bump alpine from 3.19 to 3.20 (#1636)
Bumps alpine from 3.19 to 3.20.

---
updated-dependencies:
- dependency-name: alpine
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-05-25 16:51:48 +02:00
Nicola Murino
a6549b08f9
dependabot: remove gomod
it is not really required, we update Go dependencies regularly

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-25 16:40:31 +02:00
Nicola Murino
ba3e2ecb5f
WebAdmin events page: fix rendering of some nullable strings
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-25 16:17:33 +02:00
Nicola Murino
2bd3b46e3f
update swagger ui
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-25 16:14:42 +02:00
Nicola Murino
7831ddaede
WebAdmin events page: set fixed sizes for potentially long fields
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-24 18:24:05 +02:00
Nicola Murino
613f2f1c24
WebUIs: set the lang attribute based on the chosen language
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-24 18:23:41 +02:00
Nicola Murino
525f33a07a
WebUIs: fix css loading order
Fixes #1628

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-24 18:22:58 +02:00
Nicola Murino
3f2604d33f
ssh: use 3072-bits for the auto-generated RSA key
This is the same as ssh-keygen

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-24 18:22:36 +02:00
Nicola Murino
b823bb04d2
WebAdmin: make the description visible in IP lists page
Fixes #1631

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-23 20:07:49 +02:00
Nicola Murino
9ba92d9495
WebUIs: fix datatables processing class name
was changed to dt-processing in datatables 2.0

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-23 19:47:45 +02:00
Nicola Murino
0127fc188b
SSH: allow to configure minimum key size for DHGEX
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-23 18:08:16 +02:00
Nicola Murino
3c7a651d27
plugin: don't consider file extension for env prefix
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-18 13:10:16 +02:00
Nicola Murino
50a3c0d911
defender: allow to impose a delay between login attempts
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-18 10:35:54 +02:00
Nicola Murino
b2bea85add
update README
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-16 10:40:48 +02:00
Nicola Murino
61bc0065f9
back to development
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-05-16 04:54:46 +02:00
99 changed files with 2994 additions and 1360 deletions

View file

@ -2,7 +2,7 @@ name: CI
on:
push:
branches: [2.6.x]
branches: [main]
pull_request:
jobs:

View file

@ -5,7 +5,7 @@ on:
# - cron: '0 4 * * *' # everyday at 4:00 AM UTC
push:
branches:
- 2.6.x
- main
tags:
- v*
pull_request:
@ -163,7 +163,7 @@ jobs:
if: ${{ github.event_name != 'pull_request' }}
- name: Build and push
uses: docker/build-push-action@v5
uses: docker/build-push-action@v6
with:
context: .
builder: ${{ steps.builder.outputs.name }}

View file

@ -1,4 +1,4 @@
FROM golang:1.22-alpine3.19 AS builder
FROM golang:1.22-alpine3.20 AS builder
ENV GOFLAGS="-mod=readonly"
@ -25,7 +25,7 @@ RUN set -xe && \
export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --abbrev=8 --dirty)} && \
go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -v -o sftpgo
FROM alpine:3.19
FROM alpine:3.20
# Set to "true" to install jq and the optional git and rsync dependencies
ARG INSTALL_OPTIONAL_PACKAGES=false

View file

@ -30,6 +30,7 @@ import (
"net/url"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"time"
@ -249,7 +250,7 @@ func (c *Configuration) Initialize(configDir string) error {
if c.RenewDays < 1 {
return fmt.Errorf("invalid number of days remaining before renewal: %d", c.RenewDays)
}
if !util.Contains(supportedKeyTypes, c.KeyType) {
if !slices.Contains(supportedKeyTypes, c.KeyType) {
return fmt.Errorf("invalid key type %q", c.KeyType)
}
caURL, err := url.Parse(c.CAEndpoint)

View file

@ -40,8 +40,8 @@ Please take a look at the usage below to customize the options.`,
Run: func(_ *cobra.Command, _ []string) {
logger.DisableLogger()
logger.EnableConsoleLogger(zerolog.DebugLevel)
if revertProviderTargetVersion != 28 {
logger.WarnToConsole("Unsupported target version, 28 is the only supported one")
if revertProviderTargetVersion != 29 {
logger.WarnToConsole("Unsupported target version, 29 is the only supported one")
os.Exit(1)
}
configDir = util.CleanDirInput(configDir)
@ -71,7 +71,7 @@ Please take a look at the usage below to customize the options.`,
func init() {
addConfigFlags(revertProviderCmd)
revertProviderCmd.Flags().IntVar(&revertProviderTargetVersion, "to-version", 28, `28 means the version supported in v2.5.x`)
revertProviderCmd.Flags().IntVar(&revertProviderTargetVersion, "to-version", 29, `29 means the version supported in v2.6.x`)
rootCmd.AddCommand(revertProviderCmd)
}

View file

@ -17,10 +17,9 @@ package command
import (
"fmt"
"slices"
"strings"
"time"
"github.com/drakkan/sftpgo/v2/internal/util"
)
const (
@ -117,7 +116,7 @@ func (c Config) Initialize() error {
}
// don't validate args, we allow to pass empty arguments
if cmd.Hook != "" {
if !util.Contains(supportedHooks, cmd.Hook) {
if !slices.Contains(supportedHooks, cmd.Hook) {
return fmt.Errorf("invalid hook name %q, supported values: %+v", cmd.Hook, supportedHooks)
}
}

View file

@ -25,6 +25,7 @@ import (
"os/exec"
"path"
"path/filepath"
"slices"
"strings"
"sync/atomic"
"time"
@ -86,7 +87,7 @@ func InitializeActionHandler(handler ActionHandler) {
func ExecutePreAction(conn *BaseConnection, operation, filePath, virtualPath string, fileSize int64, openFlags int) (int, error) {
var event *notifier.FsEvent
hasNotifiersPlugin := plugin.Handler.HasNotifiers()
hasHook := util.Contains(Config.Actions.ExecuteOn, operation)
hasHook := slices.Contains(Config.Actions.ExecuteOn, operation)
hasRules := eventManager.hasFsRules()
if !hasHook && !hasNotifiersPlugin && !hasRules {
return 0, nil
@ -132,7 +133,7 @@ func ExecuteActionNotification(conn *BaseConnection, operation, filePath, virtua
fileSize int64, err error, elapsed int64, metadata map[string]string,
) error {
hasNotifiersPlugin := plugin.Handler.HasNotifiers()
hasHook := util.Contains(Config.Actions.ExecuteOn, operation)
hasHook := slices.Contains(Config.Actions.ExecuteOn, operation)
hasRules := eventManager.hasFsRules()
if !hasHook && !hasNotifiersPlugin && !hasRules {
return nil
@ -173,7 +174,7 @@ func ExecuteActionNotification(conn *BaseConnection, operation, filePath, virtua
}
}
if hasHook {
if util.Contains(Config.Actions.ExecuteSync, operation) {
if slices.Contains(Config.Actions.ExecuteSync, operation) {
_, err := actionHandler.Handle(notification)
return err
}
@ -247,7 +248,7 @@ func newActionNotification(
type defaultActionHandler struct{}
func (h *defaultActionHandler) Handle(event *notifier.FsEvent) (int, error) {
if !util.Contains(Config.Actions.ExecuteOn, event.Action) {
if !slices.Contains(Config.Actions.ExecuteOn, event.Action) {
return 0, nil
}

View file

@ -25,6 +25,7 @@ import (
"os"
"os/exec"
"path/filepath"
"slices"
"strconv"
"strings"
"sync"
@ -163,13 +164,20 @@ var (
rateLimiters map[string][]*rateLimiter
isShuttingDown atomic.Bool
ftpLoginCommands = []string{"PASS", "USER"}
fnUpdateBranding func(*dataprovider.BrandingConfigs)
)
// SetUpdateBrandingFn sets the function to call to update branding configs.
func SetUpdateBrandingFn(fn func(*dataprovider.BrandingConfigs)) {
fnUpdateBranding = fn
}
// Initialize sets the common configuration
func Initialize(c Configuration, isShared int) error {
isShuttingDown.Store(false)
util.SetUmask(c.Umask)
version.SetConfig(c.ServerVersion)
dataprovider.SetTZ(c.TZ)
Config = c
Config.Actions.ExecuteOn = util.RemoveDuplicates(Config.Actions.ExecuteOn, true)
Config.Actions.ExecuteSync = util.RemoveDuplicates(Config.Actions.ExecuteSync, true)
@ -200,7 +208,7 @@ func Initialize(c Configuration, isShared int) error {
Config.rateLimitersList = rateLimitersList
}
if c.DefenderConfig.Enabled {
if !util.Contains(supportedDefenderDrivers, c.DefenderConfig.Driver) {
if !slices.Contains(supportedDefenderDrivers, c.DefenderConfig.Driver) {
return fmt.Errorf("unsupported defender driver %q", c.DefenderConfig.Driver)
}
var defender Defender
@ -402,6 +410,23 @@ func AddDefenderEvent(ip, protocol string, event HostEvent) bool {
return Config.defender.AddEvent(ip, protocol, event)
}
func reloadProviderConfigs() {
configs, err := dataprovider.GetConfigs()
if err != nil {
logger.Error(logSender, "", "unable to load config from provider: %v", err)
return
}
configs.SetNilsToEmpty()
if fnUpdateBranding != nil {
fnUpdateBranding(configs.Branding)
}
if err := configs.SMTP.TryDecrypt(); err != nil {
logger.Error(logSender, "", "unable to decrypt smtp config: %v", err)
return
}
smtp.Activate(configs.SMTP)
}
func startPeriodicChecks(duration time.Duration, isShared int) {
startEventScheduler()
spec := fmt.Sprintf("@every %s", duration)
@ -410,7 +435,7 @@ func startPeriodicChecks(duration time.Duration, isShared int) {
logger.Info(logSender, "", "scheduled overquota transfers check, schedule %q", spec)
if isShared == 1 {
logger.Info(logSender, "", "add reload configs task")
_, err := eventScheduler.AddFunc("@every 10m", smtp.ReloadProviderConf)
_, err := eventScheduler.AddFunc("@every 10m", reloadProviderConfigs)
util.PanicOnError(err)
}
if Config.IdleTimeout > 0 {
@ -588,6 +613,10 @@ type Configuration struct {
Umask string `json:"umask" mapstructure:"umask"`
// Defines the server version
ServerVersion string `json:"server_version" mapstructure:"server_version"`
// TZ defines the time zone to use for the EventManager scheduler and to
// control time-based access restrictions. Set to "local" to use the
// server's local time, otherwise UTC will be used.
TZ string `json:"tz" mapstructure:"tz"`
// Metadata configuration
Metadata MetadataConfig `json:"metadata" mapstructure:"metadata"`
idleTimeoutAsDuration time.Duration
@ -749,7 +778,7 @@ func (c *Configuration) checkPostDisconnectHook(remoteAddr, protocol, username,
if c.PostDisconnectHook == "" {
return
}
if !util.Contains(disconnHookProtocols, protocol) {
if !slices.Contains(disconnHookProtocols, protocol) {
return
}
go c.executePostDisconnectHook(remoteAddr, protocol, username, connID, connectionTime)
@ -991,7 +1020,7 @@ func (conns *ActiveConnections) Remove(connectionID string) {
metric.UpdateActiveConnectionsSize(lastIdx)
logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %q, remote address %q close fs error: %v, num open connections: %d",
conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx)
if conn.GetProtocol() == ProtocolFTP && conn.GetUsername() == "" && !util.Contains(ftpLoginCommands, conn.GetCommand()) {
if conn.GetProtocol() == ProtocolFTP && conn.GetUsername() == "" && !slices.Contains(ftpLoginCommands, conn.GetCommand()) {
ip := util.GetIPFromRemoteAddress(conn.GetRemoteAddress())
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTried, ProtocolFTP,
dataprovider.ErrNoAuthTried.Error())

View file

@ -23,6 +23,7 @@ import (
"os/exec"
"path/filepath"
"runtime"
"slices"
"sync"
"testing"
"time"
@ -1226,8 +1227,8 @@ func TestFolderCopy(t *testing.T) {
folder.ID = 2
folder.Users = []string{"user3"}
require.Len(t, folderCopy.Users, 2)
require.True(t, util.Contains(folderCopy.Users, "user1"))
require.True(t, util.Contains(folderCopy.Users, "user2"))
require.True(t, slices.Contains(folderCopy.Users, "user1"))
require.True(t, slices.Contains(folderCopy.Users, "user2"))
require.Equal(t, int64(1), folderCopy.ID)
require.Equal(t, folder.Name, folderCopy.Name)
require.Equal(t, folder.MappedPath, folderCopy.MappedPath)
@ -1243,7 +1244,7 @@ func TestFolderCopy(t *testing.T) {
folderCopy = folder.GetACopy()
folder.FsConfig.CryptConfig.Passphrase = kms.NewEmptySecret()
require.Len(t, folderCopy.Users, 1)
require.True(t, util.Contains(folderCopy.Users, "user3"))
require.True(t, slices.Contains(folderCopy.Users, "user3"))
require.Equal(t, int64(2), folderCopy.ID)
require.Equal(t, folder.Name, folderCopy.Name)
require.Equal(t, folder.MappedPath, folderCopy.MappedPath)

View file

@ -63,7 +63,7 @@ type BaseConnection struct {
// NewBaseConnection returns a new BaseConnection
func NewBaseConnection(id, protocol, localAddr, remoteAddr string, user dataprovider.User) *BaseConnection {
connID := id
if util.Contains(supportedProtocols, protocol) {
if slices.Contains(supportedProtocols, protocol) {
connID = fmt.Sprintf("%s_%s", protocol, id)
}
user.UploadBandwidth, user.DownloadBandwidth = user.GetBandwidthForIP(util.GetIPFromRemoteAddress(remoteAddr), connID)
@ -132,7 +132,7 @@ func (c *BaseConnection) GetRemoteIP() string {
// SetProtocol sets the protocol for this connection
func (c *BaseConnection) SetProtocol(protocol string) {
c.protocol = protocol
if util.Contains(supportedProtocols, c.protocol) {
if slices.Contains(supportedProtocols, c.protocol) {
c.ID = fmt.Sprintf("%v_%v", c.protocol, c.ID)
}
}
@ -450,10 +450,7 @@ func (c *BaseConnection) RemoveFile(fs vfs.Fs, fsPath, virtualPath string, info
if updateQuota && info.Mode()&os.ModeSymlink == 0 {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath))
if err == nil {
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, -1, -size, false) //nolint:errcheck
if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(&c.User, -1, -size, false) //nolint:errcheck
}
dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, -1, -size, false)
} else {
dataprovider.UpdateUserQuota(&c.User, -1, -size, false) //nolint:errcheck
}
@ -1122,10 +1119,7 @@ func (c *BaseConnection) truncateFile(fs vfs.Fs, fsPath, virtualPath string, siz
sizeDiff := initialSize - size
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath))
if err == nil {
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -sizeDiff, false) //nolint:errcheck
if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(&c.User, 0, -sizeDiff, false) //nolint:errcheck
}
dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -sizeDiff, false)
} else {
dataprovider.UpdateUserQuota(&c.User, 0, -sizeDiff, false) //nolint:errcheck
}
@ -1519,61 +1513,40 @@ func (c *BaseConnection) updateQuotaMoveBetweenVFolders(sourceFolder, dstFolder
if sourceFolder.Name == dstFolder.Name {
// both files are inside the same virtual folder
if initialSize != -1 {
dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, -numFiles, -initialSize, false) //nolint:errcheck
if dstFolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(&c.User, -numFiles, -initialSize, false) //nolint:errcheck
}
dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, -numFiles, -initialSize, false)
}
return
}
// files are inside different virtual folders
dataprovider.UpdateVirtualFolderQuota(&sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck
if sourceFolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck
}
dataprovider.UpdateUserFolderQuota(sourceFolder, &c.User, -numFiles, -filesSize, false)
if initialSize == -1 {
dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck
if dstFolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck
dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, numFiles, filesSize, false)
return
}
} else {
// we cannot have a directory here, initialSize != -1 only for files
dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck
if dstFolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck
}
}
dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, 0, filesSize-initialSize, false)
}
func (c *BaseConnection) updateQuotaMoveFromVFolder(sourceFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) {
// move between a virtual folder and the user home dir
dataprovider.UpdateVirtualFolderQuota(&sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck
if sourceFolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck
}
dataprovider.UpdateUserFolderQuota(sourceFolder, &c.User, -numFiles, -filesSize, false)
if initialSize == -1 {
dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck
} else {
return
}
// we cannot have a directory here, initialSize != -1 only for files
dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck
}
}
func (c *BaseConnection) updateQuotaMoveToVFolder(dstFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) {
// move between the user home dir and a virtual folder
dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck
if initialSize == -1 {
dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck
if dstFolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck
dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, numFiles, filesSize, false)
return
}
} else {
// we cannot have a directory here, initialSize != -1 only for files
dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck
if dstFolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck
}
}
dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, 0, filesSize-initialSize, false)
}
func (c *BaseConnection) updateQuotaAfterRename(fs vfs.Fs, virtualSourcePath, virtualTargetPath, targetPath string,
@ -1823,8 +1796,8 @@ type DirListerAt struct {
lister vfs.DirLister
}
// Add adds the given os.FileInfo to the internal cache
func (l *DirListerAt) Add(fi os.FileInfo) {
// Prepend adds the given os.FileInfo as first element of the internal cache
func (l *DirListerAt) Prepend(fi os.FileInfo) {
l.mu.Lock()
defer l.mu.Unlock()

View file

@ -22,6 +22,7 @@ import (
"path"
"path/filepath"
"runtime"
"slices"
"strconv"
"testing"
"time"
@ -389,7 +390,7 @@ func TestErrorsMapping(t *testing.T) {
err := conn.GetFsError(fs, os.ErrNotExist)
if protocol == ProtocolSFTP {
assert.ErrorIs(t, err, sftp.ErrSSHFxNoSuchFile)
} else if util.Contains(osErrorsProtocols, protocol) {
} else if slices.Contains(osErrorsProtocols, protocol) {
assert.EqualError(t, err, os.ErrNotExist.Error())
} else {
assert.EqualError(t, err, ErrNotExist.Error())
@ -1134,8 +1135,8 @@ func TestListerAt(t *testing.T) {
require.Equal(t, 0, n)
lister, err = conn.ListDir("/")
require.NoError(t, err)
lister.Add(vfs.NewFileInfo("..", true, 0, time.Unix(0, 0), false))
lister.Add(vfs.NewFileInfo(".", true, 0, time.Unix(0, 0), false))
lister.Prepend(vfs.NewFileInfo("..", true, 0, time.Unix(0, 0), false))
lister.Prepend(vfs.NewFileInfo(".", true, 0, time.Unix(0, 0), false))
files = make([]os.FileInfo, 1)
n, err = lister.ListAt(files, 0)
require.NoError(t, err)

View file

@ -31,6 +31,7 @@ import (
"os/exec"
"path"
"path/filepath"
"slices"
"strconv"
"strings"
"sync"
@ -307,7 +308,7 @@ func (*eventRulesContainer) checkIPDLoginEventMatch(conditions *dataprovider.Eve
}
func (*eventRulesContainer) checkProviderEventMatch(conditions *dataprovider.EventConditions, params *EventParams) bool {
if !util.Contains(conditions.ProviderEvents, params.Event) {
if !slices.Contains(conditions.ProviderEvents, params.Event) {
return false
}
if !checkEventConditionPatterns(params.Name, conditions.Options.Names) {
@ -316,14 +317,14 @@ func (*eventRulesContainer) checkProviderEventMatch(conditions *dataprovider.Eve
if !checkEventConditionPatterns(params.Role, conditions.Options.RoleNames) {
return false
}
if len(conditions.Options.ProviderObjects) > 0 && !util.Contains(conditions.Options.ProviderObjects, params.ObjectType) {
if len(conditions.Options.ProviderObjects) > 0 && !slices.Contains(conditions.Options.ProviderObjects, params.ObjectType) {
return false
}
return true
}
func (*eventRulesContainer) checkFsEventMatch(conditions *dataprovider.EventConditions, params *EventParams) bool {
if !util.Contains(conditions.FsEvents, params.Event) {
if !slices.Contains(conditions.FsEvents, params.Event) {
return false
}
if !checkEventConditionPatterns(params.Name, conditions.Options.Names) {
@ -338,7 +339,7 @@ func (*eventRulesContainer) checkFsEventMatch(conditions *dataprovider.EventCond
if !checkEventConditionPatterns(params.VirtualPath, conditions.Options.FsPaths) {
return false
}
if len(conditions.Options.Protocols) > 0 && !util.Contains(conditions.Options.Protocols, params.Protocol) {
if len(conditions.Options.Protocols) > 0 && !slices.Contains(conditions.Options.Protocols, params.Protocol) {
return false
}
if params.Event == operationUpload || params.Event == operationDownload {
@ -908,10 +909,7 @@ func updateUserQuotaAfterFileWrite(conn *BaseConnection, virtualPath string, num
dataprovider.UpdateUserQuota(&conn.User, numFiles, fileSize, false) //nolint:errcheck
return
}
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, numFiles, fileSize, false) //nolint:errcheck
if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(&conn.User, numFiles, fileSize, false) //nolint:errcheck
}
dataprovider.UpdateUserFolderQuota(&vfolder, &conn.User, numFiles, fileSize, false)
}
func checkWriterPermsAndQuota(conn *BaseConnection, virtualPath string, numFiles int, expectedSize, truncatedSize int64) error {

View file

@ -19,6 +19,8 @@ import (
"github.com/robfig/cron/v3"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
)
@ -36,7 +38,15 @@ func stopEventScheduler() {
func startEventScheduler() {
stopEventScheduler()
eventScheduler = cron.New(cron.WithLocation(time.UTC), cron.WithLogger(cron.DiscardLogger))
options := []cron.Option{
cron.WithLogger(cron.DiscardLogger),
}
if !dataprovider.UseLocalTime() {
eventManagerLog(logger.LevelDebug, "use UTC time for the scheduler")
options = append(options, cron.WithLocation(time.UTC))
}
eventScheduler = cron.New(options...)
eventManager.loadRules()
_, err := eventScheduler.AddFunc("@every 10m", eventManager.loadRules)
util.PanicOnError(err)

View file

@ -30,6 +30,7 @@ import (
"path"
"path/filepath"
"runtime"
"slices"
"strings"
"sync"
"testing"
@ -1457,15 +1458,15 @@ func TestTruncateQuotaLimits(t *testing.T) {
expectedQuotaSize := int64(3)
fold, _, err := httpdtest.GetFolderByName(folder2.Name, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize)
assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles)
assert.Equal(t, int64(0), fold.UsedQuotaSize)
assert.Equal(t, 0, fold.UsedQuotaFiles)
err = f.Close()
assert.NoError(t, err)
expectedQuotaFiles = 1
fold, _, err = httpdtest.GetFolderByName(folder2.Name, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize)
assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles)
assert.Equal(t, int64(0), fold.UsedQuotaSize)
assert.Equal(t, 0, fold.UsedQuotaFiles)
user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles)
@ -1777,8 +1778,8 @@ func TestVirtualFoldersQuotaValues(t *testing.T) {
assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize)
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
@ -1885,8 +1886,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 2, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
@ -1910,8 +1911,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 2, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
// rename a file inside vdir2, it isn't included inside user quota, so we have:
// - vdir1/dir1/testFileName.rename
// - vdir1/dir2/testFileName1
@ -1929,8 +1930,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
assert.Equal(t, 2, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 2, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
// rename a file inside vdir2 overwriting an existing, we now have:
// - vdir1/dir1/testFileName.rename
// - vdir1/dir2/testFileName1
@ -1947,8 +1948,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
assert.Equal(t, 1, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 2, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
// rename a file inside vdir1 overwriting an existing, we now have:
// - vdir1/dir1/testFileName.rename (initial testFileName1)
// - vdir2/dir1/testFileName.rename (initial testFileName1)
@ -1960,8 +1961,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
@ -1981,8 +1982,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
@ -2087,8 +2088,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize, user.UsedQuotaSize)
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1+testFileSize1, f.UsedQuotaSize)
@ -2106,8 +2107,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize*2, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize*2, f.UsedQuotaSize)
assert.Equal(t, 2, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1*2, f.UsedQuotaSize)
@ -2124,8 +2125,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1+testFileSize, f.UsedQuotaSize)
@ -2141,8 +2142,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
@ -2172,8 +2173,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize1*3+testFileSize*2, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1*3+testFileSize*2, f.UsedQuotaSize)
assert.Equal(t, 5, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, int64(0), f.UsedQuotaSize)
@ -2187,8 +2188,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize1*2+testFileSize, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1*2+testFileSize, f.UsedQuotaSize)
assert.Equal(t, 3, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
@ -2293,8 +2294,8 @@ func TestQuotaRenameFromVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
@ -2312,8 +2313,8 @@ func TestQuotaRenameFromVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize+testFileSize1+testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
@ -2376,8 +2377,8 @@ func TestQuotaRenameFromVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize*3+testFileSize1*3, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 2, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, int64(0), f.UsedQuotaSize)
@ -2497,8 +2498,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
// rename a file from user home dir to vdir2, vdir2 is not included in user quota so we have:
// - /vdir2/dir1/testFileName
// - /vdir1/dir1/testFileName1
@ -2537,8 +2538,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
@ -2554,8 +2555,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
@ -2577,8 +2578,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
@ -2595,8 +2596,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize*2+testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 3, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
@ -2621,8 +2622,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize*2+testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 3, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1*2+testFileSize, f.UsedQuotaSize)
@ -3978,9 +3979,9 @@ func TestEventRule(t *testing.T) {
}, 3000*time.Millisecond, 100*time.Millisecond)
email := lastReceivedEmail.get()
assert.Len(t, email.To, 3)
assert.True(t, util.Contains(email.To, "test1@example.com"))
assert.True(t, util.Contains(email.To, "test2@example.com"))
assert.True(t, util.Contains(email.To, "test3@example.com"))
assert.True(t, slices.Contains(email.To, "test1@example.com"))
assert.True(t, slices.Contains(email.To, "test2@example.com"))
assert.True(t, slices.Contains(email.To, "test3@example.com"))
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: New "upload" from "%s" status OK`, user.Username))
// test the failure action, we download a file that exceeds the transfer quota limit
err = writeSFTPFileNoCheck(path.Join("subdir1", testFileName), 1*1024*1024+65535, client)
@ -3999,9 +4000,9 @@ func TestEventRule(t *testing.T) {
}, 3000*time.Millisecond, 100*time.Millisecond)
email = lastReceivedEmail.get()
assert.Len(t, email.To, 3)
assert.True(t, util.Contains(email.To, "test1@example.com"))
assert.True(t, util.Contains(email.To, "test2@example.com"))
assert.True(t, util.Contains(email.To, "test3@example.com"))
assert.True(t, slices.Contains(email.To, "test1@example.com"))
assert.True(t, slices.Contains(email.To, "test2@example.com"))
assert.True(t, slices.Contains(email.To, "test3@example.com"))
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: New "download" from "%s" status KO`, user.Username))
assert.Contains(t, email.Data, `"download" failed`)
assert.Contains(t, email.Data, common.ErrReadQuotaExceeded.Error())
@ -4019,7 +4020,7 @@ func TestEventRule(t *testing.T) {
}, 3000*time.Millisecond, 100*time.Millisecond)
email = lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "failure@example.com"))
assert.True(t, slices.Contains(email.To, "failure@example.com"))
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: Failed "upload" from "%s"`, user.Username))
assert.Contains(t, email.Data, fmt.Sprintf(`action %q failed`, action1.Name))
// now test the download rule
@ -4036,9 +4037,9 @@ func TestEventRule(t *testing.T) {
}, 3000*time.Millisecond, 100*time.Millisecond)
email = lastReceivedEmail.get()
assert.Len(t, email.To, 3)
assert.True(t, util.Contains(email.To, "test1@example.com"))
assert.True(t, util.Contains(email.To, "test2@example.com"))
assert.True(t, util.Contains(email.To, "test3@example.com"))
assert.True(t, slices.Contains(email.To, "test1@example.com"))
assert.True(t, slices.Contains(email.To, "test2@example.com"))
assert.True(t, slices.Contains(email.To, "test3@example.com"))
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: New "download" from "%s"`, user.Username))
}
// test upload action command with arguments
@ -4079,9 +4080,9 @@ func TestEventRule(t *testing.T) {
}, 3000*time.Millisecond, 100*time.Millisecond)
email := lastReceivedEmail.get()
assert.Len(t, email.To, 3)
assert.True(t, util.Contains(email.To, "test1@example.com"))
assert.True(t, util.Contains(email.To, "test2@example.com"))
assert.True(t, util.Contains(email.To, "test3@example.com"))
assert.True(t, slices.Contains(email.To, "test1@example.com"))
assert.True(t, slices.Contains(email.To, "test2@example.com"))
assert.True(t, slices.Contains(email.To, "test3@example.com"))
assert.Contains(t, email.Data, `Subject: New "delete" from "admin"`)
_, err = httpdtest.RemoveEventRule(rule3, http.StatusOK)
assert.NoError(t, err)
@ -4236,7 +4237,7 @@ func TestEventRuleProviderEvents(t *testing.T) {
}, 3000*time.Millisecond, 100*time.Millisecond)
email := lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "test3@example.com"))
assert.True(t, slices.Contains(email.To, "test3@example.com"))
assert.Contains(t, email.Data, `Subject: New "update" from "admin"`)
}
// now delete the script to generate an error
@ -4251,7 +4252,7 @@ func TestEventRuleProviderEvents(t *testing.T) {
}, 3000*time.Millisecond, 100*time.Millisecond)
email := lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "failure@example.com"))
assert.True(t, slices.Contains(email.To, "failure@example.com"))
assert.Contains(t, email.Data, `Subject: Failed "update" from "admin"`)
assert.Contains(t, email.Data, fmt.Sprintf("Object name: %s object type: folder", folder.Name))
lastReceivedEmail.reset()
@ -5306,7 +5307,7 @@ func TestBackupAsAttachment(t *testing.T) {
}, 3000*time.Millisecond, 100*time.Millisecond)
email := lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "test@example.com"))
assert.True(t, slices.Contains(email.To, "test@example.com"))
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s OK"`, renewalEvent))
assert.Contains(t, email.Data, `Domain: example.com`)
assert.Contains(t, email.Data, "Content-Type: application/json")
@ -5676,7 +5677,7 @@ func TestEventActionCompressQuotaErrors(t *testing.T) {
}, 3*time.Second, 100*time.Millisecond)
email := lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "test@example.com"))
assert.True(t, slices.Contains(email.To, "test@example.com"))
assert.Contains(t, email.Data, `Subject: "Compress failed"`)
assert.Contains(t, email.Data, common.ErrQuotaExceeded.Error())
// update quota size so the user is already overquota
@ -5691,7 +5692,7 @@ func TestEventActionCompressQuotaErrors(t *testing.T) {
}, 3*time.Second, 100*time.Millisecond)
email = lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "test@example.com"))
assert.True(t, slices.Contains(email.To, "test@example.com"))
assert.Contains(t, email.Data, `Subject: "Compress failed"`)
assert.Contains(t, email.Data, common.ErrQuotaExceeded.Error())
// remove the path to compress to trigger an error for size estimation
@ -5705,7 +5706,7 @@ func TestEventActionCompressQuotaErrors(t *testing.T) {
}, 3*time.Second, 100*time.Millisecond)
email = lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "test@example.com"))
assert.True(t, slices.Contains(email.To, "test@example.com"))
assert.Contains(t, email.Data, `Subject: "Compress failed"`)
assert.Contains(t, email.Data, "unable to estimate archive size")
}
@ -5834,8 +5835,8 @@ func TestEventActionCompressQuotaFolder(t *testing.T) {
assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize)
vfolder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, 2, vfolder.UsedQuotaFiles)
assert.Equal(t, info.Size()+int64(len(testFileContent)), vfolder.UsedQuotaSize)
assert.Equal(t, 0, vfolder.UsedQuotaFiles)
assert.Equal(t, int64(0), vfolder.UsedQuotaSize)
}
}
@ -6041,7 +6042,7 @@ func TestEventActionEmailAttachments(t *testing.T) {
}, 1500*time.Millisecond, 100*time.Millisecond)
email := lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "test@example.com"))
assert.True(t, slices.Contains(email.To, "test@example.com"))
assert.Contains(t, email.Data, `Subject: "upload" from`)
assert.Contains(t, email.Data, "Content-Disposition: attachment")
}
@ -6218,7 +6219,7 @@ func TestEventActionsRetentionReports(t *testing.T) {
email := lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "test@example.com"))
assert.True(t, slices.Contains(email.To, "test@example.com"))
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "upload" from "%s"`, user.Username))
assert.Contains(t, email.Data, "Content-Disposition: attachment")
_, err = client.Stat(testDir)
@ -6391,7 +6392,7 @@ func TestEventRuleFirstUploadDownloadActions(t *testing.T) {
}, 1500*time.Millisecond, 100*time.Millisecond)
email := lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "test@example.com"))
assert.True(t, slices.Contains(email.To, "test@example.com"))
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "first-upload" from "%s"`, user.Username))
lastReceivedEmail.reset()
// a new upload will not produce a new notification
@ -6414,7 +6415,7 @@ func TestEventRuleFirstUploadDownloadActions(t *testing.T) {
}, 1500*time.Millisecond, 100*time.Millisecond)
email = lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "test@example.com"))
assert.True(t, slices.Contains(email.To, "test@example.com"))
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "first-download" from "%s"`, user.Username))
// download again
lastReceivedEmail.reset()
@ -6510,7 +6511,7 @@ func TestEventRuleRenameEvent(t *testing.T) {
}, 1500*time.Millisecond, 100*time.Millisecond)
email := lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "test@example.com"))
assert.True(t, slices.Contains(email.To, "test@example.com"))
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "rename" from "%s"`, user.Username))
assert.Contains(t, email.Data, "Content-Type: text/html")
assert.Contains(t, email.Data, fmt.Sprintf("Target path %q", path.Join("/subdir", testFileName)))
@ -6644,7 +6645,7 @@ func TestEventRuleIDPLogin(t *testing.T) {
}, 3000*time.Millisecond, 100*time.Millisecond)
email := lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "test@example.com"))
assert.True(t, slices.Contains(email.To, "test@example.com"))
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s OK"`, common.IDPLoginUser))
assert.Contains(t, email.Data, username)
assert.Contains(t, email.Data, custom1)
@ -6708,7 +6709,7 @@ func TestEventRuleIDPLogin(t *testing.T) {
}, 3000*time.Millisecond, 100*time.Millisecond)
email = lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "test@example.com"))
assert.True(t, slices.Contains(email.To, "test@example.com"))
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s OK"`, common.IDPLoginAdmin))
assert.Contains(t, email.Data, username)
assert.Contains(t, email.Data, custom1)
@ -6900,7 +6901,7 @@ func TestEventRuleEmailField(t *testing.T) {
}, 3000*time.Millisecond, 100*time.Millisecond)
email := lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, user.Email))
assert.True(t, slices.Contains(email.To, user.Email))
assert.Contains(t, email.Data, `Subject: "add" from "admin"`)
// if we add a user without email the notification will fail
@ -6914,7 +6915,7 @@ func TestEventRuleEmailField(t *testing.T) {
}, 3000*time.Millisecond, 100*time.Millisecond)
email = lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "failure@example.com"))
assert.True(t, slices.Contains(email.To, "failure@example.com"))
assert.Contains(t, email.Data, `no recipient addresses set`)
conn, client, err := getSftpClient(user)
@ -6931,7 +6932,7 @@ func TestEventRuleEmailField(t *testing.T) {
}, 3000*time.Millisecond, 100*time.Millisecond)
email := lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, user.Email))
assert.True(t, slices.Contains(email.To, user.Email))
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "mkdir" from "%s"`, user.Username))
}
@ -7038,7 +7039,7 @@ func TestEventRuleCertificate(t *testing.T) {
}, 3000*time.Millisecond, 100*time.Millisecond)
email := lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "test@example.com"))
assert.True(t, slices.Contains(email.To, "test@example.com"))
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s OK"`, renewalEvent))
assert.Contains(t, email.Data, "Content-Type: text/plain")
assert.Contains(t, email.Data, `Domain: example.com Timestamp`)
@ -7058,7 +7059,7 @@ func TestEventRuleCertificate(t *testing.T) {
}, 3000*time.Millisecond, 100*time.Millisecond)
email = lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "test@example.com"))
assert.True(t, slices.Contains(email.To, "test@example.com"))
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s KO"`, renewalEvent))
assert.Contains(t, email.Data, `Domain: example.com Timestamp`)
assert.Contains(t, email.Data, errRenew.Error())
@ -7184,8 +7185,8 @@ func TestEventRuleIPBlocked(t *testing.T) {
}, 3000*time.Millisecond, 100*time.Millisecond)
email := lastReceivedEmail.get()
assert.Len(t, email.To, 2)
assert.True(t, util.Contains(email.To, "test3@example.com"))
assert.True(t, util.Contains(email.To, "test4@example.com"))
assert.True(t, slices.Contains(email.To, "test3@example.com"))
assert.True(t, slices.Contains(email.To, "test4@example.com"))
assert.Contains(t, email.Data, `Subject: New "IP Blocked"`)
err = dataprovider.DeleteEventRule(rule1.Name, "", "", "")
@ -8357,7 +8358,7 @@ func TestSFTPLoopError(t *testing.T) {
}, 3000*time.Millisecond, 100*time.Millisecond)
email := lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "failure@example.com"))
assert.True(t, slices.Contains(email.To, "failure@example.com"))
assert.Contains(t, email.Data, `Subject: Failed action`)
user1.VirtualFolders[0].FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword)

View file

@ -17,6 +17,7 @@ package common
import (
"errors"
"fmt"
"slices"
"sort"
"sync"
"sync/atomic"
@ -94,7 +95,7 @@ func (r *RateLimiterConfig) validate() error {
}
r.Protocols = util.RemoveDuplicates(r.Protocols, true)
for _, protocol := range r.Protocols {
if !util.Contains(rateLimiterProtocolValues, protocol) {
if !slices.Contains(rateLimiterProtocolValues, protocol) {
return fmt.Errorf("invalid protocol %q", protocol)
}
}

View file

@ -25,6 +25,7 @@ import (
"math/rand"
"os"
"path/filepath"
"slices"
"sync"
"github.com/drakkan/sftpgo/v2/internal/logger"
@ -96,7 +97,7 @@ func (m *CertManager) loadCertificates() error {
}
logger.Debug(m.logSender, "", "TLS certificate %q successfully loaded, id %v", keyPair.Cert, keyPair.ID)
certs[keyPair.ID] = &newCert
if !util.Contains(m.monitorList, keyPair.Cert) {
if !slices.Contains(m.monitorList, keyPair.Cert) {
m.monitorList = append(m.monitorList, keyPair.Cert)
}
}
@ -190,7 +191,7 @@ func (m *CertManager) LoadCRLs() error {
logger.Debug(m.logSender, "", "CRL %q successfully loaded", revocationList)
crls = append(crls, crl)
if !util.Contains(m.monitorList, revocationList) {
if !slices.Contains(m.monitorList, revocationList) {
m.monitorList = append(m.monitorList, revocationList)
}
}

View file

@ -521,11 +521,8 @@ func (t *BaseTransfer) updateQuota(numFiles int, fileSize int64) bool {
if t.transferType == TransferUpload && (numFiles != 0 || sizeDiff != 0) {
vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath))
if err == nil {
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, numFiles, //nolint:errcheck
dataprovider.UpdateUserFolderQuota(&vfolder, &t.Connection.User, numFiles,
sizeDiff, false)
if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(&t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck
}
} else {
dataprovider.UpdateUserQuota(&t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck
}

View file

@ -20,6 +20,7 @@ import (
"fmt"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
@ -234,6 +235,7 @@ func Init() {
RateLimitersConfig: []common.RateLimiterConfig{defaultRateLimiter},
Umask: "",
ServerVersion: "",
TZ: "",
Metadata: common.MetadataConfig{
Read: 0,
},
@ -715,7 +717,7 @@ func checkOverrideDefaultSettings() {
}
}
if util.Contains(viper.AllKeys(), "mfa.totp") {
if slices.Contains(viper.AllKeys(), "mfa.totp") {
globalConf.MFAConfig.TOTP = nil
}
}
@ -2007,6 +2009,7 @@ func setViperDefaults() {
viper.SetDefault("common.defender.login_delay.password_failed", globalConf.Common.DefenderConfig.LoginDelay.PasswordFailed)
viper.SetDefault("common.umask", globalConf.Common.Umask)
viper.SetDefault("common.server_version", globalConf.Common.ServerVersion)
viper.SetDefault("common.tz", globalConf.Common.TZ)
viper.SetDefault("common.metadata.read", globalConf.Common.Metadata.Read)
viper.SetDefault("acme.email", globalConf.ACME.Email)
viper.SetDefault("acme.key_type", globalConf.ACME.KeyType)

View file

@ -19,6 +19,7 @@ import (
"encoding/json"
"os"
"path/filepath"
"slices"
"testing"
"github.com/sftpgo/sdk/kms"
@ -36,7 +37,6 @@ import (
"github.com/drakkan/sftpgo/v2/internal/plugin"
"github.com/drakkan/sftpgo/v2/internal/sftpd"
"github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/webdavd"
)
@ -679,8 +679,8 @@ func TestPluginsFromEnv(t *testing.T) {
pluginConf := pluginsConf[0]
require.Equal(t, "notifier", pluginConf.Type)
require.Len(t, pluginConf.NotifierOptions.FsEvents, 2)
require.True(t, util.Contains(pluginConf.NotifierOptions.FsEvents, "upload"))
require.True(t, util.Contains(pluginConf.NotifierOptions.FsEvents, "download"))
require.True(t, slices.Contains(pluginConf.NotifierOptions.FsEvents, "upload"))
require.True(t, slices.Contains(pluginConf.NotifierOptions.FsEvents, "download"))
require.Len(t, pluginConf.NotifierOptions.ProviderEvents, 2)
require.Equal(t, "add", pluginConf.NotifierOptions.ProviderEvents[0])
require.Equal(t, "update", pluginConf.NotifierOptions.ProviderEvents[1])
@ -729,8 +729,8 @@ func TestPluginsFromEnv(t *testing.T) {
pluginConf = pluginsConf[0]
require.Equal(t, "notifier", pluginConf.Type)
require.Len(t, pluginConf.NotifierOptions.FsEvents, 2)
require.True(t, util.Contains(pluginConf.NotifierOptions.FsEvents, "upload"))
require.True(t, util.Contains(pluginConf.NotifierOptions.FsEvents, "download"))
require.True(t, slices.Contains(pluginConf.NotifierOptions.FsEvents, "upload"))
require.True(t, slices.Contains(pluginConf.NotifierOptions.FsEvents, "download"))
require.Len(t, pluginConf.NotifierOptions.ProviderEvents, 2)
require.Equal(t, "add", pluginConf.NotifierOptions.ProviderEvents[0])
require.Equal(t, "update", pluginConf.NotifierOptions.ProviderEvents[1])
@ -787,8 +787,8 @@ func TestRateLimitersFromEnv(t *testing.T) {
require.Equal(t, 2, limiters[0].Type)
protocols := limiters[0].Protocols
require.Len(t, protocols, 2)
require.True(t, util.Contains(protocols, common.ProtocolFTP))
require.True(t, util.Contains(protocols, common.ProtocolSSH))
require.True(t, slices.Contains(protocols, common.ProtocolFTP))
require.True(t, slices.Contains(protocols, common.ProtocolSSH))
require.True(t, limiters[0].GenerateDefenderEvents)
require.Equal(t, 50, limiters[0].EntriesSoftLimit)
require.Equal(t, 100, limiters[0].EntriesHardLimit)
@ -799,10 +799,10 @@ func TestRateLimitersFromEnv(t *testing.T) {
require.Equal(t, 2, limiters[1].Type)
protocols = limiters[1].Protocols
require.Len(t, protocols, 4)
require.True(t, util.Contains(protocols, common.ProtocolFTP))
require.True(t, util.Contains(protocols, common.ProtocolSSH))
require.True(t, util.Contains(protocols, common.ProtocolWebDAV))
require.True(t, util.Contains(protocols, common.ProtocolHTTP))
require.True(t, slices.Contains(protocols, common.ProtocolFTP))
require.True(t, slices.Contains(protocols, common.ProtocolSSH))
require.True(t, slices.Contains(protocols, common.ProtocolWebDAV))
require.True(t, slices.Contains(protocols, common.ProtocolHTTP))
require.False(t, limiters[1].GenerateDefenderEvents)
require.Equal(t, 100, limiters[1].EntriesSoftLimit)
require.Equal(t, 150, limiters[1].EntriesHardLimit)

View file

@ -21,6 +21,7 @@ import (
"net/url"
"os/exec"
"path/filepath"
"slices"
"strings"
"time"
@ -78,8 +79,8 @@ func executeAction(operation, executor, ip, objectType, objectName, role string,
if config.Actions.Hook == "" {
return
}
if !util.Contains(config.Actions.ExecuteOn, operation) ||
!util.Contains(config.Actions.ExecuteFor, objectType) {
if !slices.Contains(config.Actions.ExecuteOn, operation) ||
!slices.Contains(config.Actions.ExecuteFor, objectType) {
return
}

View file

@ -20,6 +20,7 @@ import (
"fmt"
"net"
"os"
"slices"
"strconv"
"strings"
@ -96,7 +97,7 @@ func (c *AdminTOTPConfig) validate(username string) error {
if c.ConfigName == "" {
return util.NewValidationError("totp: config name is mandatory")
}
if !util.Contains(mfa.GetAvailableTOTPConfigNames(), c.ConfigName) {
if !slices.Contains(mfa.GetAvailableTOTPConfigNames(), c.ConfigName) {
return util.NewValidationError(fmt.Sprintf("totp: config name %q not found", c.ConfigName))
}
if c.Secret.IsEmpty() {
@ -337,15 +338,15 @@ func (a *Admin) validatePermissions() error {
util.I18nErrorPermissionsRequired,
)
}
if util.Contains(a.Permissions, PermAdminAny) {
if slices.Contains(a.Permissions, PermAdminAny) {
a.Permissions = []string{PermAdminAny}
}
for _, perm := range a.Permissions {
if !util.Contains(validAdminPerms, perm) {
if !slices.Contains(validAdminPerms, perm) {
return util.NewValidationError(fmt.Sprintf("invalid permission: %q", perm))
}
if a.Role != "" {
if util.Contains(forbiddenPermsForRoleAdmins, perm) {
if slices.Contains(forbiddenPermsForRoleAdmins, perm) {
deniedPerms := strings.Join(forbiddenPermsForRoleAdmins, ",")
return util.NewI18nError(
util.NewValidationError(fmt.Sprintf("a role admin cannot have the following permissions: %q", deniedPerms)),
@ -559,10 +560,10 @@ func (a *Admin) SetNilSecretsIfEmpty() {
// HasPermission returns true if the admin has the specified permission
func (a *Admin) HasPermission(perm string) bool {
if util.Contains(a.Permissions, PermAdminAny) {
if slices.Contains(a.Permissions, PermAdminAny) {
return true
}
return util.Contains(a.Permissions, perm)
return slices.Contains(a.Permissions, perm)
}
// GetAllowedIPAsString returns the allowed IP as comma separated string

View file

@ -25,6 +25,7 @@ import (
"fmt"
"net/netip"
"path/filepath"
"slices"
"sort"
"time"
@ -3134,15 +3135,11 @@ func (p *BoltProvider) migrateDatabase() error {
case version == boltDatabaseVersion:
providerLog(logger.LevelDebug, "bolt database is up to date, current version: %d", version)
return ErrNoInitRequired
case version < 28:
case version < 29:
err = fmt.Errorf("database schema version %d is too old, please see the upgrading docs", version)
providerLog(logger.LevelError, "%v", err)
logger.ErrorToConsole("%v", err)
return err
case version == 28:
logger.InfoToConsole("updating database schema version: %d -> 29", version)
providerLog(logger.LevelInfo, "updating database schema version: %d -> 29", version)
return updateBoltDatabaseVersion(p.dbHandle, 29)
default:
if version > boltDatabaseVersion {
providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version,
@ -3164,10 +3161,6 @@ func (p *BoltProvider) revertDatabase(targetVersion int) error { //nolint:gocycl
return errors.New("current version match target version, nothing to do")
}
switch dbVersion.Version {
case 29:
logger.InfoToConsole("downgrading database schema version: %d -> 28", dbVersion.Version)
providerLog(logger.LevelInfo, "downgrading database schema version: %d -> 28", dbVersion.Version)
return updateBoltDatabaseVersion(p.dbHandle, 28)
default:
return fmt.Errorf("database schema version not handled: %v", dbVersion.Version)
}
@ -3328,7 +3321,7 @@ func (p *BoltProvider) addAdminToRole(username, roleName string, bucket *bolt.Bu
if err != nil {
return err
}
if !util.Contains(role.Admins, username) {
if !slices.Contains(role.Admins, username) {
role.Admins = append(role.Admins, username)
buf, err := json.Marshal(role)
if err != nil {
@ -3353,7 +3346,7 @@ func (p *BoltProvider) removeAdminFromRole(username, roleName string, bucket *bo
if err != nil {
return err
}
if util.Contains(role.Admins, username) {
if slices.Contains(role.Admins, username) {
var admins []string
for _, admin := range role.Admins {
if admin != username {
@ -3383,7 +3376,7 @@ func (p *BoltProvider) addUserToRole(username, roleName string, bucket *bolt.Buc
if err != nil {
return err
}
if !util.Contains(role.Users, username) {
if !slices.Contains(role.Users, username) {
role.Users = append(role.Users, username)
buf, err := json.Marshal(role)
if err != nil {
@ -3408,7 +3401,7 @@ func (p *BoltProvider) removeUserFromRole(username, roleName string, bucket *bol
if err != nil {
return err
}
if util.Contains(role.Users, username) {
if slices.Contains(role.Users, username) {
var users []string
for _, user := range role.Users {
if user != username {
@ -3436,7 +3429,7 @@ func (p *BoltProvider) addRuleToActionMapping(ruleName, actionName string, bucke
if err != nil {
return err
}
if !util.Contains(action.Rules, ruleName) {
if !slices.Contains(action.Rules, ruleName) {
action.Rules = append(action.Rules, ruleName)
buf, err := json.Marshal(action)
if err != nil {
@ -3458,7 +3451,7 @@ func (p *BoltProvider) removeRuleFromActionMapping(ruleName, actionName string,
if err != nil {
return err
}
if util.Contains(action.Rules, ruleName) {
if slices.Contains(action.Rules, ruleName) {
var rules []string
for _, r := range action.Rules {
if r != ruleName {
@ -3485,7 +3478,7 @@ func (p *BoltProvider) addUserToGroupMapping(username, groupname string, bucket
if err != nil {
return err
}
if !util.Contains(group.Users, username) {
if !slices.Contains(group.Users, username) {
group.Users = append(group.Users, username)
buf, err := json.Marshal(group)
if err != nil {
@ -3530,7 +3523,7 @@ func (p *BoltProvider) addAdminToGroupMapping(username, groupname string, bucket
if err != nil {
return err
}
if !util.Contains(group.Admins, username) {
if !slices.Contains(group.Admins, username) {
group.Admins = append(group.Admins, username)
buf, err := json.Marshal(group)
if err != nil {
@ -3601,11 +3594,11 @@ func (p *BoltProvider) addRelationToFolderMapping(folderName string, user *User,
return err
}
updated := false
if user != nil && !util.Contains(folder.Users, user.Username) {
if user != nil && !slices.Contains(folder.Users, user.Username) {
folder.Users = append(folder.Users, user.Username)
updated = true
}
if group != nil && !util.Contains(folder.Groups, group.Name) {
if group != nil && !slices.Contains(folder.Groups, group.Name) {
folder.Groups = append(folder.Groups, group.Name)
updated = true
}
@ -3899,7 +3892,7 @@ func getBoltDatabaseVersion(dbHandle *bolt.DB) (schemaVersion, error) {
v := bucket.Get(dbVersionKey)
if v == nil {
dbVersion = schemaVersion{
Version: 28,
Version: 29,
}
return nil
}
@ -3908,7 +3901,7 @@ func getBoltDatabaseVersion(dbHandle *bolt.DB) (schemaVersion, error) {
return dbVersion, err
}
func updateBoltDatabaseVersion(dbHandle *bolt.DB, version int) error {
/*func updateBoltDatabaseVersion(dbHandle *bolt.DB, version int) error {
err := dbHandle.Update(func(tx *bolt.Tx) error {
bucket := tx.Bucket(dbVersionBucket)
if bucket == nil {
@ -3924,4 +3917,4 @@ func updateBoltDatabaseVersion(dbHandle *bolt.DB, version int) error {
return bucket.Put(dbVersionKey, buf)
})
return err
}
}*/

View file

@ -15,8 +15,12 @@
package dataprovider
import (
"bytes"
"encoding/json"
"fmt"
"image/png"
"net/url"
"slices"
"golang.org/x/crypto/ssh"
@ -102,7 +106,7 @@ func (c *SFTPDConfigs) validate() error {
if algo == ssh.CertAlgoRSAv01 {
continue
}
if !util.Contains(supportedHostKeyAlgos, algo) {
if !slices.Contains(supportedHostKeyAlgos, algo) {
return util.NewValidationError(fmt.Sprintf("unsupported host key algorithm %q", algo))
}
hostKeyAlgos = append(hostKeyAlgos, algo)
@ -113,24 +117,24 @@ func (c *SFTPDConfigs) validate() error {
if algo == "diffie-hellman-group18-sha512" || algo == ssh.KeyExchangeDHGEXSHA256 {
continue
}
if !util.Contains(supportedKexAlgos, algo) {
if !slices.Contains(supportedKexAlgos, algo) {
return util.NewValidationError(fmt.Sprintf("unsupported KEX algorithm %q", algo))
}
kexAlgos = append(kexAlgos, algo)
}
c.KexAlgorithms = kexAlgos
for _, cipher := range c.Ciphers {
if !util.Contains(supportedCiphers, cipher) {
if !slices.Contains(supportedCiphers, cipher) {
return util.NewValidationError(fmt.Sprintf("unsupported cipher %q", cipher))
}
}
for _, mac := range c.MACs {
if !util.Contains(supportedMACs, mac) {
if !slices.Contains(supportedMACs, mac) {
return util.NewValidationError(fmt.Sprintf("unsupported MAC algorithm %q", mac))
}
}
for _, algo := range c.PublicKeyAlgos {
if !util.Contains(supportedPublicKeyAlgos, algo) {
if !slices.Contains(supportedPublicKeyAlgos, algo) {
return util.NewValidationError(fmt.Sprintf("unsupported public key algorithm %q", algo))
}
}
@ -305,6 +309,27 @@ func (c *SMTPConfigs) TryDecrypt() error {
return nil
}
func (c *SMTPConfigs) prepareForRendering() {
if c.Password != nil {
c.Password.Hide()
if c.Password.IsEmpty() {
c.Password = nil
}
}
if c.OAuth2.ClientSecret != nil {
c.OAuth2.ClientSecret.Hide()
if c.OAuth2.ClientSecret.IsEmpty() {
c.OAuth2.ClientSecret = nil
}
}
if c.OAuth2.RefreshToken != nil {
c.OAuth2.RefreshToken.Hide()
if c.OAuth2.RefreshToken.IsEmpty() {
c.OAuth2.RefreshToken = nil
}
}
}
func (c *SMTPConfigs) getACopy() *SMTPConfigs {
var password *kms.Secret
if c.Password != nil {
@ -387,12 +412,136 @@ func (c *ACMEConfigs) getACopy() *ACMEConfigs {
}
}
// BrandingConfig defines the branding configuration
type BrandingConfig struct {
Name string `json:"name"`
ShortName string `json:"short_name"`
Logo []byte `json:"logo"`
Favicon []byte `json:"favicon"`
DisclaimerName string `json:"disclaimer_name"`
DisclaimerURL string `json:"disclaimer_url"`
}
func (c *BrandingConfig) isEmpty() bool {
if c.Name != "" {
return false
}
if c.ShortName != "" {
return false
}
if len(c.Logo) > 0 {
return false
}
if len(c.Favicon) > 0 {
return false
}
if c.DisclaimerName != "" && c.DisclaimerURL != "" {
return false
}
return true
}
func (*BrandingConfig) validatePNG(b []byte, maxWidth, maxHeight int) error {
if len(b) == 0 {
return nil
}
// DecodeConfig is more efficient, but I'm not sure if this would lead to
// accepting invalid images in some edge cases and performance does not
// matter here.
img, err := png.Decode(bytes.NewBuffer(b))
if err != nil {
return util.NewI18nError(
util.NewValidationError("invalid PNG image"),
util.I18nErrorInvalidPNG,
)
}
bounds := img.Bounds()
if bounds.Dx() > maxWidth || bounds.Dy() > maxHeight {
return util.NewI18nError(
util.NewValidationError("invalid PNG image size"),
util.I18nErrorInvalidPNGSize,
)
}
return nil
}
func (c *BrandingConfig) validateDisclaimerURL() error {
if c.DisclaimerURL == "" {
return nil
}
u, err := url.Parse(c.DisclaimerURL)
if err != nil {
return util.NewI18nError(
util.NewValidationError("invalid disclaimer URL"),
util.I18nErrorInvalidDisclaimerURL,
)
}
if u.Scheme != "http" && u.Scheme != "https" {
return util.NewI18nError(
util.NewValidationError("invalid disclaimer URL scheme"),
util.I18nErrorInvalidDisclaimerURL,
)
}
return nil
}
func (c *BrandingConfig) validate() error {
if err := c.validateDisclaimerURL(); err != nil {
return err
}
if err := c.validatePNG(c.Logo, 512, 512); err != nil {
return err
}
return c.validatePNG(c.Favicon, 256, 256)
}
func (c *BrandingConfig) getACopy() BrandingConfig {
logo := make([]byte, len(c.Logo))
copy(logo, c.Logo)
favicon := make([]byte, len(c.Favicon))
copy(favicon, c.Favicon)
return BrandingConfig{
Name: c.Name,
ShortName: c.ShortName,
Logo: logo,
Favicon: favicon,
DisclaimerName: c.DisclaimerName,
DisclaimerURL: c.DisclaimerURL,
}
}
// BrandingConfigs defines the branding configuration for WebAdmin and WebClient UI
type BrandingConfigs struct {
WebAdmin BrandingConfig
WebClient BrandingConfig
}
func (c *BrandingConfigs) isEmpty() bool {
return c.WebAdmin.isEmpty() && c.WebClient.isEmpty()
}
func (c *BrandingConfigs) validate() error {
if err := c.WebAdmin.validate(); err != nil {
return err
}
return c.WebClient.validate()
}
func (c *BrandingConfigs) getACopy() *BrandingConfigs {
return &BrandingConfigs{
WebAdmin: c.WebAdmin.getACopy(),
WebClient: c.WebClient.getACopy(),
}
}
// Configs allows to set configuration keys disabled by default without
// modifying the config file or setting env vars
type Configs struct {
SFTPD *SFTPDConfigs `json:"sftpd,omitempty"`
SMTP *SMTPConfigs `json:"smtp,omitempty"`
ACME *ACMEConfigs `json:"acme,omitempty"`
Branding *BrandingConfigs `json:"branding,omitempty"`
UpdatedAt int64 `json:"updated_at,omitempty"`
}
@ -412,6 +561,11 @@ func (c *Configs) validate() error {
return err
}
}
if c.Branding != nil {
if err := c.Branding.validate(); err != nil {
return err
}
}
return nil
}
@ -428,25 +582,11 @@ func (c *Configs) PrepareForRendering() {
if c.ACME != nil && c.ACME.isEmpty() {
c.ACME = nil
}
if c.Branding != nil && c.Branding.isEmpty() {
c.Branding = nil
}
if c.SMTP != nil {
if c.SMTP.Password != nil {
c.SMTP.Password.Hide()
if c.SMTP.Password.IsEmpty() {
c.SMTP.Password = nil
}
}
if c.SMTP.OAuth2.ClientSecret != nil {
c.SMTP.OAuth2.ClientSecret.Hide()
if c.SMTP.OAuth2.ClientSecret.IsEmpty() {
c.SMTP.OAuth2.ClientSecret = nil
}
}
if c.SMTP.OAuth2.RefreshToken != nil {
c.SMTP.OAuth2.RefreshToken.Hide()
if c.SMTP.OAuth2.RefreshToken.IsEmpty() {
c.SMTP.OAuth2.RefreshToken = nil
}
}
c.SMTP.prepareForRendering()
}
}
@ -470,6 +610,9 @@ func (c *Configs) SetNilsToEmpty() {
if c.ACME == nil {
c.ACME = &ACMEConfigs{}
}
if c.Branding == nil {
c.Branding = &BrandingConfigs{}
}
}
// RenderAsJSON implements the renderer interface used within plugins
@ -498,6 +641,9 @@ func (c *Configs) getACopy() Configs {
if c.ACME != nil {
result.ACME = c.ACME.getACopy()
}
if c.Branding != nil {
result.Branding = c.Branding.getACopy()
}
result.UpdatedAt = c.UpdatedAt
return result
}

View file

@ -44,6 +44,7 @@ import (
"path/filepath"
"regexp"
"runtime"
"slices"
"strconv"
"strings"
"sync"
@ -187,6 +188,7 @@ var (
ErrDuplicatedKey = errors.New("duplicated key not allowed")
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
tz = ""
isAdminCreated atomic.Bool
validTLSUsernames = []string{string(sdk.TLSUsernameNone), string(sdk.TLSUsernameCN)}
config Config
@ -518,7 +520,7 @@ type Config struct {
// GetShared returns the provider share mode.
// This method is called before the provider is initialized
func (c *Config) GetShared() int {
if !util.Contains(sharedProviders, c.Driver) {
if !slices.Contains(sharedProviders, c.Driver) {
return 0
}
return c.IsShared
@ -590,6 +592,16 @@ func (c *Config) doBackup() (string, error) {
return outputFile, nil
}
// SetTZ sets the configured timezone.
func SetTZ(val string) {
tz = val
}
// UseLocalTime returns true if local time should be used instead of UTC.
func UseLocalTime() bool {
return tz == "local"
}
// ExecuteBackup executes a backup
func ExecuteBackup() (string, error) {
return config.doBackup()
@ -874,7 +886,7 @@ func SetTempPath(fsPath string) {
}
func checkSharedMode() {
if !util.Contains(sharedProviders, config.Driver) {
if !slices.Contains(sharedProviders, config.Driver) {
config.IsShared = 0
}
}
@ -929,12 +941,13 @@ func checkDatabase(checkAdmins bool) error {
if config.UpdateMode == 0 {
err := provider.initializeDatabase()
if err != nil && err != ErrNoInitRequired {
logger.WarnToConsole("Unable to initialize data provider: %v", err)
providerLog(logger.LevelError, "Unable to initialize data provider: %v", err)
logger.WarnToConsole("unable to initialize data provider: %v", err)
providerLog(logger.LevelError, "unable to initialize data provider: %v", err)
return err
}
if err == nil {
logger.DebugToConsole("Data provider successfully initialized")
logger.DebugToConsole("data provider successfully initialized")
providerLog(logger.LevelInfo, "data provider successfully initialized")
}
err = provider.migrateDatabase()
if err != nil && err != ErrNoInitRequired {
@ -1503,6 +1516,15 @@ func UpdateUserQuota(user *User, filesAdd int, sizeAdd int64, reset bool) error
return nil
}
// UpdateUserFolderQuota updates the quota for the given user and virtual folder.
func UpdateUserFolderQuota(folder *vfs.VirtualFolder, user *User, filesAdd int, sizeAdd int64, reset bool) {
if folder.IsIncludedInUserQuota() {
UpdateUserQuota(user, filesAdd, sizeAdd, reset) //nolint:errcheck
return
}
UpdateVirtualFolderQuota(&folder.BaseVirtualFolder, filesAdd, sizeAdd, reset) //nolint:errcheck
}
// UpdateVirtualFolderQuota updates the quota for the given virtual folder adding filesAdd and sizeAdd.
// If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference.
func UpdateVirtualFolderQuota(vfolder *vfs.BaseVirtualFolder, filesAdd int, sizeAdd int64, reset bool) error {
@ -1693,7 +1715,7 @@ func IPListEntryExists(ipOrNet string, listType IPListType) (IPListEntry, error)
// GetIPListEntries returns the IP list entries applying the specified criteria and search limit
func GetIPListEntries(listType IPListType, filter, from, order string, limit int) ([]IPListEntry, error) {
if !util.Contains(supportedIPListType, listType) {
if !slices.Contains(supportedIPListType, listType) {
return nil, util.NewValidationError(fmt.Sprintf("invalid list type %d", listType))
}
return provider.getIPListEntries(listType, filter, from, order, limit)
@ -2352,7 +2374,7 @@ func GetFolders(limit, offset int, order string, minimal bool) ([]vfs.BaseVirtua
}
func dumpUsers(data *BackupData, scopes []string) error {
if len(scopes) == 0 || util.Contains(scopes, DumpScopeUsers) {
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeUsers) {
users, err := provider.dumpUsers()
if err != nil {
return err
@ -2363,7 +2385,7 @@ func dumpUsers(data *BackupData, scopes []string) error {
}
func dumpFolders(data *BackupData, scopes []string) error {
if len(scopes) == 0 || util.Contains(scopes, DumpScopeFolders) {
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeFolders) {
folders, err := provider.dumpFolders()
if err != nil {
return err
@ -2374,7 +2396,7 @@ func dumpFolders(data *BackupData, scopes []string) error {
}
func dumpGroups(data *BackupData, scopes []string) error {
if len(scopes) == 0 || util.Contains(scopes, DumpScopeGroups) {
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeGroups) {
groups, err := provider.dumpGroups()
if err != nil {
return err
@ -2385,7 +2407,7 @@ func dumpGroups(data *BackupData, scopes []string) error {
}
func dumpAdmins(data *BackupData, scopes []string) error {
if len(scopes) == 0 || util.Contains(scopes, DumpScopeAdmins) {
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeAdmins) {
admins, err := provider.dumpAdmins()
if err != nil {
return err
@ -2396,7 +2418,7 @@ func dumpAdmins(data *BackupData, scopes []string) error {
}
func dumpAPIKeys(data *BackupData, scopes []string) error {
if len(scopes) == 0 || util.Contains(scopes, DumpScopeAPIKeys) {
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeAPIKeys) {
apiKeys, err := provider.dumpAPIKeys()
if err != nil {
return err
@ -2407,7 +2429,7 @@ func dumpAPIKeys(data *BackupData, scopes []string) error {
}
func dumpShares(data *BackupData, scopes []string) error {
if len(scopes) == 0 || util.Contains(scopes, DumpScopeShares) {
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeShares) {
shares, err := provider.dumpShares()
if err != nil {
return err
@ -2418,7 +2440,7 @@ func dumpShares(data *BackupData, scopes []string) error {
}
func dumpActions(data *BackupData, scopes []string) error {
if len(scopes) == 0 || util.Contains(scopes, DumpScopeActions) {
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeActions) {
actions, err := provider.dumpEventActions()
if err != nil {
return err
@ -2429,7 +2451,7 @@ func dumpActions(data *BackupData, scopes []string) error {
}
func dumpRules(data *BackupData, scopes []string) error {
if len(scopes) == 0 || util.Contains(scopes, DumpScopeRules) {
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeRules) {
rules, err := provider.dumpEventRules()
if err != nil {
return err
@ -2440,7 +2462,7 @@ func dumpRules(data *BackupData, scopes []string) error {
}
func dumpRoles(data *BackupData, scopes []string) error {
if len(scopes) == 0 || util.Contains(scopes, DumpScopeRoles) {
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeRoles) {
roles, err := provider.dumpRoles()
if err != nil {
return err
@ -2451,7 +2473,7 @@ func dumpRoles(data *BackupData, scopes []string) error {
}
func dumpIPLists(data *BackupData, scopes []string) error {
if len(scopes) == 0 || util.Contains(scopes, DumpScopeIPLists) {
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeIPLists) {
ipLists, err := provider.dumpIPListEntries()
if err != nil {
return err
@ -2462,7 +2484,7 @@ func dumpIPLists(data *BackupData, scopes []string) error {
}
func dumpConfigs(data *BackupData, scopes []string) error {
if len(scopes) == 0 || util.Contains(scopes, DumpScopeConfigs) {
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeConfigs) {
configs, err := provider.getConfigs()
if err != nil {
return err
@ -2766,7 +2788,7 @@ func validateUserTOTPConfig(c *UserTOTPConfig, username string) error {
if c.ConfigName == "" {
return util.NewValidationError("totp: config name is mandatory")
}
if !util.Contains(mfa.GetAvailableTOTPConfigNames(), c.ConfigName) {
if !slices.Contains(mfa.GetAvailableTOTPConfigNames(), c.ConfigName) {
return util.NewValidationError(fmt.Sprintf("totp: config name %q not found", c.ConfigName))
}
if c.Secret.IsEmpty() {
@ -2782,7 +2804,7 @@ func validateUserTOTPConfig(c *UserTOTPConfig, username string) error {
return util.NewValidationError("totp: specify at least one protocol")
}
for _, protocol := range c.Protocols {
if !util.Contains(MFAProtocols, protocol) {
if !slices.Contains(MFAProtocols, protocol) {
return util.NewValidationError(fmt.Sprintf("totp: invalid protocol %q", protocol))
}
}
@ -2815,7 +2837,7 @@ func validateUserPermissions(permsToCheck map[string][]string) (map[string][]str
return permissions, util.NewValidationError("invalid permissions")
}
for _, p := range perms {
if !util.Contains(ValidPerms, p) {
if !slices.Contains(ValidPerms, p) {
return permissions, util.NewValidationError(fmt.Sprintf("invalid permission: %q", p))
}
}
@ -2829,7 +2851,7 @@ func validateUserPermissions(permsToCheck map[string][]string) (map[string][]str
if dir != cleanedDir && cleanedDir == "/" {
return permissions, util.NewValidationError(fmt.Sprintf("cannot set permissions for invalid subdirectory: %q is an alias for \"/\"", dir))
}
if util.Contains(perms, PermAny) {
if slices.Contains(perms, PermAny) {
permissions[cleanedDir] = []string{PermAny}
} else {
permissions[cleanedDir] = util.RemoveDuplicates(perms, false)
@ -2905,7 +2927,7 @@ func validateFiltersPatternExtensions(baseFilters *sdk.BaseUserFilters) error {
util.I18nErrorFilePatternPathInvalid,
)
}
if util.Contains(filteredPaths, cleanedPath) {
if slices.Contains(filteredPaths, cleanedPath) {
return util.NewI18nError(
util.NewValidationError(fmt.Sprintf("duplicate file patterns filter for path %q", f.Path)),
util.I18nErrorFilePatternDuplicated,
@ -3024,13 +3046,13 @@ func validateFilterProtocols(filters *sdk.BaseUserFilters) error {
return util.NewValidationError("invalid denied_protocols")
}
for _, p := range filters.DeniedProtocols {
if !util.Contains(ValidProtocols, p) {
if !slices.Contains(ValidProtocols, p) {
return util.NewValidationError(fmt.Sprintf("invalid denied protocol %q", p))
}
}
for _, p := range filters.TwoFactorAuthProtocols {
if !util.Contains(MFAProtocols, p) {
if !slices.Contains(MFAProtocols, p) {
return util.NewValidationError(fmt.Sprintf("invalid two factor protocol %q", p))
}
}
@ -3086,7 +3108,7 @@ func validateBaseFilters(filters *sdk.BaseUserFilters) error {
return util.NewValidationError("invalid denied_login_methods")
}
for _, loginMethod := range filters.DeniedLoginMethods {
if !util.Contains(ValidLoginMethods, loginMethod) {
if !slices.Contains(ValidLoginMethods, loginMethod) {
return util.NewValidationError(fmt.Sprintf("invalid login method: %q", loginMethod))
}
}
@ -3094,7 +3116,7 @@ func validateBaseFilters(filters *sdk.BaseUserFilters) error {
return err
}
if filters.TLSUsername != "" {
if !util.Contains(validTLSUsernames, string(filters.TLSUsername)) {
if !slices.Contains(validTLSUsernames, string(filters.TLSUsername)) {
return util.NewValidationError(fmt.Sprintf("invalid TLS username: %q", filters.TLSUsername))
}
}
@ -3104,7 +3126,7 @@ func validateBaseFilters(filters *sdk.BaseUserFilters) error {
}
filters.TLSCerts = certs
for _, opts := range filters.WebClient {
if !util.Contains(sdk.WebClientOptions, opts) {
if !slices.Contains(sdk.WebClientOptions, opts) {
return util.NewValidationError(fmt.Sprintf("invalid web client options %q", opts))
}
}
@ -3172,19 +3194,19 @@ func validateAccessTimeFilters(filters *sdk.BaseUserFilters) error {
}
func validateCombinedUserFilters(user *User) error {
if user.Filters.TOTPConfig.Enabled && util.Contains(user.Filters.WebClient, sdk.WebClientMFADisabled) {
if user.Filters.TOTPConfig.Enabled && slices.Contains(user.Filters.WebClient, sdk.WebClientMFADisabled) {
return util.NewI18nError(
util.NewValidationError("two-factor authentication cannot be disabled for a user with an active configuration"),
util.I18nErrorDisableActive2FA,
)
}
if user.Filters.RequirePasswordChange && util.Contains(user.Filters.WebClient, sdk.WebClientPasswordChangeDisabled) {
if user.Filters.RequirePasswordChange && slices.Contains(user.Filters.WebClient, sdk.WebClientPasswordChangeDisabled) {
return util.NewI18nError(
util.NewValidationError("you cannot require password change and at the same time disallow it"),
util.I18nErrorPwdChangeConflict,
)
}
if len(user.Filters.TwoFactorAuthProtocols) > 0 && util.Contains(user.Filters.WebClient, sdk.WebClientMFADisabled) {
if len(user.Filters.TwoFactorAuthProtocols) > 0 && slices.Contains(user.Filters.WebClient, sdk.WebClientMFADisabled) {
return util.NewI18nError(
util.NewValidationError("you cannot require two-factor authentication and at the same time disallow it"),
util.I18nError2FAConflict,
@ -3505,7 +3527,7 @@ func checkUserPasscode(user *User, password, protocol string) (string, error) {
if user.Filters.TOTPConfig.Enabled {
switch protocol {
case protocolFTP:
if util.Contains(user.Filters.TOTPConfig.Protocols, protocol) {
if slices.Contains(user.Filters.TOTPConfig.Protocols, protocol) {
// the TOTP passcode has six digits
pwdLen := len(password)
if pwdLen < 7 {
@ -3711,7 +3733,7 @@ func doBuiltinKeyboardInteractiveAuth(user *User, client ssh.KeyboardInteractive
if err := user.LoadAndApplyGroupSettings(); err != nil {
return 0, err
}
hasSecondFactor := user.Filters.TOTPConfig.Enabled && util.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH)
hasSecondFactor := user.Filters.TOTPConfig.Enabled && slices.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH)
if !isPartialAuth || !hasSecondFactor {
answers, err := client("", "", []string{"Password: "}, []bool{false})
if err != nil {
@ -3729,7 +3751,7 @@ func doBuiltinKeyboardInteractiveAuth(user *User, client ssh.KeyboardInteractive
}
func checkKeyboardInteractiveSecondFactor(user *User, client ssh.KeyboardInteractiveChallenge, protocol string) (int, error) {
if !user.Filters.TOTPConfig.Enabled || !util.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH) {
if !user.Filters.TOTPConfig.Enabled || !slices.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH) {
return 1, nil
}
err := user.Filters.TOTPConfig.Secret.TryDecrypt()
@ -3853,7 +3875,7 @@ func getKeyboardInteractiveAnswers(client ssh.KeyboardInteractiveChallenge, resp
}
if len(answers) == 1 && response.CheckPwd > 0 {
if response.CheckPwd == 2 {
if !user.Filters.TOTPConfig.Enabled || !util.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH) {
if !user.Filters.TOTPConfig.Enabled || !slices.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH) {
providerLog(logger.LevelInfo, "keyboard interactive auth error: unable to check TOTP passcode, TOTP is not enabled for user %q",
user.Username)
return answers, errors.New("TOTP not enabled for SSH protocol")
@ -4619,7 +4641,7 @@ func getConfigPath(name, configDir string) string {
}
func checkReservedUsernames(username string) error {
if util.Contains(reservedUsers, username) {
if slices.Contains(reservedUsers, username) {
return util.NewValidationError("this username is reserved")
}
return nil

View file

@ -23,6 +23,7 @@ import (
"net/http"
"path"
"path/filepath"
"slices"
"strings"
"time"
@ -60,7 +61,7 @@ var (
)
func isActionTypeValid(action int) bool {
return util.Contains(supportedEventActions, action)
return slices.Contains(supportedEventActions, action)
}
func getActionTypeAsString(action int) string {
@ -115,7 +116,7 @@ var (
)
func isEventTriggerValid(trigger int) bool {
return util.Contains(supportedEventTriggers, trigger)
return slices.Contains(supportedEventTriggers, trigger)
}
func getTriggerTypeAsString(trigger int) string {
@ -169,7 +170,7 @@ var (
)
func isFilesystemActionValid(value int) bool {
return util.Contains(supportedFsActions, value)
return slices.Contains(supportedFsActions, value)
}
func getFsActionTypeAsString(value int) string {
@ -380,7 +381,7 @@ func (c *EventActionHTTPConfig) validate(additionalData string) error {
return util.NewValidationError(fmt.Sprintf("could not encrypt HTTP password: %v", err))
}
}
if !util.Contains(SupportedHTTPActionMethods, c.Method) {
if !slices.Contains(SupportedHTTPActionMethods, c.Method) {
return util.NewValidationError(fmt.Sprintf("unsupported HTTP method: %s", c.Method))
}
for _, kv := range c.QueryParameters {
@ -1280,7 +1281,7 @@ func (a *EventAction) validateAssociation(trigger int, fsEvents []string) error
}
if trigger == EventTriggerFsEvent {
for _, ev := range fsEvents {
if !util.Contains(allowedSyncFsEvents, ev) {
if !slices.Contains(allowedSyncFsEvents, ev) {
return util.NewI18nError(
util.NewValidationError("sync execution is only supported for upload and pre-* events"),
util.I18nErrorEvSyncUnsupportedFs,
@ -1361,12 +1362,12 @@ func (f *ConditionOptions) validate() error {
}
for _, p := range f.Protocols {
if !util.Contains(SupportedRuleConditionProtocols, p) {
if !slices.Contains(SupportedRuleConditionProtocols, p) {
return util.NewValidationError(fmt.Sprintf("unsupported rule condition protocol: %q", p))
}
}
for _, p := range f.ProviderObjects {
if !util.Contains(SupporteRuleConditionProviderObjects, p) {
if !slices.Contains(SupporteRuleConditionProviderObjects, p) {
return util.NewValidationError(fmt.Sprintf("unsupported provider object: %q", p))
}
}
@ -1468,7 +1469,7 @@ func (c *EventConditions) validate(trigger int) error {
)
}
for _, ev := range c.FsEvents {
if !util.Contains(SupportedFsEvents, ev) {
if !slices.Contains(SupportedFsEvents, ev) {
return util.NewValidationError(fmt.Sprintf("unsupported fs event: %q", ev))
}
}
@ -1488,7 +1489,7 @@ func (c *EventConditions) validate(trigger int) error {
)
}
for _, ev := range c.ProviderEvents {
if !util.Contains(SupportedProviderEvents, ev) {
if !slices.Contains(SupportedProviderEvents, ev) {
return util.NewValidationError(fmt.Sprintf("unsupported provider event: %q", ev))
}
}
@ -1537,7 +1538,7 @@ func (c *EventConditions) validate(trigger int) error {
c.Options.MinFileSize = 0
c.Options.MaxFileSize = 0
c.Schedules = nil
if !util.Contains(supportedIDPLoginEvents, c.IDPLoginEvent) {
if !slices.Contains(supportedIDPLoginEvents, c.IDPLoginEvent) {
return util.NewValidationError(fmt.Sprintf("invalid Identity Provider login event %d", c.IDPLoginEvent))
}
default:
@ -1690,7 +1691,7 @@ func (r *EventRule) validateMandatorySyncActions() error {
return nil
}
for _, ev := range r.Conditions.FsEvents {
if util.Contains(mandatorySyncFsEvents, ev) {
if slices.Contains(mandatorySyncFsEvents, ev) {
return util.NewI18nError(
util.NewValidationError(fmt.Sprintf("event %q requires at least a sync action", ev)),
util.I18nErrorRuleSyncActionRequired,
@ -1708,7 +1709,7 @@ func (r *EventRule) checkIPBlockedAndCertificateActions() error {
ActionTypeDataRetentionCheck, ActionTypeFilesystem, ActionTypePasswordExpirationCheck,
ActionTypeUserExpirationCheck}
for _, action := range r.Actions {
if util.Contains(unavailableActions, action.Type) {
if slices.Contains(unavailableActions, action.Type) {
return fmt.Errorf("action %q, type %q is not supported for event trigger %q",
action.Name, getActionTypeAsString(action.Type), getTriggerTypeAsString(r.Trigger))
}
@ -1724,7 +1725,7 @@ func (r *EventRule) checkProviderEventActions(providerObjectType string) error {
ActionTypeDataRetentionCheck, ActionTypeFilesystem,
ActionTypePasswordExpirationCheck, ActionTypeUserExpirationCheck}
for _, action := range r.Actions {
if util.Contains(userSpecificActions, action.Type) && providerObjectType != actionObjectUser {
if slices.Contains(userSpecificActions, action.Type) && providerObjectType != actionObjectUser {
return fmt.Errorf("action %q, type %q is only supported for provider user events",
action.Name, getActionTypeAsString(action.Type))
}

View file

@ -19,6 +19,7 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"strings"
"sync"
"sync/atomic"
@ -85,7 +86,7 @@ var (
// CheckIPListType returns an error if the provided IP list type is not valid
func CheckIPListType(t IPListType) error {
if !util.Contains(supportedIPListType, t) {
if !slices.Contains(supportedIPListType, t) {
return util.NewValidationError(fmt.Sprintf("invalid list type %d", t))
}
return nil

View file

@ -22,6 +22,7 @@ import (
"net/netip"
"os"
"path/filepath"
"slices"
"sort"
"sync"
"time"
@ -1210,7 +1211,7 @@ func (p *MemoryProvider) addRuleToActionMapping(ruleName, actionName string) err
if err != nil {
return util.NewGenericError(fmt.Sprintf("action %q does not exist", actionName))
}
if !util.Contains(a.Rules, ruleName) {
if !slices.Contains(a.Rules, ruleName) {
a.Rules = append(a.Rules, ruleName)
p.dbHandle.actions[actionName] = a
}
@ -1223,7 +1224,7 @@ func (p *MemoryProvider) removeRuleFromActionMapping(ruleName, actionName string
providerLog(logger.LevelWarn, "action %q does not exist, cannot remove from mapping", actionName)
return
}
if util.Contains(a.Rules, ruleName) {
if slices.Contains(a.Rules, ruleName) {
var rules []string
for _, r := range a.Rules {
if r != ruleName {
@ -1240,7 +1241,7 @@ func (p *MemoryProvider) addAdminToGroupMapping(username, groupname string) erro
if err != nil {
return err
}
if !util.Contains(g.Admins, username) {
if !slices.Contains(g.Admins, username) {
g.Admins = append(g.Admins, username)
p.dbHandle.groups[groupname] = g
}
@ -1283,7 +1284,7 @@ func (p *MemoryProvider) addUserToGroupMapping(username, groupname string) error
if err != nil {
return err
}
if !util.Contains(g.Users, username) {
if !slices.Contains(g.Users, username) {
g.Users = append(g.Users, username)
p.dbHandle.groups[groupname] = g
}
@ -1313,7 +1314,7 @@ func (p *MemoryProvider) addAdminToRole(username, role string) error {
if err != nil {
return fmt.Errorf("%w: role %q does not exist", ErrForeignKeyViolated, role)
}
if !util.Contains(r.Admins, username) {
if !slices.Contains(r.Admins, username) {
r.Admins = append(r.Admins, username)
p.dbHandle.roles[role] = r
}
@ -1347,7 +1348,7 @@ func (p *MemoryProvider) addUserToRole(username, role string) error {
if err != nil {
return fmt.Errorf("%w: role %q does not exist", ErrForeignKeyViolated, role)
}
if !util.Contains(r.Users, username) {
if !slices.Contains(r.Users, username) {
r.Users = append(r.Users, username)
p.dbHandle.roles[role] = r
}
@ -1378,7 +1379,7 @@ func (p *MemoryProvider) addUserToFolderMapping(username, foldername string) err
if err != nil {
return util.NewGenericError(fmt.Sprintf("unable to get folder %q: %v", foldername, err))
}
if !util.Contains(f.Users, username) {
if !slices.Contains(f.Users, username) {
f.Users = append(f.Users, username)
p.dbHandle.vfolders[foldername] = f
}
@ -1390,7 +1391,7 @@ func (p *MemoryProvider) addGroupToFolderMapping(name, foldername string) error
if err != nil {
return util.NewGenericError(fmt.Sprintf("unable to get folder %q: %v", foldername, err))
}
if !util.Contains(f.Groups, name) {
if !slices.Contains(f.Groups, name) {
f.Groups = append(f.Groups, name)
p.dbHandle.vfolders[foldername] = f
}

View file

@ -95,8 +95,8 @@ const (
"`last_login` bigint NOT NULL, `filters` longtext NULL, `filesystem` longtext NULL, `additional_info` longtext NULL, " +
"`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL, `email` varchar(255) NULL, " +
"`upload_data_transfer` integer NOT NULL, `download_data_transfer` integer NOT NULL, " +
"`total_data_transfer` integer NOT NULL, `used_upload_data_transfer` integer NOT NULL, " +
"`used_download_data_transfer` integer NOT NULL, `deleted_at` bigint NOT NULL, `first_download` bigint NOT NULL, " +
"`total_data_transfer` integer NOT NULL, `used_upload_data_transfer` bigint NOT NULL, " +
"`used_download_data_transfer` bigint NOT NULL, `deleted_at` bigint NOT NULL, `first_download` bigint NOT NULL, " +
"`first_upload` bigint NOT NULL, `last_password_change` bigint NOT NULL, `role_id` integer NULL);" +
"CREATE TABLE `{{groups_folders_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`group_id` integer NOT NULL, `folder_id` integer NOT NULL, " +
@ -193,11 +193,7 @@ const (
"CREATE INDEX `{{prefix}}ip_lists_updated_at_idx` ON `{{ip_lists}}` (`updated_at`);" +
"CREATE INDEX `{{prefix}}ip_lists_deleted_at_idx` ON `{{ip_lists}}` (`deleted_at`);" +
"CREATE INDEX `{{prefix}}ip_lists_first_last_idx` ON `{{ip_lists}}` (`first`, `last`);" +
"INSERT INTO {{schema_version}} (version) VALUES (28);"
mysqlV29SQL = "ALTER TABLE `{{users}}` MODIFY `used_download_data_transfer` bigint NOT NULL;" +
"ALTER TABLE `{{users}}` MODIFY `used_upload_data_transfer` bigint NOT NULL;"
mysqlV29DownSQL = "ALTER TABLE `{{users}}` MODIFY `used_upload_data_transfer` integer NOT NULL;" +
"ALTER TABLE `{{users}}` MODIFY `used_download_data_transfer` integer NOT NULL;"
"INSERT INTO {{schema_version}} (version) VALUES (29);"
)
// MySQLProvider defines the auth provider for MySQL/MariaDB database
@ -776,11 +772,11 @@ func (p *MySQLProvider) initializeDatabase() error {
if errors.Is(err, sql.ErrNoRows) {
return errSchemaVersionEmpty
}
logger.InfoToConsole("creating initial database schema, version 28")
providerLog(logger.LevelInfo, "creating initial database schema, version 28")
logger.InfoToConsole("creating initial database schema, version 29")
providerLog(logger.LevelInfo, "creating initial database schema, version 29")
initialSQL := sqlReplaceAll(mysqlInitialSQL)
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(initialSQL, ";"), 28, true)
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(initialSQL, ";"), 29, true)
}
func (p *MySQLProvider) migrateDatabase() error {
@ -793,13 +789,11 @@ func (p *MySQLProvider) migrateDatabase() error {
case version == sqlDatabaseVersion:
providerLog(logger.LevelDebug, "sql database is up to date, current version: %d", version)
return ErrNoInitRequired
case version < 28:
case version < 29:
err = fmt.Errorf("database schema version %d is too old, please see the upgrading docs", version)
providerLog(logger.LevelError, "%v", err)
logger.ErrorToConsole("%v", err)
return err
case version == 28:
return updateMySQLDatabaseFrom28To29(p.dbHandle)
default:
if version > sqlDatabaseVersion {
providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version,
@ -822,8 +816,6 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error {
}
switch dbVersion.Version {
case 29:
return downgradeMySQLDatabaseFrom29To28(p.dbHandle)
default:
return fmt.Errorf("database schema version not handled: %d", dbVersion.Version)
}
@ -861,19 +853,3 @@ func (p *MySQLProvider) normalizeError(err error, fieldType int) error {
}
return err
}
func updateMySQLDatabaseFrom28To29(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 28 -> 29")
providerLog(logger.LevelInfo, "updating database schema version: 28 -> 29")
sql := strings.ReplaceAll(mysqlV29SQL, "{{users}}", sqlTableUsers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 29, true)
}
func downgradeMySQLDatabaseFrom29To28(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 29 -> 28")
providerLog(logger.LevelInfo, "downgrading database schema version: 29 -> 28")
sql := strings.ReplaceAll(mysqlV29DownSQL, "{{users}}", sqlTableUsers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 28, false)
}

View file

@ -24,6 +24,7 @@ import (
"errors"
"fmt"
"net"
"slices"
"strconv"
"strings"
"time"
@ -95,7 +96,7 @@ CREATE TABLE "{{users}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS
"download_bandwidth" integer NOT NULL, "last_login" bigint NOT NULL, "filters" text NULL, "filesystem" text NULL,
"additional_info" text NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "email" varchar(255) NULL,
"upload_data_transfer" integer NOT NULL, "download_data_transfer" integer NOT NULL, "total_data_transfer" integer NOT NULL,
"used_upload_data_transfer" integer NOT NULL, "used_download_data_transfer" integer NOT NULL, "deleted_at" bigint NOT NULL,
"used_upload_data_transfer" bigint NOT NULL, "used_download_data_transfer" bigint NOT NULL, "deleted_at" bigint NOT NULL,
"first_download" bigint NOT NULL, "first_upload" bigint NOT NULL, "last_password_change" bigint NOT NULL, "role_id" integer NULL);
CREATE TABLE "{{groups_folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "group_id" integer NOT NULL,
"folder_id" integer NOT NULL, "virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL);
@ -205,16 +206,10 @@ CREATE INDEX "{{prefix}}ip_lists_ipornet_idx" ON "{{ip_lists}}" ("ipornet");
CREATE INDEX "{{prefix}}ip_lists_updated_at_idx" ON "{{ip_lists}}" ("updated_at");
CREATE INDEX "{{prefix}}ip_lists_deleted_at_idx" ON "{{ip_lists}}" ("deleted_at");
CREATE INDEX "{{prefix}}ip_lists_first_last_idx" ON "{{ip_lists}}" ("first", "last");
INSERT INTO {{schema_version}} (version) VALUES (28);
INSERT INTO {{schema_version}} (version) VALUES (29);
`
// not supported in CockroachDB
ipListsLikeIndex = `CREATE INDEX "{{prefix}}ip_lists_ipornet_like_idx" ON "{{ip_lists}}" ("ipornet" varchar_pattern_ops);`
pgsqlV29SQL = `ALTER TABLE "{{users}}" ALTER COLUMN "used_download_data_transfer" TYPE bigint;
ALTER TABLE "{{users}}" ALTER COLUMN "used_upload_data_transfer" TYPE bigint;
`
pgsqlV29DownSQL = `ALTER TABLE "{{users}}" ALTER COLUMN "used_upload_data_transfer" TYPE integer;
ALTER TABLE "{{users}}" ALTER COLUMN "used_download_data_transfer" TYPE integer;
`
)
var (
@ -311,7 +306,7 @@ func getPGSQLConnectionString(redactedPwd bool) string {
if config.DisableSNI {
connectionString += " sslsni=0"
}
if util.Contains(pgSQLTargetSessionAttrs, config.TargetSessionAttrs) {
if slices.Contains(pgSQLTargetSessionAttrs, config.TargetSessionAttrs) {
connectionString += fmt.Sprintf(" target_session_attrs='%s'", config.TargetSessionAttrs)
}
} else {
@ -795,8 +790,8 @@ func (p *PGSQLProvider) initializeDatabase() error {
if errors.Is(err, sql.ErrNoRows) {
return errSchemaVersionEmpty
}
logger.InfoToConsole("creating initial database schema, version 28")
providerLog(logger.LevelInfo, "creating initial database schema, version 28")
logger.InfoToConsole("creating initial database schema, version 29")
providerLog(logger.LevelInfo, "creating initial database schema, version 29")
var initialSQL string
if config.Driver == CockroachDataProviderName {
initialSQL = sqlReplaceAll(pgsqlInitial)
@ -805,7 +800,7 @@ func (p *PGSQLProvider) initializeDatabase() error {
initialSQL = sqlReplaceAll(pgsqlInitial + ipListsLikeIndex)
}
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{initialSQL}, 28, true)
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{initialSQL}, 29, true)
}
func (p *PGSQLProvider) migrateDatabase() error { //nolint:dupl
@ -818,13 +813,11 @@ func (p *PGSQLProvider) migrateDatabase() error { //nolint:dupl
case version == sqlDatabaseVersion:
providerLog(logger.LevelDebug, "sql database is up to date, current version: %d", version)
return ErrNoInitRequired
case version < 28:
case version < 29:
err = fmt.Errorf("database schema version %d is too old, please see the upgrading docs", version)
providerLog(logger.LevelError, "%v", err)
logger.ErrorToConsole("%v", err)
return err
case version == 28:
return updatePGSQLDatabaseFrom28To29(p.dbHandle)
default:
if version > sqlDatabaseVersion {
providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version,
@ -847,8 +840,6 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error {
}
switch dbVersion.Version {
case 29:
return downgradePGSQLDatabaseFrom29To28(p.dbHandle)
default:
return fmt.Errorf("database schema version not handled: %d", dbVersion.Version)
}
@ -886,19 +877,3 @@ func (p *PGSQLProvider) normalizeError(err error, fieldType int) error {
}
return err
}
func updatePGSQLDatabaseFrom28To29(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 28 -> 29")
providerLog(logger.LevelInfo, "updating database schema version: 28 -> 29")
sql := strings.ReplaceAll(pgsqlV29SQL, "{{users}}", sqlTableUsers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 29, true)
}
func downgradePGSQLDatabaseFrom29To28(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 29 -> 28")
providerLog(logger.LevelInfo, "downgrading database schema version: 29 -> 28")
sql := strings.ReplaceAll(pgsqlV29DownSQL, "{{users}}", sqlTableUsers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 28, false)
}

View file

@ -12,8 +12,8 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//go:build !nosqlite
// +build !nosqlite
//go:build !nosqlite && cgo
// +build !nosqlite,cgo
package dataprovider
@ -178,7 +178,7 @@ CREATE INDEX "{{prefix}}ip_lists_ip_type_idx" ON "{{ip_lists}}" ("ip_type");
CREATE INDEX "{{prefix}}ip_lists_ip_updated_at_idx" ON "{{ip_lists}}" ("updated_at");
CREATE INDEX "{{prefix}}ip_lists_ip_deleted_at_idx" ON "{{ip_lists}}" ("deleted_at");
CREATE INDEX "{{prefix}}ip_lists_first_last_idx" ON "{{ip_lists}}" ("first", "last");
INSERT INTO {{schema_version}} (version) VALUES (28);
INSERT INTO {{schema_version}} (version) VALUES (29);
`
)
@ -693,10 +693,10 @@ func (p *SQLiteProvider) initializeDatabase() error {
if errors.Is(err, sql.ErrNoRows) {
return errSchemaVersionEmpty
}
logger.InfoToConsole("creating initial database schema, version 28")
providerLog(logger.LevelInfo, "creating initial database schema, version 28")
logger.InfoToConsole("creating initial database schema, version 29")
providerLog(logger.LevelInfo, "creating initial database schema, version 29")
sql := sqlReplaceAll(sqliteInitialSQL)
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 28, true)
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 29, true)
}
func (p *SQLiteProvider) migrateDatabase() error { //nolint:dupl
@ -709,13 +709,11 @@ func (p *SQLiteProvider) migrateDatabase() error { //nolint:dupl
case version == sqlDatabaseVersion:
providerLog(logger.LevelDebug, "sql database is up to date, current version: %d", version)
return ErrNoInitRequired
case version < 28:
case version < 29:
err = fmt.Errorf("database schema version %d is too old, please see the upgrading docs", version)
providerLog(logger.LevelError, "%v", err)
logger.ErrorToConsole("%v", err)
return err
case version == 28:
return updateSQLiteDatabaseFrom28To29(p.dbHandle)
default:
if version > sqlDatabaseVersion {
providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version,
@ -738,8 +736,6 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error {
}
switch dbVersion.Version {
case 29:
return downgradeSQLiteDatabaseFrom29To28(p.dbHandle)
default:
return fmt.Errorf("database schema version not handled: %d", dbVersion.Version)
}
@ -777,26 +773,6 @@ func (p *SQLiteProvider) normalizeError(err error, fieldType int) error {
return err
}
func updateSQLiteDatabaseFrom28To29(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 28 -> 29")
providerLog(logger.LevelInfo, "updating database schema version: 28 -> 29")
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
return sqlCommonUpdateDatabaseVersion(ctx, dbHandle, 29)
}
func downgradeSQLiteDatabaseFrom29To28(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 29 -> 28")
providerLog(logger.LevelInfo, "downgrading database schema version: 29 -> 28")
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
return sqlCommonUpdateDatabaseVersion(ctx, dbHandle, 28)
}
/*func setPragmaFK(dbHandle *sql.DB, value string) error {
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
defer cancel()

View file

@ -12,8 +12,8 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//go:build nosqlite
// +build nosqlite
//go:build nosqlite || !cgo
// +build nosqlite !cgo
package dataprovider

View file

@ -12,8 +12,8 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//go:build unixcrypt
// +build unixcrypt
//go:build unixcrypt && cgo
// +build unixcrypt,cgo
package dataprovider

View file

@ -12,8 +12,8 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//go:build !unixcrypt
// +build !unixcrypt
//go:build !unixcrypt || !cgo
// +build !unixcrypt !cgo
package dataprovider

View file

@ -22,6 +22,7 @@ import (
"net"
"os"
"path"
"slices"
"strconv"
"strings"
"time"
@ -342,7 +343,11 @@ func (u *User) isTimeBasedAccessAllowed(when time.Time) bool {
if when.IsZero() {
when = time.Now()
}
if UseLocalTime() {
when = when.Local()
} else {
when = when.UTC()
}
weekDay := when.Weekday()
hhMM := when.Format("15:04")
for _, p := range u.Filters.AccessTime {
@ -840,20 +845,20 @@ func (u *User) HasPermissionsInside(virtualPath string) bool {
// HasPerm returns true if the user has the given permission or any permission
func (u *User) HasPerm(permission, path string) bool {
perms := u.GetPermissionsForPath(path)
if util.Contains(perms, PermAny) {
if slices.Contains(perms, PermAny) {
return true
}
return util.Contains(perms, permission)
return slices.Contains(perms, permission)
}
// HasAnyPerm returns true if the user has at least one of the given permissions
func (u *User) HasAnyPerm(permissions []string, path string) bool {
perms := u.GetPermissionsForPath(path)
if util.Contains(perms, PermAny) {
if slices.Contains(perms, PermAny) {
return true
}
for _, permission := range permissions {
if util.Contains(perms, permission) {
if slices.Contains(perms, permission) {
return true
}
}
@ -863,11 +868,11 @@ func (u *User) HasAnyPerm(permissions []string, path string) bool {
// HasPerms returns true if the user has all the given permissions
func (u *User) HasPerms(permissions []string, path string) bool {
perms := u.GetPermissionsForPath(path)
if util.Contains(perms, PermAny) {
if slices.Contains(perms, PermAny) {
return true
}
for _, permission := range permissions {
if !util.Contains(perms, permission) {
if !slices.Contains(perms, permission) {
return false
}
}
@ -927,11 +932,11 @@ func (u *User) IsLoginMethodAllowed(loginMethod, protocol string) bool {
if len(u.Filters.DeniedLoginMethods) == 0 {
return true
}
if util.Contains(u.Filters.DeniedLoginMethods, loginMethod) {
if slices.Contains(u.Filters.DeniedLoginMethods, loginMethod) {
return false
}
if protocol == protocolSSH && loginMethod == LoginMethodPassword {
if util.Contains(u.Filters.DeniedLoginMethods, SSHLoginMethodPassword) {
if slices.Contains(u.Filters.DeniedLoginMethods, SSHLoginMethodPassword) {
return false
}
}
@ -965,10 +970,10 @@ func (u *User) IsPartialAuth() bool {
method == SSHLoginMethodPassword {
continue
}
if method == LoginMethodPassword && util.Contains(u.Filters.DeniedLoginMethods, SSHLoginMethodPassword) {
if method == LoginMethodPassword && slices.Contains(u.Filters.DeniedLoginMethods, SSHLoginMethodPassword) {
continue
}
if !util.Contains(SSHMultiStepsLoginMethods, method) {
if !slices.Contains(SSHMultiStepsLoginMethods, method) {
return false
}
}
@ -982,7 +987,7 @@ func (u *User) GetAllowedLoginMethods() []string {
if method == SSHLoginMethodPassword {
continue
}
if !util.Contains(u.Filters.DeniedLoginMethods, method) {
if !slices.Contains(u.Filters.DeniedLoginMethods, method) {
allowedMethods = append(allowedMethods, method)
}
}
@ -1052,7 +1057,7 @@ func (u *User) IsFileAllowed(virtualPath string) (bool, int) {
// CanManageMFA returns true if the user can add a multi-factor authentication configuration
func (u *User) CanManageMFA() bool {
if util.Contains(u.Filters.WebClient, sdk.WebClientMFADisabled) {
if slices.Contains(u.Filters.WebClient, sdk.WebClientMFADisabled) {
return false
}
return len(mfa.GetAvailableTOTPConfigs()) > 0
@ -1073,39 +1078,39 @@ func (u *User) skipExternalAuth() bool {
// CanManageShares returns true if the user can add, update and list shares
func (u *User) CanManageShares() bool {
return !util.Contains(u.Filters.WebClient, sdk.WebClientSharesDisabled)
return !slices.Contains(u.Filters.WebClient, sdk.WebClientSharesDisabled)
}
// CanResetPassword returns true if this user is allowed to reset its password
func (u *User) CanResetPassword() bool {
return !util.Contains(u.Filters.WebClient, sdk.WebClientPasswordResetDisabled)
return !slices.Contains(u.Filters.WebClient, sdk.WebClientPasswordResetDisabled)
}
// CanChangePassword returns true if this user is allowed to change its password
func (u *User) CanChangePassword() bool {
return !util.Contains(u.Filters.WebClient, sdk.WebClientPasswordChangeDisabled)
return !slices.Contains(u.Filters.WebClient, sdk.WebClientPasswordChangeDisabled)
}
// CanChangeAPIKeyAuth returns true if this user is allowed to enable/disable API key authentication
func (u *User) CanChangeAPIKeyAuth() bool {
return !util.Contains(u.Filters.WebClient, sdk.WebClientAPIKeyAuthChangeDisabled)
return !slices.Contains(u.Filters.WebClient, sdk.WebClientAPIKeyAuthChangeDisabled)
}
// CanChangeInfo returns true if this user is allowed to change its info such as email and description
func (u *User) CanChangeInfo() bool {
return !util.Contains(u.Filters.WebClient, sdk.WebClientInfoChangeDisabled)
return !slices.Contains(u.Filters.WebClient, sdk.WebClientInfoChangeDisabled)
}
// CanManagePublicKeys returns true if this user is allowed to manage public keys
// from the WebClient. Used in WebClient UI
func (u *User) CanManagePublicKeys() bool {
return !util.Contains(u.Filters.WebClient, sdk.WebClientPubKeyChangeDisabled)
return !slices.Contains(u.Filters.WebClient, sdk.WebClientPubKeyChangeDisabled)
}
// CanManageTLSCerts returns true if this user is allowed to manage TLS certificates
// from the WebClient. Used in WebClient UI
func (u *User) CanManageTLSCerts() bool {
return !util.Contains(u.Filters.WebClient, sdk.WebClientTLSCertChangeDisabled)
return !slices.Contains(u.Filters.WebClient, sdk.WebClientTLSCertChangeDisabled)
}
// CanUpdateProfile returns true if the user is allowed to update the profile.
@ -1117,7 +1122,7 @@ func (u *User) CanUpdateProfile() bool {
// CanAddFilesFromWeb returns true if the client can add files from the web UI.
// The specified target is the directory where the files must be uploaded
func (u *User) CanAddFilesFromWeb(target string) bool {
if util.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
return false
}
return u.HasPerm(PermUpload, target) || u.HasPerm(PermOverwrite, target)
@ -1126,7 +1131,7 @@ func (u *User) CanAddFilesFromWeb(target string) bool {
// CanAddDirsFromWeb returns true if the client can add directories from the web UI.
// The specified target is the directory where the new directory must be created
func (u *User) CanAddDirsFromWeb(target string) bool {
if util.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
return false
}
return u.HasPerm(PermCreateDirs, target)
@ -1135,7 +1140,7 @@ func (u *User) CanAddDirsFromWeb(target string) bool {
// CanRenameFromWeb returns true if the client can rename objects from the web UI.
// The specified src and dest are the source and target directories for the rename.
func (u *User) CanRenameFromWeb(src, dest string) bool {
if util.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
return false
}
return u.HasAnyPerm(permsRenameAny, src) && u.HasAnyPerm(permsRenameAny, dest)
@ -1144,7 +1149,7 @@ func (u *User) CanRenameFromWeb(src, dest string) bool {
// CanDeleteFromWeb returns true if the client can delete objects from the web UI.
// The specified target is the parent directory for the object to delete
func (u *User) CanDeleteFromWeb(target string) bool {
if util.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
return false
}
return u.HasAnyPerm(permsDeleteAny, target)
@ -1153,7 +1158,7 @@ func (u *User) CanDeleteFromWeb(target string) bool {
// CanCopyFromWeb returns true if the client can copy objects from the web UI.
// The specified src and dest are the source and target directories for the copy.
func (u *User) CanCopyFromWeb(src, dest string) bool {
if util.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
return false
}
if !u.HasPerm(PermListItems, src) {
@ -1213,7 +1218,7 @@ func (u *User) MustSetSecondFactor() bool {
return true
}
for _, p := range u.Filters.TwoFactorAuthProtocols {
if !util.Contains(u.Filters.TOTPConfig.Protocols, p) {
if !slices.Contains(u.Filters.TOTPConfig.Protocols, p) {
return true
}
}
@ -1224,11 +1229,11 @@ func (u *User) MustSetSecondFactor() bool {
// MustSetSecondFactorForProtocol returns true if the user must set a second factor authentication
// for the specified protocol
func (u *User) MustSetSecondFactorForProtocol(protocol string) bool {
if util.Contains(u.Filters.TwoFactorAuthProtocols, protocol) {
if slices.Contains(u.Filters.TwoFactorAuthProtocols, protocol) {
if !u.Filters.TOTPConfig.Enabled {
return true
}
if !util.Contains(u.Filters.TOTPConfig.Protocols, protocol) {
if !slices.Contains(u.Filters.TOTPConfig.Protocols, protocol) {
return true
}
}

View file

@ -2718,6 +2718,7 @@ func TestStat(t *testing.T) {
func TestUploadOverwriteVfolder(t *testing.T) {
u := getTestUser()
u.QuotaFiles = 1000
vdir := "/vdir"
mappedPath := filepath.Join(os.TempDir(), "vdir")
folderName := filepath.Base(mappedPath)
@ -2749,14 +2750,24 @@ func TestUploadOverwriteVfolder(t *testing.T) {
assert.NoError(t, err)
folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, folder.UsedQuotaSize)
assert.Equal(t, 1, folder.UsedQuotaFiles)
assert.Equal(t, int64(0), folder.UsedQuotaSize)
assert.Equal(t, 0, folder.UsedQuotaFiles)
user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, user.UsedQuotaSize)
assert.Equal(t, 1, user.UsedQuotaFiles)
err = ftpUploadFile(testFilePath, path.Join(vdir, testFileName), testFileSize, client, 0)
assert.NoError(t, err)
folder, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, folder.UsedQuotaSize)
assert.Equal(t, 1, folder.UsedQuotaFiles)
assert.Equal(t, int64(0), folder.UsedQuotaSize)
assert.Equal(t, 0, folder.UsedQuotaFiles)
user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, user.UsedQuotaSize)
assert.Equal(t, 1, user.UsedQuotaFiles)
err = client.Quit()
assert.NoError(t, err)
err = os.Remove(testFilePath)

View file

@ -493,10 +493,7 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve
if vfs.HasTruncateSupport(fs) {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
if err == nil {
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
}
dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false)
} else {
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
}

View file

@ -22,6 +22,7 @@ import (
"net"
"os"
"path/filepath"
"slices"
ftpserver "github.com/fclairamb/ftpserverlib"
"github.com/sftpgo/sdk/plugin/notifier"
@ -361,7 +362,7 @@ func (s *Server) validateUser(user dataprovider.User, cc ftpserver.ClientContext
user.Username, user.HomeDir)
return nil, fmt.Errorf("cannot login user with invalid home dir: %q", user.HomeDir)
}
if util.Contains(user.Filters.DeniedProtocols, common.ProtocolFTP) {
if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolFTP) {
logger.Info(logSender, connectionID, "cannot login user %q, protocol FTP is not allowed", user.Username)
return nil, fmt.Errorf("protocol FTP is not allowed for user %q", user.Username)
}

View file

@ -297,7 +297,7 @@ func changeAdminPassword(w http.ResponseWriter, r *http.Request) {
sendAPIResponse(w, r, err, "", getRespStatus(err))
return
}
invalidateToken(r)
invalidateToken(r, false)
sendAPIResponse(w, r, err, "Password updated", http.StatusOK)
}

View file

@ -85,7 +85,7 @@ type oauth2TokenRequest struct {
BaseRedirectURL string `json:"base_redirect_url"`
}
func handleSMTPOAuth2TokenRequestPost(w http.ResponseWriter, r *http.Request) {
func (s *httpdServer) handleSMTPOAuth2TokenRequestPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
var req oauth2TokenRequest
@ -115,7 +115,7 @@ func handleSMTPOAuth2TokenRequestPost(w http.ResponseWriter, r *http.Request) {
clientSecret.SetAdditionalData(xid.New().String())
pendingAuth := newOAuth2PendingAuth(req.Provider, cfg.RedirectURL, cfg.ClientID, clientSecret)
oauth2Mgr.addPendingAuth(pendingAuth)
stateToken := createOAuth2Token(pendingAuth.State, util.GetIPFromRemoteAddress(r.RemoteAddr))
stateToken := createOAuth2Token(s.csrfTokenAuth, pendingAuth.State, util.GetIPFromRemoteAddress(r.RemoteAddr))
if stateToken == "" {
sendAPIResponse(w, r, nil, "unable to create state token", http.StatusInternalServerError)
return

View file

@ -531,7 +531,7 @@ func changeUserPassword(w http.ResponseWriter, r *http.Request) {
sendAPIResponse(w, r, err, "", getRespStatus(err))
return
}
invalidateToken(r)
invalidateToken(r, false)
sendAPIResponse(w, r, err, "Password updated", http.StatusOK)
}

View file

@ -551,9 +551,10 @@ func RestoreUsers(users []dataprovider.User, inputFile string, mode, scanQuota i
return fmt.Errorf("unable to restore user %q: %w", user.Username, err)
}
if scanQuota == 1 || (scanQuota == 2 && user.HasQuotaRestrictions()) {
if common.QuotaScans.AddUserQuotaScan(user.Username, user.Role) {
user, err = dataprovider.GetUserWithGroupSettings(user.Username, "")
if err == nil && common.QuotaScans.AddUserQuotaScan(user.Username, user.Role) {
logger.Debug(logSender, "", "starting quota scan for restored user: %q", user.Username)
go doUserQuotaScan(user) //nolint:errcheck
go doUserQuotaScan(&user) //nolint:errcheck
}
}
}

View file

@ -20,6 +20,7 @@ import (
"fmt"
"io"
"net/http"
"slices"
"strconv"
"strings"
@ -138,8 +139,7 @@ func saveTOTPConfig(w http.ResponseWriter, r *http.Request) {
if claims.MustSetTwoFactorAuth {
// force logout
defer func() {
c := jwtTokenClaims{}
c.removeCookie(w, r, baseURL)
removeCookie(w, r, baseURL)
}()
}
@ -276,7 +276,7 @@ func saveUserTOTPConfig(username string, r *http.Request, recoveryCodes []datapr
return util.NewValidationError("two-factor authentication must be enabled")
}
for _, p := range userMerged.Filters.TwoFactorAuthProtocols {
if !util.Contains(user.Filters.TOTPConfig.Protocols, p) {
if !slices.Contains(user.Filters.TOTPConfig.Protocols, p) {
return util.NewValidationError(fmt.Sprintf("totp: the following protocols are required: %q",
strings.Join(userMerged.Filters.TwoFactorAuthProtocols, ", ")))
}

View file

@ -219,7 +219,7 @@ func doStartUserQuotaScan(w http.ResponseWriter, r *http.Request, username strin
http.StatusConflict)
return
}
go doUserQuotaScan(user) //nolint:errcheck
go doUserQuotaScan(&user) //nolint:errcheck
sendAPIResponse(w, r, err, "Scan started", http.StatusAccepted)
}
@ -242,14 +242,14 @@ func doStartFolderQuotaScan(w http.ResponseWriter, r *http.Request, name string)
sendAPIResponse(w, r, err, "Scan started", http.StatusAccepted)
}
func doUserQuotaScan(user dataprovider.User) error {
func doUserQuotaScan(user *dataprovider.User) error {
defer common.QuotaScans.RemoveUserQuotaScan(user.Username)
numFiles, size, err := user.ScanQuota()
if err != nil {
logger.Warn(logSender, "", "error scanning user quota %q: %v", user.Username, err)
return err
}
err = dataprovider.UpdateUserQuota(&user, numFiles, size, true)
err = dataprovider.UpdateUserQuota(user, numFiles, size, true)
logger.Debug(logSender, "", "user quota scanned, user: %q, error: %v", user.Username, err)
return err
}

View file

@ -22,6 +22,7 @@ import (
"net/url"
"os"
"path"
"slices"
"strings"
"time"
@ -107,7 +108,7 @@ func addShare(w http.ResponseWriter, r *http.Request) {
share.Name = share.ShareID
}
if share.Password == "" {
if util.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) {
if slices.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) {
sendAPIResponse(w, r, nil, "You are not authorized to share files/folders without a password",
http.StatusForbidden)
return
@ -155,7 +156,7 @@ func updateShare(w http.ResponseWriter, r *http.Request) {
updatedShare.Password = share.Password
}
if updatedShare.Password == "" {
if util.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) {
if slices.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) {
sendAPIResponse(w, r, nil, "You are not authorized to share files/folders without a password",
http.StatusForbidden)
return
@ -425,36 +426,42 @@ func (s *httpdServer) uploadFilesToShare(w http.ResponseWriter, r *http.Request)
}
}
func (s *httpdServer) getShareClaims(r *http.Request, shareID string) (*jwtTokenClaims, error) {
token, err := jwtauth.VerifyRequest(s.tokenAuth, r, jwtauth.TokenFromCookie)
if err != nil || token == nil {
return nil, errInvalidToken
}
tokenString := jwtauth.TokenFromCookie(r)
if tokenString == "" || invalidatedJWTTokens.Get(tokenString) {
return nil, errInvalidToken
}
if !slices.Contains(token.Audience(), tokenAudienceWebShare) {
logger.Debug(logSender, "", "invalid token audience for share %q", shareID)
return nil, errInvalidToken
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := validateIPForToken(token, ipAddr); err != nil {
logger.Debug(logSender, "", "token for share %q is not valid for the ip address %q", shareID, ipAddr)
return nil, err
}
ctx := jwtauth.NewContext(r.Context(), token, nil)
claims, err := getTokenClaims(r.WithContext(ctx))
if err != nil || claims.Username != shareID {
logger.Debug(logSender, "", "token not valid for share %q", shareID)
return nil, errInvalidToken
}
return &claims, nil
}
func (s *httpdServer) checkWebClientShareCredentials(w http.ResponseWriter, r *http.Request, share *dataprovider.Share) error {
doRedirect := func() {
redirectURL := path.Join(webClientPubSharesPath, share.ShareID, fmt.Sprintf("login?next=%s", url.QueryEscape(r.RequestURI)))
http.Redirect(w, r, redirectURL, http.StatusFound)
}
token, err := jwtauth.VerifyRequest(s.tokenAuth, r, jwtauth.TokenFromCookie)
if err != nil || token == nil {
if _, err := s.getShareClaims(r, share.ShareID); err != nil {
doRedirect()
return errInvalidToken
}
if !util.Contains(token.Audience(), tokenAudienceWebShare) {
logger.Debug(logSender, "", "invalid token audience for share %q", share.ShareID)
doRedirect()
return errInvalidToken
}
if tokenValidationMode != tokenValidationNoIPMatch {
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if !util.Contains(token.Audience(), ipAddr) {
logger.Debug(logSender, "", "token for share %q is not valid for the ip address %q", share.ShareID, ipAddr)
doRedirect()
return errInvalidToken
}
}
ctx := jwtauth.NewContext(r.Context(), token, nil)
claims, err := getTokenClaims(r.WithContext(ctx))
if err != nil || claims.Username != share.ShareID {
logger.Debug(logSender, "", "token not valid for share %q", share.ShareID)
doRedirect()
return errInvalidToken
return err
}
return nil
}
@ -480,7 +487,7 @@ func (s *httpdServer) checkPublicShare(w http.ResponseWriter, r *http.Request, v
renderError(err, "", statusCode)
return share, nil, err
}
if !util.Contains(validScopes, share.Scope) {
if !slices.Contains(validScopes, share.Scope) {
err := errors.New("invalid share scope")
renderError(util.NewI18nError(err, util.I18nErrorShareScope), "", http.StatusForbidden)
return share, nil, err
@ -537,7 +544,7 @@ func getUserForShare(share dataprovider.Share) (dataprovider.User, error) {
if !user.CanManageShares() {
return user, util.NewI18nError(util.NewRecordNotFoundError("this share does not exist"), util.I18nError404Message)
}
if share.Password == "" && util.Contains(user.Filters.WebClient, sdk.WebClientShareNoPasswordDisabled) {
if share.Password == "" && slices.Contains(user.Filters.WebClient, sdk.WebClientShareNoPasswordDisabled) {
return user, util.NewI18nError(
fmt.Errorf("sharing without a password was disabled: %w", os.ErrPermission),
util.I18nError403Message,

View file

@ -27,6 +27,7 @@ import (
"net/url"
"os"
"path"
"slices"
"strconv"
"strings"
"sync"
@ -363,6 +364,16 @@ func streamJSONArray(w http.ResponseWriter, chunkSize int, dataGetter func(limit
streamData(w, []byte("]"))
}
func renderPNGImage(w http.ResponseWriter, r *http.Request, b []byte) {
if len(b) == 0 {
ctx := context.WithValue(r.Context(), render.StatusCtxKey, http.StatusNotFound)
render.PlainText(w, r.WithContext(ctx), http.StatusText(http.StatusNotFound))
return
}
w.Header().Set("Content-Type", "image/png")
streamData(w, b)
}
func getCompressedFileName(username string, files []string) string {
if len(files) == 1 {
name := path.Base(files[0])
@ -717,7 +728,7 @@ func updateLoginMetrics(user *dataprovider.User, loginMethod, ip string, err err
}
func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string, checkSessions bool) error {
if util.Contains(user.Filters.DeniedProtocols, common.ProtocolHTTP) {
if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolHTTP) {
logger.Info(logSender, connectionID, "cannot login user %q, protocol HTTP is not allowed", user.Username)
return util.NewI18nError(
fmt.Errorf("protocol HTTP is not allowed for user %q", user.Username),
@ -902,7 +913,7 @@ func isUserAllowedToResetPassword(r *http.Request, user *dataprovider.User) bool
if !user.CanResetPassword() {
return false
}
if util.Contains(user.Filters.DeniedProtocols, common.ProtocolHTTP) {
if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolHTTP) {
return false
}
if !user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolHTTP) {

View file

@ -18,6 +18,7 @@ import (
"errors"
"fmt"
"net/http"
"slices"
"time"
"github.com/go-chi/jwtauth/v5"
@ -41,6 +42,7 @@ const (
tokenAudienceAPIUser tokenAudience = "APIUser"
tokenAudienceCSRF tokenAudience = "CSRF"
tokenAudienceOAuth2 tokenAudience = "OAuth2"
tokenAudienceWebLogin tokenAudience = "WebLogin"
)
type tokenValidation = int
@ -60,6 +62,7 @@ const (
claimMustSetSecondFactorKey = "2fa_required"
claimRequiredTwoFactorProtocols = "2fa_protos"
claimHideUserPageSection = "hus"
claimRef = "ref"
basicRealm = "Basic realm=\"SFTPGo\""
jwtCookieKey = "jwt"
)
@ -69,7 +72,7 @@ var (
shareTokenDuration = 2 * time.Hour
// csrf token duration is greater than normal token duration to reduce issues
// with the login form
csrfTokenDuration = 6 * time.Hour
csrfTokenDuration = 4 * time.Hour
tokenRefreshThreshold = 10 * time.Minute
tokenValidationMode = tokenValidationFull
)
@ -86,6 +89,8 @@ type jwtTokenClaims struct {
MustChangePassword bool
RequiredTwoFactorProtocols []string
HideUserPageSections int
JwtID string
Ref string
}
func (c *jwtTokenClaims) hasUserAudience() bool {
@ -103,6 +108,12 @@ func (c *jwtTokenClaims) asMap() map[string]any {
claims[claimUsernameKey] = c.Username
claims[claimPermissionsKey] = c.Permissions
if c.JwtID != "" {
claims[jwt.JwtIDKey] = c.JwtID
}
if c.Ref != "" {
claims[claimRef] = c.Ref
}
if c.Role != "" {
claims[claimRole] = c.Role
}
@ -169,6 +180,7 @@ func (c *jwtTokenClaims) Decode(token map[string]any) {
c.Permissions = nil
c.Username = c.decodeString(token[claimUsernameKey])
c.Signature = c.decodeString(token[jwt.SubjectKey])
c.JwtID = c.decodeString(token[jwt.JwtIDKey])
audience := token[jwt.AudienceKey]
switch v := audience.(type) {
@ -176,6 +188,10 @@ func (c *jwtTokenClaims) Decode(token map[string]any) {
c.Audience = v
}
if val, ok := token[claimRef]; ok {
c.Ref = c.decodeString(val)
}
if val, ok := token[claimAPIKey]; ok {
c.APIKeyID = c.decodeString(val)
}
@ -212,33 +228,39 @@ func (c *jwtTokenClaims) Decode(token map[string]any) {
}
func (c *jwtTokenClaims) isCriticalPermRemoved(permissions []string) bool {
if util.Contains(permissions, dataprovider.PermAdminAny) {
if slices.Contains(permissions, dataprovider.PermAdminAny) {
return false
}
if (util.Contains(c.Permissions, dataprovider.PermAdminManageAdmins) ||
util.Contains(c.Permissions, dataprovider.PermAdminAny)) &&
!util.Contains(permissions, dataprovider.PermAdminManageAdmins) &&
!util.Contains(permissions, dataprovider.PermAdminAny) {
if (slices.Contains(c.Permissions, dataprovider.PermAdminManageAdmins) ||
slices.Contains(c.Permissions, dataprovider.PermAdminAny)) &&
!slices.Contains(permissions, dataprovider.PermAdminManageAdmins) &&
!slices.Contains(permissions, dataprovider.PermAdminAny) {
return true
}
return false
}
func (c *jwtTokenClaims) hasPerm(perm string) bool {
if util.Contains(c.Permissions, dataprovider.PermAdminAny) {
if slices.Contains(c.Permissions, dataprovider.PermAdminAny) {
return true
}
return util.Contains(c.Permissions, perm)
return slices.Contains(c.Permissions, perm)
}
func (c *jwtTokenClaims) createToken(tokenAuth *jwtauth.JWTAuth, audience tokenAudience, ip string) (jwt.Token, string, error) {
claims := c.asMap()
now := time.Now().UTC()
if _, ok := claims[jwt.JwtIDKey]; !ok {
claims[jwt.JwtIDKey] = xid.New().String()
}
claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second)
if audience == tokenAudienceWebLogin {
claims[jwt.ExpirationKey] = now.Add(csrfTokenDuration)
} else {
claims[jwt.ExpirationKey] = now.Add(tokenDuration)
}
claims[jwt.AudienceKey] = []string{audience, ip}
return tokenAuth.Encode(claims)
@ -274,21 +296,25 @@ func (c *jwtTokenClaims) createAndSetCookie(w http.ResponseWriter, r *http.Reque
if audience == tokenAudienceWebShare {
duration = shareTokenDuration
}
setCookie(w, r, basePath, resp["access_token"].(string), duration)
return nil
}
func setCookie(w http.ResponseWriter, r *http.Request, cookiePath, cookieValue string, duration time.Duration) {
http.SetCookie(w, &http.Cookie{
Name: jwtCookieKey,
Value: resp["access_token"].(string),
Path: basePath,
Value: cookieValue,
Path: cookiePath,
Expires: time.Now().Add(duration),
MaxAge: int(duration / time.Second),
HttpOnly: true,
Secure: isTLS(r),
SameSite: http.SameSiteStrictMode,
})
return nil
}
func (c *jwtTokenClaims) removeCookie(w http.ResponseWriter, r *http.Request, cookiePath string) {
func removeCookie(w http.ResponseWriter, r *http.Request, cookiePath string) {
http.SetCookie(w, &http.Cookie{
Name: jwtCookieKey,
Value: "",
@ -300,10 +326,10 @@ func (c *jwtTokenClaims) removeCookie(w http.ResponseWriter, r *http.Request, co
SameSite: http.SameSiteStrictMode,
})
w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`)
invalidateToken(r)
invalidateToken(r, false)
}
func tokenFromContext(r *http.Request) string {
func oidcTokenFromContext(r *http.Request) string {
if token, ok := r.Context().Value(oidcGeneratedToken).(string); ok {
return token
}
@ -324,7 +350,7 @@ func isTokenInvalidated(r *http.Request) bool {
var findTokenFns []func(r *http.Request) string
findTokenFns = append(findTokenFns, jwtauth.TokenFromHeader)
findTokenFns = append(findTokenFns, jwtauth.TokenFromCookie)
findTokenFns = append(findTokenFns, tokenFromContext)
findTokenFns = append(findTokenFns, oidcTokenFromContext)
isTokenFound := false
for _, fn := range findTokenFns {
@ -340,14 +366,18 @@ func isTokenInvalidated(r *http.Request) bool {
return !isTokenFound
}
func invalidateToken(r *http.Request) {
func invalidateToken(r *http.Request, isLoginToken bool) {
duration := tokenDuration
if isLoginToken {
duration = csrfTokenDuration
}
tokenString := jwtauth.TokenFromHeader(r)
if tokenString != "" {
invalidatedJWTTokens.Add(tokenString, time.Now().Add(tokenDuration).UTC())
invalidatedJWTTokens.Add(tokenString, time.Now().Add(duration).UTC())
}
tokenString = jwtauth.TokenFromCookie(r)
if tokenString != "" {
invalidatedJWTTokens.Add(tokenString, time.Now().Add(tokenDuration).UTC())
invalidatedJWTTokens.Add(tokenString, time.Now().Add(duration).UTC())
}
}
@ -380,7 +410,22 @@ func getAdminFromToken(r *http.Request) *dataprovider.Admin {
return admin
}
func createCSRFToken(ip string) string {
func createLoginCookie(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwtauth.JWTAuth, tokenID, basePath, ip string,
) {
c := jwtTokenClaims{
JwtID: tokenID,
}
resp, err := c.createTokenResponse(csrfTokenAuth, tokenAudienceWebLogin, ip)
if err != nil {
return
}
setCookie(w, r, basePath, resp["access_token"].(string), csrfTokenDuration)
}
func createCSRFToken(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwtauth.JWTAuth, tokenID,
basePath string,
) string {
ip := util.GetIPFromRemoteAddress(r.RemoteAddr)
claims := make(map[string]any)
now := time.Now().UTC()
@ -388,7 +433,16 @@ func createCSRFToken(ip string) string {
claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second)
claims[jwt.ExpirationKey] = now.Add(csrfTokenDuration)
claims[jwt.AudienceKey] = []string{tokenAudienceCSRF, ip}
if tokenID != "" {
createLoginCookie(w, r, csrfTokenAuth, tokenID, basePath, ip)
claims[claimRef] = tokenID
} else {
if c, err := getTokenClaims(r); err == nil {
claims[claimRef] = c.JwtID
} else {
logger.Error(logSender, "", "unable to add reference to CSRF token: %v", err)
}
}
_, tokenString, err := csrfTokenAuth.Encode(claims)
if err != nil {
logger.Debug(logSender, "", "unable to create CSRF token: %v", err)
@ -397,29 +451,73 @@ func createCSRFToken(ip string) string {
return tokenString
}
func verifyCSRFToken(tokenString, ip string) error {
func verifyCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAuth) error {
tokenString := r.Form.Get(csrfFormToken)
token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString)
if err != nil || token == nil {
logger.Debug(logSender, "", "error validating CSRF token %q: %v", tokenString, err)
return fmt.Errorf("unable to verify form token: %v", err)
}
if !util.Contains(token.Audience(), tokenAudienceCSRF) {
if !slices.Contains(token.Audience(), tokenAudienceCSRF) {
logger.Debug(logSender, "", "error validating CSRF token audience")
return errors.New("the form token is not valid")
}
if tokenValidationMode != tokenValidationNoIPMatch {
if !util.Contains(token.Audience(), ip) {
if err := validateIPForToken(token, util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil {
logger.Debug(logSender, "", "error validating CSRF token IP audience")
return errors.New("the form token is not valid")
}
claims, err := getTokenClaims(r)
if err != nil {
logger.Debug(logSender, "", "error getting token claims for CSRF validation: %v", err)
return err
}
ref, ok := token.Get(claimRef)
if !ok {
logger.Debug(logSender, "", "error validating CSRF token, missing reference")
return errors.New("the form token is not valid")
}
if claims.JwtID == "" || claims.JwtID != ref.(string) {
logger.Debug(logSender, "", "error validating CSRF reference, id %q, reference %q", claims.JwtID, ref)
return errors.New("unexpected form token")
}
return nil
}
func createOAuth2Token(state, ip string) string {
func verifyLoginCookie(r *http.Request) error {
token, _, err := jwtauth.FromContext(r.Context())
if err != nil || token == nil {
logger.Debug(logSender, "", "error getting login token: %v", err)
return errInvalidToken
}
if isTokenInvalidated(r) {
logger.Debug(logSender, "", "the login token has been invalidated")
return errInvalidToken
}
if !slices.Contains(token.Audience(), tokenAudienceWebLogin) {
logger.Debug(logSender, "", "the token with id %q is not valid for audience %q", token.JwtID(), tokenAudienceWebLogin)
return errInvalidToken
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := validateIPForToken(token, ipAddr); err != nil {
return err
}
return nil
}
func verifyLoginCookieAndCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAuth) error {
if err := verifyLoginCookie(r); err != nil {
return err
}
if err := verifyCSRFToken(r, csrfTokenAuth); err != nil {
return err
}
return nil
}
func createOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, state, ip string) string {
claims := make(map[string]any)
now := time.Now().UTC()
@ -436,7 +534,7 @@ func createOAuth2Token(state, ip string) string {
return tokenString
}
func verifyOAuth2Token(tokenString, ip string) (string, error) {
func verifyOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, tokenString, ip string) (string, error) {
token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString)
if err != nil || token == nil {
logger.Debug(logSender, "", "error validating OAuth2 token %q: %v", tokenString, err)
@ -446,17 +544,15 @@ func verifyOAuth2Token(tokenString, ip string) (string, error) {
)
}
if !util.Contains(token.Audience(), tokenAudienceOAuth2) {
if !slices.Contains(token.Audience(), tokenAudienceOAuth2) {
logger.Debug(logSender, "", "error validating OAuth2 token audience")
return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState)
}
if tokenValidationMode != tokenValidationNoIPMatch {
if !util.Contains(token.Audience(), ip) {
if err := validateIPForToken(token, ip); err != nil {
logger.Debug(logSender, "", "error validating OAuth2 token IP audience")
return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState)
}
}
if val, ok := token.Get(jwt.JwtIDKey); ok {
if state, ok := val.(string); ok {
return state, nil
@ -465,3 +561,12 @@ func verifyOAuth2Token(tokenString, ip string) (string, error) {
logger.Debug(logSender, "", "jti not found in OAuth2 token")
return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState)
}
func validateIPForToken(token jwt.Token, ip string) error {
if tokenValidationMode != tokenValidationNoIPMatch {
if !slices.Contains(token.Audience(), ip) {
return errInvalidToken
}
}
return nil
}

View file

@ -213,10 +213,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
if vfs.HasTruncateSupport(fs) {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
if err == nil {
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
}
dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false)
} else {
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
}

View file

@ -28,11 +28,10 @@ import (
"path/filepath"
"runtime"
"strings"
"sync"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/drakkan/sftpgo/v2/internal/acme"
"github.com/drakkan/sftpgo/v2/internal/common"
@ -196,7 +195,6 @@ var (
cleanupTicker *time.Ticker
cleanupDone chan bool
invalidatedJWTTokens tokenManager
csrfTokenAuth *jwtauth.JWTAuth
webRootPath string
webBasePath string
webBaseAdminPath string
@ -288,12 +286,88 @@ var (
installationCodeHint string
fnInstallationCodeResolver FnInstallationCodeResolver
configurationDir string
dbBrandingConfig brandingCache
)
func init() {
updateWebAdminURLs("")
updateWebClientURLs("")
acme.SetReloadHTTPDCertsFn(ReloadCertificateMgr)
common.SetUpdateBrandingFn(dbBrandingConfig.Set)
}
type brandingCache struct {
mu sync.RWMutex
configs *dataprovider.BrandingConfigs
}
func (b *brandingCache) Set(configs *dataprovider.BrandingConfigs) {
b.mu.Lock()
defer b.mu.Unlock()
b.configs = configs
}
func (b *brandingCache) getWebAdminLogo() []byte {
b.mu.RLock()
defer b.mu.RUnlock()
return b.configs.WebAdmin.Logo
}
func (b *brandingCache) getWebAdminFavicon() []byte {
b.mu.RLock()
defer b.mu.RUnlock()
return b.configs.WebAdmin.Favicon
}
func (b *brandingCache) getWebClientLogo() []byte {
b.mu.RLock()
defer b.mu.RUnlock()
return b.configs.WebClient.Logo
}
func (b *brandingCache) getWebClientFavicon() []byte {
b.mu.RLock()
defer b.mu.RUnlock()
return b.configs.WebClient.Favicon
}
func (b *brandingCache) mergeBrandingConfig(branding UIBranding, isWebClient bool) UIBranding {
b.mu.RLock()
defer b.mu.RUnlock()
var urlPrefix string
var cfg dataprovider.BrandingConfig
if isWebClient {
cfg = b.configs.WebClient
urlPrefix = "webclient"
} else {
cfg = b.configs.WebAdmin
urlPrefix = "webadmin"
}
if cfg.Name != "" {
branding.Name = cfg.Name
}
if cfg.ShortName != "" {
branding.ShortName = cfg.ShortName
}
if cfg.DisclaimerName != "" {
branding.DisclaimerName = cfg.DisclaimerName
}
if cfg.DisclaimerURL != "" {
branding.DisclaimerPath = cfg.DisclaimerURL
}
if len(cfg.Logo) > 0 {
branding.LogoPath = path.Join("/", "branding", urlPrefix, "logo.png")
}
if len(cfg.Favicon) > 0 {
branding.FaviconPath = path.Join("/", "branding", urlPrefix, "favicon.png")
}
return branding
}
// FnInstallationCodeResolver defines a method to get the installation code.
@ -410,18 +484,22 @@ type UIBranding struct {
DefaultCSS []string `json:"default_css" mapstructure:"default_css"`
// Additional CSS file paths, relative to "static_files_path", to include
ExtraCSS []string `json:"extra_css" mapstructure:"extra_css"`
DefaultLogoPath string `json:"-" mapstructure:"-"`
DefaultFaviconPath string `json:"-" mapstructure:"-"`
}
func (b *UIBranding) check() {
b.DefaultLogoPath = "/img/logo.png"
b.DefaultFaviconPath = "/favicon.png"
if b.LogoPath != "" {
b.LogoPath = util.CleanPath(b.LogoPath)
} else {
b.LogoPath = "/img/logo.png"
b.LogoPath = b.DefaultLogoPath
}
if b.FaviconPath != "" {
b.FaviconPath = util.CleanPath(b.FaviconPath)
} else {
b.FaviconPath = "/favicon.ico"
b.FaviconPath = b.DefaultFaviconPath
}
if b.DisclaimerPath != "" {
if !strings.HasPrefix(b.DisclaimerPath, "https://") && !strings.HasPrefix(b.DisclaimerPath, "http://") {
@ -551,6 +629,14 @@ func (b *Binding) checkBranding() {
}
}
func (b *Binding) webAdminBranding() UIBranding {
return dbBrandingConfig.mergeBrandingConfig(b.Branding.WebAdmin, false)
}
func (b *Binding) webClientBranding() UIBranding {
return dbBrandingConfig.mergeBrandingConfig(b.Branding.WebClient, true)
}
func (b *Binding) parseAllowedProxy() error {
if filepath.IsAbs(b.Address) && len(b.ProxyAllowed) > 0 {
// unix domain socket
@ -882,6 +968,7 @@ func (c *Conf) loadFromProvider() error {
return fmt.Errorf("unable to load config from provider: %w", err)
}
configs.SetNilsToEmpty()
dbBrandingConfig.Set(configs.Branding)
if configs.ACME.Domain == "" || !configs.ACME.HasProtocol(common.ProtocolHTTP) {
return nil
}
@ -967,7 +1054,6 @@ func (c *Conf) Initialize(configDir string, isShared int) error {
c.SigningPassphrase = passphrase
}
csrfTokenAuth = jwtauth.New(jwa.HS256.String(), getSigningKey(c.SigningPassphrase), nil)
hideSupportLink = c.HideSupportLink
exitChannel := make(chan error, 1)

File diff suppressed because it is too large Load diff

View file

@ -341,7 +341,7 @@ func TestBrandingValidation(t *testing.T) {
},
}
b.checkBranding()
assert.Equal(t, "/favicon.ico", b.Branding.WebAdmin.FaviconPath)
assert.Equal(t, "/favicon.png", b.Branding.WebAdmin.FaviconPath)
assert.Equal(t, "/path1", b.Branding.WebAdmin.LogoPath)
assert.Equal(t, []string{"/my.css"}, b.Branding.WebAdmin.DefaultCSS)
assert.Len(t, b.Branding.WebAdmin.ExtraCSS, 0)
@ -412,6 +412,45 @@ func TestGCSWebInvalidFormFile(t *testing.T) {
assert.EqualError(t, err, http.ErrNotMultipart.Error())
}
func TestBrandingInvalidFormFile(t *testing.T) {
form := make(url.Values)
req, _ := http.NewRequest(http.MethodPost, webConfigsPath, strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
err := req.ParseForm()
assert.NoError(t, err)
_, err = getBrandingConfigFromPostFields(req, &dataprovider.BrandingConfigs{})
assert.EqualError(t, err, http.ErrNotMultipart.Error())
}
func TestVerifyCSRFToken(t *testing.T) {
server := httpdServer{}
server.initializeRouter()
req, err := http.NewRequest(http.MethodPost, webAdminEventActionPath, nil)
require.NoError(t, err)
req = req.WithContext(context.WithValue(req.Context(), jwtauth.ErrorCtxKey, fs.ErrPermission))
rr := httptest.NewRecorder()
tokenString := createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseAdminPath)
assert.NotEmpty(t, tokenString)
token, err := server.csrfTokenAuth.Decode(tokenString)
require.NoError(t, err)
_, ok := token.Get(claimRef)
assert.False(t, ok)
req.Form = url.Values{}
req.Form.Set(csrfFormToken, tokenString)
err = verifyCSRFToken(req, server.csrfTokenAuth)
assert.ErrorIs(t, err, fs.ErrPermission)
req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, nil)
require.NoError(t, err)
req.Form = url.Values{}
req.Form.Set(csrfFormToken, tokenString)
err = verifyCSRFToken(req, server.csrfTokenAuth)
assert.ErrorContains(t, err, "the form token is not valid")
}
func TestInvalidToken(t *testing.T) {
server := httpdServer{}
server.initializeRouter()
@ -923,13 +962,24 @@ func TestUpdateWebAdminInvalidClaims(t *testing.T) {
token, err := c.createTokenResponse(server.tokenAuth, tokenAudienceWebAdmin, "")
assert.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, webAdminPath, nil)
assert.NoError(t, err)
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"]))
parsedToken, err := jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx := req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
form := make(url.Values)
form.Set(csrfFormToken, createCSRFToken(""))
form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseAdminPath))
form.Set("status", "1")
form.Set("default_users_expiration", "30")
req, _ := http.NewRequest(http.MethodPost, path.Join(webAdminPath, "admin"), bytes.NewBuffer([]byte(form.Encode())))
req, err = http.NewRequest(http.MethodPost, path.Join(webAdminPath, "admin"), bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
rctx := chi.NewRouteContext()
rctx.URLParams.Add("username", "admin")
req = req.WithContext(ctx)
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"]))
@ -1028,7 +1078,7 @@ func TestOAuth2Redirect(t *testing.T) {
assert.Contains(t, rr.Body.String(), util.I18nOAuth2ErrorTitle)
ip := "127.1.1.4"
tokenString := createOAuth2Token(xid.New().String(), ip)
tokenString := createOAuth2Token(server.csrfTokenAuth, xid.New().String(), ip)
rr = httptest.NewRecorder()
req, err = http.NewRequest(http.MethodGet, webOAuth2RedirectPath+"?state="+tokenString, nil) //nolint:goconst
assert.NoError(t, err)
@ -1039,8 +1089,10 @@ func TestOAuth2Redirect(t *testing.T) {
}
func TestOAuth2Token(t *testing.T) {
server := httpdServer{}
server.initializeRouter()
// invalid token
_, err := verifyOAuth2Token("token", "")
_, err := verifyOAuth2Token(server.csrfTokenAuth, "token", "")
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "unable to verify OAuth2 state")
}
@ -1053,22 +1105,22 @@ func TestOAuth2Token(t *testing.T) {
claims[jwt.ExpirationKey] = now.Add(tokenDuration)
claims[jwt.AudienceKey] = []string{tokenAudienceAPI}
_, tokenString, err := csrfTokenAuth.Encode(claims)
_, tokenString, err := server.csrfTokenAuth.Encode(claims)
assert.NoError(t, err)
_, err = verifyOAuth2Token(tokenString, "")
_, err = verifyOAuth2Token(server.csrfTokenAuth, tokenString, "")
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "invalid OAuth2 state")
}
// bad IP
tokenString = createOAuth2Token("state", "127.1.1.1")
_, err = verifyOAuth2Token(tokenString, "127.1.1.2")
tokenString = createOAuth2Token(server.csrfTokenAuth, "state", "127.1.1.1")
_, err = verifyOAuth2Token(server.csrfTokenAuth, tokenString, "127.1.1.2")
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "invalid OAuth2 state")
}
// ok
state := xid.New().String()
tokenString = createOAuth2Token(state, "127.1.1.3")
s, err := verifyOAuth2Token(tokenString, "127.1.1.3")
tokenString = createOAuth2Token(server.csrfTokenAuth, state, "127.1.1.3")
s, err := verifyOAuth2Token(server.csrfTokenAuth, tokenString, "127.1.1.3")
assert.NoError(t, err)
assert.Equal(t, state, s)
// no jti
@ -1077,19 +1129,17 @@ func TestOAuth2Token(t *testing.T) {
claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second)
claims[jwt.ExpirationKey] = now.Add(tokenDuration)
claims[jwt.AudienceKey] = []string{tokenAudienceOAuth2, "127.1.1.4"}
_, tokenString, err = csrfTokenAuth.Encode(claims)
_, tokenString, err = server.csrfTokenAuth.Encode(claims)
assert.NoError(t, err)
_, err = verifyOAuth2Token(tokenString, "127.1.1.4")
_, err = verifyOAuth2Token(server.csrfTokenAuth, tokenString, "127.1.1.4")
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "invalid OAuth2 state")
}
// encode error
csrfTokenAuth = jwtauth.New("HT256", util.GenerateRandomBytes(32), nil)
tokenString = createOAuth2Token(xid.New().String(), "")
server.csrfTokenAuth = jwtauth.New("HT256", util.GenerateRandomBytes(32), nil)
tokenString = createOAuth2Token(server.csrfTokenAuth, xid.New().String(), "")
assert.Empty(t, tokenString)
server := httpdServer{}
server.initializeRouter()
rr := httptest.NewRecorder()
testReq := make(map[string]any)
testReq["base_redirect_url"] = "http://localhost:8082"
@ -1097,16 +1147,17 @@ func TestOAuth2Token(t *testing.T) {
assert.NoError(t, err)
req, err := http.NewRequest(http.MethodPost, webOAuth2TokenPath, bytes.NewBuffer(asJSON))
assert.NoError(t, err)
handleSMTPOAuth2TokenRequestPost(rr, req)
server.handleSMTPOAuth2TokenRequestPost(rr, req)
assert.Equal(t, http.StatusInternalServerError, rr.Code)
assert.Contains(t, rr.Body.String(), "unable to create state token")
csrfTokenAuth = jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil)
}
func TestCSRFToken(t *testing.T) {
server := httpdServer{}
server.initializeRouter()
// invalid token
err := verifyCSRFToken("token", "")
req := &http.Request{}
err := verifyCSRFToken(req, server.csrfTokenAuth)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "unable to verify form token")
}
@ -1119,16 +1170,23 @@ func TestCSRFToken(t *testing.T) {
claims[jwt.ExpirationKey] = now.Add(tokenDuration)
claims[jwt.AudienceKey] = []string{tokenAudienceAPI}
_, tokenString, err := csrfTokenAuth.Encode(claims)
_, tokenString, err := server.csrfTokenAuth.Encode(claims)
assert.NoError(t, err)
err = verifyCSRFToken(tokenString, "")
values := url.Values{}
values.Set(csrfFormToken, tokenString)
req.Form = values
err = verifyCSRFToken(req, server.csrfTokenAuth)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "form token is not valid")
}
// bad IP
tokenString = createCSRFToken("127.1.1.1")
err = verifyCSRFToken(tokenString, "127.1.1.2")
req.RemoteAddr = "127.1.1.1"
tokenString = createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)
values.Set(csrfFormToken, tokenString)
req.Form = values
req.RemoteAddr = "127.1.1.2"
err = verifyCSRFToken(req, server.csrfTokenAuth)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "form token is not valid")
}
@ -1137,8 +1195,9 @@ func TestCSRFToken(t *testing.T) {
claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second)
claims[jwt.ExpirationKey] = now.Add(tokenDuration)
claims[jwt.AudienceKey] = []string{tokenAudienceAPI}
_, tokenString, err = csrfTokenAuth.Encode(claims)
_, tokenString, err = server.csrfTokenAuth.Encode(claims)
assert.NoError(t, err)
assert.NotEmpty(t, tokenString)
r := GetHTTPRouter(Binding{
Address: "",
@ -1148,9 +1207,9 @@ func TestCSRFToken(t *testing.T) {
EnableRESTAPI: true,
RenderOpenAPI: true,
})
fn := verifyCSRFHeader(r)
fn := server.verifyCSRFHeader(r)
rr := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, path.Join(userPath, "username"), nil)
req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, "username"), nil)
fn.ServeHTTP(rr, req)
assert.Equal(t, http.StatusForbidden, rr.Code)
assert.Contains(t, rr.Body.String(), "Invalid token")
@ -1163,18 +1222,20 @@ func TestCSRFToken(t *testing.T) {
assert.Contains(t, rr.Body.String(), "the token is not valid")
// invalid IP
tokenString = createCSRFToken("172.16.1.2")
tokenString = createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)
req.Header.Set(csrfHeaderToken, tokenString)
req.RemoteAddr = "172.16.1.2"
rr = httptest.NewRecorder()
fn.ServeHTTP(rr, req)
assert.Equal(t, http.StatusForbidden, rr.Code)
assert.Contains(t, rr.Body.String(), "the token is not valid")
csrfTokenAuth = jwtauth.New("PS256", util.GenerateRandomBytes(32), nil)
tokenString = createCSRFToken("")
csrfTokenAuth := jwtauth.New("PS256", util.GenerateRandomBytes(32), nil)
tokenString = createCSRFToken(httptest.NewRecorder(), req, csrfTokenAuth, "", webBaseAdminPath)
assert.Empty(t, tokenString)
csrfTokenAuth = jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil)
rr = httptest.NewRecorder()
createLoginCookie(rr, req, csrfTokenAuth, "", webBaseAdminPath, req.RemoteAddr)
assert.Empty(t, rr.Header().Get("Set-Cookie"))
}
func TestCreateShareCookieError(t *testing.T) {
@ -1206,18 +1267,37 @@ func TestCreateShareCookieError(t *testing.T) {
server := httpdServer{
tokenAuth: jwtauth.New("TS256", util.GenerateRandomBytes(32), nil),
csrfTokenAuth: jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil),
}
c := jwtTokenClaims{
JwtID: xid.New().String(),
}
resp, err := c.createTokenResponse(server.csrfTokenAuth, tokenAudienceWebLogin, "127.0.0.1")
assert.NoError(t, err)
parsedToken, err := jwtauth.VerifyToken(server.csrfTokenAuth, resp["access_token"].(string))
assert.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, share.ShareID, "login"), nil)
assert.NoError(t, err)
req.RemoteAddr = "127.0.0.1:4567"
ctx := req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
form := make(url.Values)
form.Set("share_password", pwd)
form.Set(csrfFormToken, createCSRFToken("127.0.0.1"))
form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseClientPath))
rctx := chi.NewRouteContext()
rctx.URLParams.Add("id", share.ShareID)
rr := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, share.ShareID, "login"),
req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, share.ShareID, "login"),
bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
req.RemoteAddr = "127.0.0.1:2345"
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", resp["access_token"]))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req = req.WithContext(ctx)
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
server.handleClientShareLoginPost(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
@ -1230,6 +1310,7 @@ func TestCreateShareCookieError(t *testing.T) {
func TestCreateTokenError(t *testing.T) {
server := httpdServer{
tokenAuth: jwtauth.New("PS256", util.GenerateRandomBytes(32), nil),
csrfTokenAuth: jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil),
}
rr := httptest.NewRecorder()
admin := dataprovider.Admin{
@ -1253,14 +1334,36 @@ func TestCreateTokenError(t *testing.T) {
server.generateAndSendUserToken(rr, req, "", user)
assert.Equal(t, http.StatusInternalServerError, rr.Code)
c := jwtTokenClaims{
JwtID: xid.New().String(),
}
token, err := c.createTokenResponse(server.csrfTokenAuth, tokenAudienceWebLogin, "")
assert.NoError(t, err)
req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil)
assert.NoError(t, err)
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"]))
parsedToken, err := jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx := req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
rr = httptest.NewRecorder()
form := make(url.Values)
form.Set("username", admin.Username)
form.Set("password", admin.Password)
form.Set(csrfFormToken, createCSRFToken("127.0.0.1"))
form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, xid.New().String(), webBaseAdminPath))
cookie := rr.Header().Get("Set-Cookie")
assert.NotEmpty(t, cookie)
req, _ = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode())))
req.RemoteAddr = "127.0.0.1:1234"
req.Header.Set("Cookie", cookie)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx = req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
server.handleWebAdminLoginPost(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
// req with no content type
@ -1287,7 +1390,7 @@ func TestCreateTokenError(t *testing.T) {
req, _ = http.NewRequest(http.MethodGet, webAdminLoginPath+"?a=a%C3%A2%G3", nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
_, err := getAdminFromPostFields(req)
_, err = getAdminFromPostFields(req)
assert.Error(t, err)
req, _ = http.NewRequest(http.MethodPost, webAdminEventActionPath+"?a=a%C3%A2%GG", nil)
@ -1421,13 +1524,21 @@ func TestCreateTokenError(t *testing.T) {
err = dataprovider.AddUser(&user, "", "", "")
assert.NoError(t, err)
req, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil)
assert.NoError(t, err)
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"]))
parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx = req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
rr = httptest.NewRecorder()
form = make(url.Values)
form.Set("username", user.Username)
form.Set("password", "clientpwd")
form.Set(csrfFormToken, createCSRFToken("127.0.0.1"))
form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseClientPath))
req, _ = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode())))
req.RemoteAddr = "127.0.0.1:4567"
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
server.handleWebClientLoginPost(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
@ -1616,6 +1727,7 @@ func TestCookieExpiration(t *testing.T) {
claims = make(map[string]any)
claims[claimUsernameKey] = admin.Username
claims[claimPermissionsKey] = admin.Permissions
claims[jwt.JwtIDKey] = xid.New().String()
claims[jwt.SubjectKey] = admin.GetSignature()
claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute)
claims[jwt.AudienceKey] = []string{tokenAudienceAPI}
@ -1648,9 +1760,11 @@ func TestCookieExpiration(t *testing.T) {
admin, err = dataprovider.AdminExists(admin.Username)
assert.NoError(t, err)
tokenID := xid.New().String()
claims = make(map[string]any)
claims[claimUsernameKey] = admin.Username
claims[claimPermissionsKey] = admin.Permissions
claims[jwt.JwtIDKey] = tokenID
claims[jwt.SubjectKey] = admin.GetSignature()
claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute)
claims[jwt.AudienceKey] = []string{tokenAudienceAPI}
@ -1669,6 +1783,11 @@ func TestCookieExpiration(t *testing.T) {
server.checkCookieExpiration(rr, req.WithContext(ctx))
cookie = rr.Header().Get("Set-Cookie")
assert.True(t, strings.HasPrefix(cookie, "jwt="))
req.Header.Set("Cookie", cookie)
token, err = jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie)
if assert.NoError(t, err) {
assert.Equal(t, tokenID, token.JwtID())
}
err = dataprovider.DeleteAdmin(admin.Username, "", "", "")
assert.NoError(t, err)
@ -1689,6 +1808,7 @@ func TestCookieExpiration(t *testing.T) {
claims = make(map[string]any)
claims[claimUsernameKey] = user.Username
claims[claimPermissionsKey] = user.Filters.WebClient
claims[jwt.JwtIDKey] = tokenID
claims[jwt.SubjectKey] = user.GetSignature()
claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute)
claims[jwt.AudienceKey] = []string{tokenAudienceWebClient}
@ -1721,6 +1841,7 @@ func TestCookieExpiration(t *testing.T) {
claims = make(map[string]any)
claims[claimUsernameKey] = user.Username
claims[claimPermissionsKey] = user.Filters.WebClient
claims[jwt.JwtIDKey] = tokenID
claims[jwt.SubjectKey] = user.GetSignature()
claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute)
claims[jwt.AudienceKey] = []string{tokenAudienceWebClient}
@ -1740,6 +1861,11 @@ func TestCookieExpiration(t *testing.T) {
server.checkCookieExpiration(rr, req.WithContext(ctx))
cookie = rr.Header().Get("Set-Cookie")
assert.NotEmpty(t, cookie)
req.Header.Set("Cookie", cookie)
token, err = jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie)
if assert.NoError(t, err) {
assert.Equal(t, tokenID, token.JwtID())
}
// test a disabled user
user.Status = 0
@ -1751,6 +1877,7 @@ func TestCookieExpiration(t *testing.T) {
claims = make(map[string]any)
claims[claimUsernameKey] = user.Username
claims[claimPermissionsKey] = user.Filters.WebClient
claims[jwt.JwtIDKey] = tokenID
claims[jwt.SubjectKey] = user.GetSignature()
claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute)
claims[jwt.AudienceKey] = []string{tokenAudienceWebClient}
@ -1834,7 +1961,7 @@ func TestRenderInvalidTemplate(t *testing.T) {
}
func TestQuotaScanInvalidFs(t *testing.T) {
user := dataprovider.User{
user := &dataprovider.User{
BaseUser: sdk.BaseUser{
Username: "test",
HomeDir: os.TempDir(),
@ -2127,34 +2254,95 @@ func TestProxyHeaders(t *testing.T) {
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil)
assert.NoError(t, err)
req.RemoteAddr = testIP
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
cookie := rr.Header().Get("Set-Cookie")
assert.NotEmpty(t, cookie)
req.Header.Set("Cookie", cookie)
parsedToken, err := jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx := req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
form := make(url.Values)
form.Set("username", username)
form.Set("password", password)
form.Set(csrfFormToken, createCSRFToken(testIP))
form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath))
req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
req.RemoteAddr = testIP
req.Header.Set("Cookie", cookie)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials)
form.Set(csrfFormToken, createCSRFToken(validForwardedFor))
req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil)
assert.NoError(t, err)
req.RemoteAddr = validForwardedFor
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
loginCookie := rr.Header().Get("Set-Cookie")
assert.NotEmpty(t, loginCookie)
req.Header.Set("Cookie", loginCookie)
parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx = req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath))
req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
req.RemoteAddr = testIP
req.Header.Set("Cookie", loginCookie)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("X-Forwarded-For", validForwardedFor)
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String())
cookie := rr.Header().Get("Set-Cookie")
cookie = rr.Header().Get("Set-Cookie")
assert.NotContains(t, cookie, "Secure")
// The login cookie is invalidated after a successful login, the same request will fail
req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
req.RemoteAddr = testIP
req.Header.Set("Cookie", loginCookie)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("X-Forwarded-For", validForwardedFor)
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF)
req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil)
assert.NoError(t, err)
req.RemoteAddr = validForwardedFor
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
loginCookie = rr.Header().Get("Set-Cookie")
assert.NotEmpty(t, loginCookie)
req.Header.Set("Cookie", loginCookie)
parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx = req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath))
req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
req.RemoteAddr = testIP
req.Header.Set("Cookie", loginCookie)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("X-Forwarded-For", validForwardedFor)
req.Header.Set(xForwardedProto, "https")
@ -2164,9 +2352,26 @@ func TestProxyHeaders(t *testing.T) {
cookie = rr.Header().Get("Set-Cookie")
assert.Contains(t, cookie, "Secure")
req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil)
assert.NoError(t, err)
req.RemoteAddr = validForwardedFor
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
loginCookie = rr.Header().Get("Set-Cookie")
assert.NotEmpty(t, loginCookie)
req.Header.Set("Cookie", loginCookie)
parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx = req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath))
req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
req.RemoteAddr = testIP
req.Header.Set("Cookie", loginCookie)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("X-Forwarded-For", validForwardedFor)
req.Header.Set(xForwardedProto, "http")
@ -2737,10 +2942,22 @@ func TestInvalidClaims(t *testing.T) {
}
token, err := c.createTokenResponse(server.tokenAuth, tokenAudienceWebClient, "")
assert.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, webClientProfilePath, nil)
assert.NoError(t, err)
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"]))
parsedToken, err := jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx := req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
form := make(url.Values)
form.Set(csrfFormToken, createCSRFToken(""))
form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseClientPath))
form.Set("public_keys", "")
req, _ := http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode())))
req, err = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"]))
server.handleWebClientProfilePost(rr, req)
@ -2757,14 +2974,27 @@ func TestInvalidClaims(t *testing.T) {
}
token, err = c.createTokenResponse(server.tokenAuth, tokenAudienceWebAdmin, "")
assert.NoError(t, err)
req, err = http.NewRequest(http.MethodGet, webAdminProfilePath, nil)
assert.NoError(t, err)
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"]))
parsedToken, err = jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx = req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
form = make(url.Values)
form.Set(csrfFormToken, createCSRFToken(""))
form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseAdminPath))
form.Set("allow_api_key_auth", "")
req, _ = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode())))
req, err = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"]))
server.handleWebAdminProfilePost(rr, req)
assert.Equal(t, http.StatusForbidden, rr.Code)
assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken)
}
func TestTLSReq(t *testing.T) {
@ -3063,24 +3293,31 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) {
}
server.initializeRouter()
rr := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodGet, webAdminSetupPath, nil)
assert.NoError(t, err)
server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusOK, rr.Code)
for _, webURL := range []string{"/", webBasePath, webBaseAdminPath, webAdminLoginPath, webClientLoginPath} {
rr = httptest.NewRecorder()
r, err = http.NewRequest(http.MethodGet, webURL, nil)
rr := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodGet, webURL, nil)
assert.NoError(t, err)
server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusFound, rr.Code)
assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location"))
}
rr := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodGet, webAdminSetupPath, nil)
assert.NoError(t, err)
server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusOK, rr.Code)
cookie := rr.Header().Get("Set-Cookie")
r.Header.Set("Cookie", cookie)
parsedToken, err := jwtauth.VerifyRequest(server.csrfTokenAuth, r, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx := r.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
r = r.WithContext(ctx)
form := make(url.Values)
csrfToken := createCSRFToken("")
form.Set("_form_token", csrfToken)
csrfToken := createCSRFToken(rr, r, server.csrfTokenAuth, "", webBaseAdminPath)
form.Set(csrfFormToken, csrfToken)
form.Set("install_code", installationCode+"5")
form.Set("username", defaultAdminUsername)
form.Set("password", "password")
@ -3088,6 +3325,8 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) {
rr = httptest.NewRecorder()
r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
r = r.WithContext(ctx)
r.Header.Set("Cookie", cookie)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusOK, rr.Code)
@ -3099,6 +3338,8 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) {
rr = httptest.NewRecorder()
r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
r = r.WithContext(ctx)
r.Header.Set("Cookie", cookie)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusFound, rr.Code)
@ -3120,12 +3361,6 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) {
return "5678"
})
rr = httptest.NewRecorder()
r, err = http.NewRequest(http.MethodGet, webAdminSetupPath, nil)
assert.NoError(t, err)
server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusOK, rr.Code)
for _, webURL := range []string{"/", webBasePath, webBaseAdminPath, webAdminLoginPath, webClientLoginPath} {
rr = httptest.NewRecorder()
r, err = http.NewRequest(http.MethodGet, webURL, nil)
@ -3135,9 +3370,22 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) {
assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location"))
}
rr = httptest.NewRecorder()
r, err = http.NewRequest(http.MethodGet, webAdminSetupPath, nil)
assert.NoError(t, err)
server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusOK, rr.Code)
cookie = rr.Header().Get("Set-Cookie")
r.Header.Set("Cookie", cookie)
parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, r, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx = r.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
r = r.WithContext(ctx)
form = make(url.Values)
csrfToken = createCSRFToken("")
form.Set("_form_token", csrfToken)
csrfToken = createCSRFToken(rr, r, server.csrfTokenAuth, "", webBaseAdminPath)
form.Set(csrfFormToken, csrfToken)
form.Set("install_code", installationCode)
form.Set("username", defaultAdminUsername)
form.Set("password", "password")
@ -3145,6 +3393,8 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) {
rr = httptest.NewRecorder()
r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
r = r.WithContext(ctx)
r.Header.Set("Cookie", cookie)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusOK, rr.Code)
@ -3156,6 +3406,8 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) {
rr = httptest.NewRecorder()
r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
r = r.WithContext(ctx)
r.Header.Set("Cookie", cookie)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusFound, rr.Code)
@ -3221,6 +3473,7 @@ func TestDecodeToken(t *testing.T) {
claimNodeID: nodeID,
claimMustChangePasswordKey: false,
claimMustSetSecondFactorKey: true,
claimRef: "ref",
}
c := jwtTokenClaims{}
c.Decode(token)
@ -3228,6 +3481,11 @@ func TestDecodeToken(t *testing.T) {
assert.Equal(t, nodeID, c.NodeID)
assert.False(t, c.MustChangePassword)
assert.True(t, c.MustSetTwoFactorAuth)
assert.Equal(t, "ref", c.Ref)
asMap := c.asMap()
asMap[claimMustChangePasswordKey] = false
assert.Equal(t, token, asMap)
token[claimMustChangePasswordKey] = 10
c = jwtTokenClaims{}

View file

@ -20,10 +20,10 @@ import (
"io/fs"
"net/http"
"net/url"
"slices"
"strings"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/rs/xid"
"github.com/sftpgo/sdk"
@ -75,12 +75,6 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi
return errInvalidToken
}
err = jwt.Validate(token)
if err != nil {
logger.Debug(logSender, "", "error validating jwt token: %v", err)
doRedirect(http.StatusText(http.StatusUnauthorized), err)
return errInvalidToken
}
if isTokenInvalidated(r) {
logger.Debug(logSender, "", "the token has been invalidated")
doRedirect("Your token is no longer valid", nil)
@ -90,18 +84,16 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi
if err := checkPartialAuth(w, r, audience, token.Audience()); err != nil {
return err
}
if !util.Contains(token.Audience(), audience) {
if !slices.Contains(token.Audience(), audience) {
logger.Debug(logSender, "", "the token is not valid for audience %q", audience)
doRedirect("Your token audience is not valid", nil)
return errInvalidToken
}
if tokenValidationMode != tokenValidationNoIPMatch {
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if !util.Contains(token.Audience(), ipAddr) {
if err := validateIPForToken(token, ipAddr); err != nil {
logger.Debug(logSender, "", "the token with id %q is not valid for the ip address %q", token.JwtID(), ipAddr)
doRedirect("Your token is not valid", nil)
return errInvalidToken
}
return err
}
return nil
}
@ -114,7 +106,7 @@ func (s *httpdServer) validateJWTPartialToken(w http.ResponseWriter, r *http.Req
} else {
notFoundFunc = s.renderClientNotFoundPage
}
if err != nil || token == nil || jwt.Validate(token) != nil {
if err != nil || token == nil {
notFoundFunc(w, r, nil)
return errInvalidToken
}
@ -122,11 +114,17 @@ func (s *httpdServer) validateJWTPartialToken(w http.ResponseWriter, r *http.Req
notFoundFunc(w, r, nil)
return errInvalidToken
}
if !util.Contains(token.Audience(), audience) {
logger.Debug(logSender, "", "the token is not valid for audience %q", audience)
if !slices.Contains(token.Audience(), audience) {
logger.Debug(logSender, "", "the partial token with id %q is not valid for audience %q", token.JwtID(), audience)
notFoundFunc(w, r, nil)
return errInvalidToken
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := validateIPForToken(token, ipAddr); err != nil {
logger.Debug(logSender, "", "the partial token with id %q is not valid for the ip address %q", token.JwtID(), ipAddr)
notFoundFunc(w, r, nil)
return err
}
return nil
}
@ -324,29 +322,27 @@ func (s *httpdServer) checkPerm(perm string) func(next http.Handler) http.Handle
}
}
func verifyCSRFHeader(next http.Handler) http.Handler {
func (s *httpdServer) verifyCSRFHeader(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenString := r.Header.Get(csrfHeaderToken)
token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString)
token, err := jwtauth.VerifyToken(s.csrfTokenAuth, tokenString)
if err != nil || token == nil {
logger.Debug(logSender, "", "error validating CSRF header: %v", err)
sendAPIResponse(w, r, err, "Invalid token", http.StatusForbidden)
return
}
if !util.Contains(token.Audience(), tokenAudienceCSRF) {
if !slices.Contains(token.Audience(), tokenAudienceCSRF) {
logger.Debug(logSender, "", "error validating CSRF header token audience")
sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden)
return
}
if tokenValidationMode != tokenValidationNoIPMatch {
if !util.Contains(token.Audience(), util.GetIPFromRemoteAddress(r.RemoteAddr)) {
if err := validateIPForToken(token, util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil {
logger.Debug(logSender, "", "error validating CSRF header IP audience")
sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden)
return
}
}
next.ServeHTTP(w, r)
})
@ -576,11 +572,11 @@ func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAu
}
func checkPartialAuth(w http.ResponseWriter, r *http.Request, audience string, tokenAudience []string) error {
if audience == tokenAudienceWebAdmin && util.Contains(tokenAudience, tokenAudienceWebAdminPartial) {
if audience == tokenAudienceWebAdmin && slices.Contains(tokenAudience, tokenAudienceWebAdminPartial) {
http.Redirect(w, r, webAdminTwoFactorPath, http.StatusFound)
return errInvalidToken
}
if audience == tokenAudienceWebClient && util.Contains(tokenAudience, tokenAudienceWebClientPartial) {
if audience == tokenAudienceWebClient && slices.Contains(tokenAudience, tokenAudienceWebClientPartial) {
http.Redirect(w, r, webClientTwoFactorPath, http.StatusFound)
return errInvalidToken
}

View file

@ -21,6 +21,7 @@ import (
"fmt"
"net/http"
"net/url"
"slices"
"strings"
"time"
@ -143,7 +144,7 @@ func (o *OIDC) initialize() error {
if o.RedirectBaseURL == "" {
return errors.New("oidc: redirect base URL cannot be empty")
}
if !util.Contains(o.Scopes, oidc.ScopeOpenID) {
if !slices.Contains(o.Scopes, oidc.ScopeOpenID) {
return fmt.Errorf("oidc: required scope %q is not set", oidc.ScopeOpenID)
}
if o.ClientSecretFile != "" {
@ -543,6 +544,7 @@ func (s *httpdServer) oidcTokenAuthenticator(audience tokenAudience) func(next h
return
}
jwtTokenClaims := jwtTokenClaims{
JwtID: token.Cookie,
Username: token.Username,
Permissions: token.Permissions,
Role: token.TokenRole,

View file

@ -33,7 +33,6 @@ import (
"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/rs/xid"
"github.com/sftpgo/sdk"
"github.com/stretchr/testify/assert"
@ -1586,12 +1585,9 @@ func TestOIDCWithLoginFormsDisabled(t *testing.T) {
tokenCookie = k
}
// we should be able to create admins without setting a password
if csrfTokenAuth == nil {
csrfTokenAuth = jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil)
}
adminUsername := "testAdmin"
form := make(url.Values)
form.Set(csrfFormToken, createCSRFToken(""))
form.Set(csrfFormToken, createCSRFToken(rr, r, server.csrfTokenAuth, tokenCookie, webBaseAdminPath))
form.Set("username", adminUsername)
form.Set("password", "")
form.Set("status", "1")

View file

@ -24,7 +24,9 @@ import (
"net"
"net/http"
"net/url"
"path"
"path/filepath"
"slices"
"strings"
"time"
@ -68,6 +70,7 @@ type httpdServer struct {
isShared int
router *chi.Mux
tokenAuth *jwtauth.JWTAuth
csrfTokenAuth *jwtauth.JWTAuth
signingPassphrase string
cors CorsConfig
}
@ -164,14 +167,14 @@ func (s *httpdServer) refreshCookie(next http.Handler) http.Handler {
})
}
func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := loginPage{
commonBasePage: getCommonBasePage(r),
Title: util.I18nLoginTitle,
CurrentURL: webClientLoginPath,
Error: err,
CSRFToken: createCSRFToken(ip),
Branding: s.binding.Branding.WebClient,
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath),
Branding: s.binding.webClientBranding(),
FormDisabled: s.binding.isWebClientLoginFormDisabled(),
CheckRedirect: true,
}
@ -180,7 +183,7 @@ func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, r *http.Reque
}
if s.binding.showAdminLoginURL() {
data.AltLoginURL = webAdminLoginPath
data.AltLoginName = s.binding.Branding.WebAdmin.ShortName
data.AltLoginName = s.binding.webAdminBranding().ShortName
}
if smtp.IsEnabled() && !data.FormDisabled {
data.ForgotPwdURL = webClientForgotPwdPath
@ -193,8 +196,7 @@ func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, r *http.Reque
func (s *httpdServer) handleWebClientLogout(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
c := jwtTokenClaims{}
c.removeCookie(w, r, webBaseClientPath)
removeCookie(w, r, webBaseClientPath)
s.logoutOIDCUser(w, r)
http.Redirect(w, r, webClientLoginPath, http.StatusFound)
@ -206,7 +208,7 @@ func (s *httpdServer) handleWebClientChangePwdPost(w http.ResponseWriter, r *htt
s.renderClientChangePasswordPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return
}
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -226,7 +228,7 @@ func (s *httpdServer) handleClientWebLogin(w http.ResponseWriter, r *http.Reques
return
}
msg := getFlashMessage(w, r)
s.renderClientLoginPage(w, r, msg.getI18nError(), util.GetIPFromRemoteAddress(r.RemoteAddr))
s.renderClientLoginPage(w, r, msg.getI18nError())
}
func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) {
@ -234,7 +236,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := r.ParseForm(); err != nil {
s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr)
s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return
}
protocol := common.ProtocolHTTP
@ -244,20 +246,19 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials)
s.renderClientLoginPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr)
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return
}
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
dataprovider.LoginMethodPassword, ipAddr, err)
s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr)
return
s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
}
if err := common.Config.ExecutePostConnectHook(ipAddr, protocol); err != nil {
updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
dataprovider.LoginMethodPassword, ipAddr, err)
s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nError403Message), ipAddr)
s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nError403Message))
return
}
@ -265,13 +266,13 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
if err != nil {
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
s.renderClientLoginPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr)
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return
}
connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nError403Message), ipAddr)
s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nError403Message))
return
}
@ -280,7 +281,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
if err != nil {
logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorFsGeneric), ipAddr)
s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorFsGeneric))
return
}
s.loginUser(w, r, &user, connectionID, ipAddr, false, s.renderClientLoginPage)
@ -292,10 +293,10 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
err := r.ParseForm()
if err != nil {
s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr)
s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return
}
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -304,12 +305,12 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
_, user, err := handleResetPassword(r, strings.TrimSpace(r.Form.Get("code")),
newPassword, confirmPassword, false)
if err != nil {
s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric), ipAddr)
s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric))
return
}
connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String())
if err := checkHTTPClientUser(user, r, connectionID, true); err != nil {
s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorLoginAfterReset), ipAddr)
s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorLoginAfterReset))
return
}
@ -317,7 +318,7 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
err = user.CheckFsRoot(connectionID)
if err != nil {
logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorLoginAfterReset), ipAddr)
s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorLoginAfterReset))
return
}
s.loginUser(w, r, user, connectionID, ipAddr, false, s.renderClientResetPwdPage)
@ -332,18 +333,18 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := r.ParseForm(); err != nil {
s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr)
s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return
}
username := claims.Username
recoveryCode := strings.TrimSpace(r.Form.Get("recovery_code"))
if username == "" || recoveryCode == "" {
s.renderClientTwoFactorRecoveryPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr)
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return
}
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr)
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
user, userMerged, err := dataprovider.GetUserVariants(username, "")
@ -352,12 +353,12 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck
}
s.renderClientTwoFactorRecoveryPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr)
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return
}
if !userMerged.Filters.TOTPConfig.Enabled || !util.Contains(userMerged.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) {
if !userMerged.Filters.TOTPConfig.Enabled || !slices.Contains(userMerged.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) {
s.renderClientTwoFactorPage(w, r, util.NewI18nError(
util.NewValidationError("two factory authentication is not enabled"), util.I18n2FADisabled), ipAddr)
util.NewValidationError("two factory authentication is not enabled"), util.I18n2FADisabled))
return
}
for idx, code := range user.Filters.RecoveryCodes {
@ -368,7 +369,7 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
if code.Secret.GetPayload() == recoveryCode {
if code.Used {
s.renderClientTwoFactorRecoveryPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr)
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return
}
user.Filters.RecoveryCodes[idx].Used = true
@ -386,7 +387,7 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
}
handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck
s.renderClientTwoFactorRecoveryPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr)
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
}
func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *http.Request) {
@ -398,7 +399,7 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := r.ParseForm(); err != nil {
s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr)
s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return
}
username := claims.Username
@ -407,25 +408,25 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt
updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials)
s.renderClientTwoFactorPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr)
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return
}
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
dataprovider.LoginMethodPassword, ipAddr, err)
s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr)
s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
user, err := dataprovider.GetUserWithGroupSettings(username, "")
if err != nil {
updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
dataprovider.LoginMethodPassword, ipAddr, err)
s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials), ipAddr)
s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials))
return
}
if !user.Filters.TOTPConfig.Enabled || !util.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) {
if !user.Filters.TOTPConfig.Enabled || !slices.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) {
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
s.renderClientTwoFactorPage(w, r, util.NewI18nError(common.ErrInternalFailure, util.I18n2FADisabled), ipAddr)
s.renderClientTwoFactorPage(w, r, util.NewI18nError(common.ErrInternalFailure, util.I18n2FADisabled))
return
}
err = user.Filters.TOTPConfig.Secret.Decrypt()
@ -439,7 +440,7 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt
if !match || err != nil {
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, dataprovider.ErrInvalidCredentials)
s.renderClientTwoFactorPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr)
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return
}
connectionID := fmt.Sprintf("%s_%s", getProtocolFromRequest(r), xid.New().String())
@ -456,18 +457,17 @@ func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter,
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := r.ParseForm(); err != nil {
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr)
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return
}
username := claims.Username
recoveryCode := strings.TrimSpace(r.Form.Get("recovery_code"))
if username == "" || recoveryCode == "" {
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials),
ipAddr)
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return
}
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr)
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
admin, err := dataprovider.AdminExists(username)
@ -475,12 +475,11 @@ func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter,
if errors.Is(err, util.ErrNotFound) {
handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck
}
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials),
ipAddr)
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return
}
if !admin.Filters.TOTPConfig.Enabled {
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(util.NewValidationError("two factory authentication is not enabled"), util.I18n2FADisabled), ipAddr)
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(util.NewValidationError("two factory authentication is not enabled"), util.I18n2FADisabled))
return
}
for idx, code := range admin.Filters.RecoveryCodes {
@ -491,7 +490,7 @@ func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter,
if code.Secret.GetPayload() == recoveryCode {
if code.Used {
s.renderTwoFactorRecoveryPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr)
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return
}
admin.Filters.RecoveryCodes[idx].Used = true
@ -506,8 +505,7 @@ func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter,
}
}
handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials),
ipAddr)
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
}
func (s *httpdServer) handleWebAdminTwoFactorPost(w http.ResponseWriter, r *http.Request) {
@ -519,19 +517,18 @@ func (s *httpdServer) handleWebAdminTwoFactorPost(w http.ResponseWriter, r *http
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := r.ParseForm(); err != nil {
s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr)
s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return
}
username := claims.Username
passcode := strings.TrimSpace(r.Form.Get("passcode"))
if username == "" || passcode == "" {
s.renderTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials),
ipAddr)
s.renderTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return
}
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
err = handleDefenderEventLoginFailed(ipAddr, err)
s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr)
s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
admin, err := dataprovider.AdminExists(username)
@ -539,11 +536,11 @@ func (s *httpdServer) handleWebAdminTwoFactorPost(w http.ResponseWriter, r *http
if errors.Is(err, util.ErrNotFound) {
handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck
}
s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials), ipAddr)
s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials))
return
}
if !admin.Filters.TOTPConfig.Enabled {
s.renderTwoFactorPage(w, r, util.NewI18nError(common.ErrInternalFailure, util.I18n2FADisabled), ipAddr)
s.renderTwoFactorPage(w, r, util.NewI18nError(common.ErrInternalFailure, util.I18n2FADisabled))
return
}
err = admin.Filters.TOTPConfig.Secret.Decrypt()
@ -555,8 +552,7 @@ func (s *httpdServer) handleWebAdminTwoFactorPost(w http.ResponseWriter, r *http
admin.Filters.TOTPConfig.Secret.GetPayload())
if !match || err != nil {
handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck
s.renderTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials),
ipAddr)
s.renderTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return
}
s.loginAdmin(w, r, &admin, true, s.renderTwoFactorPage, ipAddr)
@ -567,44 +563,42 @@ func (s *httpdServer) handleWebAdminLoginPost(w http.ResponseWriter, r *http.Req
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := r.ParseForm(); err != nil {
s.renderAdminLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr)
s.renderAdminLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return
}
username := strings.TrimSpace(r.Form.Get("username"))
password := strings.TrimSpace(r.Form.Get("password"))
if username == "" || password == "" {
s.renderAdminLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials),
ipAddr)
s.renderAdminLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return
}
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
s.renderAdminLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr)
if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderAdminLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
admin, err := dataprovider.CheckAdminAndPass(username, password, ipAddr)
if err != nil {
handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck
s.renderAdminLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials),
ipAddr)
s.renderAdminLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return
}
s.loginAdmin(w, r, &admin, false, s.renderAdminLoginPage, ipAddr)
}
func (s *httpdServer) renderAdminLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
func (s *httpdServer) renderAdminLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := loginPage{
commonBasePage: getCommonBasePage(r),
Title: util.I18nLoginTitle,
CurrentURL: webAdminLoginPath,
Error: err,
CSRFToken: createCSRFToken(ip),
Branding: s.binding.Branding.WebAdmin,
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath),
Branding: s.binding.webAdminBranding(),
FormDisabled: s.binding.isWebAdminLoginFormDisabled(),
CheckRedirect: false,
}
if s.binding.showClientLoginURL() {
data.AltLoginURL = webClientLoginPath
data.AltLoginName = s.binding.Branding.WebClient.ShortName
data.AltLoginName = s.binding.webClientBranding().ShortName
}
if smtp.IsEnabled() && !data.FormDisabled {
data.ForgotPwdURL = webAdminForgotPwdPath
@ -622,13 +616,12 @@ func (s *httpdServer) handleWebAdminLogin(w http.ResponseWriter, r *http.Request
return
}
msg := getFlashMessage(w, r)
s.renderAdminLoginPage(w, r, msg.getI18nError(), util.GetIPFromRemoteAddress(r.RemoteAddr))
s.renderAdminLoginPage(w, r, msg.getI18nError())
}
func (s *httpdServer) handleWebAdminLogout(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
c := jwtTokenClaims{}
c.removeCookie(w, r, webBaseAdminPath)
removeCookie(w, r, webBaseAdminPath)
s.logoutOIDCUser(w, r)
http.Redirect(w, r, webAdminLoginPath, http.StatusFound)
@ -641,7 +634,7 @@ func (s *httpdServer) handleWebAdminChangePwdPost(w http.ResponseWriter, r *http
s.renderChangePasswordPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return
}
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -660,10 +653,10 @@ func (s *httpdServer) handleWebAdminPasswordResetPost(w http.ResponseWriter, r *
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
err := r.ParseForm()
if err != nil {
s.renderResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr)
s.renderResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return
}
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -672,7 +665,7 @@ func (s *httpdServer) handleWebAdminPasswordResetPost(w http.ResponseWriter, r *
admin, _, err := handleResetPassword(r, strings.TrimSpace(r.Form.Get("code")),
newPassword, confirmPassword, true)
if err != nil {
s.renderResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric), ipAddr)
s.renderResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric))
return
}
@ -688,10 +681,10 @@ func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Req
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
err := r.ParseForm()
if err != nil {
s.renderAdminSetupPage(w, r, "", ipAddr, util.NewI18nError(err, util.I18nErrorInvalidForm))
s.renderAdminSetupPage(w, r, "", util.NewI18nError(err, util.I18nErrorInvalidForm))
return
}
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -700,7 +693,7 @@ func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Req
confirmPassword := strings.TrimSpace(r.Form.Get("confirm_password"))
installCode := strings.TrimSpace(r.Form.Get("install_code"))
if installationCode != "" && installCode != resolveInstallationCode() {
s.renderAdminSetupPage(w, r, username, ipAddr,
s.renderAdminSetupPage(w, r, username,
util.NewI18nError(
util.NewValidationError(fmt.Sprintf("%v mismatch", installationCodeHint)),
util.I18nErrorSetupInstallCode),
@ -708,17 +701,17 @@ func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Req
return
}
if username == "" {
s.renderAdminSetupPage(w, r, username, ipAddr,
s.renderAdminSetupPage(w, r, username,
util.NewI18nError(util.NewValidationError("please set a username"), util.I18nError500Message))
return
}
if password == "" {
s.renderAdminSetupPage(w, r, username, ipAddr,
s.renderAdminSetupPage(w, r, username,
util.NewI18nError(util.NewValidationError("please set a password"), util.I18nError500Message))
return
}
if password != confirmPassword {
s.renderAdminSetupPage(w, r, username, ipAddr,
s.renderAdminSetupPage(w, r, username,
util.NewI18nError(errors.New("the two password fields do not match"), util.I18nErrorChangePwdNoMatch))
return
}
@ -730,7 +723,7 @@ func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Req
}
err = dataprovider.AddAdmin(&admin, username, ipAddr, "")
if err != nil {
s.renderAdminSetupPage(w, r, username, ipAddr, util.NewI18nError(err, util.I18nError500Message))
s.renderAdminSetupPage(w, r, username, util.NewI18nError(err, util.I18nError500Message))
return
}
s.loginAdmin(w, r, &admin, false, nil, ipAddr)
@ -738,7 +731,7 @@ func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Req
func (s *httpdServer) loginUser(
w http.ResponseWriter, r *http.Request, user *dataprovider.User, connectionID, ipAddr string,
isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string),
isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError),
) {
c := jwtTokenClaims{
Username: user.Username,
@ -751,7 +744,7 @@ func (s *httpdServer) loginUser(
}
audience := tokenAudienceWebClient
if user.Filters.TOTPConfig.Enabled && util.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) &&
if user.Filters.TOTPConfig.Enabled && slices.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) &&
user.CanManageMFA() && !isSecondFactorAuth {
audience = tokenAudienceWebClientPartial
}
@ -760,12 +753,10 @@ func (s *httpdServer) loginUser(
if err != nil {
logger.Warn(logSender, connectionID, "unable to set user login cookie %v", err)
updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
errorFunc(w, r, util.NewI18nError(err, util.I18nError500Message), ipAddr)
errorFunc(w, r, util.NewI18nError(err, util.I18nError500Message))
return
}
if isSecondFactorAuth {
invalidateToken(r)
}
invalidateToken(r, !isSecondFactorAuth)
if audience == tokenAudienceWebClientPartial {
redirectPath := webClientTwoFactorPath
if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) {
@ -785,7 +776,7 @@ func (s *httpdServer) loginUser(
func (s *httpdServer) loginAdmin(
w http.ResponseWriter, r *http.Request, admin *dataprovider.Admin,
isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string),
isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError),
ipAddr string,
) {
c := jwtTokenClaims{
@ -807,15 +798,13 @@ func (s *httpdServer) loginAdmin(
if err != nil {
logger.Warn(logSender, "", "unable to set admin login cookie %v", err)
if errorFunc == nil {
s.renderAdminSetupPage(w, r, admin.Username, ipAddr, util.NewI18nError(err, util.I18nError500Message))
s.renderAdminSetupPage(w, r, admin.Username, util.NewI18nError(err, util.I18nError500Message))
return
}
errorFunc(w, r, util.NewI18nError(err, util.I18nError500Message), ipAddr)
errorFunc(w, r, util.NewI18nError(err, util.I18nError500Message))
return
}
if isSecondFactorAuth {
invalidateToken(r)
}
invalidateToken(r, !isSecondFactorAuth)
if audience == tokenAudienceWebAdminPartial {
http.Redirect(w, r, webAdminTwoFactorPath, http.StatusFound)
return
@ -831,7 +820,7 @@ func (s *httpdServer) loginAdmin(
func (s *httpdServer) logout(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
invalidateToken(r)
invalidateToken(r, false)
sendAPIResponse(w, r, nil, "Your token has been invalidated", http.StatusOK)
}
@ -875,7 +864,7 @@ func (s *httpdServer) getUserToken(w http.ResponseWriter, r *http.Request) {
return
}
if user.Filters.TOTPConfig.Enabled && util.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) {
if user.Filters.TOTPConfig.Enabled && slices.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) {
passcode := r.Header.Get(otpHeaderCode)
if passcode == "" {
logger.Debug(logSender, "", "TOTP enabled for user %q and not passcode provided, authentication refused", user.Username)
@ -1021,14 +1010,14 @@ func (s *httpdServer) checkCookieExpiration(w http.ResponseWriter, r *http.Reque
if time.Until(token.Expiration()) > tokenRefreshThreshold {
return
}
if util.Contains(token.Audience(), tokenAudienceWebClient) {
s.refreshClientToken(w, r, tokenClaims)
if slices.Contains(token.Audience(), tokenAudienceWebClient) {
s.refreshClientToken(w, r, &tokenClaims)
} else {
s.refreshAdminToken(w, r, tokenClaims)
s.refreshAdminToken(w, r, &tokenClaims)
}
}
func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request, tokenClaims jwtTokenClaims) {
func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwtTokenClaims) {
user, err := dataprovider.GetUserWithGroupSettings(tokenClaims.Username, "")
if err != nil {
return
@ -1052,7 +1041,7 @@ func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request,
tokenClaims.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebClient, util.GetIPFromRemoteAddress(r.RemoteAddr)) //nolint:errcheck
}
func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request, tokenClaims jwtTokenClaims) {
func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwtTokenClaims) {
admin, err := dataprovider.AdminExists(tokenClaims.Username)
if err != nil {
return
@ -1236,6 +1225,7 @@ func (s *httpdServer) mustCheckPath(r *http.Request) bool {
func (s *httpdServer) initializeRouter() {
var hasHTTPSRedirect bool
s.tokenAuth = jwtauth.New(jwa.HS256.String(), getSigningKey(s.signingPassphrase), nil)
s.csrfTokenAuth = jwtauth.New(jwa.HS256.String(), getSigningKey(s.signingPassphrase), nil)
s.router = chi.NewRouter()
s.router.Use(middleware.RequestID)
@ -1532,16 +1522,30 @@ func (s *httpdServer) setupWebClientRoutes() {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
http.Redirect(w, r, webClientLoginPath, http.StatusFound)
})
s.router.Get(path.Join(webStaticFilesPath, "branding/webclient/logo.png"),
func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
renderPNGImage(w, r, dbBrandingConfig.getWebClientLogo())
})
s.router.Get(path.Join(webStaticFilesPath, "branding/webclient/favicon.png"),
func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
renderPNGImage(w, r, dbBrandingConfig.getWebClientFavicon())
})
s.router.Get(webClientLoginPath, s.handleClientWebLogin)
if s.binding.OIDC.isEnabled() && !s.binding.isWebClientOIDCLoginDisabled() {
s.router.Get(webClientOIDCLoginPath, s.handleWebClientOIDCLogin)
}
if !s.binding.isWebClientLoginFormDisabled() {
s.router.Post(webClientLoginPath, s.handleWebClientLoginPost)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Post(webClientLoginPath, s.handleWebClientLoginPost)
s.router.Get(webClientForgotPwdPath, s.handleWebClientForgotPwd)
s.router.Post(webClientForgotPwdPath, s.handleWebClientForgotPwdPost)
s.router.Get(webClientResetPwdPath, s.handleWebClientPasswordReset)
s.router.Post(webClientResetPwdPath, s.handleWebClientPasswordResetPost)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Post(webClientForgotPwdPath, s.handleWebClientForgotPwdPost)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Get(webClientResetPwdPath, s.handleWebClientPasswordReset)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Post(webClientResetPwdPath, s.handleWebClientPasswordResetPost)
s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie),
s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)).
Get(webClientTwoFactorPath, s.handleWebClientTwoFactor)
@ -1557,7 +1561,9 @@ func (s *httpdServer) setupWebClientRoutes() {
}
// share routes available to external users
s.router.Get(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginGet)
s.router.Post(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginPost)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Post(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginPost)
s.router.Get(webClientPubSharesPath+"/{id}/logout", s.handleClientShareLogout)
s.router.Get(webClientPubSharesPath+"/{id}", s.downloadFromShare)
s.router.Post(webClientPubSharesPath+"/{id}/partial", s.handleClientSharePartialDownload)
s.router.Get(webClientPubSharesPath+"/{id}/browse", s.handleShareGetFiles)
@ -1574,32 +1580,32 @@ func (s *httpdServer) setupWebClientRoutes() {
if s.binding.OIDC.isEnabled() {
router.Use(s.oidcTokenAuthenticator(tokenAudienceWebClient))
}
router.Use(jwtauth.Verify(s.tokenAuth, tokenFromContext, jwtauth.TokenFromCookie))
router.Use(jwtauth.Verify(s.tokenAuth, oidcTokenFromContext, jwtauth.TokenFromCookie))
router.Use(jwtAuthenticatorWebClient)
router.Get(webClientLogoutPath, s.handleWebClientLogout)
router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientFilesPath, s.handleClientGetFiles)
router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientViewPDFPath, s.handleClientViewPDF)
router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientGetPDFPath, s.handleClientGetPDF)
router.With(s.checkAuthRequirements, s.refreshCookie, verifyCSRFHeader).Get(webClientFilePath, getUserFile)
router.With(s.checkAuthRequirements, s.refreshCookie, verifyCSRFHeader).Get(webClientTasksPath+"/{id}",
router.With(s.checkAuthRequirements, s.refreshCookie, s.verifyCSRFHeader).Get(webClientFilePath, getUserFile)
router.With(s.checkAuthRequirements, s.refreshCookie, s.verifyCSRFHeader).Get(webClientTasksPath+"/{id}",
getWebTask)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader).
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader).
Post(webClientFilePath, uploadUserFile)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader).
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader).
Post(webClientExistPath, s.handleClientCheckExist)
router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientEditFilePath, s.handleClientEditFile)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader).
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader).
Delete(webClientFilesPath, deleteUserFile)
router.With(s.checkAuthRequirements, compressor.Handler, s.refreshCookie).
Get(webClientDirsPath, s.handleClientGetDirContents)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader).
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader).
Post(webClientDirsPath, createUserDir)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader).
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader).
Delete(webClientDirsPath, taskDeleteDir)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader).
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader).
Post(webClientFileActionsPath+"/move", taskRenameFsEntry)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader).
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader).
Post(webClientFileActionsPath+"/copy", taskCopyFsEntry)
router.With(s.checkAuthRequirements, s.refreshCookie).
Post(webClientDownloadZipPath, s.handleWebClientDownloadZip)
@ -1615,15 +1621,15 @@ func (s *httpdServer) setupWebClientRoutes() {
Get(webClientMFAPath, s.handleWebClientMFA)
router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.refreshCookie).
Get(webClientMFAPath+"/qrcode", getQRCode)
router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), verifyCSRFHeader).
router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader).
Post(webClientTOTPGeneratePath, generateTOTPSecret)
router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), verifyCSRFHeader).
router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader).
Post(webClientTOTPValidatePath, validateTOTPPasscode)
router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), verifyCSRFHeader).
router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader).
Post(webClientTOTPSavePath, saveTOTPConfig)
router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), verifyCSRFHeader, s.refreshCookie).
router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader, s.refreshCookie).
Get(webClientRecoveryCodesPath, getRecoveryCodes)
router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), verifyCSRFHeader).
router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader).
Post(webClientRecoveryCodesPath, generateRecoveryCodes)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), compressor.Handler, s.refreshCookie).
Get(webClientSharesPath+jsonAPISuffix, getAllShares)
@ -1637,7 +1643,7 @@ func (s *httpdServer) setupWebClientRoutes() {
Get(webClientSharePath+"/{id}", s.handleClientUpdateShareGet)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)).
Post(webClientSharePath+"/{id}", s.handleClientUpdateSharePost)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), verifyCSRFHeader).
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), s.verifyCSRFHeader).
Delete(webClientSharePath+"/{id}", deleteShare)
})
}
@ -1649,15 +1655,27 @@ func (s *httpdServer) setupWebAdminRoutes() {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
s.redirectToWebPath(w, r, webAdminLoginPath)
})
s.router.Get(path.Join(webStaticFilesPath, "branding/webadmin/logo.png"),
func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
renderPNGImage(w, r, dbBrandingConfig.getWebAdminLogo())
})
s.router.Get(path.Join(webStaticFilesPath, "branding/webadmin/favicon.png"),
func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
renderPNGImage(w, r, dbBrandingConfig.getWebAdminFavicon())
})
s.router.Get(webAdminLoginPath, s.handleWebAdminLogin)
if s.binding.OIDC.hasRoles() && !s.binding.isWebAdminOIDCLoginDisabled() {
s.router.Get(webAdminOIDCLoginPath, s.handleWebAdminOIDCLogin)
}
s.router.Get(webOAuth2RedirectPath, s.handleOAuth2TokenRedirect)
s.router.Get(webAdminSetupPath, s.handleWebAdminSetupGet)
s.router.Post(webAdminSetupPath, s.handleWebAdminSetupPost)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Post(webAdminSetupPath, s.handleWebAdminSetupPost)
if !s.binding.isWebAdminLoginFormDisabled() {
s.router.Post(webAdminLoginPath, s.handleWebAdminLoginPost)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Post(webAdminLoginPath, s.handleWebAdminLoginPost)
s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie),
s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)).
Get(webAdminTwoFactorPath, s.handleWebAdminTwoFactor)
@ -1671,16 +1689,19 @@ func (s *httpdServer) setupWebAdminRoutes() {
s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)).
Post(webAdminTwoFactorRecoveryPath, s.handleWebAdminTwoFactorRecoveryPost)
s.router.Get(webAdminForgotPwdPath, s.handleWebAdminForgotPwd)
s.router.Post(webAdminForgotPwdPath, s.handleWebAdminForgotPwdPost)
s.router.Get(webAdminResetPwdPath, s.handleWebAdminPasswordReset)
s.router.Post(webAdminResetPwdPath, s.handleWebAdminPasswordResetPost)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Post(webAdminForgotPwdPath, s.handleWebAdminForgotPwdPost)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Get(webAdminResetPwdPath, s.handleWebAdminPasswordReset)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Post(webAdminResetPwdPath, s.handleWebAdminPasswordResetPost)
}
s.router.Group(func(router chi.Router) {
if s.binding.OIDC.isEnabled() {
router.Use(s.oidcTokenAuthenticator(tokenAudienceWebAdmin))
}
router.Use(jwtauth.Verify(s.tokenAuth, tokenFromContext, jwtauth.TokenFromCookie))
router.Use(jwtauth.Verify(s.tokenAuth, oidcTokenFromContext, jwtauth.TokenFromCookie))
router.Use(jwtAuthenticatorWebAdmin)
router.Get(webLogoutPath, s.handleWebAdminLogout)
@ -1692,12 +1713,12 @@ func (s *httpdServer) setupWebAdminRoutes() {
router.With(s.refreshCookie, s.requireBuiltinLogin).Get(webAdminMFAPath, s.handleWebAdminMFA)
router.With(s.refreshCookie, s.requireBuiltinLogin).Get(webAdminMFAPath+"/qrcode", getQRCode)
router.With(verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPGeneratePath, generateTOTPSecret)
router.With(verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPValidatePath, validateTOTPPasscode)
router.With(verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPSavePath, saveTOTPConfig)
router.With(verifyCSRFHeader, s.requireBuiltinLogin, s.refreshCookie).Get(webAdminRecoveryCodesPath,
router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPGeneratePath, generateTOTPSecret)
router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPValidatePath, validateTOTPPasscode)
router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPSavePath, saveTOTPConfig)
router.With(s.verifyCSRFHeader, s.requireBuiltinLogin, s.refreshCookie).Get(webAdminRecoveryCodesPath,
getRecoveryCodes)
router.With(verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminRecoveryCodesPath, generateRecoveryCodes)
router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminRecoveryCodesPath, generateRecoveryCodes)
router.Group(func(router chi.Router) {
router.Use(s.checkAuthRequirements)
@ -1724,7 +1745,7 @@ func (s *httpdServer) setupWebAdminRoutes() {
Get(webGroupPath+"/{name}", s.handleWebUpdateGroupGet)
router.With(s.checkPerm(dataprovider.PermAdminManageGroups)).Post(webGroupPath+"/{name}",
s.handleWebUpdateGroupPost)
router.With(s.checkPerm(dataprovider.PermAdminManageGroups), verifyCSRFHeader).
router.With(s.checkPerm(dataprovider.PermAdminManageGroups), s.verifyCSRFHeader).
Delete(webGroupPath+"/{name}", deleteGroup)
router.With(s.checkPerm(dataprovider.PermAdminViewConnections), s.refreshCookie).
Get(webConnectionsPath, s.handleWebGetConnections)
@ -1750,25 +1771,25 @@ func (s *httpdServer) setupWebAdminRoutes() {
router.With(s.checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath, s.handleWebAddAdminPost)
router.With(s.checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath+"/{username}",
s.handleWebUpdateAdminPost)
router.With(s.checkPerm(dataprovider.PermAdminManageAdmins), verifyCSRFHeader).
router.With(s.checkPerm(dataprovider.PermAdminManageAdmins), s.verifyCSRFHeader).
Delete(webAdminPath+"/{username}", deleteAdmin)
router.With(s.checkPerm(dataprovider.PermAdminDisableMFA), verifyCSRFHeader).
router.With(s.checkPerm(dataprovider.PermAdminDisableMFA), s.verifyCSRFHeader).
Put(webAdminPath+"/{username}/2fa/disable", disableAdmin2FA)
router.With(s.checkPerm(dataprovider.PermAdminCloseConnections), verifyCSRFHeader).
router.With(s.checkPerm(dataprovider.PermAdminCloseConnections), s.verifyCSRFHeader).
Delete(webConnectionsPath+"/{connectionID}", handleCloseConnection)
router.With(s.checkPerm(dataprovider.PermAdminManageFolders), s.refreshCookie).
Get(webFolderPath+"/{name}", s.handleWebUpdateFolderGet)
router.With(s.checkPerm(dataprovider.PermAdminManageFolders)).Post(webFolderPath+"/{name}",
s.handleWebUpdateFolderPost)
router.With(s.checkPerm(dataprovider.PermAdminManageFolders), verifyCSRFHeader).
router.With(s.checkPerm(dataprovider.PermAdminManageFolders), s.verifyCSRFHeader).
Delete(webFolderPath+"/{name}", deleteFolder)
router.With(s.checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader).
router.With(s.checkPerm(dataprovider.PermAdminQuotaScans), s.verifyCSRFHeader).
Post(webScanVFolderPath+"/{name}", startFolderQuotaScan)
router.With(s.checkPerm(dataprovider.PermAdminDeleteUsers), verifyCSRFHeader).
router.With(s.checkPerm(dataprovider.PermAdminDeleteUsers), s.verifyCSRFHeader).
Delete(webUserPath+"/{username}", deleteUser)
router.With(s.checkPerm(dataprovider.PermAdminDisableMFA), verifyCSRFHeader).
router.With(s.checkPerm(dataprovider.PermAdminDisableMFA), s.verifyCSRFHeader).
Put(webUserPath+"/{username}/2fa/disable", disableUser2FA)
router.With(s.checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader).
router.With(s.checkPerm(dataprovider.PermAdminQuotaScans), s.verifyCSRFHeader).
Post(webQuotaScanPath+"/{username}", startUserQuotaScan)
router.With(s.checkPerm(dataprovider.PermAdminManageSystem)).Get(webMaintenancePath, s.handleWebMaintenance)
router.With(s.checkPerm(dataprovider.PermAdminManageSystem)).Get(webBackupPath, dumpData)
@ -1795,7 +1816,7 @@ func (s *httpdServer) setupWebAdminRoutes() {
Get(webAdminEventActionPath+"/{name}", s.handleWebUpdateEventActionGet)
router.With(s.checkPerm(dataprovider.PermAdminManageEventRules)).Post(webAdminEventActionPath+"/{name}",
s.handleWebUpdateEventActionPost)
router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), verifyCSRFHeader).
router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), s.verifyCSRFHeader).
Delete(webAdminEventActionPath+"/{name}", deleteEventAction)
router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), compressor.Handler, s.refreshCookie).
Get(webAdminEventRulesPath+jsonAPISuffix, getAllRules)
@ -1809,9 +1830,9 @@ func (s *httpdServer) setupWebAdminRoutes() {
Get(webAdminEventRulePath+"/{name}", s.handleWebUpdateEventRuleGet)
router.With(s.checkPerm(dataprovider.PermAdminManageEventRules)).Post(webAdminEventRulePath+"/{name}",
s.handleWebUpdateEventRulePost)
router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), verifyCSRFHeader).
router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), s.verifyCSRFHeader).
Delete(webAdminEventRulePath+"/{name}", deleteEventRule)
router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), verifyCSRFHeader).
router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), s.verifyCSRFHeader).
Post(webAdminEventRulePath+"/run/{name}", runOnDemandRule)
router.With(s.checkPerm(dataprovider.PermAdminManageRoles), s.refreshCookie).
Get(webAdminRolesPath, s.handleWebGetRoles)
@ -1824,7 +1845,7 @@ func (s *httpdServer) setupWebAdminRoutes() {
Get(webAdminRolePath+"/{name}", s.handleWebUpdateRoleGet)
router.With(s.checkPerm(dataprovider.PermAdminManageRoles)).Post(webAdminRolePath+"/{name}",
s.handleWebUpdateRolePost)
router.With(s.checkPerm(dataprovider.PermAdminManageRoles), verifyCSRFHeader).
router.With(s.checkPerm(dataprovider.PermAdminManageRoles), s.verifyCSRFHeader).
Delete(webAdminRolePath+"/{name}", deleteRole)
router.With(s.checkPerm(dataprovider.PermAdminViewEvents), s.refreshCookie).Get(webEventsPath,
s.handleWebGetEvents)
@ -1845,14 +1866,14 @@ func (s *httpdServer) setupWebAdminRoutes() {
s.handleWebUpdateIPListEntryGet)
router.With(s.checkPerm(dataprovider.PermAdminManageIPLists)).Post(webIPListPath+"/{type}/{ipornet}",
s.handleWebUpdateIPListEntryPost)
router.With(s.checkPerm(dataprovider.PermAdminManageIPLists), verifyCSRFHeader).
router.With(s.checkPerm(dataprovider.PermAdminManageIPLists), s.verifyCSRFHeader).
Delete(webIPListPath+"/{type}/{ipornet}", deleteIPListEntry)
router.With(s.checkPerm(dataprovider.PermAdminManageSystem), s.refreshCookie).Get(webConfigsPath, s.handleWebConfigs)
router.With(s.checkPerm(dataprovider.PermAdminManageSystem)).Post(webConfigsPath, s.handleWebConfigsPost)
router.With(s.checkPerm(dataprovider.PermAdminManageSystem), verifyCSRFHeader, s.refreshCookie).
router.With(s.checkPerm(dataprovider.PermAdminManageSystem), s.verifyCSRFHeader, s.refreshCookie).
Post(webConfigsPath+"/smtp/test", testSMTPConfig)
router.With(s.checkPerm(dataprovider.PermAdminManageSystem), verifyCSRFHeader, s.refreshCookie).
Post(webOAuth2TokenPath, handleSMTPOAuth2TokenRequestPost)
router.With(s.checkPerm(dataprovider.PermAdminManageSystem), s.verifyCSRFHeader, s.refreshCookie).
Post(webOAuth2TokenPath, s.handleSMTPOAuth2TokenRequestPost)
})
})
}

View file

@ -25,12 +25,14 @@ import (
"net/url"
"os"
"path/filepath"
"slices"
"sort"
"strconv"
"strings"
"time"
"github.com/go-chi/render"
"github.com/rs/xid"
"github.com/sftpgo/sdk"
sdkkms "github.com/sftpgo/sdk/kms"
@ -150,6 +152,7 @@ type basePage struct {
HasSearcher bool
HasExternalLogin bool
LoggedUser *dataprovider.Admin
IsLoggedToShare bool
Branding UIBranding
}
@ -327,6 +330,7 @@ type configsPage struct {
RedactedSecret string
OAuth2TokenURL string
OAuth2RedirectURL string
WebClientBranding UIBranding
Error *util.I18nError
}
@ -612,10 +616,10 @@ func isServerManagerResource(currentURL string) bool {
currentURL == webConfigsPath
}
func (s *httpdServer) getBasePageData(title, currentURL string, r *http.Request) basePage {
func (s *httpdServer) getBasePageData(title, currentURL string, w http.ResponseWriter, r *http.Request) basePage {
var csrfToken string
if currentURL != "" {
csrfToken = createCSRFToken(util.GetIPFromRemoteAddress(r.RemoteAddr))
csrfToken = createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath)
}
return basePage{
commonBasePage: getCommonBasePage(r),
@ -660,7 +664,7 @@ func (s *httpdServer) getBasePageData(title, currentURL string, r *http.Request)
HasSearcher: plugin.Handler.HasSearcher(),
HasExternalLogin: isLoggedInWithOIDC(r),
CSRFToken: csrfToken,
Branding: s.binding.Branding.WebAdmin,
Branding: s.binding.webAdminBranding(),
}
}
@ -675,7 +679,7 @@ func (s *httpdServer) renderMessagePageWithString(w http.ResponseWriter, r *http
err error, message, text string,
) {
data := messagePage{
basePage: s.getBasePageData(title, "", r),
basePage: s.getBasePageData(title, "", w, r),
Error: getI18nError(err),
Success: message,
Text: text,
@ -710,60 +714,60 @@ func (s *httpdServer) renderNotFoundPage(w http.ResponseWriter, r *http.Request,
util.NewI18nError(err, util.I18nError404Message), "")
}
func (s *httpdServer) renderForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
func (s *httpdServer) renderForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := forgotPwdPage{
commonBasePage: getCommonBasePage(r),
CurrentURL: webAdminForgotPwdPath,
Error: err,
CSRFToken: createCSRFToken(ip),
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath),
LoginURL: webAdminLoginPath,
Title: util.I18nForgotPwdTitle,
Branding: s.binding.Branding.WebAdmin,
Branding: s.binding.webAdminBranding(),
}
renderAdminTemplate(w, templateForgotPassword, data)
}
func (s *httpdServer) renderResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
func (s *httpdServer) renderResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := resetPwdPage{
commonBasePage: getCommonBasePage(r),
CurrentURL: webAdminResetPwdPath,
Error: err,
CSRFToken: createCSRFToken(ip),
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath),
LoginURL: webAdminLoginPath,
Title: util.I18nResetPwdTitle,
Branding: s.binding.Branding.WebAdmin,
Branding: s.binding.webAdminBranding(),
}
renderAdminTemplate(w, templateResetPassword, data)
}
func (s *httpdServer) renderTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
func (s *httpdServer) renderTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := twoFactorPage{
commonBasePage: getCommonBasePage(r),
Title: pageTwoFactorTitle,
CurrentURL: webAdminTwoFactorPath,
Error: err,
CSRFToken: createCSRFToken(ip),
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath),
RecoveryURL: webAdminTwoFactorRecoveryPath,
Branding: s.binding.Branding.WebAdmin,
Branding: s.binding.webAdminBranding(),
}
renderAdminTemplate(w, templateTwoFactor, data)
}
func (s *httpdServer) renderTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
func (s *httpdServer) renderTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := twoFactorPage{
commonBasePage: getCommonBasePage(r),
Title: pageTwoFactorRecoveryTitle,
CurrentURL: webAdminTwoFactorRecoveryPath,
Error: err,
CSRFToken: createCSRFToken(ip),
Branding: s.binding.Branding.WebAdmin,
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath),
Branding: s.binding.webAdminBranding(),
}
renderAdminTemplate(w, templateTwoFactorRecovery, data)
}
func (s *httpdServer) renderMFAPage(w http.ResponseWriter, r *http.Request) {
data := mfaPage{
basePage: s.getBasePageData(pageMFATitle, webAdminMFAPath, r),
basePage: s.getBasePageData(pageMFATitle, webAdminMFAPath, w, r),
TOTPConfigs: mfa.GetAvailableTOTPConfigNames(),
GenerateTOTPURL: webAdminTOTPGeneratePath,
ValidateTOTPURL: webAdminTOTPValidatePath,
@ -782,7 +786,7 @@ func (s *httpdServer) renderMFAPage(w http.ResponseWriter, r *http.Request) {
func (s *httpdServer) renderProfilePage(w http.ResponseWriter, r *http.Request, err error) {
data := profilePage{
basePage: s.getBasePageData(util.I18nProfileTitle, webAdminProfilePath, r),
basePage: s.getBasePageData(util.I18nProfileTitle, webAdminProfilePath, w, r),
Error: getI18nError(err),
}
admin, err := dataprovider.AdminExists(data.LoggedUser.Username)
@ -799,7 +803,7 @@ func (s *httpdServer) renderProfilePage(w http.ResponseWriter, r *http.Request,
func (s *httpdServer) renderChangePasswordPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := changePasswordPage{
basePage: s.getBasePageData(util.I18nChangePwdTitle, webChangeAdminPwdPath, r),
basePage: s.getBasePageData(util.I18nChangePwdTitle, webChangeAdminPwdPath, w, r),
Error: err,
}
@ -808,7 +812,7 @@ func (s *httpdServer) renderChangePasswordPage(w http.ResponseWriter, r *http.Re
func (s *httpdServer) renderMaintenancePage(w http.ResponseWriter, r *http.Request, err error) {
data := maintenancePage{
basePage: s.getBasePageData(util.I18nMaintenanceTitle, webMaintenancePath, r),
basePage: s.getBasePageData(util.I18nMaintenanceTitle, webMaintenancePath, w, r),
BackupPath: webBackupPath,
RestorePath: webRestorePath,
Error: getI18nError(err),
@ -830,30 +834,31 @@ func (s *httpdServer) renderConfigsPage(w http.ResponseWriter, r *http.Request,
configs.ACME.HTTP01Challenge.Port = 80
}
data := configsPage{
basePage: s.getBasePageData(util.I18nConfigsTitle, webConfigsPath, r),
basePage: s.getBasePageData(util.I18nConfigsTitle, webConfigsPath, w, r),
Configs: configs,
ConfigSection: section,
RedactedSecret: redactedSecret,
OAuth2TokenURL: webOAuth2TokenPath,
OAuth2RedirectURL: webOAuth2RedirectPath,
WebClientBranding: s.binding.webClientBranding(),
Error: getI18nError(err),
}
renderAdminTemplate(w, templateConfigs, data)
}
func (s *httpdServer) renderAdminSetupPage(w http.ResponseWriter, r *http.Request, username, ip string, err *util.I18nError) {
func (s *httpdServer) renderAdminSetupPage(w http.ResponseWriter, r *http.Request, username string, err *util.I18nError) {
data := setupPage{
commonBasePage: getCommonBasePage(r),
Title: util.I18nSetupTitle,
CurrentURL: webAdminSetupPath,
CSRFToken: createCSRFToken(ip),
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath),
Username: username,
HasInstallationCode: installationCode != "",
InstallationCodeHint: installationCodeHint,
HideSupportLink: hideSupportLink,
Error: err,
Branding: s.binding.Branding.WebAdmin,
Branding: s.binding.webAdminBranding(),
}
renderAdminTemplate(w, templateSetup, data)
@ -876,7 +881,7 @@ func (s *httpdServer) renderAddUpdateAdminPage(w http.ResponseWriter, r *http.Re
title = util.I18nUpdateAdminTitle
}
data := adminPage{
basePage: s.getBasePageData(title, currentURL, r),
basePage: s.getBasePageData(title, currentURL, w, r),
Admin: admin,
Groups: groups,
Roles: roles,
@ -917,7 +922,7 @@ func (s *httpdServer) renderUserPage(w http.ResponseWriter, r *http.Request, use
}
}
user.FsConfig.RedactedSecret = redactedSecret
basePage := s.getBasePageData(title, currentURL, r)
basePage := s.getBasePageData(title, currentURL, w, r)
if (mode == userPageModeAdd || mode == userPageModeTemplate) && len(user.Groups) == 0 && admin != nil {
for _, group := range admin.Groups {
user.Groups = append(user.Groups, sdk.GroupMapping{
@ -982,7 +987,7 @@ func (s *httpdServer) renderIPListPage(w http.ResponseWriter, r *http.Request, e
currentURL = fmt.Sprintf("%s/%d/%s", webIPListPath, entry.Type, url.PathEscape(entry.IPOrNet))
}
data := ipListPage{
basePage: s.getBasePageData(title, currentURL, r),
basePage: s.getBasePageData(title, currentURL, w, r),
Error: getI18nError(err),
Entry: &entry,
Mode: mode,
@ -1003,7 +1008,7 @@ func (s *httpdServer) renderRolePage(w http.ResponseWriter, r *http.Request, rol
currentURL = fmt.Sprintf("%s/%s", webAdminRolePath, url.PathEscape(role.Name))
}
data := rolePage{
basePage: s.getBasePageData(title, currentURL, r),
basePage: s.getBasePageData(title, currentURL, w, r),
Error: getI18nError(err),
Role: &role,
Mode: mode,
@ -1033,7 +1038,7 @@ func (s *httpdServer) renderGroupPage(w http.ResponseWriter, r *http.Request, gr
group.UserSettings.FsConfig.SetEmptySecretsIfNil()
data := groupPage{
basePage: s.getBasePageData(title, currentURL, r),
basePage: s.getBasePageData(title, currentURL, w, r),
Error: getI18nError(err),
Group: &group,
Mode: mode,
@ -1078,7 +1083,7 @@ func (s *httpdServer) renderEventActionPage(w http.ResponseWriter, r *http.Reque
}
data := eventActionPage{
basePage: s.getBasePageData(title, currentURL, r),
basePage: s.getBasePageData(title, currentURL, w, r),
Action: action,
ActionTypes: dataprovider.EventActionTypes,
FsActions: dataprovider.FsActionTypes,
@ -1108,7 +1113,7 @@ func (s *httpdServer) renderEventRulePage(w http.ResponseWriter, r *http.Request
}
data := eventRulePage{
basePage: s.getBasePageData(title, currentURL, r),
basePage: s.getBasePageData(title, currentURL, w, r),
Rule: rule,
TriggerTypes: dataprovider.EventTriggerTypes,
Actions: actions,
@ -1142,7 +1147,7 @@ func (s *httpdServer) renderFolderPage(w http.ResponseWriter, r *http.Request, f
folder.FsConfig.SetEmptySecretsIfNil()
data := folderPage{
basePage: s.getBasePageData(title, currentURL, r),
basePage: s.getBasePageData(title, currentURL, w, r),
Error: getI18nError(err),
Folder: folder,
Mode: mode,
@ -1484,13 +1489,13 @@ func getFiltersFromUserPostFields(r *http.Request) (sdk.BaseUserFilters, error)
filters.PasswordStrength = passwordStrength
filters.AccessTime = getAccessTimeRestrictionsFromPostFields(r)
hooks := r.Form["hooks"]
if util.Contains(hooks, "external_auth_disabled") {
if slices.Contains(hooks, "external_auth_disabled") {
filters.Hooks.ExternalAuthDisabled = true
}
if util.Contains(hooks, "pre_login_disabled") {
if slices.Contains(hooks, "pre_login_disabled") {
filters.Hooks.PreLoginDisabled = true
}
if util.Contains(hooks, "check_password_disabled") {
if slices.Contains(hooks, "check_password_disabled") {
filters.Hooks.CheckPasswordDisabled = true
}
filters.IsAnonymous = r.Form.Get("is_anonymous") != ""
@ -1580,7 +1585,7 @@ func getGCSConfig(r *http.Request) (vfs.GCSFsConfig, error) {
config.AutomaticCredentials = 0
}
credentials, _, err := r.FormFile("gcs_credential_file")
if err == http.ErrMissingFile {
if errors.Is(err, http.ErrMissingFile) {
return config, nil
}
if err != nil {
@ -2211,7 +2216,7 @@ func getFoldersRetentionFromPostFields(r *http.Request) ([]dataprovider.FolderRe
res = append(res, dataprovider.FolderRetention{
Path: p,
Retention: retention,
DeleteEmptyDirs: util.Contains(opts, "1"),
DeleteEmptyDirs: slices.Contains(opts, "1"),
})
}
}
@ -2553,9 +2558,9 @@ func getEventRuleActionsFromPostFields(r *http.Request) []dataprovider.EventActi
},
Order: order + 1,
Options: dataprovider.EventActionOptions{
IsFailureAction: util.Contains(options, "1"),
StopOnFailure: util.Contains(options, "2"),
ExecuteSync: util.Contains(options, "3"),
IsFailureAction: slices.Contains(options, "1"),
StopOnFailure: slices.Contains(options, "2"),
ExecuteSync: slices.Contains(options, "3"),
},
})
}
@ -2758,31 +2763,95 @@ func getSMTPConfigsFromPostFields(r *http.Request) *dataprovider.SMTPConfigs {
}
}
func getImageInputBytes(r *http.Request, fieldName, removeFieldName string, defaultVal []byte) ([]byte, error) {
var result []byte
remove := r.Form.Get(removeFieldName)
if remove == "" || remove == "0" {
result = defaultVal
}
f, _, err := r.FormFile(fieldName)
if err != nil {
if errors.Is(err, http.ErrMissingFile) {
return result, nil
}
return nil, err
}
defer f.Close()
return io.ReadAll(f)
}
func getBrandingConfigFromPostFields(r *http.Request, config *dataprovider.BrandingConfigs) (
*dataprovider.BrandingConfigs, error,
) {
if config == nil {
config = &dataprovider.BrandingConfigs{}
}
adminLogo, err := getImageInputBytes(r, "branding_webadmin_logo", "branding_webadmin_logo_remove", config.WebAdmin.Logo)
if err != nil {
return nil, util.NewI18nError(err, util.I18nErrorInvalidForm)
}
adminFavicon, err := getImageInputBytes(r, "branding_webadmin_favicon", "branding_webadmin_favicon_remove",
config.WebAdmin.Favicon)
if err != nil {
return nil, util.NewI18nError(err, util.I18nErrorInvalidForm)
}
clientLogo, err := getImageInputBytes(r, "branding_webclient_logo", "branding_webclient_logo_remove",
config.WebClient.Logo)
if err != nil {
return nil, util.NewI18nError(err, util.I18nErrorInvalidForm)
}
clientFavicon, err := getImageInputBytes(r, "branding_webclient_favicon", "branding_webclient_favicon_remove",
config.WebClient.Favicon)
if err != nil {
return nil, util.NewI18nError(err, util.I18nErrorInvalidForm)
}
branding := &dataprovider.BrandingConfigs{
WebAdmin: dataprovider.BrandingConfig{
Name: strings.TrimSpace(r.Form.Get("branding_webadmin_name")),
ShortName: strings.TrimSpace(r.Form.Get("branding_webadmin_short_name")),
Logo: adminLogo,
Favicon: adminFavicon,
DisclaimerName: strings.TrimSpace(r.Form.Get("branding_webadmin_disclaimer_name")),
DisclaimerURL: strings.TrimSpace(r.Form.Get("branding_webadmin_disclaimer_url")),
},
WebClient: dataprovider.BrandingConfig{
Name: strings.TrimSpace(r.Form.Get("branding_webclient_name")),
ShortName: strings.TrimSpace(r.Form.Get("branding_webclient_short_name")),
Logo: clientLogo,
Favicon: clientFavicon,
DisclaimerName: strings.TrimSpace(r.Form.Get("branding_webclient_disclaimer_name")),
DisclaimerURL: strings.TrimSpace(r.Form.Get("branding_webclient_disclaimer_url")),
},
}
return branding, nil
}
func (s *httpdServer) handleWebAdminForgotPwd(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
if !smtp.IsEnabled() {
s.renderNotFoundPage(w, r, errors.New("this page does not exist"))
return
}
s.renderForgotPwdPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
s.renderForgotPwdPage(w, r, nil)
}
func (s *httpdServer) handleWebAdminForgotPwdPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
err := r.ParseForm()
if err != nil {
s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr)
s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return
}
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
err = handleForgotPassword(r, r.Form.Get("username"), true)
if err != nil {
s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric), ipAddr)
s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric))
return
}
http.Redirect(w, r, webAdminResetPwdPath, http.StatusFound)
@ -2794,17 +2863,17 @@ func (s *httpdServer) handleWebAdminPasswordReset(w http.ResponseWriter, r *http
s.renderNotFoundPage(w, r, errors.New("this page does not exist"))
return
}
s.renderResetPwdPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
s.renderResetPwdPage(w, r, nil)
}
func (s *httpdServer) handleWebAdminTwoFactor(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
s.renderTwoFactorPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
s.renderTwoFactorPage(w, r, nil)
}
func (s *httpdServer) handleWebAdminTwoFactorRecovery(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
s.renderTwoFactorRecoveryPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
s.renderTwoFactorRecoveryPage(w, r, nil)
}
func (s *httpdServer) handleWebAdminMFA(w http.ResponseWriter, r *http.Request) {
@ -2830,7 +2899,7 @@ func (s *httpdServer) handleWebAdminProfilePost(w http.ResponseWriter, r *http.R
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -2875,7 +2944,7 @@ func (s *httpdServer) handleWebRestore(w http.ResponseWriter, r *http.Request) {
defer r.MultipartForm.RemoveAll() //nolint:errcheck
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -2936,7 +3005,7 @@ func getAllAdmins(w http.ResponseWriter, r *http.Request) {
func (s *httpdServer) handleGetWebAdmins(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := s.getBasePageData(util.I18nAdminsTitle, webAdminsPath, r)
data := s.getBasePageData(util.I18nAdminsTitle, webAdminsPath, w, r)
renderAdminTemplate(w, templateAdmins, data)
}
@ -2946,7 +3015,7 @@ func (s *httpdServer) handleWebAdminSetupGet(w http.ResponseWriter, r *http.Requ
http.Redirect(w, r, webAdminLoginPath, http.StatusFound)
return
}
s.renderAdminSetupPage(w, r, "", util.GetIPFromRemoteAddress(r.RemoteAddr), nil)
s.renderAdminSetupPage(w, r, "", nil)
}
func (s *httpdServer) handleWebAddAdminGet(w http.ResponseWriter, r *http.Request) {
@ -2987,7 +3056,7 @@ func (s *httpdServer) handleWebAddAdminPost(w http.ResponseWriter, r *http.Reque
admin.Password = util.GenerateUniqueID()
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -3018,7 +3087,7 @@ func (s *httpdServer) handleWebUpdateAdminPost(w http.ResponseWriter, r *http.Re
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -3071,7 +3140,7 @@ func (s *httpdServer) handleWebUpdateAdminPost(w http.ResponseWriter, r *http.Re
func (s *httpdServer) handleWebDefenderPage(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := defenderHostsPage{
basePage: s.getBasePageData(util.I18nDefenderTitle, webDefenderPath, r),
basePage: s.getBasePageData(util.I18nDefenderTitle, webDefenderPath, w, r),
DefenderHostsURL: webDefenderHostsPath,
}
@ -3105,7 +3174,7 @@ func (s *httpdServer) handleGetWebUsers(w http.ResponseWriter, r *http.Request)
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
}
data := s.getBasePageData(util.I18nUsersTitle, webUsersPath, r)
data := s.getBasePageData(util.I18nUsersTitle, webUsersPath, w, r)
renderAdminTemplate(w, templateUsers, data)
}
@ -3144,7 +3213,7 @@ func (s *httpdServer) handleWebTemplateFolderPost(w http.ResponseWriter, r *http
defer r.MultipartForm.RemoveAll() //nolint:errcheck
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -3244,7 +3313,7 @@ func (s *httpdServer) handleWebTemplateUserPost(w http.ResponseWriter, r *http.R
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -3341,7 +3410,7 @@ func (s *httpdServer) handleWebAddUserPost(w http.ResponseWriter, r *http.Reques
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -3387,7 +3456,7 @@ func (s *httpdServer) handleWebUpdateUserPost(w http.ResponseWriter, r *http.Req
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -3425,7 +3494,7 @@ func (s *httpdServer) handleWebUpdateUserPost(w http.ResponseWriter, r *http.Req
func (s *httpdServer) handleWebGetStatus(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := statusPage{
basePage: s.getBasePageData(util.I18nStatusTitle, webStatusPath, r),
basePage: s.getBasePageData(util.I18nStatusTitle, webStatusPath, w, r),
Status: getServicesStatus(),
}
renderAdminTemplate(w, templateStatus, data)
@ -3439,7 +3508,7 @@ func (s *httpdServer) handleWebGetConnections(w http.ResponseWriter, r *http.Req
return
}
data := s.getBasePageData(util.I18nSessionsTitle, webConnectionsPath, r)
data := s.getBasePageData(util.I18nSessionsTitle, webConnectionsPath, w, r)
renderAdminTemplate(w, templateConnections, data)
}
@ -3464,7 +3533,7 @@ func (s *httpdServer) handleWebAddFolderPost(w http.ResponseWriter, r *http.Requ
defer r.MultipartForm.RemoveAll() //nolint:errcheck
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -3525,7 +3594,7 @@ func (s *httpdServer) handleWebUpdateFolderPost(w http.ResponseWriter, r *http.R
defer r.MultipartForm.RemoveAll() //nolint:errcheck
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -3588,7 +3657,7 @@ func getAllFolders(w http.ResponseWriter, r *http.Request) {
func (s *httpdServer) handleWebGetFolders(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := s.getBasePageData(util.I18nFoldersTitle, webFoldersPath, r)
data := s.getBasePageData(util.I18nFoldersTitle, webFoldersPath, w, r)
renderAdminTemplate(w, templateFolders, data)
}
@ -3626,7 +3695,7 @@ func getAllGroups(w http.ResponseWriter, r *http.Request) {
func (s *httpdServer) handleWebGetGroups(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := s.getBasePageData(util.I18nGroupsTitle, webGroupsPath, r)
data := s.getBasePageData(util.I18nGroupsTitle, webGroupsPath, w, r)
renderAdminTemplate(w, templateGroups, data)
}
@ -3648,7 +3717,7 @@ func (s *httpdServer) handleWebAddGroupPost(w http.ResponseWriter, r *http.Reque
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -3695,7 +3764,7 @@ func (s *httpdServer) handleWebUpdateGroupPost(w http.ResponseWriter, r *http.Re
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -3748,7 +3817,7 @@ func getAllActions(w http.ResponseWriter, r *http.Request) {
func (s *httpdServer) handleWebGetEventActions(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := s.getBasePageData(util.I18nActionsTitle, webAdminEventActionsPath, r)
data := s.getBasePageData(util.I18nActionsTitle, webAdminEventActionsPath, w, r)
renderAdminTemplate(w, templateEventActions, data)
}
@ -3773,7 +3842,7 @@ func (s *httpdServer) handleWebAddEventActionPost(w http.ResponseWriter, r *http
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -3819,7 +3888,7 @@ func (s *httpdServer) handleWebUpdateEventActionPost(w http.ResponseWriter, r *h
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -3858,7 +3927,7 @@ func getAllRules(w http.ResponseWriter, r *http.Request) {
func (s *httpdServer) handleWebGetEventRules(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := s.getBasePageData(util.I18nRulesTitle, webAdminEventRulesPath, r)
data := s.getBasePageData(util.I18nRulesTitle, webAdminEventRulesPath, w, r)
renderAdminTemplate(w, templateEventRules, data)
}
@ -3884,7 +3953,7 @@ func (s *httpdServer) handleWebAddEventRulePost(w http.ResponseWriter, r *http.R
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
err = verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr)
err = verifyCSRFToken(r, s.csrfTokenAuth)
if err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
@ -3931,7 +4000,7 @@ func (s *httpdServer) handleWebUpdateEventRulePost(w http.ResponseWriter, r *htt
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -3978,7 +4047,7 @@ func getAllRoles(w http.ResponseWriter, r *http.Request) {
func (s *httpdServer) handleWebGetRoles(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := s.getBasePageData(util.I18nRolesTitle, webAdminRolesPath, r)
data := s.getBasePageData(util.I18nRolesTitle, webAdminRolesPath, w, r)
renderAdminTemplate(w, templateRoles, data)
}
@ -4001,7 +4070,7 @@ func (s *httpdServer) handleWebAddRolePost(w http.ResponseWriter, r *http.Reques
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -4047,7 +4116,7 @@ func (s *httpdServer) handleWebUpdateRolePost(w http.ResponseWriter, r *http.Req
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -4065,7 +4134,7 @@ func (s *httpdServer) handleWebGetEvents(w http.ResponseWriter, r *http.Request)
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := eventsPage{
basePage: s.getBasePageData(util.I18nEventsTitle, webEventsPath, r),
basePage: s.getBasePageData(util.I18nEventsTitle, webEventsPath, w, r),
FsEventsSearchURL: webEventsFsSearchPath,
ProviderEventsSearchURL: webEventsProviderSearchPath,
LogEventsSearchURL: webEventsLogSearchPath,
@ -4077,7 +4146,7 @@ func (s *httpdServer) handleWebIPListsPage(w http.ResponseWriter, r *http.Reques
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
rtlStatus, rtlProtocols := common.Config.GetRateLimitersStatus()
data := ipListsPage{
basePage: s.getBasePageData(util.I18nIPListsTitle, webIPListsPath, r),
basePage: s.getBasePageData(util.I18nIPListsTitle, webIPListsPath, w, r),
RateLimitersStatus: rtlStatus,
RateLimitersProtocols: strings.Join(rtlProtocols, ", "),
IsAllowListEnabled: common.Config.IsAllowListEnabled(),
@ -4115,7 +4184,7 @@ func (s *httpdServer) handleWebAddIPListEntryPost(w http.ResponseWriter, r *http
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -4170,7 +4239,7 @@ func (s *httpdServer) handleWebUpdateIPListEntryPost(w http.ResponseWriter, r *h
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -4206,13 +4275,15 @@ func (s *httpdServer) handleWebConfigsPost(w http.ResponseWriter, r *http.Reques
s.renderInternalServerErrorPage(w, r, err)
return
}
err = r.ParseForm()
err = r.ParseMultipartForm(maxRequestSize)
if err != nil {
s.renderBadRequestPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return
}
defer r.MultipartForm.RemoveAll() //nolint:errcheck
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -4236,6 +4307,15 @@ func (s *httpdServer) handleWebConfigsPost(w http.ResponseWriter, r *http.Reques
smtpConfigs := getSMTPConfigsFromPostFields(r)
updateSMTPSecrets(smtpConfigs, configs.SMTP)
configs.SMTP = smtpConfigs
case "branding_submit":
configSection = 4
brandingConfigs, err := getBrandingConfigFromPostFields(r, configs.Branding)
if err != nil {
logger.Info(logSender, "", "unable to get branding config: %v", err)
s.renderConfigsPage(w, r, configs, err, configSection)
return
}
configs.Branding = brandingConfigs
default:
s.renderBadRequestPage(w, r, errors.New("unsupported form action"))
return
@ -4246,15 +4326,22 @@ func (s *httpdServer) handleWebConfigsPost(w http.ResponseWriter, r *http.Reques
s.renderConfigsPage(w, r, configs, err, configSection)
return
}
if configSection == 3 {
postConfigsUpdate(configSection, configs)
s.renderMessagePage(w, r, util.I18nConfigsTitle, http.StatusOK, nil, util.I18nConfigsOK)
}
func postConfigsUpdate(section int, configs dataprovider.Configs) {
switch section {
case 3:
err := configs.SMTP.TryDecrypt()
if err == nil {
smtp.Activate(configs.SMTP)
} else {
logger.Error(logSender, "", "unable to decrypt SMTP configuration, cannot activate configuration: %v", err)
}
case 4:
dbBrandingConfig.Set(configs.Branding)
}
s.renderMessagePage(w, r, util.I18nConfigsTitle, http.StatusOK, nil, util.I18nConfigsOK)
}
func (s *httpdServer) handleOAuth2TokenRedirect(w http.ResponseWriter, r *http.Request) {
@ -4262,7 +4349,7 @@ func (s *httpdServer) handleOAuth2TokenRedirect(w http.ResponseWriter, r *http.R
stateToken := r.URL.Query().Get("state")
state, err := verifyOAuth2Token(stateToken, util.GetIPFromRemoteAddress(r.RemoteAddr))
state, err := verifyOAuth2Token(s.csrfTokenAuth, stateToken, util.GetIPFromRemoteAddress(r.RemoteAddr))
if err != nil {
s.renderMessagePage(w, r, util.I18nOAuth2ErrorTitle, http.StatusBadRequest, err, "")
return

View file

@ -27,6 +27,7 @@ import (
"os"
"path"
"path/filepath"
"slices"
"strconv"
"strings"
"time"
@ -95,6 +96,7 @@ type baseClientPage struct {
MFAURL string
CSRFToken string
LoggedUser *dataprovider.User
IsLoggedToShare bool
Branding UIBranding
}
@ -523,10 +525,10 @@ func loadClientTemplates(templatesPath string) {
clientTemplates[templateShareDownload] = shareDownloadTmpl
}
func (s *httpdServer) getBaseClientPageData(title, currentURL string, r *http.Request) baseClientPage {
func (s *httpdServer) getBaseClientPageData(title, currentURL string, w http.ResponseWriter, r *http.Request) baseClientPage {
var csrfToken string
if currentURL != "" {
csrfToken = createCSRFToken(util.GetIPFromRemoteAddress(r.RemoteAddr))
csrfToken = createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath)
}
data := baseClientPage{
@ -544,7 +546,8 @@ func (s *httpdServer) getBaseClientPageData(title, currentURL string, r *http.Re
MFAURL: webClientMFAPath,
CSRFToken: csrfToken,
LoggedUser: getUserFromToken(r),
Branding: s.binding.Branding.WebClient,
IsLoggedToShare: false,
Branding: s.binding.webClientBranding(),
}
if !strings.HasPrefix(r.RequestURI, webClientPubSharesPath) {
data.LoginURL = webClientLoginPath
@ -552,40 +555,40 @@ func (s *httpdServer) getBaseClientPageData(title, currentURL string, r *http.Re
return data
}
func (s *httpdServer) renderClientForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
func (s *httpdServer) renderClientForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := forgotPwdPage{
commonBasePage: getCommonBasePage(r),
CurrentURL: webClientForgotPwdPath,
Error: err,
CSRFToken: createCSRFToken(ip),
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath),
LoginURL: webClientLoginPath,
Title: util.I18nForgotPwdTitle,
Branding: s.binding.Branding.WebClient,
Branding: s.binding.webClientBranding(),
}
renderClientTemplate(w, templateForgotPassword, data)
}
func (s *httpdServer) renderClientResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
func (s *httpdServer) renderClientResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := resetPwdPage{
commonBasePage: getCommonBasePage(r),
CurrentURL: webClientResetPwdPath,
Error: err,
CSRFToken: createCSRFToken(ip),
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath),
LoginURL: webClientLoginPath,
Title: util.I18nResetPwdTitle,
Branding: s.binding.Branding.WebClient,
Branding: s.binding.webClientBranding(),
}
renderClientTemplate(w, templateResetPassword, data)
}
func (s *httpdServer) renderShareLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
func (s *httpdServer) renderShareLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := shareLoginPage{
commonBasePage: getCommonBasePage(r),
Title: util.I18nShareLoginTitle,
CurrentURL: r.RequestURI,
Error: err,
CSRFToken: createCSRFToken(ip),
Branding: s.binding.Branding.WebClient,
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath),
Branding: s.binding.webClientBranding(),
}
renderClientTemplate(w, templateShareLogin, data)
}
@ -599,7 +602,7 @@ func renderClientTemplate(w http.ResponseWriter, tmplName string, data any) {
func (s *httpdServer) renderClientMessagePage(w http.ResponseWriter, r *http.Request, title string, statusCode int, err error, message string) {
data := clientMessagePage{
baseClientPage: s.getBaseClientPageData(title, "", r),
baseClientPage: s.getBaseClientPageData(title, "", w, r),
Error: getI18nError(err),
Success: message,
}
@ -627,15 +630,15 @@ func (s *httpdServer) renderClientNotFoundPage(w http.ResponseWriter, r *http.Re
util.NewI18nError(err, util.I18nError404Message), "")
}
func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := twoFactorPage{
commonBasePage: getCommonBasePage(r),
Title: pageTwoFactorTitle,
CurrentURL: webClientTwoFactorPath,
Error: err,
CSRFToken: createCSRFToken(ip),
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath),
RecoveryURL: webClientTwoFactorRecoveryPath,
Branding: s.binding.Branding.WebClient,
Branding: s.binding.webClientBranding(),
}
if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) {
data.CurrentURL += "?next=" + url.QueryEscape(next)
@ -643,21 +646,21 @@ func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, r *http.R
renderClientTemplate(w, templateTwoFactor, data)
}
func (s *httpdServer) renderClientTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
func (s *httpdServer) renderClientTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := twoFactorPage{
commonBasePage: getCommonBasePage(r),
Title: pageTwoFactorRecoveryTitle,
CurrentURL: webClientTwoFactorRecoveryPath,
Error: err,
CSRFToken: createCSRFToken(ip),
Branding: s.binding.Branding.WebClient,
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath),
Branding: s.binding.webClientBranding(),
}
renderClientTemplate(w, templateTwoFactorRecovery, data)
}
func (s *httpdServer) renderClientMFAPage(w http.ResponseWriter, r *http.Request) {
data := clientMFAPage{
baseClientPage: s.getBaseClientPageData(util.I18n2FATitle, webClientMFAPath, r),
baseClientPage: s.getBaseClientPageData(util.I18n2FATitle, webClientMFAPath, w, r),
TOTPConfigs: mfa.GetAvailableTOTPConfigNames(),
GenerateTOTPURL: webClientTOTPGeneratePath,
ValidateTOTPURL: webClientTOTPValidatePath,
@ -681,7 +684,7 @@ func (s *httpdServer) renderEditFilePage(w http.ResponseWriter, r *http.Request,
title = util.I18nEditFileTitle
}
data := editFilePage{
baseClientPage: s.getBaseClientPageData(title, webClientEditFilePath, r),
baseClientPage: s.getBaseClientPageData(title, webClientEditFilePath, w, r),
Path: fileName,
Name: path.Base(fileName),
CurrentDir: path.Dir(fileName),
@ -702,7 +705,7 @@ func (s *httpdServer) renderAddUpdateSharePage(w http.ResponseWriter, r *http.Re
title = util.I18nShareUpdateTitle
}
data := clientSharePage{
baseClientPage: s.getBaseClientPageData(title, currentURL, r),
baseClientPage: s.getBaseClientPageData(title, currentURL, w, r),
Share: share,
Error: err,
IsAdd: isAdd,
@ -736,9 +739,11 @@ func (s *httpdServer) renderSharedFilesPage(w http.ResponseWriter, r *http.Reque
err *util.I18nError, share dataprovider.Share,
) {
currentURL := path.Join(webClientPubSharesPath, share.ShareID, "browse")
baseData := s.getBaseClientPageData(util.I18nSharedFilesTitle, currentURL, r)
baseData := s.getBaseClientPageData(util.I18nSharedFilesTitle, currentURL, w, r)
baseData.FilesURL = currentURL
baseSharePath := path.Join(webClientPubSharesPath, share.ShareID)
baseData.LogoutURL = path.Join(webClientPubSharesPath, share.ShareID, "logout")
baseData.IsLoggedToShare = share.Password != ""
data := filesPage{
baseClientPage: baseData,
@ -766,28 +771,39 @@ func (s *httpdServer) renderSharedFilesPage(w http.ResponseWriter, r *http.Reque
renderClientTemplate(w, templateClientFiles, data)
}
func (s *httpdServer) renderShareDownloadPage(w http.ResponseWriter, r *http.Request, downloadLink string) {
func (s *httpdServer) renderShareDownloadPage(w http.ResponseWriter, r *http.Request, share *dataprovider.Share,
downloadLink string,
) {
data := shareDownloadPage{
baseClientPage: s.getBaseClientPageData(util.I18nShareDownloadTitle, "", r),
baseClientPage: s.getBaseClientPageData(util.I18nShareDownloadTitle, "", w, r),
DownloadLink: downloadLink,
}
data.LogoutURL = ""
if share.Password != "" {
data.LogoutURL = path.Join(webClientPubSharesPath, share.ShareID, "logout")
}
renderClientTemplate(w, templateShareDownload, data)
}
func (s *httpdServer) renderUploadToSharePage(w http.ResponseWriter, r *http.Request, share dataprovider.Share) {
func (s *httpdServer) renderUploadToSharePage(w http.ResponseWriter, r *http.Request, share *dataprovider.Share) {
currentURL := path.Join(webClientPubSharesPath, share.ShareID, "upload")
data := shareUploadPage{
baseClientPage: s.getBaseClientPageData(util.I18nShareUploadTitle, currentURL, r),
Share: &share,
baseClientPage: s.getBaseClientPageData(util.I18nShareUploadTitle, currentURL, w, r),
Share: share,
UploadBasePath: path.Join(webClientPubSharesPath, share.ShareID),
}
data.LogoutURL = ""
if share.Password != "" {
data.LogoutURL = path.Join(webClientPubSharesPath, share.ShareID, "logout")
}
renderClientTemplate(w, templateUploadToShare, data)
}
func (s *httpdServer) renderFilesPage(w http.ResponseWriter, r *http.Request, dirName string,
err *util.I18nError, user *dataprovider.User) {
data := filesPage{
baseClientPage: s.getBaseClientPageData(util.I18nFilesTitle, webClientFilesPath, r),
baseClientPage: s.getBaseClientPageData(util.I18nFilesTitle, webClientFilesPath, w, r),
Error: err,
CurrentDir: url.QueryEscape(dirName),
DownloadURL: webClientDownloadZipPath,
@ -813,7 +829,7 @@ func (s *httpdServer) renderFilesPage(w http.ResponseWriter, r *http.Request, di
func (s *httpdServer) renderClientProfilePage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := clientProfilePage{
baseClientPage: s.getBaseClientPageData(util.I18nProfileTitle, webClientProfilePath, r),
baseClientPage: s.getBaseClientPageData(util.I18nProfileTitle, webClientProfilePath, w, r),
Error: err,
}
user, userMerged, errUser := dataprovider.GetUserVariants(data.LoggedUser.Username, "")
@ -832,7 +848,7 @@ func (s *httpdServer) renderClientProfilePage(w http.ResponseWriter, r *http.Req
func (s *httpdServer) renderClientChangePasswordPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := changeClientPasswordPage{
baseClientPage: s.getBaseClientPageData(util.I18nChangePwdTitle, webChangeClientPwdPath, r),
baseClientPage: s.getBaseClientPageData(util.I18nChangePwdTitle, webChangeClientPwdPath, w, r),
Error: err,
}
@ -850,8 +866,7 @@ func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http.
s.renderClientBadRequestPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -1023,7 +1038,7 @@ func (s *httpdServer) handleClientUploadToShare(w http.ResponseWriter, r *http.R
http.Redirect(w, r, path.Join(webClientPubSharesPath, share.ShareID, "browse"), http.StatusFound)
return
}
s.renderUploadToSharePage(w, r, share)
s.renderUploadToSharePage(w, r, &share)
}
func (s *httpdServer) handleShareGetFiles(w http.ResponseWriter, r *http.Request) {
@ -1088,7 +1103,7 @@ func (s *httpdServer) handleShareViewPDF(w http.ResponseWriter, r *http.Request)
Title: path.Base(name),
URL: fmt.Sprintf("%s?path=%s&_=%d", path.Join(webClientPubSharesPath, share.ShareID, "getpdf"),
url.QueryEscape(name), time.Now().UTC().Unix()),
Branding: s.binding.Branding.WebClient,
Branding: s.binding.webClientBranding(),
}
renderClientTemplate(w, templateClientViewPDF, data)
}
@ -1440,7 +1455,7 @@ func (s *httpdServer) handleClientAddSharePost(w http.ResponseWriter, r *http.Re
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -1449,7 +1464,7 @@ func (s *httpdServer) handleClientAddSharePost(w http.ResponseWriter, r *http.Re
share.LastUseAt = 0
share.Username = claims.Username
if share.Password == "" {
if util.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) {
if slices.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) {
s.renderAddUpdateSharePage(w, r, share,
util.NewI18nError(util.NewValidationError("You are not allowed to share files/folders without password"), util.I18nErrorShareNoPwd),
true)
@ -1508,7 +1523,7 @@ func (s *httpdServer) handleClientUpdateSharePost(w http.ResponseWriter, r *http
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -1518,7 +1533,7 @@ func (s *httpdServer) handleClientUpdateSharePost(w http.ResponseWriter, r *http
updatedShare.Password = share.Password
}
if updatedShare.Password == "" {
if util.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) {
if slices.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) {
s.renderAddUpdateSharePage(w, r, updatedShare,
util.NewI18nError(util.NewValidationError("You are not allowed to share files/folders without password"), util.I18nErrorShareNoPwd),
false)
@ -1579,7 +1594,7 @@ func (s *httpdServer) handleClientGetShares(w http.ResponseWriter, r *http.Reque
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := clientSharesPage{
baseClientPage: s.getBaseClientPageData(util.I18nSharesTitle, webClientSharesPath, r),
baseClientPage: s.getBaseClientPageData(util.I18nSharesTitle, webClientSharesPath, w, r),
BasePublicSharesURL: webClientPubSharesPath,
}
renderClientTemplate(w, templateClientShares, data)
@ -1603,7 +1618,7 @@ func (s *httpdServer) handleWebClientProfilePost(w http.ResponseWriter, r *http.
return
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
@ -1662,12 +1677,12 @@ func (s *httpdServer) handleWebClientMFA(w http.ResponseWriter, r *http.Request)
func (s *httpdServer) handleWebClientTwoFactor(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
s.renderClientTwoFactorPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
s.renderClientTwoFactorPage(w, r, nil)
}
func (s *httpdServer) handleWebClientTwoFactorRecovery(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
s.renderClientTwoFactorRecoveryPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
s.renderClientTwoFactorRecoveryPage(w, r, nil)
}
func getShareFromPostFields(r *http.Request) (*dataprovider.Share, error) {
@ -1719,26 +1734,25 @@ func (s *httpdServer) handleWebClientForgotPwd(w http.ResponseWriter, r *http.Re
s.renderClientNotFoundPage(w, r, errors.New("this page does not exist"))
return
}
s.renderClientForgotPwdPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
s.renderClientForgotPwdPage(w, r, nil)
}
func (s *httpdServer) handleWebClientForgotPwdPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
err := r.ParseForm()
if err != nil {
s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr)
s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return
}
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
username := strings.TrimSpace(r.Form.Get("username"))
err = handleForgotPassword(r, username, false)
if err != nil {
s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric), ipAddr)
s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric))
return
}
http.Redirect(w, r, webClientResetPwdPath, http.StatusFound)
@ -1750,7 +1764,7 @@ func (s *httpdServer) handleWebClientPasswordReset(w http.ResponseWriter, r *htt
s.renderClientNotFoundPage(w, r, errors.New("this page does not exist"))
return
}
s.renderClientResetPwdPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
s.renderClientResetPwdPage(w, r, nil)
}
func (s *httpdServer) handleClientViewPDF(w http.ResponseWriter, r *http.Request) {
@ -1765,7 +1779,7 @@ func (s *httpdServer) handleClientViewPDF(w http.ResponseWriter, r *http.Request
commonBasePage: getCommonBasePage(r),
Title: path.Base(name),
URL: fmt.Sprintf("%s?path=%s&_=%d", webClientGetPDFPath, url.QueryEscape(name), time.Now().UTC().Unix()),
Branding: s.binding.Branding.WebClient,
Branding: s.binding.webClientBranding(),
}
renderClientTemplate(w, templateClientViewPDF, data)
}
@ -1853,43 +1867,46 @@ func (s *httpdServer) ensurePDF(w http.ResponseWriter, r *http.Request, name str
func (s *httpdServer) handleClientShareLoginGet(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
s.renderShareLoginPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
s.renderShareLoginPage(w, r, nil)
}
func (s *httpdServer) handleClientShareLoginPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := r.ParseForm(); err != nil {
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr)
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return
}
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr)
if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
invalidateToken(r, true)
shareID := getURLParam(r, "id")
share, err := dataprovider.ShareExists(shareID, "")
if err != nil {
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials), ipAddr)
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials))
return
}
match, err := share.CheckCredentials(strings.TrimSpace(r.Form.Get("share_password")))
if !match || err != nil {
s.renderShareLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials),
ipAddr)
return
}
c := jwtTokenClaims{
Username: shareID,
}
err = c.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebShare, ipAddr)
if err != nil {
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nError500Message), ipAddr)
s.renderShareLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return
}
next := path.Clean(r.URL.Query().Get("next"))
baseShareURL := path.Join(webClientPubSharesPath, share.ShareID)
isRedirect, redirectTo := checkShareRedirectURL(next, baseShareURL)
c := jwtTokenClaims{
Username: shareID,
}
if isRedirect {
c.Ref = next
}
err = c.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebShare, ipAddr)
if err != nil {
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nError500Message))
return
}
if isRedirect {
http.Redirect(w, r, redirectTo, http.StatusFound)
return
@ -1897,6 +1914,22 @@ func (s *httpdServer) handleClientShareLoginPost(w http.ResponseWriter, r *http.
s.renderClientMessagePage(w, r, util.I18nSharedFilesTitle, http.StatusOK, nil, util.I18nShareLoginOK)
}
func (s *httpdServer) handleClientShareLogout(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
shareID := getURLParam(r, "id")
claims, err := s.getShareClaims(r, shareID)
if err != nil {
s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, http.StatusForbidden,
util.NewI18nError(err, util.I18nErrorInvalidToken), "")
return
}
removeCookie(w, r, webBaseClientPath)
redirectURL := path.Join(webClientPubSharesPath, shareID, fmt.Sprintf("login?next=%s", url.QueryEscape(claims.Ref)))
http.Redirect(w, r, redirectURL, http.StatusFound)
}
func (s *httpdServer) handleClientSharedFile(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead}
@ -1908,7 +1941,7 @@ func (s *httpdServer) handleClientSharedFile(w http.ResponseWriter, r *http.Requ
if r.URL.RawQuery != "" {
query = "?" + r.URL.RawQuery
}
s.renderShareDownloadPage(w, r, path.Join(webClientPubSharesPath, share.ShareID)+query)
s.renderShareDownloadPage(w, r, &share, path.Join(webClientPubSharesPath, share.ShareID)+query)
}
func (s *httpdServer) handleClientCheckExist(w http.ResponseWriter, r *http.Request) {
@ -1983,7 +2016,7 @@ func doCheckExist(w http.ResponseWriter, r *http.Request, connection *Connection
}
existing := make([]map[string]any, 0)
for _, info := range contents {
if util.Contains(filesList.Files, info.Name()) {
if slices.Contains(filesList.Files, info.Name()) {
res := make(map[string]any)
res["name"] = info.Name()
if info.IsDir() {

View file

@ -25,6 +25,7 @@ import (
"net/http"
"net/url"
"path"
"slices"
"strconv"
"strings"
@ -36,7 +37,6 @@ import (
"github.com/drakkan/sftpgo/v2/internal/httpclient"
"github.com/drakkan/sftpgo/v2/internal/httpd"
"github.com/drakkan/sftpgo/v2/internal/kms"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/version"
"github.com/drakkan/sftpgo/v2/internal/vfs"
)
@ -1679,7 +1679,7 @@ func checkEventConditionOptions(expected, actual dataprovider.ConditionOptions)
return errors.New("condition protocols mismatch")
}
for _, v := range expected.Protocols {
if !util.Contains(actual.Protocols, v) {
if !slices.Contains(actual.Protocols, v) {
return errors.New("condition protocols content mismatch")
}
}
@ -1687,7 +1687,7 @@ func checkEventConditionOptions(expected, actual dataprovider.ConditionOptions)
return errors.New("condition provider objects mismatch")
}
for _, v := range expected.ProviderObjects {
if !util.Contains(actual.ProviderObjects, v) {
if !slices.Contains(actual.ProviderObjects, v) {
return errors.New("condition provider objects content mismatch")
}
}
@ -1705,7 +1705,7 @@ func checkEventConditions(expected, actual dataprovider.EventConditions) error {
return errors.New("fs events mismatch")
}
for _, v := range expected.FsEvents {
if !util.Contains(actual.FsEvents, v) {
if !slices.Contains(actual.FsEvents, v) {
return errors.New("fs events content mismatch")
}
}
@ -1713,7 +1713,7 @@ func checkEventConditions(expected, actual dataprovider.EventConditions) error {
return errors.New("provider events mismatch")
}
for _, v := range expected.ProviderEvents {
if !util.Contains(actual.ProviderEvents, v) {
if !slices.Contains(actual.ProviderEvents, v) {
return errors.New("provider events content mismatch")
}
}
@ -1948,7 +1948,7 @@ func checkAdmin(expected, actual *dataprovider.Admin) error {
return errors.New("permissions mismatch")
}
for _, p := range expected.Permissions {
if !util.Contains(actual.Permissions, p) {
if !slices.Contains(actual.Permissions, p) {
return errors.New("permissions content mismatch")
}
}
@ -1966,7 +1966,7 @@ func compareAdminFilters(expected, actual dataprovider.AdminFilters) error {
return errors.New("allow list mismatch")
}
for _, v := range expected.AllowList {
if !util.Contains(actual.AllowList, v) {
if !slices.Contains(actual.AllowList, v) {
return errors.New("allow list content mismatch")
}
}
@ -2057,7 +2057,7 @@ func compareUserPermissions(expected map[string][]string, actual map[string][]st
for dir, perms := range expected {
if actualPerms, ok := actual[dir]; ok {
for _, v := range actualPerms {
if !util.Contains(perms, v) {
if !slices.Contains(perms, v) {
return errors.New("permissions contents mismatch")
}
}
@ -2310,7 +2310,7 @@ func compareSFTPFsConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error
return errors.New("SFTPFs fingerprints mismatch")
}
for _, value := range actual.SFTPConfig.Fingerprints {
if !util.Contains(expected.SFTPConfig.Fingerprints, value) {
if !slices.Contains(expected.SFTPConfig.Fingerprints, value) {
return errors.New("SFTPFs fingerprints mismatch")
}
}
@ -2401,27 +2401,27 @@ func checkEncryptedSecret(expected, actual *kms.Secret) error {
func compareUserFilterSubStructs(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error {
for _, IPMask := range expected.AllowedIP {
if !util.Contains(actual.AllowedIP, IPMask) {
if !slices.Contains(actual.AllowedIP, IPMask) {
return errors.New("allowed IP contents mismatch")
}
}
for _, IPMask := range expected.DeniedIP {
if !util.Contains(actual.DeniedIP, IPMask) {
if !slices.Contains(actual.DeniedIP, IPMask) {
return errors.New("denied IP contents mismatch")
}
}
for _, method := range expected.DeniedLoginMethods {
if !util.Contains(actual.DeniedLoginMethods, method) {
if !slices.Contains(actual.DeniedLoginMethods, method) {
return errors.New("denied login methods contents mismatch")
}
}
for _, protocol := range expected.DeniedProtocols {
if !util.Contains(actual.DeniedProtocols, protocol) {
if !slices.Contains(actual.DeniedProtocols, protocol) {
return errors.New("denied protocols contents mismatch")
}
}
for _, options := range expected.WebClient {
if !util.Contains(actual.WebClient, options) {
if !slices.Contains(actual.WebClient, options) {
return errors.New("web client options contents mismatch")
}
}
@ -2430,7 +2430,7 @@ func compareUserFilterSubStructs(expected sdk.BaseUserFilters, actual sdk.BaseUs
return errors.New("TLS certs mismatch")
}
for _, cert := range expected.TLSCerts {
if !util.Contains(actual.TLSCerts, cert) {
if !slices.Contains(actual.TLSCerts, cert) {
return errors.New("TLS certs content mismatch")
}
}
@ -2527,7 +2527,7 @@ func checkFilterMatch(expected []string, actual []string) bool {
return false
}
for _, e := range expected {
if !util.Contains(actual, strings.ToLower(e)) {
if !slices.Contains(actual, strings.ToLower(e)) {
return false
}
}
@ -2570,7 +2570,7 @@ func compareUserBandwidthLimitFilters(expected sdk.BaseUserFilters, actual sdk.B
return errors.New("bandwidth filters sources mismatch")
}
for _, source := range actual.BandwidthLimits[idx].Sources {
if !util.Contains(l.Sources, source) {
if !slices.Contains(l.Sources, source) {
return errors.New("bandwidth filters source mismatch")
}
}
@ -2680,7 +2680,7 @@ func compareEventActionEmailConfigFields(expected, actual dataprovider.EventActi
return errors.New("email recipients mismatch")
}
for _, v := range expected.Recipients {
if !util.Contains(actual.Recipients, v) {
if !slices.Contains(actual.Recipients, v) {
return errors.New("email recipients content mismatch")
}
}
@ -2688,7 +2688,7 @@ func compareEventActionEmailConfigFields(expected, actual dataprovider.EventActi
return errors.New("email bcc mismatch")
}
for _, v := range expected.Bcc {
if !util.Contains(actual.Bcc, v) {
if !slices.Contains(actual.Bcc, v) {
return errors.New("email bcc content mismatch")
}
}
@ -2705,7 +2705,7 @@ func compareEventActionEmailConfigFields(expected, actual dataprovider.EventActi
return errors.New("email attachments mismatch")
}
for _, v := range expected.Attachments {
if !util.Contains(actual.Attachments, v) {
if !slices.Contains(actual.Attachments, v) {
return errors.New("email attachments content mismatch")
}
}
@ -2720,7 +2720,7 @@ func compareEventActionFsCompressFields(expected, actual dataprovider.EventActio
return errors.New("fs compress paths mismatch")
}
for _, v := range expected.Paths {
if !util.Contains(actual.Paths, v) {
if !slices.Contains(actual.Paths, v) {
return errors.New("fs compress paths content mismatch")
}
}
@ -2741,7 +2741,7 @@ func compareEventActionFsConfigFields(expected, actual dataprovider.EventActionF
return errors.New("fs deletes mismatch")
}
for _, v := range expected.Deletes {
if !util.Contains(actual.Deletes, v) {
if !slices.Contains(actual.Deletes, v) {
return errors.New("fs deletes content mismatch")
}
}
@ -2749,7 +2749,7 @@ func compareEventActionFsConfigFields(expected, actual dataprovider.EventActionF
return errors.New("fs mkdirs mismatch")
}
for _, v := range expected.MkDirs {
if !util.Contains(actual.MkDirs, v) {
if !slices.Contains(actual.MkDirs, v) {
return errors.New("fs mkdir content mismatch")
}
}
@ -2757,7 +2757,7 @@ func compareEventActionFsConfigFields(expected, actual dataprovider.EventActionF
return errors.New("fs exist mismatch")
}
for _, v := range expected.Exist {
if !util.Contains(actual.Exist, v) {
if !slices.Contains(actual.Exist, v) {
return errors.New("fs exist content mismatch")
}
}
@ -2788,7 +2788,7 @@ func compareEventActionCmdConfigFields(expected, actual dataprovider.EventAction
return errors.New("cmd args mismatch")
}
for _, v := range expected.Args {
if !util.Contains(actual.Args, v) {
if !slices.Contains(actual.Args, v) {
return errors.New("cmd args content mismatch")
}
}

View file

@ -17,6 +17,7 @@ package plugin
import (
"fmt"
"path/filepath"
"slices"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-plugin"
@ -25,7 +26,6 @@ import (
"github.com/drakkan/sftpgo/v2/internal/kms"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
)
var (
@ -41,10 +41,10 @@ type KMSConfig struct {
}
func (c *KMSConfig) validate() error {
if !util.Contains(validKMSSchemes, c.Scheme) {
if !slices.Contains(validKMSSchemes, c.Scheme) {
return fmt.Errorf("invalid kms scheme: %v", c.Scheme)
}
if !util.Contains(validKMSEncryptedStatuses, c.EncryptedStatus) {
if !slices.Contains(validKMSEncryptedStatuses, c.EncryptedStatus) {
return fmt.Errorf("invalid kms encrypted status: %v", c.EncryptedStatus)
}
return nil

View file

@ -16,6 +16,7 @@ package plugin
import (
"fmt"
"slices"
"sync"
"time"
@ -24,7 +25,6 @@ import (
"github.com/sftpgo/sdk/plugin/notifier"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
)
// NotifierConfig defines configuration parameters for notifiers plugins
@ -220,7 +220,7 @@ func (p *notifierPlugin) canQueueEvent(timestamp int64) bool {
}
func (p *notifierPlugin) notifyFsAction(event *notifier.FsEvent) {
if !util.Contains(p.config.NotifierOptions.FsEvents, event.Action) {
if !slices.Contains(p.config.NotifierOptions.FsEvents, event.Action) {
return
}
@ -233,8 +233,8 @@ func (p *notifierPlugin) notifyFsAction(event *notifier.FsEvent) {
}
func (p *notifierPlugin) notifyProviderAction(event *notifier.ProviderEvent, object Renderer) {
if !util.Contains(p.config.NotifierOptions.ProviderEvents, event.Action) ||
!util.Contains(p.config.NotifierOptions.ProviderObjects, event.ObjectType) {
if !slices.Contains(p.config.NotifierOptions.ProviderEvents, event.Action) ||
!slices.Contains(p.config.NotifierOptions.ProviderObjects, event.ObjectType) {
return
}

View file

@ -24,6 +24,7 @@ import (
"os"
"os/exec"
"path/filepath"
"slices"
"strings"
"sync"
"sync/atomic"
@ -336,7 +337,7 @@ func (m *Manager) NotifyLogEvent(event notifier.LogEventType, protocol, username
var e *notifier.LogEvent
for _, n := range m.notifiers {
if util.Contains(n.config.NotifierOptions.LogEvents, int(event)) {
if slices.Contains(n.config.NotifierOptions.LogEvents, int(event)) {
if e == nil {
message := ""
if err != nil {

View file

@ -20,6 +20,7 @@ package service
import (
"fmt"
"math/rand"
"slices"
"strings"
"github.com/sftpgo/sdk"
@ -211,7 +212,7 @@ func configurePortableSFTPService(port int, enabledSSHCommands []string) {
} else {
sftpdConf.Bindings[0].Port = 0
}
if util.Contains(enabledSSHCommands, "*") {
if slices.Contains(enabledSSHCommands, "*") {
sftpdConf.EnabledSSHCommands = sftpd.GetSupportedSSHCommands()
} else {
sftpdConf.EnabledSSHCommands = enabledSSHCommands

View file

@ -221,9 +221,9 @@ func (c *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
}
modTime := time.Unix(0, 0)
if request.Filepath != "/" {
lister.Add(vfs.NewFileInfo("..", true, 0, modTime, false))
lister.Prepend(vfs.NewFileInfo("..", true, 0, modTime, false))
}
lister.Add(vfs.NewFileInfo(".", true, 0, modTime, false))
lister.Prepend(vfs.NewFileInfo(".", true, 0, modTime, false))
return lister, nil
case "Stat":
if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(request.Filepath)) {
@ -559,13 +559,10 @@ func (c *Connection) getStatVFSFromQuotaResult(fs vfs.Fs, name string, quotaResu
func (c *Connection) updateQuotaAfterTruncate(requestPath string, fileSize int64) {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
if err == nil {
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false)
return
}
} else {
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
}
}
func getOSOpenFlags(requestFlags sftp.FileOpenFlags) (flags int) {

View file

@ -24,6 +24,7 @@ import (
"os"
"path/filepath"
"runtime"
"slices"
"testing"
"time"
@ -418,7 +419,7 @@ func TestSupportedSSHCommands(t *testing.T) {
assert.Equal(t, len(supportedSSHCommands), len(cmds))
for _, c := range cmds {
assert.True(t, util.Contains(supportedSSHCommands, c))
assert.True(t, slices.Contains(supportedSSHCommands, c))
}
}
@ -842,7 +843,7 @@ func TestRsyncOptions(t *testing.T) {
}
cmd, err := sshCmd.getSystemCommand()
assert.NoError(t, err)
assert.True(t, util.Contains(cmd.cmd.Args, "--safe-links"),
assert.True(t, slices.Contains(cmd.cmd.Args, "--safe-links"),
"--safe-links must be added if the user has the create symlinks permission")
permissions["/"] = []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermCreateDirs,
@ -859,7 +860,7 @@ func TestRsyncOptions(t *testing.T) {
}
cmd, err = sshCmd.getSystemCommand()
assert.NoError(t, err)
assert.True(t, util.Contains(cmd.cmd.Args, "--munge-links"),
assert.True(t, slices.Contains(cmd.cmd.Args, "--munge-links"),
"--munge-links must be added if the user has the create symlinks permission")
sshCmd.connection.User.VirtualFolders = append(sshCmd.connection.User.VirtualFolders, vfs.VirtualFolder{

View file

@ -258,10 +258,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string,
if vfs.HasTruncateSupport(fs) {
vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath))
if err == nil {
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck
}
dataprovider.UpdateUserFolderQuota(&vfolder, &c.connection.User, 0, -fileSize, false)
} else {
dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck
}

View file

@ -26,6 +26,7 @@ import (
"os"
"path/filepath"
"runtime/debug"
"slices"
"strings"
"sync"
"time"
@ -263,13 +264,13 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
func (c *Configuration) updateSupportedAuthentications() {
serviceStatus.Authentications = util.RemoveDuplicates(serviceStatus.Authentications, false)
if util.Contains(serviceStatus.Authentications, dataprovider.LoginMethodPassword) &&
util.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) {
if slices.Contains(serviceStatus.Authentications, dataprovider.LoginMethodPassword) &&
slices.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) {
serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyAndPassword)
}
if util.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyboardInteractive) &&
util.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) {
if slices.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyboardInteractive) &&
slices.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) {
serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyAndKeyboardInt)
}
}
@ -422,7 +423,7 @@ func (c *Configuration) configureKeyAlgos(serverConfig *ssh.ServerConfig) error
c.HostKeyAlgorithms = util.RemoveDuplicates(c.HostKeyAlgorithms, true)
}
for _, hostKeyAlgo := range c.HostKeyAlgorithms {
if !util.Contains(supportedHostKeyAlgos, hostKeyAlgo) {
if !slices.Contains(supportedHostKeyAlgos, hostKeyAlgo) {
return fmt.Errorf("unsupported host key algorithm %q", hostKeyAlgo)
}
}
@ -430,7 +431,7 @@ func (c *Configuration) configureKeyAlgos(serverConfig *ssh.ServerConfig) error
if len(c.PublicKeyAlgorithms) > 0 {
c.PublicKeyAlgorithms = util.RemoveDuplicates(c.PublicKeyAlgorithms, true)
for _, algo := range c.PublicKeyAlgorithms {
if !util.Contains(supportedPublicKeyAlgos, algo) {
if !slices.Contains(supportedPublicKeyAlgos, algo) {
return fmt.Errorf("unsupported public key authentication algorithm %q", algo)
}
}
@ -472,7 +473,7 @@ func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig)
if kex == keyExchangeCurve25519SHA256LibSSH {
continue
}
if !util.Contains(supportedKexAlgos, kex) {
if !slices.Contains(supportedKexAlgos, kex) {
return fmt.Errorf("unsupported key-exchange algorithm %q", kex)
}
}
@ -486,7 +487,7 @@ func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig)
if len(c.Ciphers) > 0 {
c.Ciphers = util.RemoveDuplicates(c.Ciphers, true)
for _, cipher := range c.Ciphers {
if !util.Contains(supportedCiphers, cipher) {
if !slices.Contains(supportedCiphers, cipher) {
return fmt.Errorf("unsupported cipher %q", cipher)
}
}
@ -499,7 +500,7 @@ func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig)
if len(c.MACs) > 0 {
c.MACs = util.RemoveDuplicates(c.MACs, true)
for _, mac := range c.MACs {
if !util.Contains(supportedMACs, mac) {
if !slices.Contains(supportedMACs, mac) {
return fmt.Errorf("unsupported MAC algorithm %q", mac)
}
}
@ -785,7 +786,7 @@ func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.
user.Username, user.HomeDir)
return nil, fmt.Errorf("cannot login user with invalid home dir: %q", user.HomeDir)
}
if util.Contains(user.Filters.DeniedProtocols, common.ProtocolSSH) {
if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolSSH) {
logger.Info(logSender, connectionID, "cannot login user %q, protocol SSH is not allowed", user.Username)
return nil, fmt.Errorf("protocol SSH is not allowed for user %q", user.Username)
}
@ -830,14 +831,14 @@ func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.
}
func (c *Configuration) checkSSHCommands() {
if util.Contains(c.EnabledSSHCommands, "*") {
if slices.Contains(c.EnabledSSHCommands, "*") {
c.EnabledSSHCommands = GetSupportedSSHCommands()
return
}
sshCommands := []string{}
for _, command := range c.EnabledSSHCommands {
command = strings.TrimSpace(command)
if util.Contains(supportedSSHCommands, command) {
if slices.Contains(supportedSSHCommands, command) {
sshCommands = append(sshCommands, command)
} else {
logger.Warn(logSender, "", "unsupported ssh command: %q ignored", command)
@ -927,7 +928,7 @@ func (c *Configuration) checkHostKeyAutoGeneration(configDir string) error {
func (c *Configuration) getHostKeyAlgorithms(keyFormat string) []string {
var algos []string
for _, algo := range algorithmsForKeyFormat(keyFormat) {
if util.Contains(c.HostKeyAlgorithms, algo) {
if slices.Contains(c.HostKeyAlgorithms, algo) {
algos = append(algos, algo)
}
}
@ -986,7 +987,7 @@ func (c *Configuration) checkAndLoadHostKeys(configDir string, serverConfig *ssh
var algos []string
for _, algo := range algorithmsForKeyFormat(signer.PublicKey().Type()) {
if underlyingAlgo, ok := certKeyAlgoNames[algo]; ok {
if util.Contains(mas.Algorithms(), underlyingAlgo) {
if slices.Contains(mas.Algorithms(), underlyingAlgo) {
algos = append(algos, algo)
}
}
@ -1098,12 +1099,12 @@ func (c *Configuration) initializeCertChecker(configDir string) error {
func (c *Configuration) getPartialSuccessError(nextAuthMethods []string) error {
err := &ssh.PartialSuccessError{}
if c.PasswordAuthentication && util.Contains(nextAuthMethods, dataprovider.LoginMethodPassword) {
if c.PasswordAuthentication && slices.Contains(nextAuthMethods, dataprovider.LoginMethodPassword) {
err.Next.PasswordCallback = func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
return c.validatePasswordCredentials(conn, password, dataprovider.SSHLoginMethodKeyAndPassword)
}
}
if c.KeyboardInteractiveAuthentication && util.Contains(nextAuthMethods, dataprovider.SSHLoginMethodKeyboardInteractive) {
if c.KeyboardInteractiveAuthentication && slices.Contains(nextAuthMethods, dataprovider.SSHLoginMethodKeyboardInteractive) {
err.Next.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
return c.validateKeyboardInteractiveCredentials(conn, client, dataprovider.SSHLoginMethodKeyAndKeyboardInt, true)
}

View file

@ -23,6 +23,7 @@ import (
"crypto/sha512"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@ -37,6 +38,7 @@ import (
"path"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"sync"
@ -782,6 +784,34 @@ func TestSFTPFsEscapeHomeDir(t *testing.T) {
assert.NoError(t, err)
}
func TestReadDirLongNames(t *testing.T) {
usePubKey := true
user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated)
assert.NoError(t, err)
conn, client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
numFiles := 1000
for i := 0; i < 1000; i++ {
fPath := filepath.Join(user.GetHomeDir(), hex.EncodeToString(util.GenerateRandomBytes(127)))
err = os.WriteFile(fPath, util.GenerateRandomBytes(30), 0666)
assert.NoError(t, err)
}
entries, err := client.ReadDir("/")
assert.NoError(t, err)
assert.Len(t, entries, numFiles)
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
}
func TestGroupSettingsOverride(t *testing.T) {
usePubKey := true
g := getTestGroup()
@ -5496,13 +5526,13 @@ func TestNestedVirtualFolders(t *testing.T) {
folderGet, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, int64(18769), folderGet.UsedQuotaSize)
assert.Equal(t, 1, folderGet.UsedQuotaFiles)
assert.Equal(t, int64(0), folderGet.UsedQuotaSize)
assert.Equal(t, 0, folderGet.UsedQuotaFiles)
folderGet, _, err = httpdtest.GetFolderByName(folderNameNested, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, int64(27658), folderGet.UsedQuotaSize)
assert.Equal(t, 1, folderGet.UsedQuotaFiles)
assert.Equal(t, int64(0), folderGet.UsedQuotaSize)
assert.Equal(t, 0, folderGet.UsedQuotaFiles)
files, err := client.ReadDir("/")
if assert.NoError(t, err) {
@ -6169,8 +6199,8 @@ func TestVirtualFoldersQuotaValues(t *testing.T) {
assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize)
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
@ -6289,8 +6319,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 2, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
@ -6314,8 +6344,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 2, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
// rename a file inside vdir2, it isn't included inside user quota, so we have:
// - vdir1/dir1/testFileName.rename
// - vdir1/dir2/testFileName1
@ -6333,8 +6363,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
assert.Equal(t, 2, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 2, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
// rename a file inside vdir2 overwriting an existing, we now have:
// - vdir1/dir1/testFileName.rename
// - vdir1/dir2/testFileName1
@ -6351,8 +6381,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
assert.Equal(t, 1, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 2, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
// rename a file inside vdir1 overwriting an existing, we now have:
// - vdir1/dir1/testFileName.rename (initial testFileName1)
// - vdir2/dir1/testFileName.rename (initial testFileName1)
@ -6364,8 +6394,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
@ -6385,8 +6415,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
@ -6507,8 +6537,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize, user.UsedQuotaSize)
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1+testFileSize1, f.UsedQuotaSize)
@ -6526,8 +6556,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize*2, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize*2, f.UsedQuotaSize)
assert.Equal(t, 2, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1*2, f.UsedQuotaSize)
@ -6544,8 +6574,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1+testFileSize, f.UsedQuotaSize)
@ -6561,8 +6591,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
@ -6592,8 +6622,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize1*3+testFileSize*2, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1*3+testFileSize*2, f.UsedQuotaSize)
assert.Equal(t, 5, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, int64(0), f.UsedQuotaSize)
@ -6607,8 +6637,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize1*2+testFileSize, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1*2+testFileSize, f.UsedQuotaSize)
assert.Equal(t, 3, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
@ -6729,8 +6759,8 @@ func TestQuotaRenameFromVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
@ -6748,8 +6778,8 @@ func TestQuotaRenameFromVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize+testFileSize1+testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
@ -6812,8 +6842,8 @@ func TestQuotaRenameFromVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize*3+testFileSize1*3, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 2, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, int64(0), f.UsedQuotaSize)
@ -6946,8 +6976,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
// rename a file from user home dir to vdir2, vdir2 is not included in user quota so we have:
// - /vdir2/dir1/testFileName
// - /vdir1/dir1/testFileName1
@ -6986,8 +7016,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
@ -7003,8 +7033,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
@ -7026,8 +7056,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
@ -7044,8 +7074,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize*2+testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 3, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
@ -7070,8 +7100,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize*2+testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 3, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize1*2+testFileSize, f.UsedQuotaSize)
@ -7339,8 +7369,8 @@ func TestVFolderQuotaSize(t *testing.T) {
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
assert.Equal(t, 1, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, f.UsedQuotaSize)
@ -8610,8 +8640,8 @@ func TestUserAllowedLoginMethods(t *testing.T) {
allowedMethods = user.GetAllowedLoginMethods()
assert.Equal(t, 4, len(allowedMethods))
assert.True(t, util.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndKeyboardInt))
assert.True(t, util.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndPassword))
assert.True(t, slices.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndKeyboardInt))
assert.True(t, slices.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndPassword))
}
func TestUserPartialAuth(t *testing.T) {
@ -9118,8 +9148,8 @@ func TestSSHCopy(t *testing.T) {
assert.Equal(t, 2*testFileSize+2*testFileSize1, user.UsedQuotaSize)
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 2, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
@ -9197,8 +9227,8 @@ func TestSSHCopy(t *testing.T) {
assert.Equal(t, 5*testFileSize+4*testFileSize1, user.UsedQuotaSize)
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, 2*testFileSize+2*testFileSize1, f.UsedQuotaSize)
assert.Equal(t, 4, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
}
// cross folder copy
newDir := "newdir"
@ -9894,8 +9924,8 @@ func TestGitIncludedVirtualFolders(t *testing.T) {
folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, user.UsedQuotaFiles, folder.UsedQuotaFiles)
assert.Equal(t, user.UsedQuotaSize, folder.UsedQuotaSize)
assert.Equal(t, 0, folder.UsedQuotaFiles)
assert.Equal(t, int64(0), folder.UsedQuotaSize)
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
@ -10680,8 +10710,8 @@ func TestSCPVirtualFoldersQuota(t *testing.T) {
assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize)
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, expectedQuotaSize, f.UsedQuotaSize)
assert.Equal(t, expectedQuotaFiles, f.UsedQuotaFiles)
assert.Equal(t, int64(0), f.UsedQuotaSize)
assert.Equal(t, 0, f.UsedQuotaFiles)
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, expectedQuotaSize, f.UsedQuotaSize)

View file

@ -27,6 +27,7 @@ import (
"os/exec"
"path"
"runtime/debug"
"slices"
"strings"
"sync"
"time"
@ -91,7 +92,7 @@ func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommand
name, args, err := parseCommandPayload(msg.Command)
connection.Log(logger.LevelDebug, "new ssh command: %q args: %v num args: %d user: %s, error: %v",
name, args, len(args), connection.User.Username, err)
if err == nil && util.Contains(enabledSSHCommands, name) {
if err == nil && slices.Contains(enabledSSHCommands, name) {
connection.command = msg.Command
if name == scpCmdName && len(args) >= 2 {
connection.SetProtocol(common.ProtocolSCP)
@ -139,9 +140,9 @@ func (c *sshCommand) handle() (err error) {
defer common.Connections.Remove(c.connection.GetID())
c.connection.UpdateLastActivity()
if util.Contains(sshHashCommands, c.command) {
if slices.Contains(sshHashCommands, c.command) {
return c.handleHashCommands()
} else if util.Contains(systemCommands, c.command) {
} else if slices.Contains(systemCommands, c.command) {
command, err := c.getSystemCommand()
if err != nil {
return c.sendErrorResponse(err)
@ -192,13 +193,10 @@ func (c *sshCommand) handleSFTPGoRemove() error {
func (c *sshCommand) updateQuota(sshDestPath string, filesNum int, filesSize int64) {
vfolder, err := c.connection.User.GetVirtualFolderForPath(sshDestPath)
if err == nil {
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, filesNum, filesSize, false) //nolint:errcheck
if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(&c.connection.User, filesNum, filesSize, false) //nolint:errcheck
dataprovider.UpdateUserFolderQuota(&vfolder, &c.connection.User, filesNum, filesSize, false)
return
}
} else {
dataprovider.UpdateUserQuota(&c.connection.User, filesNum, filesSize, false) //nolint:errcheck
}
}
func (c *sshCommand) handleHashCommands() error {
@ -432,11 +430,11 @@ func (c *sshCommand) getSystemCommand() (systemCommand, error) {
// If the user cannot create symlinks we add the option --munge-links, if it is not
// already set. This should make symlinks unusable (but manually recoverable)
if c.connection.User.HasPerm(dataprovider.PermCreateSymlinks, c.getDestPath()) {
if !util.Contains(args, "--safe-links") {
if !slices.Contains(args, "--safe-links") {
args = append([]string{"--safe-links"}, args...)
}
} else {
if !util.Contains(args, "--munge-links") {
if !slices.Contains(args, "--munge-links") {
args = append([]string{"--munge-links"}, args...)
}
}

View file

@ -19,6 +19,7 @@ import (
"context"
"errors"
"fmt"
"slices"
"sync"
"time"
@ -27,7 +28,6 @@ import (
"golang.org/x/oauth2/microsoft"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
)
// Supported OAuth2 providers
@ -56,7 +56,7 @@ type OAuth2Config struct {
// Validate validates and initializes the configuration
func (c *OAuth2Config) Validate() error {
if !util.Contains(supportedOAuth2Providers, c.Provider) {
if !slices.Contains(supportedOAuth2Providers, c.Provider) {
return fmt.Errorf("smtp oauth2: unsupported provider %d", c.Provider)
}
if c.ClientID == "" {

View file

@ -279,15 +279,15 @@ func (c *Config) Initialize(configDir string, isService bool) error {
}
func (c *Config) getMailClientOptions() []mail.Option {
options := []mail.Option{mail.WithoutNoop()}
options := []mail.Option{mail.WithPort(c.Port), mail.WithoutNoop()}
switch c.Encryption {
case 1:
options = append(options, mail.WithSSLPort(false))
options = append(options, mail.WithSSL())
case 2:
options = append(options, mail.WithTLSPortPolicy(mail.TLSMandatory))
options = append(options, mail.WithTLSPolicy(mail.TLSMandatory))
default:
options = append(options, mail.WithTLSPortPolicy(mail.NoTLS))
options = append(options, mail.WithTLSPolicy(mail.NoTLS))
}
if c.User != "" {
options = append(options, mail.WithUsername(c.User))
@ -317,7 +317,6 @@ func (c *Config) getMailClientOptions() []mail.Option {
}),
mail.WithDebugLog())
}
options = append(options, mail.WithPort(c.Port))
return options
}
@ -416,12 +415,6 @@ func SendEmail(to, bcc []string, subject, body string, contentType EmailContentT
return config.sendEmail(to, bcc, subject, body, contentType, attachments...)
}
// ReloadProviderConf reloads the configuration from the provider
// and apply it if different from the active one
func ReloadProviderConf() {
loadConfigFromProvider() //nolint:errcheck
}
func loadConfigFromProvider() error {
configs, err := dataprovider.GetConfigs()
if err != nil {

View file

@ -303,6 +303,9 @@ const (
I18nErrorEvSyncUnsupportedFs = "rules.sync_unsupported_fs_event"
I18nErrorRuleFailureActionsOnly = "rules.only_failure_actions"
I18nErrorRuleSyncActionRequired = "rules.sync_action_required"
I18nErrorInvalidPNG = "branding.invalid_png"
I18nErrorInvalidPNGSize = "branding.invalid_png_size"
I18nErrorInvalidDisclaimerURL = "branding.invalid_disclaimer_url"
)
// NewI18nError returns a I18nError wrappring the provided error

View file

@ -128,16 +128,6 @@ var bytesSizeTable = map[string]uint64{
"e": eByte,
}
// Contains reports whether v is present in elems.
func Contains[T comparable](elems []T, v T) bool {
for _, s := range elems {
if v == s {
return true
}
}
return false
}
// Remove removes an element from a string slice and
// returns the modified slice
func Remove(elems []string, val string) []string {

View file

@ -18,7 +18,7 @@ package version
import "strings"
const (
version = "2.6.2"
version = "2.6.99-dev"
appName = "SFTPGo"
)

View file

@ -24,6 +24,7 @@ import (
"os"
"path"
"path/filepath"
"slices"
"strings"
"time"
@ -34,7 +35,6 @@ import (
"github.com/sftpgo/sdk"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
)
const (
@ -475,7 +475,7 @@ func (fs *OsFs) findNonexistentDirs(filePath string) ([]string, error) {
for fs.IsNotExist(err) {
results = append(results, parent)
parent = filepath.Dir(parent)
if util.Contains(results, parent) {
if slices.Contains(results, parent) {
break
}
_, err = os.Stat(parent)

View file

@ -30,6 +30,7 @@ import (
"os"
"path"
"path/filepath"
"slices"
"sort"
"strings"
"sync"
@ -161,7 +162,7 @@ func (fs *S3Fs) Stat(name string) (os.FileInfo, error) {
if err == nil {
// Some S3 providers (like SeaweedFS) remove the trailing '/' from object keys.
// So we check some common content types to detect if this is a "directory".
isDir := util.Contains(s3DirMimeTypes, util.GetStringFromPointer(obj.ContentType))
isDir := slices.Contains(s3DirMimeTypes, util.GetStringFromPointer(obj.ContentType))
if util.GetIntFromPointer(obj.ContentLength) == 0 && !isDir {
_, err = fs.headObject(name + "/")
isDir = err == nil

View file

@ -28,6 +28,7 @@ import (
"os"
"path"
"path/filepath"
"slices"
"strconv"
"strings"
"sync"
@ -69,6 +70,17 @@ type SFTPFsConfig struct {
forbiddenSelfUsernames []string `json:"-"`
}
func (c *SFTPFsConfig) getKeySigner() (ssh.Signer, error) {
privPayload := c.PrivateKey.GetPayload()
if privPayload == "" {
return nil, nil
}
if key := c.KeyPassphrase.GetPayload(); key != "" {
return ssh.ParsePrivateKeyWithPassphrase([]byte(privPayload), []byte(key))
}
return ssh.ParsePrivateKey([]byte(privPayload))
}
// HideConfidentialData hides confidential data
func (c *SFTPFsConfig) HideConfidentialData() {
if c.Password != nil {
@ -114,7 +126,7 @@ func (c *SFTPFsConfig) isEqual(other SFTPFsConfig) bool {
return false
}
for _, fp := range c.Fingerprints {
if !util.Contains(other.Fingerprints, fp) {
if !slices.Contains(other.Fingerprints, fp) {
return false
}
}
@ -185,17 +197,11 @@ func (c *SFTPFsConfig) validate() error {
func (c *SFTPFsConfig) validatePrivateKey() error {
if c.PrivateKey.IsPlain() {
var signer ssh.Signer
var err error
if c.KeyPassphrase.IsPlain() {
signer, err = ssh.ParsePrivateKeyWithPassphrase([]byte(c.PrivateKey.GetPayload()),
[]byte(c.KeyPassphrase.GetPayload()))
} else {
signer, err = ssh.ParsePrivateKey([]byte(c.PrivateKey.GetPayload()))
}
signer, err := c.getKeySigner()
if err != nil {
return util.NewI18nError(fmt.Errorf("invalid private key: %w", err), util.I18nErrorPrivKeyInvalid)
}
if signer != nil {
if key, ok := signer.PublicKey().(ssh.CryptoPublicKey); ok {
cryptoKey := key.CryptoPublicKey()
if rsaKey, ok := cryptoKey.(*rsa.PublicKey); ok {
@ -208,6 +214,7 @@ func (c *SFTPFsConfig) validatePrivateKey() error {
}
}
}
}
return nil
}
@ -331,15 +338,19 @@ func NewSFTPFs(connectionID, mountPath, localTempDir string, forbiddenSelfUserna
return nil, err
}
}
conn, err := sftpConnsCache.Get(&config, connectionID)
if err != nil {
return nil, err
}
config.forbiddenSelfUsernames = forbiddenSelfUsernames
sftpFs := &SFTPFs{
connectionID: connectionID,
mountPath: getMountPath(mountPath),
localTempDir: localTempDir,
config: &config,
conn: sftpConnsCache.Get(&config, connectionID),
conn: conn,
}
err := sftpFs.createConnection()
err = sftpFs.createConnection()
if err != nil {
sftpFs.Close() //nolint:errcheck
}
@ -910,6 +921,7 @@ type sftpConnection struct {
isConnected bool
sessions map[string]bool
lastActivity time.Time
signer ssh.Signer
}
func newSFTPConnection(config *SFTPFsConfig, sessionID string) *sftpConnection {
@ -919,6 +931,7 @@ func newSFTPConnection(config *SFTPFsConfig, sessionID string) *sftpConnection {
isConnected: false,
sessions: map[string]bool{},
lastActivity: time.Now().UTC(),
signer: nil,
}
c.sessions[sessionID] = true
return c
@ -931,17 +944,6 @@ func (c *sftpConnection) OpenConnection() error {
return c.openConnNoLock()
}
func (c *sftpConnection) getKeySigner() (ssh.Signer, error) {
privPayload := c.config.PrivateKey.GetPayload()
if privPayload == "" {
return nil, nil
}
if key := c.config.KeyPassphrase.GetPayload(); key != "" {
return ssh.ParsePrivateKeyWithPassphrase([]byte(privPayload), []byte(key))
}
return ssh.ParsePrivateKey([]byte(privPayload))
}
func (c *sftpConnection) openConnNoLock() error {
if c.isConnected {
logger.Debug(c.logSender, "", "reusing connection")
@ -953,12 +955,12 @@ func (c *sftpConnection) openConnNoLock() error {
User: c.config.Username,
HostKeyCallback: func(_ string, _ net.Addr, key ssh.PublicKey) error {
fp := ssh.FingerprintSHA256(key)
if util.Contains(sftpFingerprints, fp) {
if slices.Contains(sftpFingerprints, fp) {
if allowSelfConnections == 0 {
logger.Log(logger.LevelError, c.logSender, "", "SFTP self connections not allowed")
return ErrSFTPLoop
}
if util.Contains(c.config.forbiddenSelfUsernames, c.config.Username) {
if slices.Contains(c.config.forbiddenSelfUsernames, c.config.Username) {
logger.Log(logger.LevelError, c.logSender, "",
"SFTP loop or nested local SFTP folders detected, username %q, forbidden usernames: %+v",
c.config.Username, c.config.forbiddenSelfUsernames)
@ -979,12 +981,8 @@ func (c *sftpConnection) openConnNoLock() error {
Timeout: 15 * time.Second,
ClientVersion: fmt.Sprintf("SSH-2.0-%s", version.GetServerVersion("_", false)),
}
signer, err := c.getKeySigner()
if err != nil {
return fmt.Errorf("sftpfs: unable to parse the private key: %w", err)
}
if signer != nil {
clientConfig.Auth = append(clientConfig.Auth, ssh.PublicKeys(signer))
if c.signer != nil {
clientConfig.Auth = append(clientConfig.Auth, ssh.PublicKeys(c.signer))
}
if pwd := c.config.Password.GetPayload(); pwd != "" {
clientConfig.Auth = append(clientConfig.Auth, ssh.Password(pwd))
@ -1156,7 +1154,7 @@ func newSFTPConnectionCache() *sftpConnectionsCache {
return c
}
func (c *sftpConnectionsCache) Get(config *SFTPFsConfig, sessionID string) *sftpConnection {
func (c *sftpConnectionsCache) Get(config *SFTPFsConfig, sessionID string) (*sftpConnection, error) {
partition := 0
key := config.getUniqueID(partition)
@ -1172,7 +1170,7 @@ func (c *sftpConnectionsCache) Get(config *SFTPFsConfig, sessionID string) *sftp
"reusing connection for session ID %q, key: %d, active sessions %d, active connections: %d",
sessionID, key, activeSessions+1, len(c.items))
val.AddSession(sessionID)
return val
return val, nil
}
partition++
oldKey = key
@ -1182,11 +1180,16 @@ func (c *sftpConnectionsCache) Get(config *SFTPFsConfig, sessionID string) *sftp
partition, activeSessions, oldKey, key)
} else {
conn := newSFTPConnection(config, sessionID)
signer, err := config.getKeySigner()
if err != nil {
return nil, fmt.Errorf("sftpfs: unable to parse the private key: %w", err)
}
conn.signer = signer
c.items[key] = conn
logger.Debug(logSenderSFTPCache, "",
"adding new connection for session ID %q, partition: %d, key: %d, active connections: %d",
sessionID, partition, key, len(c.items))
return conn
return conn, nil
}
}
}

View file

@ -24,6 +24,7 @@ import (
"path"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"sync"
@ -764,7 +765,7 @@ func (c *AzBlobFsConfig) validate() error {
if err := c.checkPartSizeAndConcurrency(); err != nil {
return err
}
if !util.Contains(validAzAccessTier, c.AccessTier) {
if !slices.Contains(validAzAccessTier, c.AccessTier) {
return fmt.Errorf("invalid access tier %q, valid values: \"''%v\"", c.AccessTier, strings.Join(validAzAccessTier, ", "))
}
return nil

View file

@ -23,6 +23,7 @@ import (
"net/http"
"os"
"path"
"slices"
"sync/atomic"
"time"
@ -447,7 +448,7 @@ func (f *webDavFile) Patch(patches []webdav.Proppatch) ([]webdav.Propstat, error
pstat := webdav.Propstat{}
for _, p := range patch.Props {
if status == http.StatusForbidden && !hasError {
if !patch.Remove && util.Contains(lastModifiedProps, p.XMLName.Local) {
if !patch.Remove && slices.Contains(lastModifiedProps, p.XMLName.Local) {
parsed, err := parseTime(util.BytesToString(p.InnerXML))
if err != nil {
f.Connection.Log(logger.LevelWarn, "unsupported last modification time: %q, err: %v",

View file

@ -272,10 +272,7 @@ func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePat
if vfs.HasTruncateSupport(fs) {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
if err == nil {
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
}
dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false)
} else {
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
}

View file

@ -26,6 +26,7 @@ import (
"path"
"path/filepath"
"runtime/debug"
"slices"
"strings"
"time"
@ -346,7 +347,7 @@ func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, lo
user.Username, user.HomeDir)
return connID, fmt.Errorf("cannot login user with invalid home dir: %q", user.HomeDir)
}
if util.Contains(user.Filters.DeniedProtocols, common.ProtocolWebDAV) {
if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolWebDAV) {
logger.Info(logSender, connectionID, "cannot login user %q, protocol DAV is not allowed", user.Username)
return connID, fmt.Errorf("protocol DAV is not allowed for user %q", user.Username)
}

View file

@ -2539,6 +2539,7 @@ func TestStat(t *testing.T) {
func TestUploadOverwriteVfolder(t *testing.T) {
u := getTestUser()
u.QuotaFiles = 1000
vdir := "/vdir"
mappedPath := filepath.Join(os.TempDir(), "mappedDir")
folderName := filepath.Base(mappedPath)
@ -2585,15 +2586,25 @@ func TestUploadOverwriteVfolder(t *testing.T) {
assert.NoError(t, err)
folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, folder.UsedQuotaSize)
assert.Equal(t, 1, folder.UsedQuotaFiles)
assert.Equal(t, int64(0), folder.UsedQuotaSize)
assert.Equal(t, 0, folder.UsedQuotaFiles)
user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, user.UsedQuotaSize)
assert.Equal(t, 1, user.UsedQuotaFiles)
err = uploadFileWithRawClient(testFilePath, path.Join(vdir, testFileName), user.Username,
defaultPassword, true, testFileSize, client)
assert.NoError(t, err)
folder, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, folder.UsedQuotaSize)
assert.Equal(t, 1, folder.UsedQuotaFiles)
assert.Equal(t, int64(0), folder.UsedQuotaSize)
assert.Equal(t, 0, folder.UsedQuotaFiles)
user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, testFileSize, user.UsedQuotaSize)
assert.Equal(t, 1, user.UsedQuotaFiles)
err = os.Remove(testFilePath)
assert.NoError(t, err)
_, err = httpdtest.RemoveUser(user, http.StatusOK)

View file

@ -29,7 +29,7 @@ info:
SFTPGo supports groups to simplify the administration of multiple accounts by letting you assign settings once to a group, instead of multiple times to each individual user.
The SFTPGo WebClient allows end users to change their credentials, browse and manage their files in the browser and setup two-factor authentication which works with Authy, Google Authenticator and other compatible apps.
From the WebClient each authorized user can also create HTTP/S links to externally share files and folders securely, by setting limits to the number of downloads/uploads, protecting the share with a password, limiting access by source IP address, setting an automatic expiration date.
version: 2.6.2
version: 2.6.99-dev
contact:
name: API support
url: 'https://github.com/drakkan/sftpgo'

View file

@ -1,6 +1,6 @@
#!/bin/bash
NFPM_VERSION=2.37.1
NFPM_VERSION=2.38.0
NFPM_ARCH=${NFPM_ARCH:-amd64}
if [ -z ${SFTPGO_VERSION} ]
then

View file

@ -3,17 +3,17 @@
<package xmlns="http://schemas.microsoft.com/packaging/2015/06/nuspec.xsd">
<metadata>
<id>sftpgo</id>
<version>2.5.6</version>
<version>2.6.2</version>
<packageSourceUrl>https://github.com/drakkan/sftpgo/tree/main/pkgs/choco</packageSourceUrl>
<owners>asheroto</owners>
<title>SFTPGo</title>
<authors>Nicola Murino</authors>
<projectUrl>https://github.com/drakkan/sftpgo</projectUrl>
<iconUrl>https://cdn.statically.io/gh/drakkan/sftpgo/v2.5.6/static/img/logo.png</iconUrl>
<iconUrl>https://cdn.statically.io/gh/drakkan/sftpgo/v2.6.2/static/img/logo.png</iconUrl>
<licenseUrl>https://github.com/drakkan/sftpgo/blob/main/LICENSE</licenseUrl>
<requireLicenseAcceptance>false</requireLicenseAcceptance>
<projectSourceUrl>https://github.com/drakkan/sftpgo</projectSourceUrl>
<docsUrl>https://github.com/drakkan/sftpgo/tree/v2.5.6/docs</docsUrl>
<docsUrl>https://sftpgo.github.io/2.6/</docsUrl>
<bugTrackerUrl>https://github.com/drakkan/sftpgo/issues</bugTrackerUrl>
<tags>sftp sftp-server ftp webdav s3 azure-blob google-cloud-storage cloud-storage scp data-at-rest-encryption multi-factor-authentication multi-step-authentication</tags>
<summary>Full-featured and highly configurable file transfer server: SFTP, HTTP/S,FTP/S, WebDAV.</summary>
@ -26,13 +26,13 @@ SFTPGo allows to create HTTP/S links to externally share files and folders secur
SFTPGo is highly customizable and extensible to suit your needs.
You can find more info [here](https://github.com/drakkan/sftpgo).
You can find more info [here](https://sftpgo.github.io/2.6/).
### Notes
* This package installs SFTPGo as Windows Service.
* After the first installation please take a look at the [Getting Started Guide](https://sftpgo.github.io/latest/initial-configuration/).</description>
<releaseNotes>https://github.com/drakkan/sftpgo/releases/tag/v2.5.6</releaseNotes>
* After the first installation please take a look at the [Getting Started Guide](https://sftpgo.github.io/2.6/initial-configuration/).</description>
<releaseNotes>https://github.com/drakkan/sftpgo/releases/tag/v2.6.2</releaseNotes>
</metadata>
<files>
<file src="**" exclude="**\*.md;**\icon.png;**\icon.jpg;**\icon.svg" />

View file

@ -1,8 +1,8 @@
$ErrorActionPreference = 'Stop'
$packageName = 'sftpgo'
$softwareName = 'SFTPGo'
$url = 'https://github.com/drakkan/sftpgo/releases/download/v2.5.6/sftpgo_v2.5.6_windows_x86_64.exe'
$checksum = 'C20BB051D3EA2ACBF05231ECB94410D01648E15DE304CA6240D37FDC34007DBE'
$url = 'https://github.com/drakkan/sftpgo/releases/download/v2.6.2/sftpgo_v2.6.2_windows_x86_64.exe'
$checksum = '40905AED44A7189C5DF6164631DB07BCBB53FAB7A63503A6BE7AD1328D9986D5'
$silentArgs = '/VERYSILENT'
$validExitCodes = @(0)
@ -48,5 +48,7 @@ Write-Output "General information (README) location:"
Write-Output "`thttps://github.com/drakkan/sftpgo"
Write-Output "Documentation location:"
Write-Output "`thttps://sftpgo.github.io/"
Write-Output "Commercial support:"
Write-Output "`thttps://sftpgo.com/#pricing"
Write-Output ""
Write-Output "---------------------------"

View file

@ -1,3 +1,21 @@
sftpgo (2.6.2-1ppa1) bionic; urgency=medium
* New upstream release
-- Nicola Murino <nicola.murino@gmail.com> Fri, 21 Jun 2024 19:47:58 +0200
sftpgo (2.6.1-1ppa1) bionic; urgency=medium
* New upstream release
-- Nicola Murino <nicola.murino@gmail.com> Wed, 19 Jun 2024 10:17:49 +0200
sftpgo (2.6.0-1ppa1) bionic; urgency=medium
* New upstream release
-- Nicola Murino <nicola.murino@gmail.com> Wed, 15 May 2024 19:12:01 +0200
sftpgo (2.5.6-1ppa1) bionic; urgency=medium
* New upstream release

View file

@ -3,7 +3,7 @@ Upstream-Name: SFTPGo
Source: https://github.com/drakkan/sftpgo
Files: *
Copyright: 2019-2023 Nicola Murino <nicola.murino@gmail.com>
Copyright: 2019 Nicola Murino <nicola.murino@gmail.com>
License: AGPL-3
License: AGPL-3

View file

@ -2,7 +2,7 @@ Index: sftpgo/sftpgo.json
===================================================================
--- sftpgo.orig/sftpgo.json
+++ sftpgo/sftpgo.json
@@ -57,7 +57,7 @@
@@ -67,7 +67,7 @@
"domains": [],
"email": "",
"key_type": "4096",
@ -11,7 +11,7 @@ Index: sftpgo/sftpgo.json
"ca_endpoint": "https://acme-v02.api.letsencrypt.org/directory",
"renew_days": 30,
"http01_challenge": {
@@ -186,7 +186,7 @@
@@ -198,7 +198,7 @@
},
"data_provider": {
"driver": "sqlite",
@ -20,7 +20,7 @@ Index: sftpgo/sftpgo.json
"host": "",
"port": 0,
"username": "",
@@ -202,7 +202,7 @@
@@ -214,7 +214,7 @@
"track_quota": 2,
"delayed_quota_update": 0,
"pool_size": 0,
@ -29,7 +29,7 @@ Index: sftpgo/sftpgo.json
"actions": {
"execute_on": [],
"execute_for": [],
@@ -244,7 +244,7 @@
@@ -256,7 +256,7 @@
"port": 0,
"proto": "http"
},

View file

@ -24,6 +24,7 @@
"allow_self_connections": 0,
"umask": "",
"server_version": "",
"tz": "",
"metadata": {
"read": 0
},

Binary file not shown.

Before

Width:  |  Height:  |  Size: 17 KiB

BIN
static/favicon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

View file

@ -730,7 +730,7 @@
"external_auth_cache_time": "External auth cache time",
"external_auth_cache_time_help": "Cache time, in seconds, for users authenticated using an external auth hook. 0 means no cache",
"access_time": "Access time restrictions",
"access_time_help": "No restrictions means access is always allowed, the time must be set in the format HH:MM. Use UTC time"
"access_time_help": "No restrictions means access is always allowed, the time must be set in the format HH:MM"
},
"admin": {
"role_permissions": "A role admin cannot have the following permissions: {{val}}",
@ -883,6 +883,19 @@
"help": "From this section you can enable algorithms disabled by default. You don't need to set values already defined using env vars or config file. A service restart is required to apply changes",
"host_key_algos": "Host Key Algorithms"
},
"branding": {
"title": "Branding",
"help": "From this section you can customize SFTPGo to fit your brand and add a disclaimer to the login pages",
"short_name": "Short name",
"logo": "Logo",
"logo_help": "PNG image, max accepted size 512x512, default logo size is 256x256",
"favicon": "Favicon",
"disclaimer_name": "Disclaimer title",
"disclaimer_url": "Disclaimer URL",
"invalid_png": "Invalid PNG image",
"invalid_png_size": "Invalid PNG image size",
"invalid_disclaimer_url": "The disclaimer URL must be an http or https link"
},
"events": {
"search": "Search logs",
"fs_events": "Filesystem events",
@ -1071,7 +1084,7 @@
"sync_unsupported_fs_event": "Synchronous execution is only supported for upload and pre-* filesystem events",
"only_failure_actions": "At least a non-failure action is required",
"sync_action_required": "Event \"{{val}}\" requires at least a synchronous action",
"scheduler_help": "The scheduler uses UTC time. Hours: 0-23. Day of week: 0-6 (Sun-Sat). Day of month: 1-31. Month: 1-12. Asterisk (*) indicates a match for all the values of the field. e.g. every day of week, every day of month and so on",
"scheduler_help": "Hours: 0-23. Day of week: 0-6 (Sun-Sat). Day of month: 1-31. Month: 1-12. Asterisk (*) indicates a match for all the values of the field. e.g. every day of week, every day of month and so on",
"concurrent_run": "Allow concurrent execution from multiple instances",
"protocol_filters": "Protocol filters",
"object_filters": "Object filters",

View file

@ -730,7 +730,7 @@
"external_auth_cache_time": "Cache per autenticazione esterna",
"external_auth_cache_time_help": "Tempo di memorizzazione nella cache, in secondi, per gli utenti autenticati utilizzando un hook di autenticazione esterno. 0 significa nessuna cache",
"access_time": "Limitazioni temporali all'accesso",
"access_time_help": "Nessuna restrizione significa che l'accesso è sempre consentito, l'ora deve essere impostata nel formato HH:MM. Utilizzare l'ora UTC"
"access_time_help": "Nessuna restrizione significa che l'accesso è sempre consentito, l'ora deve essere impostata nel formato HH:MM"
},
"admin": {
"role_permissions": "Un amministratore di ruolo non può avere le seguenti autorizzazioni: {{val}}",
@ -883,6 +883,19 @@
"help": "Da questa sezione è possibile abilitare gli algoritmi disabilitati di default. Non è necessario impostare valori già definiti utilizzando env vars o il file di configurazione. Per applicare le modifiche è necessario il riavvio del servizio",
"host_key_algos": "Algoritmi per chiavi host"
},
"branding": {
"title": "Branding",
"help": "Da questa sezione puoi personalizzare SFTPGo per adattarlo al tuo marchio e aggiungere un disclaimer alle pagine di accesso",
"short_name": "Nome breve",
"logo": "Logo",
"logo_help": "Immagine PNG, dimensione massima accettata 512x512, la dimensione predefinita del logo è 256x256",
"favicon": "Favicon",
"disclaimer_name": "Titolo Disclaimer",
"disclaimer_url": "URL Disclaimer",
"invalid_png": "Immagine PNG non valida",
"invalid_png_size": "Dimensione immagine PNG non valida",
"invalid_disclaimer_url": "L'URL del Disclaimer deve essere un link http o https"
},
"events": {
"search": "Cerca eventi",
"fs_events": "Eventi filesystem",
@ -1071,7 +1084,7 @@
"sync_unsupported_fs_event": "L'esecuzione sincrona è supporta solo per gli eventi \"upload\" e \"pre-*\"",
"only_failure_actions": "E' richiesta almeno un'azione che non venga eseguita su errore",
"sync_action_required": "L'evento \"{{val}}\" richiede almeno un'azione da eseguire sincronamente",
"scheduler_help": "Lo scheduler utilizza l'ora UTC. Orari: 0-23. Giorno della settimana: 0-6 (dom-sab). Giorno del mese: 1-31. Mese: 1-12. L'asterisco (*) indica una corrispondenza per tutti i valori del campo. per esempio. ogni giorno della settimana, ogni giorno del mese e così via",
"scheduler_help": "Orari: 0-23. Giorno della settimana: 0-6 (dom-sab). Giorno del mese: 1-31. Mese: 1-12. L'asterisco (*) indica una corrispondenza per tutti i valori del campo. per esempio. ogni giorno della settimana, ogni giorno del mese e così via",
"concurrent_run": "Consentire l'esecuzione simultanea da più istanze",
"protocol_filters": "Filtro su protocolli",
"object_filters": "Filtro su oggetti",

View file

@ -193,7 +193,7 @@ explicit grant from the SFTPGo Team (support@sftpgo.com).
.use(i18nextBrowserLanguageDetector)
.init({
debug: false,
supportedLngs: ["en", "it"],
supportedLngs: Object.keys(lngs),
fallbackLng: 'en',
load: 'languageOnly',
backend: {
@ -800,7 +800,7 @@ explicit grant from the SFTPGo Team (support@sftpgo.com).
<meta name="description" content="" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="robots" content="noindex">
<link rel="shortcut icon" href="{{.StaticURL}}{{.Branding.FaviconPath}}" />
<link rel="icon" type="image/png" href="{{.StaticURL}}{{.Branding.FaviconPath}}" />
{{- template "fonts" . }}
{{- block "extra_css" .}}{{- end}}
{{- range .Branding.DefaultCSS}}
@ -845,6 +845,17 @@ explicit grant from the SFTPGo Team (support@sftpgo.com).
</div>
</div>
</div>
{{- else if .IsLoggedToShare}}
<div class="app-container container-fluid d-flex mt-5">
<div class="d-flex align-items-center d-block ms-3">
<img alt="Logo" src="{{.StaticURL}}{{.Branding.LogoPath}}" class="h-50px" />
</div>
<div class="flex-grow-1">
<div class="d-flex justify-content-end">
{{- template "navitems" .}}
</div>
</div>
</div>
{{- end}}
<div class="{{- if .LoggedUser.Username}}app-wrapper{{- end}} flex-column flex-row-fluid " id="kt_app_wrapper">
{{- if .LoggedUser.Username}}

View file

@ -22,7 +22,7 @@ explicit grant from the SFTPGo Team (support@sftpgo.com).
<meta name="description" content="" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="robots" content="noindex">
<link rel="shortcut icon" href="{{.StaticURL}}{{.Branding.FaviconPath}}" />
<link rel="icon" type="image/png" href="{{.StaticURL}}{{.Branding.FaviconPath}}" />
{{- template "fonts" . }}
{{- range .Branding.DefaultCSS}}
<link href="{{$.StaticURL}}{{.}}" rel="stylesheet" type="text/css">

View file

@ -15,6 +15,36 @@ explicit grant from the SFTPGo Team (support@sftpgo.com).
-->
{{template "base" .}}
{{- define "extra_css"}}
<style {{- if .CSPNonce}} nonce="{{.CSPNonce}}"{{- end}}>
.image-input-placeholder {
background-image: url('{{.StaticURL}}{{.Branding.DefaultLogoPath}}');
}
/*{{- if ne .Branding.DefaultLogoPath .Branding.LogoPath}}*/
.image-input-webadmin-current {
background-image: url('{{.StaticURL}}{{.Branding.LogoPath}}');
}
/*{{- end}}*/
/*{{- if ne .Branding.DefaultFaviconPath .Branding.FaviconPath}}*/
.image-input-webadmin-fav-current {
background-image: url('{{.StaticURL}}{{.Branding.FaviconPath}}');
}
/*{{- end}}*/
/*{{- if ne .Branding.DefaultLogoPath .WebClientBranding.LogoPath}}*/
.image-input-webclient-current {
background-image: url('{{.StaticURL}}{{.WebClientBranding.LogoPath}}');
}
/*{{- end}}*/
/*{{- if ne .Branding.DefaultFaviconPath .WebClientBranding.FaviconPath}}*/
.image-input-webclient-fav-current {
background-image: url('{{.StaticURL}}{{.WebClientBranding.FaviconPath}}');
}
/*{{- end}}*/
</style>
{{- end}}
{{- define "page_body"}}
<div class="card shadow-sm">
<div class="card-header bg-light">
@ -352,6 +382,221 @@ explicit grant from the SFTPGo Team (support@sftpgo.com).
</div>
</div>
<div class="accordion-item">
<h2 class="accordion-header" id="accordion_header_branding">
<button class="accordion-button section-title-inner text-primary collapsed" type="button" data-bs-toggle="collapse" data-bs-target="#accordion_branding_body" aria-expanded="{{if eq .ConfigSection 4}}true{{else}}false{{end}}" aria-controls="accordion_branding_body">
<span data-i18n="branding.title">Branding</span>
</button>
</h2>
<div id="accordion_branding_body" class="accordion-collapse collapse {{if eq .ConfigSection 4}}show{{end}}" aria-labelledby="accordion_header_branding" data-bs-parent="#accordion_configs">
<div class="accordion-body">
{{- template "infomsg" "branding.help"}}
<form id="configs_branding_form" enctype="multipart/form-data" action="{{.CurrentURL}}" method="POST" autocomplete="off">
<div class="card">
<div class="card-header bg-light">
<h3 class="card-title section-title-inner">WebAdmin</h3>
</div>
<div class="card-body">
<div class="form-group row">
<label for="idBrandingWebAdminName" data-i18n="general.name" class="col-md-3 col-form-label">Name</label>
<div class="col-md-9">
<input id="idBrandingWebAdminName" type="text" placeholder="SFTPGo WebAdmin" name="branding_webadmin_name" value="{{.Configs.Branding.WebAdmin.Name}}" maxlength="255"
class="form-control" />
</div>
</div>
<div class="form-group row mt-10">
<label for="idBrandingWebAdminShortName" data-i18n="branding.short_name" class="col-md-3 col-form-label">Short Name</label>
<div class="col-md-9">
<input id="idBrandingWebAdminShortName" type="text" placeholder="WebAdmin" name="branding_webadmin_short_name" value="{{.Configs.Branding.WebAdmin.ShortName}}" maxlength="255"
class="form-control" />
</div>
</div>
<div class="form-group row mt-10">
<label for="idBrandingWebAdminLogo" data-i18n="branding.logo" class="col-md-3 col-form-label">Logo</label>
<div class="col-md-9">
<div class="image-input image-input-outline {{if eq .Branding.DefaultLogoPath .Branding.LogoPath}}image-input-empty{{end}} image-input-placeholder" data-kt-image-input="true">
<div class="image-input-wrapper w-125px h-125px image-input-webadmin-current"></div>
<label class="btn btn-icon btn-circle btn-color-muted btn-active-color-primary w-25px h-25px bg-body shadow"
data-kt-image-input-action="change">
<i class="ki-duotone ki-pencil fs-6"><span class="path1"></span><span class="path2"></span></i>
<input type="file" id="idBrandingWebAdminLogo" name="branding_webadmin_logo" accept=".png" aria-describedby="idBrandingWebAdminLogoHelp"/>
<input type="hidden"name="branding_webadmin_logo_remove" />
</label>
<span class="btn btn-icon btn-circle btn-color-muted btn-active-color-primary w-25px h-25px bg-body shadow"
data-kt-image-input-action="cancel">
<i class="ki-outline ki-cross fs-3"></i>
</span>
<span class="btn btn-icon btn-circle btn-color-muted btn-active-color-primary w-25px h-25px bg-body shadow"
data-kt-image-input-action="remove">
<i class="ki-outline ki-cross fs-3"></i>
</span>
</div>
<div id="idBrandingWebAdminLogoHelp" class="form-text mt-3" data-i18n="branding.logo_help"></div>
</div>
</div>
<div class="form-group row mt-10">
<label for="idBrandingWebAdminFavicon" data-i18n="branding.favicon" class="col-md-3 col-form-label">Favicon</label>
<div class="col-md-9">
<div class="image-input image-input-outline {{if eq .Branding.DefaultFaviconPath .Branding.FaviconPath}}image-input-empty{{end}} image-input-placeholder" data-kt-image-input="true">
<div class="image-input-wrapper w-50px h-50px image-input-webadmin-fav-current"></div>
<label class="btn btn-icon btn-circle btn-color-muted btn-active-color-primary w-25px h-25px bg-body shadow"
data-kt-image-input-action="change">
<i class="ki-duotone ki-pencil fs-6"><span class="path1"></span><span class="path2"></span></i>
<input type="file" id="idBrandingWebAdminFavicon" name="branding_webadmin_favicon" accept=".png" />
<input type="hidden"name="branding_webadmin_favicon_remove" />
</label>
<span class="btn btn-icon btn-circle btn-color-muted btn-active-color-primary w-25px h-25px bg-body shadow"
data-kt-image-input-action="cancel">
<i class="ki-outline ki-cross fs-3"></i>
</span>
<span class="btn btn-icon btn-circle btn-color-muted btn-active-color-primary w-25px h-25px bg-body shadow"
data-kt-image-input-action="remove">
<i class="ki-outline ki-cross fs-3"></i>
</span>
</div>
</div>
</div>
<div class="form-group row mt-10">
<label for="idBrandingWebAdminDisclaimerName" data-i18n="branding.disclaimer_name" class="col-md-3 col-form-label">Disclaimer Name</label>
<div class="col-md-9">
<input id="idBrandingWebAdminDisclaimerName" type="text" placeholder="" name="branding_webadmin_disclaimer_name" value="{{.Configs.Branding.WebAdmin.DisclaimerName}}" maxlength="255"
class="form-control" />
</div>
</div>
<div class="form-group row mt-10">
<label for="idBrandingWebAdminDisclaimerURL" data-i18n="branding.disclaimer_url" class="col-md-3 col-form-label">Disclaimer URL</label>
<div class="col-md-9">
<input id="idBrandingWebAdminDisclaimerURL" type="text" placeholder="" name="branding_webadmin_disclaimer_url" value="{{.Configs.Branding.WebAdmin.DisclaimerURL}}" maxlength="1024"
class="form-control" />
</div>
</div>
</div>
</div>
<div class="card mt-10">
<div class="card-header bg-light">
<h3 class="card-title section-title-inner">WebClient</h3>
</div>
<div class="card-body">
<div class="form-group row">
<label for="idBrandingWebClientName" data-i18n="general.name" class="col-md-3 col-form-label">Name</label>
<div class="col-md-9">
<input id="idBrandingWebClientName" type="text" placeholder="SFTPGo WebClient" name="branding_webclient_name" value="{{.Configs.Branding.WebClient.Name}}" maxlength="255"
class="form-control" />
</div>
</div>
<div class="form-group row mt-10">
<label for="idBrandingWebClientShortName" data-i18n="branding.short_name" class="col-md-3 col-form-label">Short Name</label>
<div class="col-md-9">
<input id="idBrandingWebClientShortName" type="text" placeholder="WebClient" name="branding_webclient_short_name" value="{{.Configs.Branding.WebClient.ShortName}}" maxlength="255"
class="form-control" />
</div>
</div>
<div class="form-group row mt-10">
<label for="idBrandingWebClientLogo" data-i18n="branding.logo" class="col-md-3 col-form-label">Logo</label>
<div class="col-md-9">
<div class="image-input image-input-outline {{if eq .Branding.DefaultLogoPath .WebClientBranding.LogoPath}}image-input-empty{{end}} image-input-placeholder" data-kt-image-input="true">
<div class="image-input-wrapper w-125px h-125px image-input-webclient-current"></div>
<label class="btn btn-icon btn-circle btn-color-muted btn-active-color-primary w-25px h-25px bg-body shadow"
data-kt-image-input-action="change">
<i class="ki-duotone ki-pencil fs-6"><span class="path1"></span><span class="path2"></span></i>
<input type="file" id="idBrandingWebClientLogo" name="branding_webclient_logo" accept=".png" aria-describedby="idBrandingWebClientLogoHelp"/>
<input type="hidden"name="branding_webclient_logo_remove" />
</label>
<span class="btn btn-icon btn-circle btn-color-muted btn-active-color-primary w-25px h-25px bg-body shadow"
data-kt-image-input-action="cancel">
<i class="ki-outline ki-cross fs-3"></i>
</span>
<span class="btn btn-icon btn-circle btn-color-muted btn-active-color-primary w-25px h-25px bg-body shadow"
data-kt-image-input-action="remove">
<i class="ki-outline ki-cross fs-3"></i>
</span>
</div>
<div id="idBrandingWebClientLogoHelp" class="form-text mt-3" data-i18n="branding.logo_help"></div>
</div>
</div>
<div class="form-group row mt-10">
<label for="idBrandingWebClientFavicon" data-i18n="branding.favicon" class="col-md-3 col-form-label">Favicon</label>
<div class="col-md-9">
<div class="image-input image-input-outline {{if eq .Branding.DefaultFaviconPath .WebClientBranding.FaviconPath}}image-input-empty{{end}} image-input-placeholder" data-kt-image-input="true">
<div class="image-input-wrapper w-50px h-50px image-input-webclient-fav-current"></div>
<label class="btn btn-icon btn-circle btn-color-muted btn-active-color-primary w-25px h-25px bg-body shadow"
data-kt-image-input-action="change">
<i class="ki-duotone ki-pencil fs-6"><span class="path1"></span><span class="path2"></span></i>
<input type="file" id="idBrandingWebClientFavicon" name="branding_webclient_favicon" accept=".png" />
<input type="hidden"name="branding_webclient_favicon_remove" />
</label>
<span class="btn btn-icon btn-circle btn-color-muted btn-active-color-primary w-25px h-25px bg-body shadow"
data-kt-image-input-action="cancel">
<i class="ki-outline ki-cross fs-3"></i>
</span>
<span class="btn btn-icon btn-circle btn-color-muted btn-active-color-primary w-25px h-25px bg-body shadow"
data-kt-image-input-action="remove">
<i class="ki-outline ki-cross fs-3"></i>
</span>
</div>
</div>
</div>
<div class="form-group row mt-10">
<label for="idBrandingWebClientDisclaimerName" data-i18n="branding.disclaimer_name" class="col-md-3 col-form-label">Disclaimer Name</label>
<div class="col-md-9">
<input id="idBrandingWebClientDisclaimerName" type="text" placeholder="" name="branding_webclient_disclaimer_name" value="{{.Configs.Branding.WebClient.DisclaimerName}}" maxlength="255"
class="form-control" />
</div>
</div>
<div class="form-group row mt-10">
<label for="idBrandingWebClientDisclaimerURL" data-i18n="branding.disclaimer_url" class="col-md-3 col-form-label">Disclaimer URL</label>
<div class="col-md-9">
<input id="idBrandingWebClientDisclaimerURL" type="text" placeholder="" name="branding_webclient_disclaimer_url" value="{{.Configs.Branding.WebClient.DisclaimerURL}}" maxlength="1024"
class="form-control" />
</div>
</div>
</div>
</div>
<div class="d-flex justify-content-end mt-12">
<input type="hidden" name="_form_token" value="{{.CSRFToken}}">
<input type="hidden" name="form_action" value="branding_submit">
<button type="submit" id="branding_form_submit" class="btn btn-primary px-10">
<span data-i18n="general.submit" class="indicator-label">
Submit
</span>
<span data-i18n="general.wait" class="indicator-progress">
Please wait...
<span class="spinner-border spinner-border-sm align-middle ms-2"></span>
</span>
</button>
</div>
</form>
</div>
</div>
</div>
</div>
</div>
</div>
@ -591,6 +836,12 @@ explicit grant from the SFTPGo Team (support@sftpgo.com).
submitButton.setAttribute('data-kt-indicator', 'on');
submitButton.disabled = true;
});
$('#configs_branding_form').submit(function (event) {
let submitButton = document.querySelector('#branding_form_submit');
submitButton.setAttribute('data-kt-indicator', 'on');
submitButton.disabled = true;
});
});
</script>
{{- end}}

View file

@ -26,6 +26,7 @@ explicit grant from the SFTPGo Team (support@sftpgo.com).
</i>
</div>
<div class="menu menu-sub menu-sub-dropdown menu-column menu-rounded menu-title-gray-700 menu-icon-gray-500 menu-active-bg menu-state-color fw-semibold py-4 w-250px" data-kt-menu="true">
{{- if not .IsLoggedToShare }}
<div class="menu-item px-3 my-0">
<div class="menu-content d-flex align-items-center px-3 py-2">
<div class="me-5">
@ -54,6 +55,7 @@ explicit grant from the SFTPGo Team (support@sftpgo.com).
</a>
</div>
{{- end}}
{{- end}}
<div class="menu-item px-3 my-0">
<a id="id_logout_link" href="#" class="menu-link px-3 py-2">
<span data-i18n="login.signout" class="menu-title">Sign out</span>

View file

@ -25,9 +25,16 @@ explicit grant from the SFTPGo Team (support@sftpgo.com).
{{.Branding.ShortName}}
</span>
</div>
<div class="card shadow-sm w-lg-600px">
<div class="card shadow-sm w-lg-600px w-md-450px">
<div class="card-header bg-light">
<h3 data-i18n="fs.download_ready" class="card-title section-title">Your download is ready</h3>
{{- if .LogoutURL}}
<div class="card-toolbar">
<a id="id_logout_link" href="#" class="btn btn-light-primary">
<span data-i18n="login.signout">Sign out</span>
</a>
</div>
{{- end}}
</div>
<div class="card-body">
<div>

View file

@ -25,9 +25,16 @@ explicit grant from the SFTPGo Team (support@sftpgo.com).
{{.Branding.ShortName}}
</span>
</div>
<div class="card shadow-sm w-lg-600px">
<div class="card shadow-sm w-lg-600px w-md-450px">
<div class="card-header bg-light">
<h3 data-i18n="title.upload_to_share" class="card-title section-title">Upload one or more files to share</h3>
{{- if .LogoutURL}}
<div class="card-toolbar">
<a id="id_logout_link" href="#" class="btn btn-light-primary">
<span data-i18n="login.signout">Sign out</span>
</a>
</div>
{{- end}}
</div>
<div class="card-body">
{{- template "errmsg" ""}}

View file

@ -20,7 +20,7 @@ explicit grant from the SFTPGo Team (support@sftpgo.com).
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="robots" content="noindex">
<link rel="shortcut icon" href="{{.StaticURL}}{{.Branding.FaviconPath}}" />
<link rel="icon" type="image/png" href="{{.StaticURL}}{{.Branding.FaviconPath}}" />
{{range .Branding.ExtraCSS}}
<link href="{{$.StaticURL}}{{.}}" rel="stylesheet" type="text/css">