Compare commits
217 commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
37f8fb3a0e | ||
![]() |
484bda7940 | ||
![]() |
deea9ff038 | ||
![]() |
91340bbe2f | ||
![]() |
e689d52dca | ||
![]() |
22f80b97f0 | ||
![]() |
dee3f3f87a | ||
![]() |
d2c5a6a914 | ||
![]() |
1a7f346b51 | ||
![]() |
bb579e36db | ||
![]() |
843b8c38d3 | ||
![]() |
70fc00d7eb | ||
![]() |
9f873d1059 | ||
![]() |
b0061f570e | ||
![]() |
bfe6c58133 | ||
![]() |
8c5f92aeb1 | ||
![]() |
ec90b61bb4 | ||
![]() |
6a72552754 | ||
![]() |
1ce408e673 | ||
![]() |
d3db80dc32 | ||
![]() |
c56be285a5 | ||
![]() |
599ee5a58f | ||
![]() |
7703f57122 | ||
![]() |
b8a4ea50bd | ||
![]() |
49f2555914 | ||
![]() |
e21c989038 | ||
![]() |
f8bdb84e8d | ||
![]() |
e161015c67 | ||
![]() |
cbd7fc917e | ||
![]() |
6a7c8df1ef | ||
![]() |
d3e76898cd | ||
![]() |
0f9314f900 | ||
![]() |
502e3658e0 | ||
![]() |
0e77ba9546 | ||
![]() |
10b2e5671b | ||
![]() |
ebc085da77 | ||
![]() |
4a414f0fa4 | ||
![]() |
7a12db6cdb | ||
![]() |
f30a9a2095 | ||
![]() |
ed5ff9c5cc | ||
![]() |
59833fba0d | ||
![]() |
a79cb30cdc | ||
![]() |
e1cd69d5ff | ||
![]() |
85333087fa | ||
![]() |
5ddac4b3b4 | ||
![]() |
c37b7f0493 | ||
![]() |
5896c1b7a5 | ||
![]() |
0f073a40fd | ||
![]() |
618723c457 | ||
![]() |
4cb6acefb2 | ||
![]() |
f22ec2275f | ||
![]() |
7bffed712a | ||
![]() |
f30d6ad82a | ||
![]() |
b524da11e9 | ||
![]() |
3dd412f6e3 | ||
![]() |
ef98ee7d11 | ||
![]() |
30fb1d6240 | ||
![]() |
7aac64531f | ||
![]() |
03724d5eb1 | ||
![]() |
4eb4ff66ce | ||
![]() |
0bff3e1a67 | ||
![]() |
82b437c502 | ||
![]() |
88b1850b58 | ||
![]() |
60558de728 | ||
![]() |
beff4432dc | ||
![]() |
9ae0bc4ec4 | ||
![]() |
21bd8c5660 | ||
![]() |
97bb004c12 | ||
![]() |
e4e31ec4fb | ||
![]() |
259986ed1d | ||
![]() |
0c75d234b9 | ||
![]() |
ae1487d733 | ||
![]() |
c69fbe6bf9 | ||
![]() |
8d697bcc94 | ||
![]() |
7e7005f5b3 | ||
![]() |
12a210e1f6 | ||
![]() |
169d8f6223 | ||
![]() |
cd3147c654 | ||
![]() |
7feeec6941 | ||
![]() |
12d888f49d | ||
![]() |
ca41b59fc4 | ||
![]() |
77b2f8dfb3 | ||
![]() |
d8691d1e1a | ||
![]() |
5cb1b9c1e9 | ||
![]() |
b23e67ae6a | ||
![]() |
8e7086ab39 | ||
![]() |
dc907c0ba3 | ||
![]() |
eba4c93efd | ||
![]() |
bdd6de10a5 | ||
![]() |
66e1e7ac2b | ||
![]() |
4103344989 | ||
![]() |
0c470b9202 | ||
![]() |
57309dcd5f | ||
![]() |
72ba54b5be | ||
![]() |
18bf0c6121 | ||
![]() |
f88ce014df | ||
![]() |
3b2f709aeb | ||
![]() |
6626c8846b | ||
![]() |
2ecd20d444 | ||
![]() |
46e64706ea | ||
![]() |
bdc5493593 | ||
![]() |
424999dacd | ||
![]() |
2ec6aecc5d | ||
![]() |
addee12510 | ||
![]() |
57c0ca90e5 | ||
![]() |
85c65dcad3 | ||
![]() |
a34d9bc916 | ||
![]() |
27e98b85ce | ||
![]() |
ae5ecbe909 | ||
![]() |
126cb1ee0d | ||
![]() |
eeef23139d | ||
![]() |
433d45ed87 | ||
![]() |
5f67fcdce5 | ||
![]() |
9288010636 | ||
![]() |
5162c5de87 | ||
![]() |
c2aed5ee92 | ||
![]() |
44ebf2f48d | ||
![]() |
6896d2bfb1 | ||
![]() |
14cabda5c2 | ||
![]() |
8cf0491b65 | ||
![]() |
1f46df0d60 | ||
![]() |
1b928ef6b2 | ||
![]() |
db35a55a3d | ||
![]() |
fd6126134e | ||
![]() |
3b5fba2eec | ||
![]() |
eb5ffb940e | ||
![]() |
bb422ad5b9 | ||
![]() |
53c3905ce3 | ||
![]() |
dc42680e1c | ||
![]() |
56ef9355da | ||
![]() |
d8e4978b61 | ||
![]() |
b9b370fbb8 | ||
![]() |
dfeca3a972 | ||
![]() |
a2934deaa6 | ||
![]() |
2fbf608895 | ||
![]() |
d783ffc13f | ||
![]() |
62426d25da | ||
![]() |
fa710b36c2 | ||
![]() |
321c3f00d2 | ||
![]() |
ec4bf3d76a | ||
![]() |
68e62d3d9b | ||
![]() |
954c36c0a2 | ||
![]() |
81433e00d1 | ||
![]() |
a5c5e85144 | ||
![]() |
b94451f731 | ||
![]() |
4edecc5c77 | ||
![]() |
51e9a689a6 | ||
![]() |
aa920432f3 | ||
![]() |
ce189e5065 | ||
![]() |
00155eaaf6 | ||
![]() |
d94f80c8da | ||
![]() |
bd5eb03d9c | ||
![]() |
6ba1198c47 | ||
![]() |
b5c821795a | ||
![]() |
b2926377b7 | ||
![]() |
99f47ca4e7 | ||
![]() |
fef388d8cb | ||
![]() |
92849ca473 | ||
![]() |
0952887157 | ||
![]() |
d010b26e1c | ||
![]() |
58de410850 | ||
![]() |
54bc3ea87d | ||
![]() |
64a2f7aa4f | ||
![]() |
55be9f0b9c | ||
![]() |
97ffa0394f | ||
![]() |
dc91ec2056 | ||
![]() |
356795f8b0 | ||
![]() |
3efcd94e14 | ||
![]() |
34bc21b3b7 | ||
![]() |
37845c2936 | ||
![]() |
47924716c1 | ||
![]() |
1d60505629 | ||
![]() |
9daf0ba767 | ||
![]() |
bdae378569 | ||
![]() |
363770ab84 | ||
![]() |
8bc08b25dc | ||
![]() |
e0c1b974c9 | ||
![]() |
39cf9f6943 | ||
![]() |
d650defa08 | ||
![]() |
c5c42f072b | ||
![]() |
bd5b32101f | ||
![]() |
8208ac817d | ||
![]() |
a99c4879de | ||
![]() |
01b666a78f | ||
![]() |
8294952474 | ||
![]() |
7fb5b1b996 | ||
![]() |
2749a98f26 | ||
![]() |
08526da153 | ||
![]() |
8269adf176 | ||
![]() |
0cddcba5a7 | ||
![]() |
3bd1eeacc1 | ||
![]() |
1698ec2eb3 | ||
![]() |
07710ad98d | ||
![]() |
f63bf7093c | ||
![]() |
0597bf1047 | ||
![]() |
5bde4b92a2 | ||
![]() |
faa994e3b3 | ||
![]() |
68cc1a8e2c | ||
![]() |
9c775e2213 | ||
![]() |
6c94173ca1 | ||
![]() |
d1e0560d28 | ||
![]() |
52a94b2593 | ||
![]() |
9550fd2921 | ||
![]() |
a6549b08f9 | ||
![]() |
ba3e2ecb5f | ||
![]() |
2bd3b46e3f | ||
![]() |
7831ddaede | ||
![]() |
613f2f1c24 | ||
![]() |
525f33a07a | ||
![]() |
3f2604d33f | ||
![]() |
b823bb04d2 | ||
![]() |
9ba92d9495 | ||
![]() |
0127fc188b | ||
![]() |
3c7a651d27 | ||
![]() |
50a3c0d911 | ||
![]() |
b2bea85add | ||
![]() |
61bc0065f9 |
150 changed files with 7094 additions and 2620 deletions
|
@ -8,11 +8,11 @@ freebsd_task:
|
|||
|
||||
pkginstall_script:
|
||||
- pkg update -f
|
||||
- pkg install -y go122
|
||||
- pkg install -y go123
|
||||
- pkg install -y git
|
||||
|
||||
setup_script:
|
||||
- ln -s /usr/local/bin/go122 /usr/local/bin/go
|
||||
- ln -s /usr/local/bin/go123 /usr/local/bin/go
|
||||
- pw groupadd sftpgo
|
||||
- pw useradd sftpgo -g sftpgo -w none -m
|
||||
- mkdir /home/sftpgo/sftpgo
|
||||
|
|
2
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
|
@ -8,7 +8,7 @@ body:
|
|||
value: |
|
||||
### 👍 Thank you for contributing to our project!
|
||||
Before asking for help please check the [support policy](https://github.com/drakkan/sftpgo#support-policy).
|
||||
If you are a commercial user or a project sponsor please contact us using the dedicated [email address](mailto:support@sftpgo.com).
|
||||
If you are a [commercial user](https://sftpgo.com/) or a project sponsor please contact us using the dedicated [email address](mailto:support@sftpgo.com).
|
||||
- type: checkboxes
|
||||
id: before-posting
|
||||
attributes:
|
||||
|
|
14
.github/workflows/development.yml
vendored
14
.github/workflows/development.yml
vendored
|
@ -2,7 +2,7 @@ name: CI
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [2.6.x]
|
||||
branches: [main]
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
|
@ -15,7 +15,7 @@ jobs:
|
|||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.22']
|
||||
go: ['1.23']
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
|
||||
steps:
|
||||
|
@ -50,7 +50,7 @@ jobs:
|
|||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
file: ./coverage.txt
|
||||
files: ./coverage.txt
|
||||
fail_ci_if_error: false
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
|
@ -119,7 +119,7 @@ jobs:
|
|||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version: '1.23'
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
|
@ -326,7 +326,7 @@ jobs:
|
|||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version: '1.23'
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
|
@ -390,7 +390,7 @@ jobs:
|
|||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version: '1.23'
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
|
@ -611,7 +611,7 @@ jobs:
|
|||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version: '1.23'
|
||||
- uses: actions/checkout@v4
|
||||
- name: Run golangci-lint
|
||||
uses: golangci/golangci-lint-action@v6
|
||||
|
|
4
.github/workflows/docker.yml
vendored
4
.github/workflows/docker.yml
vendored
|
@ -5,7 +5,7 @@ on:
|
|||
# - cron: '0 4 * * *' # everyday at 4:00 AM UTC
|
||||
push:
|
||||
branches:
|
||||
- 2.6.x
|
||||
- main
|
||||
tags:
|
||||
- v*
|
||||
pull_request:
|
||||
|
@ -163,7 +163,7 @@ jobs:
|
|||
if: ${{ github.event_name != 'pull_request' }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
builder: ${{ steps.builder.outputs.name }}
|
||||
|
|
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
|
@ -9,7 +9,7 @@ permissions:
|
|||
contents: write
|
||||
|
||||
env:
|
||||
GO_VERSION: 1.22.9
|
||||
GO_VERSION: 1.23.3
|
||||
|
||||
jobs:
|
||||
prepare-sources-with-deps:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
FROM golang:1.22-bookworm as builder
|
||||
FROM golang:1.23-bookworm AS builder
|
||||
|
||||
ENV GOFLAGS="-mod=readonly"
|
||||
|
||||
|
@ -10,7 +10,7 @@ WORKDIR /workspace
|
|||
ARG GOPROXY
|
||||
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
RUN go mod download && go mod verify
|
||||
|
||||
ARG COMMIT_SHA
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
FROM golang:1.22-alpine3.19 AS builder
|
||||
FROM golang:1.23-alpine3.21 AS builder
|
||||
|
||||
ENV GOFLAGS="-mod=readonly"
|
||||
|
||||
|
@ -10,7 +10,7 @@ WORKDIR /workspace
|
|||
ARG GOPROXY
|
||||
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
RUN go mod download && go mod verify
|
||||
|
||||
ARG COMMIT_SHA
|
||||
|
||||
|
@ -25,7 +25,7 @@ RUN set -xe && \
|
|||
export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --abbrev=8 --dirty)} && \
|
||||
go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -v -o sftpgo
|
||||
|
||||
FROM alpine:3.19
|
||||
FROM alpine:3.21
|
||||
|
||||
# Set to "true" to install jq and the optional git and rsync dependencies
|
||||
ARG INSTALL_OPTIONAL_PACKAGES=false
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
FROM golang:1.22-bookworm as builder
|
||||
FROM golang:1.23-bookworm AS builder
|
||||
|
||||
ENV CGO_ENABLED=0 GOFLAGS="-mod=readonly"
|
||||
|
||||
|
@ -10,7 +10,7 @@ WORKDIR /workspace
|
|||
ARG GOPROXY
|
||||
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
RUN go mod download && go mod verify
|
||||
|
||||
ARG COMMIT_SHA
|
||||
|
||||
|
|
|
@ -58,6 +58,12 @@ We offer commercial support, guarantees, and advice for SFTPGo:
|
|||
|
||||
You can read more about supported features and documentation at [docs.sftpgo.com](https://docs.sftpgo.com/).
|
||||
|
||||
## Internationalization
|
||||
|
||||
The translations are available via [Crowdin](https://crowdin.com/project/sftpgo), who have granted us an open source license.
|
||||
|
||||
Before start translating please take a look at our contribution [guidelines](https://sftpgo.github.io/latest/web-interfaces/#internationalization).
|
||||
|
||||
## Release Cadence
|
||||
|
||||
SFTPGo releases are feature-driven, we don't have a fixed time based schedule. As a rough estimate, you can expect 1 or 2 new major releases per year and several bug fix releases.
|
||||
|
|
8
go.mod
8
go.mod
|
@ -35,11 +35,11 @@ require (
|
|||
github.com/hashicorp/go-hclog v1.6.3
|
||||
github.com/hashicorp/go-plugin v1.6.2
|
||||
github.com/hashicorp/go-retryablehttp v0.7.7
|
||||
github.com/jackc/pgx/v5 v5.7.1
|
||||
github.com/jackc/pgx/v5 v5.7.2
|
||||
github.com/jlaffaye/ftp v0.2.0
|
||||
github.com/klauspost/compress v1.17.11
|
||||
github.com/lestrrat-go/jwx/v2 v2.1.3
|
||||
github.com/lithammer/shortuuid/v3 v3.0.7
|
||||
github.com/lithammer/shortuuid/v4 v4.2.0
|
||||
github.com/mattn/go-sqlite3 v1.14.24
|
||||
github.com/mhale/smtpd v0.8.3
|
||||
github.com/minio/sio v0.4.1
|
||||
|
@ -52,7 +52,7 @@ require (
|
|||
github.com/rs/cors v1.11.1
|
||||
github.com/rs/xid v1.6.0
|
||||
github.com/rs/zerolog v1.33.0
|
||||
github.com/sftpgo/sdk v0.1.8
|
||||
github.com/sftpgo/sdk v0.1.9-0.20241011171103-64fc18a344f9
|
||||
github.com/shirou/gopsutil/v3 v3.24.5
|
||||
github.com/spf13/afero v1.11.0
|
||||
github.com/spf13/cobra v1.8.1
|
||||
|
@ -199,5 +199,5 @@ replace (
|
|||
github.com/fclairamb/ftpserverlib => github.com/drakkan/ftpserverlib v0.0.0-20240603150004-6a8f643fbf2e
|
||||
github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f
|
||||
github.com/robfig/cron/v3 => github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0
|
||||
golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20241215105839-39c227bf1e16
|
||||
golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20241215104834-a9cd4736223d
|
||||
)
|
||||
|
|
17
go.sum
17
go.sum
|
@ -134,8 +134,8 @@ github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 h1:rpfIENRNNilwHwZeG5+P150SMrnN
|
|||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0=
|
||||
github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0 h1:EW9gIJRmt9lzk66Fhh4S8VEtURA6QHZqGeSRE9Nb2/U=
|
||||
github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
github.com/drakkan/crypto v0.0.0-20241215105839-39c227bf1e16 h1:L5sDMg/hkI4OLReZ9niFvGAJs0vhMtWJreTeRXliNrQ=
|
||||
github.com/drakkan/crypto v0.0.0-20241215105839-39c227bf1e16/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
github.com/drakkan/crypto v0.0.0-20241215104834-a9cd4736223d h1:xpQVtm9fMX+Moy260dDuYlKyER0L8jj7WJF6bcLXtL4=
|
||||
github.com/drakkan/crypto v0.0.0-20241215104834-a9cd4736223d/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f h1:S9JUlrOzjK58UKoLqqb40YLyVlt0bcIFtYrvnanV3zc=
|
||||
github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f/go.mod h1:4p8lUl4vQ80L598CygL+3IFtm+3nggvvW/palOlViwE=
|
||||
github.com/drakkan/ftpserverlib v0.0.0-20240603150004-6a8f643fbf2e h1:VBpqQeChkGXSV1FXCtvd3BJTyB+DcMgiu7SfkpsGuKw=
|
||||
|
@ -225,7 +225,6 @@ github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO
|
|||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/wire v0.6.0 h1:HBkoIh4BdSxoyo9PveV8giw7ZsaBOvzWKfcg/6MrVwI=
|
||||
|
@ -257,8 +256,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
|
|||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs=
|
||||
github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA=
|
||||
github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI=
|
||||
github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jhump/protoreflect v1.15.1 h1:HUMERORf3I3ZdX05WaQ6MIpd/NJ434hTp5YiKgfCL6c=
|
||||
|
@ -287,8 +286,8 @@ github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNB
|
|||
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/lithammer/shortuuid/v3 v3.0.7 h1:trX0KTHy4Pbwo/6ia8fscyHoGA+mf1jWbPJVuvyJQQ8=
|
||||
github.com/lithammer/shortuuid/v3 v3.0.7/go.mod h1:vMk8ke37EmiewwolSO1NLW8vP4ZaKlRuDIi8tWWmAts=
|
||||
github.com/lithammer/shortuuid/v4 v4.2.0 h1:LMFOzVB3996a7b8aBuEXxqOBflbfPQAiVzkIcHO0h8c=
|
||||
github.com/lithammer/shortuuid/v4 v4.2.0/go.mod h1:D5noHZ2oFw/YaKCfGy0YxyE7M0wMbezmMjPdhyEFe6Y=
|
||||
github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683 h1:7UMa6KCCMjZEMDtTVdcGu0B1GmmC7QJKiCCjyTAWQy0=
|
||||
github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
|
||||
github.com/magiconair/properties v1.8.9 h1:nWcCbLq1N2v/cpNsy5WvQ37Fb+YElfq20WJ/a8RkpQM=
|
||||
|
@ -369,8 +368,8 @@ github.com/secsy/goftp v0.0.0-20200609142545-aa2de14babf4 h1:PT+ElG/UUFMfqy5HrxJ
|
|||
github.com/secsy/goftp v0.0.0-20200609142545-aa2de14babf4/go.mod h1:MnkX001NG75g3p8bhFycnyIjeQoOjGL6CEIsdE/nKSY=
|
||||
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
|
||||
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
|
||||
github.com/sftpgo/sdk v0.1.8 h1:HAywJl9jZnigFGztA/CWLieOW+R+HH6js6o6/qYvuSY=
|
||||
github.com/sftpgo/sdk v0.1.8/go.mod h1:Isl0IEzS/Muvh8Fr4X+NWFsOS/fZQHRD4oPQpoY7C4g=
|
||||
github.com/sftpgo/sdk v0.1.9-0.20241011171103-64fc18a344f9 h1:wlXBnaNfJJJRZjHO2AerSS5gp0ckkYUgBzSXivUo0Wo=
|
||||
github.com/sftpgo/sdk v0.1.9-0.20241011171103-64fc18a344f9/go.mod h1:ehimvlTP+XTEiE3t1CPwWx9n7+6A6OGvMGlZ7ouvKFk=
|
||||
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
|
||||
github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk=
|
||||
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
||||
|
|
|
@ -30,6 +30,7 @@ import (
|
|||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -43,6 +44,7 @@ import (
|
|||
"github.com/go-acme/lego/v4/log"
|
||||
"github.com/go-acme/lego/v4/providers/http/webroot"
|
||||
"github.com/go-acme/lego/v4/registration"
|
||||
"github.com/hashicorp/go-retryablehttp"
|
||||
"github.com/robfig/cron/v3"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/common"
|
||||
|
@ -249,7 +251,7 @@ func (c *Configuration) Initialize(configDir string) error {
|
|||
if c.RenewDays < 1 {
|
||||
return fmt.Errorf("invalid number of days remaining before renewal: %d", c.RenewDays)
|
||||
}
|
||||
if !util.Contains(supportedKeyTypes, c.KeyType) {
|
||||
if !slices.Contains(supportedKeyTypes, c.KeyType) {
|
||||
return fmt.Errorf("invalid key type %q", c.KeyType)
|
||||
}
|
||||
caURL, err := url.Parse(c.CAEndpoint)
|
||||
|
@ -489,7 +491,15 @@ func (c *Configuration) setup() (*account, *lego.Client, error) {
|
|||
config := lego.NewConfig(&account)
|
||||
config.CADirURL = c.CAEndpoint
|
||||
config.Certificate.KeyType = certcrypto.KeyType(c.KeyType)
|
||||
config.Certificate.OverallRequestLimit = 6
|
||||
config.UserAgent = version.GetServerVersion("/", false)
|
||||
|
||||
retryClient := retryablehttp.NewClient()
|
||||
retryClient.RetryMax = 5
|
||||
retryClient.HTTPClient = config.HTTPClient
|
||||
|
||||
config.HTTPClient = retryClient.StandardClient()
|
||||
|
||||
client, err := lego.NewClient(config)
|
||||
if err != nil {
|
||||
acmeLog(logger.LevelError, "unable to get ACME client: %v", err)
|
||||
|
@ -557,6 +567,12 @@ func (c *Configuration) tryRecoverRegistration(privateKey crypto.PrivateKey) (*r
|
|||
config.CADirURL = c.CAEndpoint
|
||||
config.UserAgent = version.GetServerVersion("/", false)
|
||||
|
||||
retryClient := retryablehttp.NewClient()
|
||||
retryClient.RetryMax = 5
|
||||
retryClient.HTTPClient = config.HTTPClient
|
||||
|
||||
config.HTTPClient = retryClient.StandardClient()
|
||||
|
||||
client, err := lego.NewClient(config)
|
||||
if err != nil {
|
||||
acmeLog(logger.LevelError, "unable to get the ACME client: %v", err)
|
||||
|
|
|
@ -40,8 +40,8 @@ Please take a look at the usage below to customize the options.`,
|
|||
Run: func(_ *cobra.Command, _ []string) {
|
||||
logger.DisableLogger()
|
||||
logger.EnableConsoleLogger(zerolog.DebugLevel)
|
||||
if revertProviderTargetVersion != 28 {
|
||||
logger.WarnToConsole("Unsupported target version, 28 is the only supported one")
|
||||
if revertProviderTargetVersion != 29 {
|
||||
logger.WarnToConsole("Unsupported target version, 29 is the only supported one")
|
||||
os.Exit(1)
|
||||
}
|
||||
configDir = util.CleanDirInput(configDir)
|
||||
|
@ -71,7 +71,7 @@ Please take a look at the usage below to customize the options.`,
|
|||
|
||||
func init() {
|
||||
addConfigFlags(revertProviderCmd)
|
||||
revertProviderCmd.Flags().IntVar(&revertProviderTargetVersion, "to-version", 28, `28 means the version supported in v2.5.x`)
|
||||
revertProviderCmd.Flags().IntVar(&revertProviderTargetVersion, "to-version", 29, `29 means the version supported in v2.6.x`)
|
||||
|
||||
rootCmd.AddCommand(revertProviderCmd)
|
||||
}
|
||||
|
|
|
@ -17,10 +17,9 @@ package command
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -117,7 +116,7 @@ func (c Config) Initialize() error {
|
|||
}
|
||||
// don't validate args, we allow to pass empty arguments
|
||||
if cmd.Hook != "" {
|
||||
if !util.Contains(supportedHooks, cmd.Hook) {
|
||||
if !slices.Contains(supportedHooks, cmd.Hook) {
|
||||
return fmt.Errorf("invalid hook name %q, supported values: %+v", cmd.Hook, supportedHooks)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -86,7 +87,7 @@ func InitializeActionHandler(handler ActionHandler) {
|
|||
func ExecutePreAction(conn *BaseConnection, operation, filePath, virtualPath string, fileSize int64, openFlags int) (int, error) {
|
||||
var event *notifier.FsEvent
|
||||
hasNotifiersPlugin := plugin.Handler.HasNotifiers()
|
||||
hasHook := util.Contains(Config.Actions.ExecuteOn, operation)
|
||||
hasHook := slices.Contains(Config.Actions.ExecuteOn, operation)
|
||||
hasRules := eventManager.hasFsRules()
|
||||
if !hasHook && !hasNotifiersPlugin && !hasRules {
|
||||
return 0, nil
|
||||
|
@ -133,7 +134,7 @@ func ExecuteActionNotification(conn *BaseConnection, operation, filePath, virtua
|
|||
fileSize int64, err error, elapsed int64, metadata map[string]string,
|
||||
) error {
|
||||
hasNotifiersPlugin := plugin.Handler.HasNotifiers()
|
||||
hasHook := util.Contains(Config.Actions.ExecuteOn, operation)
|
||||
hasHook := slices.Contains(Config.Actions.ExecuteOn, operation)
|
||||
hasRules := eventManager.hasFsRules()
|
||||
if !hasHook && !hasNotifiersPlugin && !hasRules {
|
||||
return nil
|
||||
|
@ -175,7 +176,7 @@ func ExecuteActionNotification(conn *BaseConnection, operation, filePath, virtua
|
|||
}
|
||||
}
|
||||
if hasHook {
|
||||
if util.Contains(Config.Actions.ExecuteSync, operation) {
|
||||
if slices.Contains(Config.Actions.ExecuteSync, operation) {
|
||||
_, err := actionHandler.Handle(notification)
|
||||
return err
|
||||
}
|
||||
|
@ -250,7 +251,7 @@ func newActionNotification(
|
|||
type defaultActionHandler struct{}
|
||||
|
||||
func (h *defaultActionHandler) Handle(event *notifier.FsEvent) (int, error) {
|
||||
if !util.Contains(Config.Actions.ExecuteOn, event.Action) {
|
||||
if !slices.Contains(Config.Actions.ExecuteOn, event.Action) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lithammer/shortuuid/v3"
|
||||
"github.com/lithammer/shortuuid/v4"
|
||||
"github.com/rs/xid"
|
||||
"github.com/sftpgo/sdk"
|
||||
"github.com/sftpgo/sdk/plugin/notifier"
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -124,6 +125,9 @@ func init() {
|
|||
Connections.clients = clientsMap{
|
||||
clients: make(map[string]int),
|
||||
}
|
||||
Connections.transfers = clientsMap{
|
||||
clients: make(map[string]int),
|
||||
}
|
||||
Connections.perUserConns = make(map[string]int)
|
||||
Connections.mapping = make(map[string]int)
|
||||
Connections.sshMapping = make(map[string]int)
|
||||
|
@ -163,13 +167,20 @@ var (
|
|||
rateLimiters map[string][]*rateLimiter
|
||||
isShuttingDown atomic.Bool
|
||||
ftpLoginCommands = []string{"PASS", "USER"}
|
||||
fnUpdateBranding func(*dataprovider.BrandingConfigs)
|
||||
)
|
||||
|
||||
// SetUpdateBrandingFn sets the function to call to update branding configs.
|
||||
func SetUpdateBrandingFn(fn func(*dataprovider.BrandingConfigs)) {
|
||||
fnUpdateBranding = fn
|
||||
}
|
||||
|
||||
// Initialize sets the common configuration
|
||||
func Initialize(c Configuration, isShared int) error {
|
||||
isShuttingDown.Store(false)
|
||||
util.SetUmask(c.Umask)
|
||||
version.SetConfig(c.ServerVersion)
|
||||
dataprovider.SetTZ(c.TZ)
|
||||
Config = c
|
||||
Config.Actions.ExecuteOn = util.RemoveDuplicates(Config.Actions.ExecuteOn, true)
|
||||
Config.Actions.ExecuteSync = util.RemoveDuplicates(Config.Actions.ExecuteSync, true)
|
||||
|
@ -200,7 +211,7 @@ func Initialize(c Configuration, isShared int) error {
|
|||
Config.rateLimitersList = rateLimitersList
|
||||
}
|
||||
if c.DefenderConfig.Enabled {
|
||||
if !util.Contains(supportedDefenderDrivers, c.DefenderConfig.Driver) {
|
||||
if !slices.Contains(supportedDefenderDrivers, c.DefenderConfig.Driver) {
|
||||
return fmt.Errorf("unsupported defender driver %q", c.DefenderConfig.Driver)
|
||||
}
|
||||
var defender Defender
|
||||
|
@ -406,6 +417,23 @@ func AddDefenderEvent(ip, protocol string, event HostEvent) bool {
|
|||
return Config.defender.AddEvent(ip, protocol, event)
|
||||
}
|
||||
|
||||
func reloadProviderConfigs() {
|
||||
configs, err := dataprovider.GetConfigs()
|
||||
if err != nil {
|
||||
logger.Error(logSender, "", "unable to load config from provider: %v", err)
|
||||
return
|
||||
}
|
||||
configs.SetNilsToEmpty()
|
||||
if fnUpdateBranding != nil {
|
||||
fnUpdateBranding(configs.Branding)
|
||||
}
|
||||
if err := configs.SMTP.TryDecrypt(); err != nil {
|
||||
logger.Error(logSender, "", "unable to decrypt smtp config: %v", err)
|
||||
return
|
||||
}
|
||||
smtp.Activate(configs.SMTP)
|
||||
}
|
||||
|
||||
func startPeriodicChecks(duration time.Duration, isShared int) {
|
||||
startEventScheduler()
|
||||
spec := fmt.Sprintf("@every %s", duration)
|
||||
|
@ -414,7 +442,7 @@ func startPeriodicChecks(duration time.Duration, isShared int) {
|
|||
logger.Info(logSender, "", "scheduled overquota transfers check, schedule %q", spec)
|
||||
if isShared == 1 {
|
||||
logger.Info(logSender, "", "add reload configs task")
|
||||
_, err := eventScheduler.AddFunc("@every 10m", smtp.ReloadProviderConf)
|
||||
_, err := eventScheduler.AddFunc("@every 10m", reloadProviderConfigs)
|
||||
util.PanicOnError(err)
|
||||
}
|
||||
if Config.IdleTimeout > 0 {
|
||||
|
@ -609,6 +637,10 @@ type Configuration struct {
|
|||
Umask string `json:"umask" mapstructure:"umask"`
|
||||
// Defines the server version
|
||||
ServerVersion string `json:"server_version" mapstructure:"server_version"`
|
||||
// TZ defines the time zone to use for the EventManager scheduler and to
|
||||
// control time-based access restrictions. Set to "local" to use the
|
||||
// server's local time, otherwise UTC will be used.
|
||||
TZ string `json:"tz" mapstructure:"tz"`
|
||||
// Metadata configuration
|
||||
Metadata MetadataConfig `json:"metadata" mapstructure:"metadata"`
|
||||
// EventManager configuration
|
||||
|
@ -645,7 +677,7 @@ func (c *Configuration) initializeProxyProtocol() error {
|
|||
|
||||
// GetProxyListener returns a wrapper for the given listener that supports the
|
||||
// HAProxy Proxy Protocol
|
||||
func (c *Configuration) GetProxyListener(listener net.Listener) (*proxyproto.Listener, error) {
|
||||
func (c *Configuration) GetProxyListener(listener net.Listener) (net.Listener, error) {
|
||||
if c.ProxyProtocol > 0 {
|
||||
defaultPolicy := proxyproto.REQUIRE
|
||||
if c.ProxyProtocol == 1 {
|
||||
|
@ -772,7 +804,7 @@ func (c *Configuration) checkPostDisconnectHook(remoteAddr, protocol, username,
|
|||
if c.PostDisconnectHook == "" {
|
||||
return
|
||||
}
|
||||
if !util.Contains(disconnHookProtocols, protocol) {
|
||||
if !slices.Contains(disconnHookProtocols, protocol) {
|
||||
return
|
||||
}
|
||||
go c.executePostDisconnectHook(remoteAddr, protocol, username, connID, connectionTime)
|
||||
|
@ -835,6 +867,7 @@ func getProxyPolicy(allowed, skipped []func(net.IP) bool, def proxyproto.Policy)
|
|||
if err != nil {
|
||||
// Something is wrong with the source IP, better reject the
|
||||
// connection.
|
||||
logger.Error(logSender, "", "reject connection from ip %q, err: %v", connPolicyOptions.Upstream, err)
|
||||
return proxyproto.REJECT, proxyproto.ErrInvalidUpstream
|
||||
}
|
||||
|
||||
|
@ -904,7 +937,9 @@ func (c *SSHConnection) Close() error {
|
|||
type ActiveConnections struct {
|
||||
// clients contains both authenticated and estabilished connections and the ones waiting
|
||||
// for authentication
|
||||
clients clientsMap
|
||||
clients clientsMap
|
||||
// transfers contains active transfers, total and per-user
|
||||
transfers clientsMap
|
||||
transfersCheckStatus atomic.Bool
|
||||
sync.RWMutex
|
||||
connections []ActiveConnection
|
||||
|
@ -955,6 +990,9 @@ func (conns *ActiveConnections) Add(c ActiveConnection) error {
|
|||
if val := conns.perUserConns[username]; val >= maxSessions {
|
||||
return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions)
|
||||
}
|
||||
if val := conns.transfers.getTotalFrom(username); val >= maxSessions {
|
||||
return fmt.Errorf("too many open transfers: %d/%d", val, maxSessions)
|
||||
}
|
||||
}
|
||||
conns.addUserConnection(username)
|
||||
}
|
||||
|
@ -1016,7 +1054,7 @@ func (conns *ActiveConnections) Remove(connectionID string) {
|
|||
metric.UpdateActiveConnectionsSize(lastIdx)
|
||||
logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %q, remote address %q close fs error: %v, num open connections: %d",
|
||||
conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx)
|
||||
if conn.GetProtocol() == ProtocolFTP && conn.GetUsername() == "" && !util.Contains(ftpLoginCommands, conn.GetCommand()) {
|
||||
if conn.GetProtocol() == ProtocolFTP && conn.GetUsername() == "" && !slices.Contains(ftpLoginCommands, conn.GetCommand()) {
|
||||
ip := util.GetIPFromRemoteAddress(conn.GetRemoteAddress())
|
||||
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTried, ProtocolFTP,
|
||||
dataprovider.ErrNoAuthTried.Error())
|
||||
|
@ -1215,6 +1253,35 @@ func (conns *ActiveConnections) GetClientConnections() int32 {
|
|||
return conns.clients.getTotal()
|
||||
}
|
||||
|
||||
// GetTotalTransfers returns the total number of active transfers
|
||||
func (conns *ActiveConnections) GetTotalTransfers() int32 {
|
||||
return conns.transfers.getTotal()
|
||||
}
|
||||
|
||||
// IsNewTransferAllowed returns an error if the maximum number of concurrent allowed
|
||||
// transfers is exceeded
|
||||
func (conns *ActiveConnections) IsNewTransferAllowed(username string) error {
|
||||
if isShuttingDown.Load() {
|
||||
return ErrShuttingDown
|
||||
}
|
||||
if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 {
|
||||
return nil
|
||||
}
|
||||
if Config.MaxPerHostConnections > 0 {
|
||||
if transfers := conns.transfers.getTotalFrom(username); transfers >= Config.MaxPerHostConnections {
|
||||
logger.Info(logSender, "", "active transfers from user %q: %d/%d", username, transfers, Config.MaxPerHostConnections)
|
||||
return ErrConnectionDenied
|
||||
}
|
||||
}
|
||||
if Config.MaxTotalConnections > 0 {
|
||||
if transfers := conns.transfers.getTotal(); transfers >= int32(Config.MaxTotalConnections) {
|
||||
logger.Info(logSender, "", "active transfers %d/%d", transfers, Config.MaxTotalConnections)
|
||||
return ErrConnectionDenied
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsNewConnectionAllowed returns an error if the maximum number of concurrent allowed
|
||||
// connections is exceeded or a whitelist is defined and the specified ipAddr is not listed
|
||||
// or the service is shutting down
|
||||
|
@ -1255,7 +1322,11 @@ func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr, protocol string)
|
|||
}
|
||||
|
||||
// on a single SFTP connection we could have multiple SFTP channels or commands
|
||||
// so we check the estabilished connections too
|
||||
// so we check the estabilished connections and active uploads too
|
||||
if transfers := conns.transfers.getTotal(); transfers >= int32(Config.MaxTotalConnections) {
|
||||
logger.Info(logSender, "", "active transfers %d/%d", transfers, Config.MaxTotalConnections)
|
||||
return ErrConnectionDenied
|
||||
}
|
||||
|
||||
conns.RLock()
|
||||
defer conns.RUnlock()
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -652,11 +653,17 @@ func TestMaxConnections(t *testing.T) {
|
|||
|
||||
ipAddr := "192.168.7.8"
|
||||
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolFTP))
|
||||
assert.NoError(t, Connections.IsNewTransferAllowed(userTestUsername))
|
||||
|
||||
Config.MaxTotalConnections = 1
|
||||
Config.MaxPerHostConnections = perHost
|
||||
|
||||
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolHTTP))
|
||||
assert.NoError(t, Connections.IsNewTransferAllowed(userTestUsername))
|
||||
isShuttingDown.Store(true)
|
||||
assert.ErrorIs(t, Connections.IsNewTransferAllowed(userTestUsername), ErrShuttingDown)
|
||||
isShuttingDown.Store(false)
|
||||
|
||||
c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{})
|
||||
fakeConn := &fakeConnection{
|
||||
BaseConnection: c,
|
||||
|
@ -665,6 +672,10 @@ func TestMaxConnections(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Len(t, Connections.GetStats(""), 1)
|
||||
assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH))
|
||||
Connections.transfers.add(userTestUsername)
|
||||
assert.Error(t, Connections.IsNewTransferAllowed(userTestUsername))
|
||||
Connections.transfers.remove(userTestUsername)
|
||||
assert.Equal(t, int32(0), Connections.GetTotalTransfers())
|
||||
|
||||
res := Connections.Close(fakeConn.GetID(), "")
|
||||
assert.True(t, res)
|
||||
|
@ -676,6 +687,9 @@ func TestMaxConnections(t *testing.T) {
|
|||
assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH))
|
||||
Connections.RemoveClientConnection(ipAddr)
|
||||
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolWebDAV))
|
||||
Connections.transfers.add(userTestUsername)
|
||||
assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH))
|
||||
Connections.transfers.remove(userTestUsername)
|
||||
Connections.RemoveClientConnection(ipAddr)
|
||||
|
||||
Config.MaxTotalConnections = oldValue
|
||||
|
@ -1144,13 +1158,17 @@ func TestProxyProtocolVersion(t *testing.T) {
|
|||
assert.Contains(t, err.Error(), "proxy protocol not configured")
|
||||
}
|
||||
c.ProxyProtocol = 1
|
||||
proxyListener, err := c.GetProxyListener(nil)
|
||||
listener, err := c.GetProxyListener(nil)
|
||||
assert.NoError(t, err)
|
||||
proxyListener, ok := listener.(*proxyproto.Listener)
|
||||
require.True(t, ok)
|
||||
assert.NotNil(t, proxyListener.ConnPolicy)
|
||||
|
||||
c.ProxyProtocol = 2
|
||||
proxyListener, err = c.GetProxyListener(nil)
|
||||
listener, err = c.GetProxyListener(nil)
|
||||
assert.NoError(t, err)
|
||||
proxyListener, ok = listener.(*proxyproto.Listener)
|
||||
require.True(t, ok)
|
||||
assert.NotNil(t, proxyListener.ConnPolicy)
|
||||
}
|
||||
|
||||
|
@ -1281,8 +1299,8 @@ func TestFolderCopy(t *testing.T) {
|
|||
folder.ID = 2
|
||||
folder.Users = []string{"user3"}
|
||||
require.Len(t, folderCopy.Users, 2)
|
||||
require.True(t, util.Contains(folderCopy.Users, "user1"))
|
||||
require.True(t, util.Contains(folderCopy.Users, "user2"))
|
||||
require.True(t, slices.Contains(folderCopy.Users, "user1"))
|
||||
require.True(t, slices.Contains(folderCopy.Users, "user2"))
|
||||
require.Equal(t, int64(1), folderCopy.ID)
|
||||
require.Equal(t, folder.Name, folderCopy.Name)
|
||||
require.Equal(t, folder.MappedPath, folderCopy.MappedPath)
|
||||
|
@ -1298,7 +1316,7 @@ func TestFolderCopy(t *testing.T) {
|
|||
folderCopy = folder.GetACopy()
|
||||
folder.FsConfig.CryptConfig.Passphrase = kms.NewEmptySecret()
|
||||
require.Len(t, folderCopy.Users, 1)
|
||||
require.True(t, util.Contains(folderCopy.Users, "user3"))
|
||||
require.True(t, slices.Contains(folderCopy.Users, "user3"))
|
||||
require.Equal(t, int64(2), folderCopy.ID)
|
||||
require.Equal(t, folder.Name, folderCopy.Name)
|
||||
require.Equal(t, folder.MappedPath, folderCopy.MappedPath)
|
||||
|
|
|
@ -63,7 +63,7 @@ type BaseConnection struct {
|
|||
// NewBaseConnection returns a new BaseConnection
|
||||
func NewBaseConnection(id, protocol, localAddr, remoteAddr string, user dataprovider.User) *BaseConnection {
|
||||
connID := id
|
||||
if util.Contains(supportedProtocols, protocol) {
|
||||
if slices.Contains(supportedProtocols, protocol) {
|
||||
connID = fmt.Sprintf("%s_%s", protocol, id)
|
||||
}
|
||||
user.UploadBandwidth, user.DownloadBandwidth = user.GetBandwidthForIP(util.GetIPFromRemoteAddress(remoteAddr), connID)
|
||||
|
@ -132,7 +132,7 @@ func (c *BaseConnection) GetRemoteIP() string {
|
|||
// SetProtocol sets the protocol for this connection
|
||||
func (c *BaseConnection) SetProtocol(protocol string) {
|
||||
c.protocol = protocol
|
||||
if util.Contains(supportedProtocols, c.protocol) {
|
||||
if slices.Contains(supportedProtocols, c.protocol) {
|
||||
c.ID = fmt.Sprintf("%v_%v", c.protocol, c.ID)
|
||||
}
|
||||
}
|
||||
|
@ -159,6 +159,8 @@ func (c *BaseConnection) CloseFS() error {
|
|||
|
||||
// AddTransfer associates a new transfer to this connection
|
||||
func (c *BaseConnection) AddTransfer(t ActiveTransfer) {
|
||||
Connections.transfers.add(c.User.Username)
|
||||
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
|
@ -190,6 +192,8 @@ func (c *BaseConnection) AddTransfer(t ActiveTransfer) {
|
|||
|
||||
// RemoveTransfer removes the specified transfer from the active ones
|
||||
func (c *BaseConnection) RemoveTransfer(t ActiveTransfer) {
|
||||
Connections.transfers.remove(c.User.Username)
|
||||
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
|
@ -449,10 +453,7 @@ func (c *BaseConnection) RemoveFile(fs vfs.Fs, fsPath, virtualPath string, info
|
|||
if updateQuota && info.Mode()&os.ModeSymlink == 0 {
|
||||
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath))
|
||||
if err == nil {
|
||||
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, -1, -size, false) //nolint:errcheck
|
||||
if vfolder.IsIncludedInUserQuota() {
|
||||
dataprovider.UpdateUserQuota(&c.User, -1, -size, false) //nolint:errcheck
|
||||
}
|
||||
dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, -1, -size, false)
|
||||
} else {
|
||||
dataprovider.UpdateUserQuota(&c.User, -1, -size, false) //nolint:errcheck
|
||||
}
|
||||
|
@ -622,7 +623,7 @@ func (c *BaseConnection) checkCopy(srcInfo, dstInfo os.FileInfo, virtualSource,
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *BaseConnection) copyFile(virtualSourcePath, virtualTargetPath string, srcSize int64) error {
|
||||
func (c *BaseConnection) copyFile(virtualSourcePath, virtualTargetPath string, srcInfo os.FileInfo) error {
|
||||
if !c.User.HasPerm(dataprovider.PermCopy, virtualSourcePath) || !c.User.HasPerm(dataprovider.PermCopy, virtualTargetPath) {
|
||||
return c.GetPermissionDeniedError()
|
||||
}
|
||||
|
@ -640,12 +641,12 @@ func (c *BaseConnection) copyFile(virtualSourcePath, virtualTargetPath string, s
|
|||
return err
|
||||
}
|
||||
startTime := time.Now()
|
||||
numFiles, sizeDiff, err := copier.CopyFile(fsSourcePath, fsTargetPath, srcSize)
|
||||
numFiles, sizeDiff, err := copier.CopyFile(fsSourcePath, fsTargetPath, srcInfo)
|
||||
elapsed := time.Since(startTime).Nanoseconds() / 1000000
|
||||
updateUserQuotaAfterFileWrite(c, virtualTargetPath, numFiles, sizeDiff)
|
||||
logger.CommandLog(copyLogSender, fsSourcePath, fsTargetPath, c.User.Username, "", c.ID, c.protocol, -1, -1,
|
||||
"", "", "", srcSize, c.localAddr, c.remoteAddr, elapsed)
|
||||
ExecuteActionNotification(c, operationCopy, fsSourcePath, virtualSourcePath, fsTargetPath, virtualTargetPath, "", srcSize, err, elapsed, nil) //nolint:errcheck
|
||||
"", "", "", srcInfo.Size(), c.localAddr, c.remoteAddr, elapsed)
|
||||
ExecuteActionNotification(c, operationCopy, fsSourcePath, virtualSourcePath, fsTargetPath, virtualTargetPath, "", srcInfo.Size(), err, elapsed, nil) //nolint:errcheck
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -657,7 +658,7 @@ func (c *BaseConnection) copyFile(virtualSourcePath, virtualTargetPath string, s
|
|||
defer rCancelFn()
|
||||
defer reader.Close()
|
||||
|
||||
writer, numFiles, truncatedSize, wCancelFn, err := getFileWriter(c, virtualTargetPath, srcSize)
|
||||
writer, numFiles, truncatedSize, wCancelFn, err := getFileWriter(c, virtualTargetPath, srcInfo.Size())
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get writer for path %q: %w", virtualTargetPath, err)
|
||||
}
|
||||
|
@ -708,7 +709,7 @@ func (c *BaseConnection) doRecursiveCopy(virtualSourcePath, virtualTargetPath st
|
|||
return nil
|
||||
}
|
||||
|
||||
return c.copyFile(virtualSourcePath, virtualTargetPath, srcInfo.Size())
|
||||
return c.copyFile(virtualSourcePath, virtualTargetPath, srcInfo)
|
||||
}
|
||||
|
||||
func (c *BaseConnection) recursiveCopyEntries(virtualSourcePath, virtualTargetPath string, entries []os.FileInfo, recursion int) error {
|
||||
|
@ -789,10 +790,12 @@ func (c *BaseConnection) Copy(virtualSourcePath, virtualTargetPath string) error
|
|||
|
||||
// Rename renames (moves) virtualSourcePath to virtualTargetPath
|
||||
func (c *BaseConnection) Rename(virtualSourcePath, virtualTargetPath string) error {
|
||||
return c.renameInternal(virtualSourcePath, virtualTargetPath, false)
|
||||
return c.renameInternal(virtualSourcePath, virtualTargetPath, false, vfs.CheckParentDir)
|
||||
}
|
||||
|
||||
func (c *BaseConnection) renameInternal(virtualSourcePath, virtualTargetPath string, checkParentDestination bool) error { //nolint:gocyclo
|
||||
func (c *BaseConnection) renameInternal(virtualSourcePath, virtualTargetPath string, //nolint:gocyclo
|
||||
checkParentDestination bool, checks int,
|
||||
) error {
|
||||
if virtualSourcePath == virtualTargetPath {
|
||||
return fmt.Errorf("the rename source and target cannot be the same: %w", c.GetOpUnsupportedError())
|
||||
}
|
||||
|
@ -839,7 +842,7 @@ func (c *BaseConnection) renameInternal(virtualSourcePath, virtualTargetPath str
|
|||
return err
|
||||
}
|
||||
}
|
||||
if !c.hasSpaceForRename(fsSrc, virtualSourcePath, virtualTargetPath, initialSize, fsSourcePath) {
|
||||
if !c.hasSpaceForRename(fsSrc, virtualSourcePath, virtualTargetPath, initialSize, fsSourcePath, srcInfo) {
|
||||
c.Log(logger.LevelInfo, "denying cross rename due to space limit")
|
||||
return c.GetGenericError(ErrQuotaExceeded)
|
||||
}
|
||||
|
@ -850,7 +853,7 @@ func (c *BaseConnection) renameInternal(virtualSourcePath, virtualTargetPath str
|
|||
defer close(done)
|
||||
go keepConnectionAlive(c, done, 2*time.Minute)
|
||||
|
||||
files, size, err := fsDst.Rename(fsSourcePath, fsTargetPath)
|
||||
files, size, err := fsDst.Rename(fsSourcePath, fsTargetPath, checks)
|
||||
if err != nil {
|
||||
c.Log(logger.LevelError, "failed to rename %q -> %q: %+v", fsSourcePath, fsTargetPath, err)
|
||||
return c.GetFsError(fsSrc, err)
|
||||
|
@ -1115,10 +1118,7 @@ func (c *BaseConnection) truncateFile(fs vfs.Fs, fsPath, virtualPath string, siz
|
|||
sizeDiff := initialSize - size
|
||||
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath))
|
||||
if err == nil {
|
||||
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -sizeDiff, false) //nolint:errcheck
|
||||
if vfolder.IsIncludedInUserQuota() {
|
||||
dataprovider.UpdateUserQuota(&c.User, 0, -sizeDiff, false) //nolint:errcheck
|
||||
}
|
||||
dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -sizeDiff, false)
|
||||
} else {
|
||||
dataprovider.UpdateUserQuota(&c.User, 0, -sizeDiff, false) //nolint:errcheck
|
||||
}
|
||||
|
@ -1127,11 +1127,11 @@ func (c *BaseConnection) truncateFile(fs vfs.Fs, fsPath, virtualPath string, siz
|
|||
}
|
||||
|
||||
func (c *BaseConnection) checkRecursiveRenameDirPermissions(fsSrc, fsDst vfs.Fs, sourcePath, targetPath,
|
||||
virtualSourcePath, virtualTargetPath string, fi os.FileInfo,
|
||||
virtualSourcePath, virtualTargetPath string, srcInfo os.FileInfo,
|
||||
) error {
|
||||
if !c.User.HasPermissionsInside(virtualSourcePath) &&
|
||||
!c.User.HasPermissionsInside(virtualTargetPath) {
|
||||
if !c.isRenamePermitted(fsSrc, fsDst, sourcePath, targetPath, virtualSourcePath, virtualTargetPath, fi) {
|
||||
if !c.isRenamePermitted(fsSrc, fsDst, sourcePath, targetPath, virtualSourcePath, virtualTargetPath, srcInfo) {
|
||||
c.Log(logger.LevelInfo, "rename %q -> %q is not allowed, virtual destination path: %q",
|
||||
sourcePath, targetPath, virtualTargetPath)
|
||||
return c.GetPermissionDeniedError()
|
||||
|
@ -1191,7 +1191,7 @@ func (c *BaseConnection) hasRenamePerms(virtualSourcePath, virtualTargetPath str
|
|||
}
|
||||
|
||||
func (c *BaseConnection) checkFolderRename(fsSrc, fsDst vfs.Fs, fsSourcePath, fsTargetPath, virtualSourcePath,
|
||||
virtualTargetPath string, fi os.FileInfo) error {
|
||||
virtualTargetPath string, srcInfo os.FileInfo) error {
|
||||
if util.IsDirOverlapped(virtualSourcePath, virtualTargetPath, true, "/") {
|
||||
c.Log(logger.LevelDebug, "renaming the folder %q->%q is not supported: nested folders",
|
||||
virtualSourcePath, virtualTargetPath)
|
||||
|
@ -1215,7 +1215,7 @@ func (c *BaseConnection) checkFolderRename(fsSrc, fsDst vfs.Fs, fsSourcePath, fs
|
|||
return fmt.Errorf("folder %q has virtual folders inside it: %w", virtualTargetPath, c.GetOpUnsupportedError())
|
||||
}
|
||||
if err := c.checkRecursiveRenameDirPermissions(fsSrc, fsDst, fsSourcePath, fsTargetPath,
|
||||
virtualSourcePath, virtualTargetPath, fi); err != nil {
|
||||
virtualSourcePath, virtualTargetPath, srcInfo); err != nil {
|
||||
c.Log(logger.LevelDebug, "error checking recursive permissions before renaming %q: %+v", fsSourcePath, err)
|
||||
return err
|
||||
}
|
||||
|
@ -1223,7 +1223,7 @@ func (c *BaseConnection) checkFolderRename(fsSrc, fsDst vfs.Fs, fsSourcePath, fs
|
|||
}
|
||||
|
||||
func (c *BaseConnection) isRenamePermitted(fsSrc, fsDst vfs.Fs, fsSourcePath, fsTargetPath, virtualSourcePath,
|
||||
virtualTargetPath string, fi os.FileInfo,
|
||||
virtualTargetPath string, srcInfo os.FileInfo,
|
||||
) bool {
|
||||
if !c.IsSameResource(virtualSourcePath, virtualTargetPath) {
|
||||
c.Log(logger.LevelInfo, "rename %q->%q is not allowed: the paths must be on the same resource",
|
||||
|
@ -1253,11 +1253,11 @@ func (c *BaseConnection) isRenamePermitted(fsSrc, fsDst vfs.Fs, fsSourcePath, fs
|
|||
virtualTargetPath)
|
||||
return false
|
||||
}
|
||||
return c.hasRenamePerms(virtualSourcePath, virtualTargetPath, fi)
|
||||
return c.hasRenamePerms(virtualSourcePath, virtualTargetPath, srcInfo)
|
||||
}
|
||||
|
||||
func (c *BaseConnection) hasSpaceForRename(fs vfs.Fs, virtualSourcePath, virtualTargetPath string, initialSize int64,
|
||||
fsSourcePath string) bool {
|
||||
sourcePath string, srcInfo os.FileInfo) bool {
|
||||
if dataprovider.GetQuotaTracking() == 0 {
|
||||
return true
|
||||
}
|
||||
|
@ -1287,30 +1287,28 @@ func (c *BaseConnection) hasSpaceForRename(fs vfs.Fs, virtualSourcePath, virtual
|
|||
// no quota restrictions
|
||||
return true
|
||||
}
|
||||
return c.hasSpaceForCrossRename(fs, quotaResult, initialSize, fsSourcePath)
|
||||
return c.hasSpaceForCrossRename(fs, quotaResult, initialSize, sourcePath, srcInfo)
|
||||
}
|
||||
|
||||
// hasSpaceForCrossRename checks the quota after a rename between different folders
|
||||
func (c *BaseConnection) hasSpaceForCrossRename(fs vfs.Fs, quotaResult vfs.QuotaCheckResult, initialSize int64, sourcePath string) bool {
|
||||
func (c *BaseConnection) hasSpaceForCrossRename(fs vfs.Fs, quotaResult vfs.QuotaCheckResult, initialSize int64,
|
||||
sourcePath string, srcInfo os.FileInfo,
|
||||
) bool {
|
||||
if !quotaResult.HasSpace && initialSize == -1 {
|
||||
// we are over quota and this is not a file replace
|
||||
return false
|
||||
}
|
||||
fi, err := fs.Lstat(sourcePath)
|
||||
if err != nil {
|
||||
c.Log(logger.LevelError, "cross rename denied, stat error for path %q: %v", sourcePath, err)
|
||||
return false
|
||||
}
|
||||
var sizeDiff int64
|
||||
var filesDiff int
|
||||
if fi.Mode().IsRegular() {
|
||||
sizeDiff = fi.Size()
|
||||
var err error
|
||||
if srcInfo.Mode().IsRegular() {
|
||||
sizeDiff = srcInfo.Size()
|
||||
filesDiff = 1
|
||||
if initialSize != -1 {
|
||||
sizeDiff -= initialSize
|
||||
filesDiff = 0
|
||||
}
|
||||
} else if fi.IsDir() {
|
||||
} else if srcInfo.IsDir() {
|
||||
filesDiff, sizeDiff, err = fs.GetDirSize(sourcePath)
|
||||
if err != nil {
|
||||
c.Log(logger.LevelError, "cross rename denied, error getting size for directory %q: %v", sourcePath, err)
|
||||
|
@ -1337,7 +1335,7 @@ func (c *BaseConnection) hasSpaceForCrossRename(fs vfs.Fs, quotaResult vfs.Quota
|
|||
}
|
||||
if quotaResult.QuotaSize > 0 {
|
||||
remainingSize := quotaResult.GetRemainingSize()
|
||||
c.Log(logger.LevelDebug, "cross rename, source %q remaining size %d to add %d", sourcePath,
|
||||
c.Log(logger.LevelDebug, "cross rename, source %q remaining size %d to add %d", srcInfo.Name(),
|
||||
remainingSize, sizeDiff)
|
||||
if remainingSize < sizeDiff {
|
||||
return false
|
||||
|
@ -1512,61 +1510,40 @@ func (c *BaseConnection) updateQuotaMoveBetweenVFolders(sourceFolder, dstFolder
|
|||
if sourceFolder.Name == dstFolder.Name {
|
||||
// both files are inside the same virtual folder
|
||||
if initialSize != -1 {
|
||||
dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, -numFiles, -initialSize, false) //nolint:errcheck
|
||||
if dstFolder.IsIncludedInUserQuota() {
|
||||
dataprovider.UpdateUserQuota(&c.User, -numFiles, -initialSize, false) //nolint:errcheck
|
||||
}
|
||||
dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, -numFiles, -initialSize, false)
|
||||
}
|
||||
return
|
||||
}
|
||||
// files are inside different virtual folders
|
||||
dataprovider.UpdateVirtualFolderQuota(&sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck
|
||||
if sourceFolder.IsIncludedInUserQuota() {
|
||||
dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck
|
||||
}
|
||||
dataprovider.UpdateUserFolderQuota(sourceFolder, &c.User, -numFiles, -filesSize, false)
|
||||
if initialSize == -1 {
|
||||
dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck
|
||||
if dstFolder.IsIncludedInUserQuota() {
|
||||
dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck
|
||||
}
|
||||
} else {
|
||||
// we cannot have a directory here, initialSize != -1 only for files
|
||||
dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck
|
||||
if dstFolder.IsIncludedInUserQuota() {
|
||||
dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck
|
||||
}
|
||||
dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, numFiles, filesSize, false)
|
||||
return
|
||||
}
|
||||
// we cannot have a directory here, initialSize != -1 only for files
|
||||
dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, 0, filesSize-initialSize, false)
|
||||
}
|
||||
|
||||
func (c *BaseConnection) updateQuotaMoveFromVFolder(sourceFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) {
|
||||
// move between a virtual folder and the user home dir
|
||||
dataprovider.UpdateVirtualFolderQuota(&sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck
|
||||
if sourceFolder.IsIncludedInUserQuota() {
|
||||
dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck
|
||||
}
|
||||
dataprovider.UpdateUserFolderQuota(sourceFolder, &c.User, -numFiles, -filesSize, false)
|
||||
if initialSize == -1 {
|
||||
dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck
|
||||
} else {
|
||||
// we cannot have a directory here, initialSize != -1 only for files
|
||||
dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
// we cannot have a directory here, initialSize != -1 only for files
|
||||
dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck
|
||||
}
|
||||
|
||||
func (c *BaseConnection) updateQuotaMoveToVFolder(dstFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) {
|
||||
// move between the user home dir and a virtual folder
|
||||
dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck
|
||||
if initialSize == -1 {
|
||||
dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck
|
||||
if dstFolder.IsIncludedInUserQuota() {
|
||||
dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck
|
||||
}
|
||||
} else {
|
||||
// we cannot have a directory here, initialSize != -1 only for files
|
||||
dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck
|
||||
if dstFolder.IsIncludedInUserQuota() {
|
||||
dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck
|
||||
}
|
||||
dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, numFiles, filesSize, false)
|
||||
return
|
||||
}
|
||||
// we cannot have a directory here, initialSize != -1 only for files
|
||||
dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, 0, filesSize-initialSize, false)
|
||||
}
|
||||
|
||||
func (c *BaseConnection) updateQuotaAfterRename(fs vfs.Fs, virtualSourcePath, virtualTargetPath, targetPath string,
|
||||
|
@ -1815,8 +1792,8 @@ type DirListerAt struct {
|
|||
lister vfs.DirLister
|
||||
}
|
||||
|
||||
// Add adds the given os.FileInfo to the internal cache
|
||||
func (l *DirListerAt) Add(fi os.FileInfo) {
|
||||
// Prepend adds the given os.FileInfo as first element of the internal cache
|
||||
func (l *DirListerAt) Prepend(fi os.FileInfo) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -197,25 +198,24 @@ func TestRecursiveRenameWalkError(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCrossRenameFsErrors(t *testing.T) {
|
||||
if runtime.GOOS == osWindows {
|
||||
t.Skip("this test is not available on Windows")
|
||||
}
|
||||
fs := vfs.NewOsFs("", os.TempDir(), "", nil)
|
||||
conn := NewBaseConnection("", ProtocolWebDAV, "", "", dataprovider.User{})
|
||||
res := conn.hasSpaceForCrossRename(fs, vfs.QuotaCheckResult{}, 1, "missingsource")
|
||||
dirPath := filepath.Join(os.TempDir(), "d")
|
||||
err := os.Mkdir(dirPath, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = os.Chmod(dirPath, 0001)
|
||||
assert.NoError(t, err)
|
||||
srcInfo := vfs.NewFileInfo(filepath.Base(dirPath), true, 0, time.Now(), false)
|
||||
res := conn.hasSpaceForCrossRename(fs, vfs.QuotaCheckResult{}, 1, dirPath, srcInfo)
|
||||
assert.False(t, res)
|
||||
if runtime.GOOS != osWindows {
|
||||
dirPath := filepath.Join(os.TempDir(), "d")
|
||||
err := os.Mkdir(dirPath, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = os.Chmod(dirPath, 0001)
|
||||
assert.NoError(t, err)
|
||||
|
||||
res = conn.hasSpaceForCrossRename(fs, vfs.QuotaCheckResult{}, 1, dirPath)
|
||||
assert.False(t, res)
|
||||
|
||||
err = os.Chmod(dirPath, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = os.Remove(dirPath)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
err = os.Chmod(dirPath, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = os.Remove(dirPath)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestRenameVirtualFolders(t *testing.T) {
|
||||
|
@ -389,7 +389,7 @@ func TestErrorsMapping(t *testing.T) {
|
|||
err := conn.GetFsError(fs, os.ErrNotExist)
|
||||
if protocol == ProtocolSFTP {
|
||||
assert.ErrorIs(t, err, sftp.ErrSSHFxNoSuchFile)
|
||||
} else if util.Contains(osErrorsProtocols, protocol) {
|
||||
} else if slices.Contains(osErrorsProtocols, protocol) {
|
||||
assert.EqualError(t, err, os.ErrNotExist.Error())
|
||||
} else {
|
||||
assert.EqualError(t, err, ErrNotExist.Error())
|
||||
|
@ -1134,8 +1134,8 @@ func TestListerAt(t *testing.T) {
|
|||
require.Equal(t, 0, n)
|
||||
lister, err = conn.ListDir("/")
|
||||
require.NoError(t, err)
|
||||
lister.Add(vfs.NewFileInfo("..", true, 0, time.Unix(0, 0), false))
|
||||
lister.Add(vfs.NewFileInfo(".", true, 0, time.Unix(0, 0), false))
|
||||
lister.Prepend(vfs.NewFileInfo("..", true, 0, time.Unix(0, 0), false))
|
||||
lister.Prepend(vfs.NewFileInfo(".", true, 0, time.Unix(0, 0), false))
|
||||
files = make([]os.FileInfo, 1)
|
||||
n, err = lister.ListAt(files, 0)
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -277,7 +277,8 @@ func (r *eventRulesContainer) addUpdateRuleInternal(rule dataprovider.EventRule)
|
|||
func (r *eventRulesContainer) loadRules() {
|
||||
eventManagerLog(logger.LevelDebug, "loading updated rules")
|
||||
modTime := util.GetTimeAsMsSinceEpoch(time.Now())
|
||||
rules, err := dataprovider.GetRecentlyUpdatedRules(r.getLastLoadTime())
|
||||
lastLoadTime := r.getLastLoadTime()
|
||||
rules, err := dataprovider.GetRecentlyUpdatedRules(lastLoadTime)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to load event rules: %v", err)
|
||||
return
|
||||
|
@ -313,7 +314,7 @@ func (*eventRulesContainer) checkIPDLoginEventMatch(conditions *dataprovider.Eve
|
|||
}
|
||||
|
||||
func (*eventRulesContainer) checkProviderEventMatch(conditions *dataprovider.EventConditions, params *EventParams) bool {
|
||||
if !util.Contains(conditions.ProviderEvents, params.Event) {
|
||||
if !slices.Contains(conditions.ProviderEvents, params.Event) {
|
||||
return false
|
||||
}
|
||||
if !checkEventConditionPatterns(params.Name, conditions.Options.Names) {
|
||||
|
@ -325,14 +326,14 @@ func (*eventRulesContainer) checkProviderEventMatch(conditions *dataprovider.Eve
|
|||
if !checkEventConditionPatterns(params.Role, conditions.Options.RoleNames) {
|
||||
return false
|
||||
}
|
||||
if len(conditions.Options.ProviderObjects) > 0 && !util.Contains(conditions.Options.ProviderObjects, params.ObjectType) {
|
||||
if len(conditions.Options.ProviderObjects) > 0 && !slices.Contains(conditions.Options.ProviderObjects, params.ObjectType) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (*eventRulesContainer) checkFsEventMatch(conditions *dataprovider.EventConditions, params *EventParams) bool {
|
||||
if !util.Contains(conditions.FsEvents, params.Event) {
|
||||
if !slices.Contains(conditions.FsEvents, params.Event) {
|
||||
return false
|
||||
}
|
||||
if !checkEventConditionPatterns(params.Name, conditions.Options.Names) {
|
||||
|
@ -347,7 +348,7 @@ func (*eventRulesContainer) checkFsEventMatch(conditions *dataprovider.EventCond
|
|||
if !checkEventConditionPatterns(params.VirtualPath, conditions.Options.FsPaths) {
|
||||
return false
|
||||
}
|
||||
if len(conditions.Options.Protocols) > 0 && !util.Contains(conditions.Options.Protocols, params.Protocol) {
|
||||
if len(conditions.Options.Protocols) > 0 && !slices.Contains(conditions.Options.Protocols, params.Protocol) {
|
||||
return false
|
||||
}
|
||||
if slices.Contains(fsEventsWithSize, params.Event) {
|
||||
|
@ -782,6 +783,12 @@ func (*EventParams) getStringReplacement(val string, jsonEscaped bool) string {
|
|||
}
|
||||
|
||||
func (p *EventParams) getStringReplacements(addObjectData, jsonEscaped bool) []string {
|
||||
var dateTimeString string
|
||||
if Config.TZ == "local" {
|
||||
dateTimeString = p.Timestamp.Local().Format(dateTimeMillisFormat)
|
||||
} else {
|
||||
dateTimeString = p.Timestamp.UTC().Format(dateTimeMillisFormat)
|
||||
}
|
||||
replacements := []string{
|
||||
"{{Name}}", p.getStringReplacement(p.Name, jsonEscaped),
|
||||
"{{Event}}", p.Event,
|
||||
|
@ -792,6 +799,7 @@ func (p *EventParams) getStringReplacements(addObjectData, jsonEscaped bool) []s
|
|||
"{{VirtualTargetPath}}", p.getStringReplacement(p.VirtualTargetPath, jsonEscaped),
|
||||
"{{FsTargetPath}}", p.getStringReplacement(p.FsTargetPath, jsonEscaped),
|
||||
"{{ObjectName}}", p.getStringReplacement(p.ObjectName, jsonEscaped),
|
||||
"{{ObjectBaseName}}", p.getStringReplacement(strings.TrimSuffix(p.ObjectName, p.Extension), jsonEscaped),
|
||||
"{{ObjectType}}", p.ObjectType,
|
||||
"{{FileSize}}", strconv.FormatInt(p.FileSize, 10),
|
||||
"{{Elapsed}}", strconv.FormatInt(p.Elapsed, 10),
|
||||
|
@ -800,7 +808,7 @@ func (p *EventParams) getStringReplacements(addObjectData, jsonEscaped bool) []s
|
|||
"{{Role}}", p.getStringReplacement(p.Role, jsonEscaped),
|
||||
"{{Email}}", p.getStringReplacement(p.Email, jsonEscaped),
|
||||
"{{Timestamp}}", strconv.FormatInt(p.Timestamp.UnixNano(), 10),
|
||||
"{{DateTime}}", p.Timestamp.UTC().Format(dateTimeMillisFormat),
|
||||
"{{DateTime}}", dateTimeString,
|
||||
"{{StatusString}}", p.getStatusString(),
|
||||
"{{UID}}", p.getStringReplacement(p.UID, jsonEscaped),
|
||||
"{{Ext}}", p.getStringReplacement(p.Extension, jsonEscaped),
|
||||
|
@ -919,10 +927,7 @@ func updateUserQuotaAfterFileWrite(conn *BaseConnection, virtualPath string, num
|
|||
dataprovider.UpdateUserQuota(&conn.User, numFiles, fileSize, false) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, numFiles, fileSize, false) //nolint:errcheck
|
||||
if vfolder.IsIncludedInUserQuota() {
|
||||
dataprovider.UpdateUserQuota(&conn.User, numFiles, fileSize, false) //nolint:errcheck
|
||||
}
|
||||
dataprovider.UpdateUserFolderQuota(&vfolder, &conn.User, numFiles, fileSize, false)
|
||||
}
|
||||
|
||||
func checkWriterPermsAndQuota(conn *BaseConnection, virtualPath string, numFiles int, expectedSize, truncatedSize int64) error {
|
||||
|
@ -1773,7 +1778,7 @@ func executeMkdirFsRuleAction(dirs []string, replacer *strings.Replacer,
|
|||
return nil
|
||||
}
|
||||
|
||||
func executeRenameFsActionForUser(renames []dataprovider.KeyValue, replacer *strings.Replacer,
|
||||
func executeRenameFsActionForUser(renames []dataprovider.RenameConfig, replacer *strings.Replacer,
|
||||
user dataprovider.User,
|
||||
) error {
|
||||
user, err := getUserForEventAction(user)
|
||||
|
@ -1792,7 +1797,11 @@ func executeRenameFsActionForUser(renames []dataprovider.KeyValue, replacer *str
|
|||
for _, item := range renames {
|
||||
source := util.CleanPath(replaceWithReplacer(item.Key, replacer))
|
||||
target := util.CleanPath(replaceWithReplacer(item.Value, replacer))
|
||||
if err = conn.renameInternal(source, target, true); err != nil {
|
||||
checks := 0
|
||||
if item.UpdateModTime {
|
||||
checks += vfs.CheckUpdateModTime
|
||||
}
|
||||
if err = conn.renameInternal(source, target, true, checks); err != nil {
|
||||
return fmt.Errorf("unable to rename %q->%q, user %q: %w", source, target, user.Username, err)
|
||||
}
|
||||
eventManagerLog(logger.LevelDebug, "rename %q->%q ok, user %q", source, target, user.Username)
|
||||
|
@ -1858,7 +1867,7 @@ func executeExistFsActionForUser(exist []string, replacer *strings.Replacer,
|
|||
return nil
|
||||
}
|
||||
|
||||
func executeRenameFsRuleAction(renames []dataprovider.KeyValue, replacer *strings.Replacer,
|
||||
func executeRenameFsRuleAction(renames []dataprovider.RenameConfig, replacer *strings.Replacer,
|
||||
conditions dataprovider.ConditionOptions, params *EventParams,
|
||||
) error {
|
||||
users, err := params.getUsers()
|
||||
|
@ -2479,7 +2488,7 @@ func executePwdExpirationCheckForUser(user *dataprovider.User, config dataprovid
|
|||
}
|
||||
subject := "SFTPGo password expiration notification"
|
||||
startTime := time.Now()
|
||||
if err := smtp.SendEmail([]string{user.Email}, nil, subject, body.String(), smtp.EmailContentTypeTextHTML); err != nil {
|
||||
if err := smtp.SendEmail(user.GetEmailAddresses(), nil, subject, body.String(), smtp.EmailContentTypeTextHTML); err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to notify password expiration for user %s: %v, elapsed: %s",
|
||||
user.Username, err, time.Since(startTime))
|
||||
return err
|
||||
|
@ -2574,6 +2583,9 @@ func preserveUserProfile(user, newUser *dataprovider.User) {
|
|||
if user.Email != "" {
|
||||
newUser.Email = user.Email
|
||||
}
|
||||
if len(user.Filters.AdditionalEmails) > 0 {
|
||||
newUser.Filters.AdditionalEmails = user.Filters.AdditionalEmails
|
||||
}
|
||||
}
|
||||
if newUser.CanChangeAPIKeyAuth() {
|
||||
newUser.Filters.AllowAPIKeyAuth = user.Filters.AllowAPIKeyAuth
|
||||
|
|
|
@ -801,6 +801,9 @@ func TestEventManagerErrors(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDateTimePlaceholder(t *testing.T) {
|
||||
oldTZ := Config.TZ
|
||||
|
||||
Config.TZ = ""
|
||||
dateTime := time.Now()
|
||||
params := EventParams{
|
||||
Timestamp: dateTime,
|
||||
|
@ -809,6 +812,14 @@ func TestDateTimePlaceholder(t *testing.T) {
|
|||
r := strings.NewReplacer(replacements...)
|
||||
res := r.Replace("{{DateTime}}")
|
||||
assert.Equal(t, dateTime.UTC().Format(dateTimeMillisFormat), res)
|
||||
|
||||
Config.TZ = "local"
|
||||
replacements = params.getStringReplacements(false, false)
|
||||
r = strings.NewReplacer(replacements...)
|
||||
res = r.Replace("{{DateTime}}")
|
||||
assert.Equal(t, dateTime.Local().Format(dateTimeMillisFormat), res)
|
||||
|
||||
Config.TZ = oldTZ
|
||||
}
|
||||
|
||||
func TestEventRuleActions(t *testing.T) {
|
||||
|
@ -1177,10 +1188,12 @@ func TestEventRuleActions(t *testing.T) {
|
|||
action.Options = dataprovider.BaseEventActionOptions{
|
||||
FsConfig: dataprovider.EventActionFilesystemConfig{
|
||||
Type: dataprovider.FilesystemActionRename,
|
||||
Renames: []dataprovider.KeyValue{
|
||||
Renames: []dataprovider.RenameConfig{
|
||||
{
|
||||
Key: "/source",
|
||||
Value: "/target",
|
||||
KeyValue: dataprovider.KeyValue{
|
||||
Key: "/source",
|
||||
Value: "/target",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1405,6 +1418,7 @@ func TestIDPAccountCheckRule(t *testing.T) {
|
|||
// Update the profile attribute and make sure they are preserved
|
||||
user.Password = "secret"
|
||||
user.Email = "example@example.com"
|
||||
user.Filters.AdditionalEmails = []string{"alias@example.com"}
|
||||
user.Description = "some desc"
|
||||
user.Filters.TLSCerts = []string{serverCert}
|
||||
user.PublicKeys = []string{"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1"}
|
||||
|
@ -1419,6 +1433,7 @@ func TestIDPAccountCheckRule(t *testing.T) {
|
|||
assert.Len(t, user.PublicKeys, 1)
|
||||
assert.Len(t, user.Filters.TLSCerts, 1)
|
||||
assert.NotEmpty(t, user.Email)
|
||||
assert.Len(t, user.Filters.AdditionalEmails, 1)
|
||||
assert.NotEmpty(t, user.Description)
|
||||
|
||||
err = dataprovider.DeleteUser(username, "", "", "")
|
||||
|
@ -1720,10 +1735,12 @@ func TestFilesystemActionErrors(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
err = dataprovider.AddUser(&user, "", "", "")
|
||||
assert.NoError(t, err)
|
||||
err = executeRenameFsActionForUser([]dataprovider.KeyValue{
|
||||
err = executeRenameFsActionForUser([]dataprovider.RenameConfig{
|
||||
{
|
||||
Key: "/p1",
|
||||
Value: "/p1",
|
||||
KeyValue: dataprovider.KeyValue{
|
||||
Key: "/p1",
|
||||
Value: "/p1",
|
||||
},
|
||||
},
|
||||
}, testReplacer, user)
|
||||
if assert.Error(t, err) {
|
||||
|
@ -1734,10 +1751,12 @@ func TestFilesystemActionErrors(t *testing.T) {
|
|||
Options: dataprovider.BaseEventActionOptions{
|
||||
FsConfig: dataprovider.EventActionFilesystemConfig{
|
||||
Type: dataprovider.FilesystemActionRename,
|
||||
Renames: []dataprovider.KeyValue{
|
||||
Renames: []dataprovider.RenameConfig{
|
||||
{
|
||||
Key: "/p2",
|
||||
Value: "/p2",
|
||||
KeyValue: dataprovider.KeyValue{
|
||||
Key: "/p2",
|
||||
Value: "/p2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
@ -19,6 +19,8 @@ import (
|
|||
|
||||
"github.com/robfig/cron/v3"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
|
||||
"github.com/drakkan/sftpgo/v2/internal/logger"
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
)
|
||||
|
||||
|
@ -36,7 +38,15 @@ func stopEventScheduler() {
|
|||
func startEventScheduler() {
|
||||
stopEventScheduler()
|
||||
|
||||
eventScheduler = cron.New(cron.WithLocation(time.UTC), cron.WithLogger(cron.DiscardLogger))
|
||||
options := []cron.Option{
|
||||
cron.WithLogger(cron.DiscardLogger),
|
||||
}
|
||||
if !dataprovider.UseLocalTime() {
|
||||
eventManagerLog(logger.LevelDebug, "use UTC time for the scheduler")
|
||||
options = append(options, cron.WithLocation(time.UTC))
|
||||
}
|
||||
|
||||
eventScheduler = cron.New(options...)
|
||||
eventManager.loadRules()
|
||||
_, err := eventScheduler.AddFunc("@every 10m", eventManager.loadRules)
|
||||
util.PanicOnError(err)
|
||||
|
|
|
@ -1459,15 +1459,15 @@ func TestTruncateQuotaLimits(t *testing.T) {
|
|||
expectedQuotaSize := int64(3)
|
||||
fold, _, err := httpdtest.GetFolderByName(folder2.Name, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize)
|
||||
assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), fold.UsedQuotaSize)
|
||||
assert.Equal(t, 0, fold.UsedQuotaFiles)
|
||||
err = f.Close()
|
||||
assert.NoError(t, err)
|
||||
expectedQuotaFiles = 1
|
||||
fold, _, err = httpdtest.GetFolderByName(folder2.Name, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize)
|
||||
assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), fold.UsedQuotaSize)
|
||||
assert.Equal(t, 0, fold.UsedQuotaFiles)
|
||||
user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles)
|
||||
|
@ -1779,8 +1779,8 @@ func TestVirtualFoldersQuotaValues(t *testing.T) {
|
|||
assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize)
|
||||
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
|
@ -1887,8 +1887,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 2, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -1912,8 +1912,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 2, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
// rename a file inside vdir2, it isn't included inside user quota, so we have:
|
||||
// - vdir1/dir1/testFileName.rename
|
||||
// - vdir1/dir2/testFileName1
|
||||
|
@ -1931,8 +1931,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, 2, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 2, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
// rename a file inside vdir2 overwriting an existing, we now have:
|
||||
// - vdir1/dir1/testFileName.rename
|
||||
// - vdir1/dir2/testFileName1
|
||||
|
@ -1949,8 +1949,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 2, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
// rename a file inside vdir1 overwriting an existing, we now have:
|
||||
// - vdir1/dir1/testFileName.rename (initial testFileName1)
|
||||
// - vdir2/dir1/testFileName.rename (initial testFileName1)
|
||||
|
@ -1962,8 +1962,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -1983,8 +1983,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -2089,8 +2089,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize, user.UsedQuotaSize)
|
||||
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1+testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -2108,8 +2108,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize*2, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize*2, f.UsedQuotaSize)
|
||||
assert.Equal(t, 2, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1*2, f.UsedQuotaSize)
|
||||
|
@ -2126,8 +2126,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1+testFileSize, f.UsedQuotaSize)
|
||||
|
@ -2143,8 +2143,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
|
@ -2174,8 +2174,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize1*3+testFileSize*2, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1*3+testFileSize*2, f.UsedQuotaSize)
|
||||
assert.Equal(t, 5, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
|
@ -2189,8 +2189,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize1*2+testFileSize, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1*2+testFileSize, f.UsedQuotaSize)
|
||||
assert.Equal(t, 3, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -2295,8 +2295,8 @@ func TestQuotaRenameFromVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -2314,8 +2314,8 @@ func TestQuotaRenameFromVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize+testFileSize1+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
|
@ -2378,8 +2378,8 @@ func TestQuotaRenameFromVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize*3+testFileSize1*3, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 2, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
|
@ -2499,8 +2499,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
// rename a file from user home dir to vdir2, vdir2 is not included in user quota so we have:
|
||||
// - /vdir2/dir1/testFileName
|
||||
// - /vdir1/dir1/testFileName1
|
||||
|
@ -2539,8 +2539,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
|
@ -2556,8 +2556,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -2579,8 +2579,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -2597,8 +2597,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize*2+testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 3, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -2623,8 +2623,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize*2+testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 3, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1*2+testFileSize, f.UsedQuotaSize)
|
||||
|
@ -3986,9 +3986,9 @@ func TestEventRule(t *testing.T) {
|
|||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 3)
|
||||
assert.True(t, util.Contains(email.To, "test1@example.com"))
|
||||
assert.True(t, util.Contains(email.To, "test2@example.com"))
|
||||
assert.True(t, util.Contains(email.To, "test3@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test1@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test2@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test3@example.com"))
|
||||
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: New "upload" from "%s" status OK`, user.Username))
|
||||
// test the failure action, we download a file that exceeds the transfer quota limit
|
||||
err = writeSFTPFileNoCheck(path.Join("subdir1", testFileName), 1*1024*1024+65535, client)
|
||||
|
@ -4007,9 +4007,9 @@ func TestEventRule(t *testing.T) {
|
|||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email = lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 3)
|
||||
assert.True(t, util.Contains(email.To, "test1@example.com"))
|
||||
assert.True(t, util.Contains(email.To, "test2@example.com"))
|
||||
assert.True(t, util.Contains(email.To, "test3@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test1@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test2@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test3@example.com"))
|
||||
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: New "download" from "%s" status KO`, user.Username))
|
||||
assert.Contains(t, email.Data, `"download" failed`)
|
||||
assert.Contains(t, email.Data, common.ErrReadQuotaExceeded.Error())
|
||||
|
@ -4027,7 +4027,7 @@ func TestEventRule(t *testing.T) {
|
|||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email = lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "failure@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "failure@example.com"))
|
||||
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: Failed "upload" from "%s"`, user.Username))
|
||||
assert.Contains(t, email.Data, fmt.Sprintf(`action %q failed`, action1.Name))
|
||||
// now test the download rule
|
||||
|
@ -4044,9 +4044,9 @@ func TestEventRule(t *testing.T) {
|
|||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email = lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 3)
|
||||
assert.True(t, util.Contains(email.To, "test1@example.com"))
|
||||
assert.True(t, util.Contains(email.To, "test2@example.com"))
|
||||
assert.True(t, util.Contains(email.To, "test3@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test1@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test2@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test3@example.com"))
|
||||
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: New "download" from "%s"`, user.Username))
|
||||
}
|
||||
// test upload action command with arguments
|
||||
|
@ -4087,9 +4087,9 @@ func TestEventRule(t *testing.T) {
|
|||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 3)
|
||||
assert.True(t, util.Contains(email.To, "test1@example.com"))
|
||||
assert.True(t, util.Contains(email.To, "test2@example.com"))
|
||||
assert.True(t, util.Contains(email.To, "test3@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test1@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test2@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test3@example.com"))
|
||||
assert.Contains(t, email.Data, `Subject: New "delete" from "admin"`)
|
||||
_, err = httpdtest.RemoveEventRule(rule3, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
|
@ -4494,7 +4494,7 @@ func TestEventRuleProviderEvents(t *testing.T) {
|
|||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "test3@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test3@example.com"))
|
||||
assert.Contains(t, email.Data, `Subject: New "update" from "admin"`)
|
||||
}
|
||||
// now delete the script to generate an error
|
||||
|
@ -4509,7 +4509,7 @@ func TestEventRuleProviderEvents(t *testing.T) {
|
|||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "failure@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "failure@example.com"))
|
||||
assert.Contains(t, email.Data, `Subject: Failed "update" from "admin"`)
|
||||
assert.Contains(t, email.Data, fmt.Sprintf("Object name: %s object type: folder", folder.Name))
|
||||
lastReceivedEmail.reset()
|
||||
|
@ -4558,10 +4558,12 @@ func TestEventRuleFsActions(t *testing.T) {
|
|||
Options: dataprovider.BaseEventActionOptions{
|
||||
FsConfig: dataprovider.EventActionFilesystemConfig{
|
||||
Type: dataprovider.FilesystemActionRename,
|
||||
Renames: []dataprovider.KeyValue{
|
||||
Renames: []dataprovider.RenameConfig{
|
||||
{
|
||||
Key: "/{{VirtualDirPath}}/{{ObjectName}}",
|
||||
Value: "/{{ObjectName}}_renamed",
|
||||
KeyValue: dataprovider.KeyValue{
|
||||
Key: "/{{VirtualDirPath}}/{{ObjectName}}",
|
||||
Value: "/{{ObjectName}}_renamed",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -4819,6 +4821,80 @@ func TestEventRuleFsActions(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestEventActionObjectBaseName(t *testing.T) {
|
||||
a1 := dataprovider.BaseEventAction{
|
||||
Name: "a1",
|
||||
Type: dataprovider.ActionTypeFilesystem,
|
||||
Options: dataprovider.BaseEventActionOptions{
|
||||
FsConfig: dataprovider.EventActionFilesystemConfig{
|
||||
Type: dataprovider.FilesystemActionRename,
|
||||
Renames: []dataprovider.RenameConfig{
|
||||
{
|
||||
KeyValue: dataprovider.KeyValue{
|
||||
Key: "/{{VirtualDirPath}}/{{ObjectName}}",
|
||||
Value: "/{{ObjectBaseName}}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
action1, resp, err := httpdtest.AddEventAction(a1, http.StatusCreated)
|
||||
assert.NoError(t, err, string(resp))
|
||||
|
||||
r1 := dataprovider.EventRule{
|
||||
Name: "r2",
|
||||
Status: 1,
|
||||
Trigger: dataprovider.EventTriggerFsEvent,
|
||||
Conditions: dataprovider.EventConditions{
|
||||
FsEvents: []string{"upload"},
|
||||
},
|
||||
Actions: []dataprovider.EventAction{
|
||||
{
|
||||
BaseEventAction: dataprovider.BaseEventAction{
|
||||
Name: action1.Name,
|
||||
},
|
||||
Order: 1,
|
||||
Options: dataprovider.EventActionOptions{
|
||||
ExecuteSync: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
|
||||
user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
conn, client, err := getSftpClient(user)
|
||||
if assert.NoError(t, err) {
|
||||
defer conn.Close()
|
||||
defer client.Close()
|
||||
|
||||
testDir := "test dir name"
|
||||
err = client.Mkdir(testDir)
|
||||
fileSize := int64(32768)
|
||||
assert.NoError(t, err)
|
||||
err = writeSFTPFileNoCheck(path.Join(testDir, testFileName), fileSize, client)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = client.Stat(path.Join(testDir, testFileName))
|
||||
assert.ErrorIs(t, err, os.ErrNotExist)
|
||||
|
||||
_, err = client.Stat(strings.TrimSuffix(testFileName, path.Ext(testFileName)))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
_, err = httpdtest.RemoveEventRule(rule1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
_, err = httpdtest.RemoveEventAction(action1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestUploadEventRule(t *testing.T) {
|
||||
smtpCfg := smtp.Config{
|
||||
Host: "127.0.0.1",
|
||||
|
@ -4968,10 +5044,13 @@ func TestEventRulePreDelete(t *testing.T) {
|
|||
Options: dataprovider.BaseEventActionOptions{
|
||||
FsConfig: dataprovider.EventActionFilesystemConfig{
|
||||
Type: dataprovider.FilesystemActionRename,
|
||||
Renames: []dataprovider.KeyValue{
|
||||
Renames: []dataprovider.RenameConfig{
|
||||
{
|
||||
Key: "/{{VirtualPath}}",
|
||||
Value: fmt.Sprintf("/%s/{{VirtualPath}}", movePath),
|
||||
KeyValue: dataprovider.KeyValue{
|
||||
Key: "/{{VirtualPath}}",
|
||||
Value: fmt.Sprintf("/%s/{{VirtualPath}}", movePath),
|
||||
},
|
||||
UpdateModTime: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -5025,59 +5104,83 @@ func TestEventRulePreDelete(t *testing.T) {
|
|||
QuotaFiles: 1000,
|
||||
},
|
||||
}
|
||||
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
|
||||
localUser, _, err := httpdtest.AddUser(u, http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
u = getTestSFTPUser()
|
||||
u.QuotaFiles = 1000
|
||||
sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
conn, client, err := getSftpClient(user)
|
||||
if assert.NoError(t, err) {
|
||||
defer conn.Close()
|
||||
defer client.Close()
|
||||
|
||||
testDir := "sub dir"
|
||||
err = client.MkdirAll(testDir)
|
||||
assert.NoError(t, err)
|
||||
err = writeSFTPFile(testFileName, 100, client)
|
||||
assert.NoError(t, err)
|
||||
err = writeSFTPFile(path.Join(testDir, testFileName), 100, client)
|
||||
assert.NoError(t, err)
|
||||
err = client.Remove(testFileName)
|
||||
assert.NoError(t, err)
|
||||
err = client.Remove(path.Join(testDir, testFileName))
|
||||
assert.NoError(t, err)
|
||||
// check files
|
||||
_, err = client.Stat(testFileName)
|
||||
assert.ErrorIs(t, err, os.ErrNotExist)
|
||||
_, err = client.Stat(path.Join(testDir, testFileName))
|
||||
assert.ErrorIs(t, err, os.ErrNotExist)
|
||||
_, err = client.Stat(path.Join("/", movePath, testFileName))
|
||||
assert.NoError(t, err)
|
||||
_, err = client.Stat(path.Join("/", movePath, testDir, testFileName))
|
||||
assert.NoError(t, err)
|
||||
// check quota
|
||||
user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, user.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), user.UsedQuotaSize)
|
||||
folder, _, err := httpdtest.GetFolderByName(movePath, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, folder.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(200), folder.UsedQuotaSize)
|
||||
// pre-delete action is not executed in movePath
|
||||
err = client.Remove(path.Join("/", movePath, testFileName))
|
||||
assert.NoError(t, err)
|
||||
// check quota
|
||||
folder, _, err = httpdtest.GetFolderByName(movePath, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, folder.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(100), folder.UsedQuotaSize)
|
||||
for _, user := range []dataprovider.User{localUser, sftpUser} {
|
||||
conn, client, err := getSftpClient(user)
|
||||
if assert.NoError(t, err) {
|
||||
defer conn.Close()
|
||||
defer client.Close()
|
||||
|
||||
testDir := "sub dir"
|
||||
err = client.MkdirAll(testDir)
|
||||
assert.NoError(t, err)
|
||||
err = writeSFTPFile(testFileName, 100, client)
|
||||
assert.NoError(t, err)
|
||||
err = writeSFTPFile(path.Join(testDir, testFileName), 100, client)
|
||||
assert.NoError(t, err)
|
||||
modTime := time.Now().Add(-36 * time.Hour)
|
||||
err = client.Chtimes(testFileName, modTime, modTime)
|
||||
assert.NoError(t, err)
|
||||
err = client.Remove(testFileName)
|
||||
assert.NoError(t, err)
|
||||
err = client.Remove(path.Join(testDir, testFileName))
|
||||
assert.NoError(t, err)
|
||||
// check files
|
||||
_, err = client.Stat(testFileName)
|
||||
assert.ErrorIs(t, err, os.ErrNotExist)
|
||||
_, err = client.Stat(path.Join(testDir, testFileName))
|
||||
assert.ErrorIs(t, err, os.ErrNotExist)
|
||||
info, err := client.Stat(path.Join("/", movePath, testFileName))
|
||||
assert.NoError(t, err)
|
||||
diff := math.Abs(time.Until(info.ModTime()).Seconds())
|
||||
assert.LessOrEqual(t, diff, float64(2))
|
||||
|
||||
_, err = client.Stat(path.Join("/", movePath, testDir, testFileName))
|
||||
assert.NoError(t, err)
|
||||
// check quota
|
||||
user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
if user.Username == localUser.Username {
|
||||
assert.Equal(t, 0, user.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), user.UsedQuotaSize)
|
||||
folder, _, err := httpdtest.GetFolderByName(movePath, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, folder.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(200), folder.UsedQuotaSize)
|
||||
} else {
|
||||
assert.Equal(t, 1, user.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(100), user.UsedQuotaSize)
|
||||
}
|
||||
// pre-delete action is not executed in movePath
|
||||
err = client.Remove(path.Join("/", movePath, testFileName))
|
||||
assert.NoError(t, err)
|
||||
if user.Username == localUser.Username {
|
||||
// check quota
|
||||
folder, _, err := httpdtest.GetFolderByName(movePath, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, folder.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(100), folder.UsedQuotaSize)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, err = httpdtest.RemoveEventRule(rule1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
_, err = httpdtest.RemoveEventAction(action1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||
_, err = httpdtest.RemoveUser(sftpUser, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
_, err = httpdtest.RemoveUser(localUser, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(localUser.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
_, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: movePath}, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
|
@ -5105,10 +5208,12 @@ func TestEventRulePreDownloadUpload(t *testing.T) {
|
|||
Options: dataprovider.BaseEventActionOptions{
|
||||
FsConfig: dataprovider.EventActionFilesystemConfig{
|
||||
Type: dataprovider.FilesystemActionRename,
|
||||
Renames: []dataprovider.KeyValue{
|
||||
Renames: []dataprovider.RenameConfig{
|
||||
{
|
||||
Key: "/missing source",
|
||||
Value: "/missing target",
|
||||
KeyValue: dataprovider.KeyValue{
|
||||
Key: "/missing source",
|
||||
Value: "/missing target",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -5701,7 +5806,7 @@ func TestBackupAsAttachment(t *testing.T) {
|
|||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "test@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test@example.com"))
|
||||
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s OK"`, renewalEvent))
|
||||
assert.Contains(t, email.Data, `Domain: example.com`)
|
||||
assert.Contains(t, email.Data, "Content-Type: application/json")
|
||||
|
@ -6071,7 +6176,7 @@ func TestEventActionCompressQuotaErrors(t *testing.T) {
|
|||
}, 3*time.Second, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "test@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test@example.com"))
|
||||
assert.Contains(t, email.Data, `Subject: "Compress failed"`)
|
||||
assert.Contains(t, email.Data, common.ErrQuotaExceeded.Error())
|
||||
// update quota size so the user is already overquota
|
||||
|
@ -6086,7 +6191,7 @@ func TestEventActionCompressQuotaErrors(t *testing.T) {
|
|||
}, 3*time.Second, 100*time.Millisecond)
|
||||
email = lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "test@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test@example.com"))
|
||||
assert.Contains(t, email.Data, `Subject: "Compress failed"`)
|
||||
assert.Contains(t, email.Data, common.ErrQuotaExceeded.Error())
|
||||
// remove the path to compress to trigger an error for size estimation
|
||||
|
@ -6100,7 +6205,7 @@ func TestEventActionCompressQuotaErrors(t *testing.T) {
|
|||
}, 3*time.Second, 100*time.Millisecond)
|
||||
email = lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "test@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test@example.com"))
|
||||
assert.Contains(t, email.Data, `Subject: "Compress failed"`)
|
||||
assert.Contains(t, email.Data, "unable to estimate archive size")
|
||||
}
|
||||
|
@ -6229,8 +6334,8 @@ func TestEventActionCompressQuotaFolder(t *testing.T) {
|
|||
assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize)
|
||||
vfolder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, vfolder.UsedQuotaFiles)
|
||||
assert.Equal(t, info.Size()+int64(len(testFileContent)), vfolder.UsedQuotaSize)
|
||||
assert.Equal(t, 0, vfolder.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), vfolder.UsedQuotaSize)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6436,7 +6541,7 @@ func TestEventActionEmailAttachments(t *testing.T) {
|
|||
}, 1500*time.Millisecond, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "test@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test@example.com"))
|
||||
assert.Contains(t, email.Data, `Subject: "upload" from`)
|
||||
assert.Contains(t, email.Data, url.QueryEscape("/"+testFileName))
|
||||
assert.Contains(t, email.Data, "Content-Disposition: attachment")
|
||||
|
@ -6614,7 +6719,7 @@ func TestEventActionsRetentionReports(t *testing.T) {
|
|||
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "test@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test@example.com"))
|
||||
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "upload" from "%s"`, user.Username))
|
||||
assert.Contains(t, email.Data, "Content-Disposition: attachment")
|
||||
_, err = client.Stat(testDir)
|
||||
|
@ -6787,7 +6892,7 @@ func TestEventRuleFirstUploadDownloadActions(t *testing.T) {
|
|||
}, 1500*time.Millisecond, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "test@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test@example.com"))
|
||||
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "first-upload" from "%s"`, user.Username))
|
||||
lastReceivedEmail.reset()
|
||||
// a new upload will not produce a new notification
|
||||
|
@ -6810,7 +6915,7 @@ func TestEventRuleFirstUploadDownloadActions(t *testing.T) {
|
|||
}, 1500*time.Millisecond, 100*time.Millisecond)
|
||||
email = lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "test@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test@example.com"))
|
||||
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "first-download" from "%s"`, user.Username))
|
||||
// download again
|
||||
lastReceivedEmail.reset()
|
||||
|
@ -6906,7 +7011,7 @@ func TestEventRuleRenameEvent(t *testing.T) {
|
|||
}, 1500*time.Millisecond, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "test@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test@example.com"))
|
||||
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "rename" from "%s"`, user.Username))
|
||||
assert.Contains(t, email.Data, "Content-Type: text/html")
|
||||
assert.Contains(t, email.Data, fmt.Sprintf("Target path %q", path.Join("/subdir", testFileName)))
|
||||
|
@ -7040,7 +7145,7 @@ func TestEventRuleIDPLogin(t *testing.T) {
|
|||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "test@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test@example.com"))
|
||||
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s OK"`, common.IDPLoginUser))
|
||||
assert.Contains(t, email.Data, username)
|
||||
assert.Contains(t, email.Data, custom1)
|
||||
|
@ -7104,7 +7209,7 @@ func TestEventRuleIDPLogin(t *testing.T) {
|
|||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email = lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "test@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test@example.com"))
|
||||
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s OK"`, common.IDPLoginAdmin))
|
||||
assert.Contains(t, email.Data, username)
|
||||
assert.Contains(t, email.Data, custom1)
|
||||
|
@ -7296,7 +7401,7 @@ func TestEventRuleEmailField(t *testing.T) {
|
|||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, user.Email))
|
||||
assert.True(t, slices.Contains(email.To, user.Email))
|
||||
assert.Contains(t, email.Data, `Subject: "add" from "admin"`)
|
||||
|
||||
// if we add a user without email the notification will fail
|
||||
|
@ -7310,7 +7415,7 @@ func TestEventRuleEmailField(t *testing.T) {
|
|||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email = lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "failure@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "failure@example.com"))
|
||||
assert.Contains(t, email.Data, `no recipient addresses set`)
|
||||
|
||||
conn, client, err := getSftpClient(user)
|
||||
|
@ -7327,7 +7432,7 @@ func TestEventRuleEmailField(t *testing.T) {
|
|||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, user.Email))
|
||||
assert.True(t, slices.Contains(email.To, user.Email))
|
||||
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "mkdir" from "%s"`, user.Username))
|
||||
}
|
||||
|
||||
|
@ -7434,7 +7539,7 @@ func TestEventRuleCertificate(t *testing.T) {
|
|||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "test@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test@example.com"))
|
||||
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s OK"`, renewalEvent))
|
||||
assert.Contains(t, email.Data, "Content-Type: text/plain")
|
||||
assert.Contains(t, email.Data, `Domain: example.com Timestamp`)
|
||||
|
@ -7455,7 +7560,7 @@ func TestEventRuleCertificate(t *testing.T) {
|
|||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email = lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "test@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test@example.com"))
|
||||
assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s KO"`, renewalEvent))
|
||||
assert.Contains(t, email.Data, `Domain: example.com Timestamp`)
|
||||
assert.Contains(t, email.Data, dateTime.UTC().Format("2006-01-02T15:04:05.000"))
|
||||
|
@ -7582,8 +7687,8 @@ func TestEventRuleIPBlocked(t *testing.T) {
|
|||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 2)
|
||||
assert.True(t, util.Contains(email.To, "test3@example.com"))
|
||||
assert.True(t, util.Contains(email.To, "test4@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test3@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "test4@example.com"))
|
||||
assert.Contains(t, email.Data, `Subject: New "IP Blocked"`)
|
||||
|
||||
err = dataprovider.DeleteEventRule(rule1.Name, "", "", "")
|
||||
|
@ -7971,6 +8076,7 @@ func TestEventRulePasswordExpiration(t *testing.T) {
|
|||
_, _, err = httpdtest.UpdateEventRule(rule1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
user.Email = "user@example.net"
|
||||
user.Filters.AdditionalEmails = []string{"additional@example.net"}
|
||||
_, _, err = httpdtest.UpdateUser(user, http.StatusOK, "")
|
||||
assert.NoError(t, err)
|
||||
conn, client, err = getSftpClient(user)
|
||||
|
@ -7986,8 +8092,9 @@ func TestEventRulePasswordExpiration(t *testing.T) {
|
|||
return lastReceivedEmail.get().From != ""
|
||||
}, 1500*time.Millisecond, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.Len(t, email.To, 2)
|
||||
assert.Contains(t, email.To, user.Email)
|
||||
assert.Contains(t, email.To, user.Filters.AdditionalEmails[0])
|
||||
assert.Contains(t, email.Data, "your SFTPGo password expires in 5 days")
|
||||
err = client.RemoveDirectory(dirName)
|
||||
assert.NoError(t, err)
|
||||
|
@ -8388,6 +8495,87 @@ func TestRetentionAPI(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPerUserTransferLimits(t *testing.T) {
|
||||
oldMaxPerHostConns := common.Config.MaxPerHostConnections
|
||||
|
||||
common.Config.MaxPerHostConnections = 2
|
||||
|
||||
u := getTestUser()
|
||||
u.UploadBandwidth = 32
|
||||
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
conn, client, err := getSftpClient(user)
|
||||
if assert.NoError(t, err) {
|
||||
defer conn.Close()
|
||||
defer client.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numErrors := 0
|
||||
for i := 0; i <= 2; i++ {
|
||||
wg.Add(1)
|
||||
go func(counter int) {
|
||||
defer wg.Done()
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
err := writeSFTPFile(fmt.Sprintf("%s_%d", testFileName, counter), 64*1024, client)
|
||||
if err != nil {
|
||||
numErrors++
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, 1, numErrors)
|
||||
}
|
||||
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
|
||||
common.Config.MaxPerHostConnections = oldMaxPerHostConns
|
||||
}
|
||||
|
||||
func TestMaxSessionsSameConnection(t *testing.T) {
|
||||
u := getTestUser()
|
||||
u.UploadBandwidth = 32
|
||||
u.MaxSessions = 2
|
||||
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
conn, client, err := getSftpClient(user)
|
||||
if assert.NoError(t, err) {
|
||||
defer conn.Close()
|
||||
defer client.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numErrors := 0
|
||||
for i := 0; i <= 2; i++ {
|
||||
wg.Add(1)
|
||||
go func(counter int) {
|
||||
defer wg.Done()
|
||||
|
||||
var err error
|
||||
if counter < 2 {
|
||||
err = writeSFTPFile(fmt.Sprintf("%s_%d", testFileName, counter), 64*1024, client)
|
||||
} else {
|
||||
// wait for the transfers to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_, _, err = getSftpClient(user)
|
||||
}
|
||||
if err != nil {
|
||||
numErrors++
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
assert.Equal(t, 1, numErrors)
|
||||
}
|
||||
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestRenameDir(t *testing.T) {
|
||||
u := getTestUser()
|
||||
testDir := "/dir-to-rename"
|
||||
|
@ -8755,7 +8943,7 @@ func TestSFTPLoopError(t *testing.T) {
|
|||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "failure@example.com"))
|
||||
assert.True(t, slices.Contains(email.To, "failure@example.com"))
|
||||
assert.Contains(t, email.Data, `Subject: Failed action`)
|
||||
|
||||
user1.VirtualFolders[0].FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword)
|
||||
|
|
|
@ -17,6 +17,7 @@ package common
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -94,7 +95,7 @@ func (r *RateLimiterConfig) validate() error {
|
|||
}
|
||||
r.Protocols = util.RemoveDuplicates(r.Protocols, true)
|
||||
for _, protocol := range r.Protocols {
|
||||
if !util.Contains(rateLimiterProtocolValues, protocol) {
|
||||
if !slices.Contains(rateLimiterProtocolValues, protocol) {
|
||||
return fmt.Errorf("invalid protocol %q", protocol)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/logger"
|
||||
|
@ -96,7 +97,7 @@ func (m *CertManager) loadCertificates() error {
|
|||
}
|
||||
logger.Debug(m.logSender, "", "TLS certificate %q successfully loaded, id %v", keyPair.Cert, keyPair.ID)
|
||||
certs[keyPair.ID] = &newCert
|
||||
if !util.Contains(m.monitorList, keyPair.Cert) {
|
||||
if !slices.Contains(m.monitorList, keyPair.Cert) {
|
||||
m.monitorList = append(m.monitorList, keyPair.Cert)
|
||||
}
|
||||
}
|
||||
|
@ -190,7 +191,7 @@ func (m *CertManager) LoadCRLs() error {
|
|||
|
||||
logger.Debug(m.logSender, "", "CRL %q successfully loaded", revocationList)
|
||||
crls = append(crls, crl)
|
||||
if !util.Contains(m.monitorList, revocationList) {
|
||||
if !slices.Contains(m.monitorList, revocationList) {
|
||||
m.monitorList = append(m.monitorList, revocationList)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -329,6 +329,21 @@ func (t *BaseTransfer) getUploadFileSize() (int64, int, error) {
|
|||
var fileSize int64
|
||||
var deletedFiles int
|
||||
|
||||
switch dataprovider.GetQuotaTracking() {
|
||||
case 0:
|
||||
return fileSize, deletedFiles, errors.New("quota tracking disabled")
|
||||
case 2:
|
||||
if !t.Connection.User.HasQuotaRestrictions() {
|
||||
vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath))
|
||||
if err != nil {
|
||||
return fileSize, deletedFiles, errors.New("quota tracking disabled for this user")
|
||||
}
|
||||
if vfolder.IsIncludedInUserQuota() {
|
||||
return fileSize, deletedFiles, errors.New("quota tracking disabled for this user and folder included in user quota")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info, err := t.Fs.Stat(t.fsPath)
|
||||
if err == nil {
|
||||
fileSize = info.Size()
|
||||
|
@ -394,7 +409,7 @@ func (t *BaseTransfer) Close() error {
|
|||
t.effectiveFsPath, err)
|
||||
} else if t.isAtomicUpload() {
|
||||
if t.ErrTransfer == nil || Config.UploadMode&UploadModeAtomicWithResume != 0 {
|
||||
_, _, err = t.Fs.Rename(t.effectiveFsPath, t.fsPath)
|
||||
_, _, err = t.Fs.Rename(t.effectiveFsPath, t.fsPath, 0)
|
||||
t.Connection.Log(logger.LevelDebug, "atomic upload completed, rename: %q -> %q, error: %v",
|
||||
t.effectiveFsPath, t.fsPath, err)
|
||||
// the file must be removed if it is uploaded to a path outside the home dir and cannot be renamed
|
||||
|
@ -521,11 +536,8 @@ func (t *BaseTransfer) updateQuota(numFiles int, fileSize int64) bool {
|
|||
if t.transferType == TransferUpload && (numFiles != 0 || sizeDiff != 0) {
|
||||
vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath))
|
||||
if err == nil {
|
||||
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, numFiles, //nolint:errcheck
|
||||
dataprovider.UpdateUserFolderQuota(&vfolder, &t.Connection.User, numFiles,
|
||||
sizeDiff, false)
|
||||
if vfolder.IsIncludedInUserQuota() {
|
||||
dataprovider.UpdateUserQuota(&t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck
|
||||
}
|
||||
} else {
|
||||
dataprovider.UpdateUserQuota(&t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck
|
||||
}
|
||||
|
|
|
@ -306,8 +306,9 @@ func TestRemovePartialCryptoFile(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
u := dataprovider.User{
|
||||
BaseUser: sdk.BaseUser{
|
||||
Username: "test",
|
||||
HomeDir: os.TempDir(),
|
||||
Username: "test",
|
||||
HomeDir: os.TempDir(),
|
||||
QuotaFiles: 1000000,
|
||||
},
|
||||
}
|
||||
conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u)
|
||||
|
@ -323,6 +324,9 @@ func TestRemovePartialCryptoFile(t *testing.T) {
|
|||
assert.Equal(t, int64(0), size)
|
||||
assert.Equal(t, 1, deletedFiles)
|
||||
assert.NoFileExists(t, testFile)
|
||||
err = transfer.Close()
|
||||
assert.Error(t, err)
|
||||
assert.Len(t, conn.GetTransfers(), 0)
|
||||
}
|
||||
|
||||
func TestFTPMode(t *testing.T) {
|
||||
|
@ -434,6 +438,11 @@ func TestTransferQuota(t *testing.T) {
|
|||
}
|
||||
err = transfer.CheckWrite()
|
||||
assert.True(t, conn.IsQuotaExceededError(err))
|
||||
|
||||
err = transfer.Close()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, conn.GetTransfers(), 0)
|
||||
assert.Equal(t, int32(0), Connections.GetTotalTransfers())
|
||||
}
|
||||
|
||||
func TestUploadOutsideHomeRenameError(t *testing.T) {
|
||||
|
|
|
@ -250,6 +250,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
|
|||
Connections.Remove(fakeConn5.GetID())
|
||||
stats := Connections.GetStats("")
|
||||
assert.Len(t, stats, 0)
|
||||
assert.Equal(t, int32(0), Connections.GetTotalTransfers())
|
||||
|
||||
err = dataprovider.DeleteUser(user.Username, "", "", "")
|
||||
assert.NoError(t, err)
|
||||
|
@ -368,11 +369,16 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
|
|||
if assert.Error(t, transfer4.errAbort) {
|
||||
assert.Contains(t, transfer4.errAbort.Error(), ErrReadQuotaExceeded.Error())
|
||||
}
|
||||
err = transfer3.Close()
|
||||
assert.NoError(t, err)
|
||||
err = transfer4.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
Connections.Remove(fakeConn3.GetID())
|
||||
Connections.Remove(fakeConn4.GetID())
|
||||
stats := Connections.GetStats("")
|
||||
assert.Len(t, stats, 0)
|
||||
assert.Equal(t, int32(0), Connections.GetTotalTransfers())
|
||||
|
||||
err = dataprovider.DeleteUser(user.Username, "", "", "")
|
||||
assert.NoError(t, err)
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
|
@ -91,6 +92,7 @@ var (
|
|||
TLSCipherSuites: nil,
|
||||
Protocols: nil,
|
||||
Prefix: "",
|
||||
ProxyMode: 0,
|
||||
ProxyAllowed: nil,
|
||||
ClientIPProxyHeader: "",
|
||||
ClientIPHeaderDepth: 0,
|
||||
|
@ -110,11 +112,13 @@ var (
|
|||
ClientAuthType: 0,
|
||||
TLSCipherSuites: nil,
|
||||
Protocols: nil,
|
||||
ProxyMode: 0,
|
||||
ProxyAllowed: nil,
|
||||
ClientIPProxyHeader: "",
|
||||
ClientIPHeaderDepth: 0,
|
||||
HideLoginURL: 0,
|
||||
RenderOpenAPI: true,
|
||||
Languages: []string{"en"},
|
||||
OIDC: httpd.OIDC{
|
||||
ClientID: "",
|
||||
ClientSecret: "",
|
||||
|
@ -235,6 +239,7 @@ func Init() {
|
|||
RateLimitersConfig: []common.RateLimiterConfig{defaultRateLimiter},
|
||||
Umask: "",
|
||||
ServerVersion: "",
|
||||
TZ: "",
|
||||
Metadata: common.MetadataConfig{
|
||||
Read: 0,
|
||||
},
|
||||
|
@ -399,6 +404,9 @@ func Init() {
|
|||
SigningPassphrase: "",
|
||||
SigningPassphraseFile: "",
|
||||
TokenValidation: 0,
|
||||
CookieLifetime: 20,
|
||||
ShareCookieLifetime: 120,
|
||||
JWTLifetime: 20,
|
||||
MaxUploadFileSize: 0,
|
||||
Cors: httpd.CorsConfig{
|
||||
Enabled: false,
|
||||
|
@ -719,7 +727,7 @@ func checkOverrideDefaultSettings() {
|
|||
}
|
||||
}
|
||||
|
||||
if util.Contains(viper.AllKeys(), "mfa.totp") {
|
||||
if slices.Contains(viper.AllKeys(), "mfa.totp") {
|
||||
globalConf.MFAConfig.TOTP = nil
|
||||
}
|
||||
}
|
||||
|
@ -883,13 +891,13 @@ func getRateLimitersFromEnv(idx int) {
|
|||
isSet = true
|
||||
}
|
||||
|
||||
burst, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__BURST", idx), 0)
|
||||
burst, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__BURST", idx), 32)
|
||||
if ok {
|
||||
rtlConfig.Burst = int(burst)
|
||||
isSet = true
|
||||
}
|
||||
|
||||
rtlType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__TYPE", idx), 0)
|
||||
rtlType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__TYPE", idx), 32)
|
||||
if ok {
|
||||
rtlConfig.Type = int(rtlType)
|
||||
isSet = true
|
||||
|
@ -907,13 +915,13 @@ func getRateLimitersFromEnv(idx int) {
|
|||
isSet = true
|
||||
}
|
||||
|
||||
softLimit, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__ENTRIES_SOFT_LIMIT", idx), 0)
|
||||
softLimit, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__ENTRIES_SOFT_LIMIT", idx), 32)
|
||||
if ok {
|
||||
rtlConfig.EntriesSoftLimit = int(softLimit)
|
||||
isSet = true
|
||||
}
|
||||
|
||||
hardLimit, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__ENTRIES_HARD_LIMIT", idx), 0)
|
||||
hardLimit, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__ENTRIES_HARD_LIMIT", idx), 32)
|
||||
if ok {
|
||||
rtlConfig.EntriesHardLimit = int(hardLimit)
|
||||
isSet = true
|
||||
|
@ -949,7 +957,7 @@ func getKMSPluginFromEnv(idx int, pluginConfig *plugin.Config) bool {
|
|||
func getAuthPluginFromEnv(idx int, pluginConfig *plugin.Config) bool {
|
||||
isSet := false
|
||||
|
||||
authScope, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__AUTH_OPTIONS__SCOPE", idx), 0)
|
||||
authScope, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__AUTH_OPTIONS__SCOPE", idx), 32)
|
||||
if ok {
|
||||
pluginConfig.AuthOptions.Scope = int(authScope)
|
||||
isSet = true
|
||||
|
@ -994,13 +1002,13 @@ func getNotifierPluginFromEnv(idx int, pluginConfig *plugin.Config) bool {
|
|||
}
|
||||
}
|
||||
|
||||
notifierRetryMaxTime, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__NOTIFIER_OPTIONS__RETRY_MAX_TIME", idx), 0)
|
||||
notifierRetryMaxTime, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__NOTIFIER_OPTIONS__RETRY_MAX_TIME", idx), 32)
|
||||
if ok {
|
||||
pluginConfig.NotifierOptions.RetryMaxTime = int(notifierRetryMaxTime)
|
||||
isSet = true
|
||||
}
|
||||
|
||||
notifierRetryQueueMaxSize, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__NOTIFIER_OPTIONS__RETRY_QUEUE_MAX_SIZE", idx), 0)
|
||||
notifierRetryQueueMaxSize, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__NOTIFIER_OPTIONS__RETRY_QUEUE_MAX_SIZE", idx), 32)
|
||||
if ok {
|
||||
pluginConfig.NotifierOptions.RetryQueueMaxSize = int(notifierRetryQueueMaxSize)
|
||||
isSet = true
|
||||
|
@ -1088,7 +1096,7 @@ func getSFTPDBindindFromEnv(idx int) {
|
|||
|
||||
isSet := false
|
||||
|
||||
port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_SFTPD__BINDINGS__%v__PORT", idx), 0)
|
||||
port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_SFTPD__BINDINGS__%v__PORT", idx), 32)
|
||||
if ok {
|
||||
binding.Port = int(port)
|
||||
isSet = true
|
||||
|
@ -1175,19 +1183,19 @@ func getFTPDBindingSecurityFromEnv(idx int, binding *ftpd.Binding) bool {
|
|||
isSet = true
|
||||
}
|
||||
|
||||
tlsMode, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__TLS_MODE", idx), 0)
|
||||
tlsMode, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__TLS_MODE", idx), 32)
|
||||
if ok {
|
||||
binding.TLSMode = int(tlsMode)
|
||||
isSet = true
|
||||
}
|
||||
|
||||
tlsSessionReuse, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__TLS_SESSION_REUSE", idx), 0)
|
||||
tlsSessionReuse, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__TLS_SESSION_REUSE", idx), 32)
|
||||
if ok {
|
||||
binding.TLSSessionReuse = int(tlsSessionReuse)
|
||||
isSet = true
|
||||
}
|
||||
|
||||
tlsVer, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__MIN_TLS_VERSION", idx), 0)
|
||||
tlsVer, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__MIN_TLS_VERSION", idx), 32)
|
||||
if ok {
|
||||
binding.MinTLSVersion = int(tlsVer)
|
||||
isSet = true
|
||||
|
@ -1199,25 +1207,25 @@ func getFTPDBindingSecurityFromEnv(idx int, binding *ftpd.Binding) bool {
|
|||
isSet = true
|
||||
}
|
||||
|
||||
clientAuthType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__CLIENT_AUTH_TYPE", idx), 0)
|
||||
clientAuthType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__CLIENT_AUTH_TYPE", idx), 32)
|
||||
if ok {
|
||||
binding.ClientAuthType = int(clientAuthType)
|
||||
isSet = true
|
||||
}
|
||||
|
||||
pasvSecurity, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__PASSIVE_CONNECTIONS_SECURITY", idx), 0)
|
||||
pasvSecurity, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__PASSIVE_CONNECTIONS_SECURITY", idx), 32)
|
||||
if ok {
|
||||
binding.PassiveConnectionsSecurity = int(pasvSecurity)
|
||||
isSet = true
|
||||
}
|
||||
|
||||
activeSecurity, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__ACTIVE_CONNECTIONS_SECURITY", idx), 0)
|
||||
activeSecurity, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__ACTIVE_CONNECTIONS_SECURITY", idx), 32)
|
||||
if ok {
|
||||
binding.ActiveConnectionsSecurity = int(activeSecurity)
|
||||
isSet = true
|
||||
}
|
||||
|
||||
ignoreASCIITransferType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%d__IGNORE_ASCII_TRANSFER_TYPE", idx), 0)
|
||||
ignoreASCIITransferType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%d__IGNORE_ASCII_TRANSFER_TYPE", idx), 32)
|
||||
if ok {
|
||||
binding.IgnoreASCIITransferType = int(ignoreASCIITransferType)
|
||||
isSet = true
|
||||
|
@ -1230,7 +1238,7 @@ func getFTPDBindingFromEnv(idx int) {
|
|||
binding := getDefaultFTPDBinding(idx)
|
||||
isSet := false
|
||||
|
||||
port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__PORT", idx), 0)
|
||||
port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__PORT", idx), 32)
|
||||
if ok {
|
||||
binding.Port = int(port)
|
||||
isSet = true
|
||||
|
@ -1310,13 +1318,13 @@ func getWebDAVBindingHTTPSConfigsFromEnv(idx int, binding *webdavd.Binding) bool
|
|||
isSet = true
|
||||
}
|
||||
|
||||
tlsVer, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__MIN_TLS_VERSION", idx), 0)
|
||||
tlsVer, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__MIN_TLS_VERSION", idx), 32)
|
||||
if ok {
|
||||
binding.MinTLSVersion = int(tlsVer)
|
||||
isSet = true
|
||||
}
|
||||
|
||||
clientAuthType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CLIENT_AUTH_TYPE", idx), 0)
|
||||
clientAuthType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CLIENT_AUTH_TYPE", idx), 32)
|
||||
if ok {
|
||||
binding.ClientAuthType = int(clientAuthType)
|
||||
isSet = true
|
||||
|
@ -1340,6 +1348,12 @@ func getWebDAVBindingHTTPSConfigsFromEnv(idx int, binding *webdavd.Binding) bool
|
|||
func getWebDAVDBindingProxyConfigsFromEnv(idx int, binding *webdavd.Binding) bool {
|
||||
isSet := false
|
||||
|
||||
proxyMode, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PROXY_MODE", idx), 32)
|
||||
if ok {
|
||||
binding.ProxyMode = int(proxyMode)
|
||||
isSet = true
|
||||
}
|
||||
|
||||
proxyAllowed, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PROXY_ALLOWED", idx))
|
||||
if ok {
|
||||
binding.ProxyAllowed = proxyAllowed
|
||||
|
@ -1352,7 +1366,7 @@ func getWebDAVDBindingProxyConfigsFromEnv(idx int, binding *webdavd.Binding) boo
|
|||
isSet = true
|
||||
}
|
||||
|
||||
clientIPHeaderDepth, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CLIENT_IP_HEADER_DEPTH", idx), 0)
|
||||
clientIPHeaderDepth, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CLIENT_IP_HEADER_DEPTH", idx), 32)
|
||||
if ok {
|
||||
binding.ClientIPHeaderDepth = int(clientIPHeaderDepth)
|
||||
isSet = true
|
||||
|
@ -1390,7 +1404,7 @@ func getWebDAVDBindingFromEnv(idx int) {
|
|||
|
||||
isSet := false
|
||||
|
||||
port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PORT", idx), 0)
|
||||
port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PORT", idx), 32)
|
||||
if ok {
|
||||
binding.Port = int(port)
|
||||
isSet = true
|
||||
|
@ -1761,6 +1775,12 @@ func getHTTPDNestedObjectsFromEnv(idx int, binding *httpd.Binding) bool {
|
|||
func getHTTPDBindingProxyConfigsFromEnv(idx int, binding *httpd.Binding) bool {
|
||||
isSet := false
|
||||
|
||||
proxyMode, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__PROXY_MODE", idx), 32)
|
||||
if ok {
|
||||
binding.ProxyMode = int(proxyMode)
|
||||
isSet = true
|
||||
}
|
||||
|
||||
proxyAllowed, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__PROXY_ALLOWED", idx))
|
||||
if ok {
|
||||
binding.ProxyAllowed = proxyAllowed
|
||||
|
@ -1773,7 +1793,7 @@ func getHTTPDBindingProxyConfigsFromEnv(idx int, binding *httpd.Binding) bool {
|
|||
isSet = true
|
||||
}
|
||||
|
||||
clientIPHeaderDepth, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__CLIENT_IP_HEADER_DEPTH", idx), 0)
|
||||
clientIPHeaderDepth, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__CLIENT_IP_HEADER_DEPTH", idx), 32)
|
||||
if ok {
|
||||
binding.ClientIPHeaderDepth = int(clientIPHeaderDepth)
|
||||
isSet = true
|
||||
|
@ -1786,7 +1806,7 @@ func getHTTPDBindingFromEnv(idx int) { //nolint:gocyclo
|
|||
binding := getDefaultHTTPBinding(idx)
|
||||
isSet := false
|
||||
|
||||
port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__PORT", idx), 0)
|
||||
port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__PORT", idx), 32)
|
||||
if ok {
|
||||
binding.Port = int(port)
|
||||
isSet = true
|
||||
|
@ -1828,7 +1848,7 @@ func getHTTPDBindingFromEnv(idx int) { //nolint:gocyclo
|
|||
isSet = true
|
||||
}
|
||||
|
||||
enabledLoginMethods, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLED_LOGIN_METHODS", idx), 0)
|
||||
enabledLoginMethods, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLED_LOGIN_METHODS", idx), 32)
|
||||
if ok {
|
||||
binding.EnabledLoginMethods = int(enabledLoginMethods)
|
||||
isSet = true
|
||||
|
@ -1840,19 +1860,25 @@ func getHTTPDBindingFromEnv(idx int) { //nolint:gocyclo
|
|||
isSet = true
|
||||
}
|
||||
|
||||
languages, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%d__LANGUAGES", idx))
|
||||
if ok {
|
||||
binding.Languages = languages
|
||||
isSet = true
|
||||
}
|
||||
|
||||
enableHTTPS, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLE_HTTPS", idx))
|
||||
if ok {
|
||||
binding.EnableHTTPS = enableHTTPS
|
||||
isSet = true
|
||||
}
|
||||
|
||||
tlsVer, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__MIN_TLS_VERSION", idx), 0)
|
||||
tlsVer, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__MIN_TLS_VERSION", idx), 32)
|
||||
if ok {
|
||||
binding.MinTLSVersion = int(tlsVer)
|
||||
isSet = true
|
||||
}
|
||||
|
||||
clientAuthType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__CLIENT_AUTH_TYPE", idx), 0)
|
||||
clientAuthType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__CLIENT_AUTH_TYPE", idx), 32)
|
||||
if ok {
|
||||
binding.ClientAuthType = int(clientAuthType)
|
||||
isSet = true
|
||||
|
@ -1874,7 +1900,7 @@ func getHTTPDBindingFromEnv(idx int) { //nolint:gocyclo
|
|||
isSet = true
|
||||
}
|
||||
|
||||
hideLoginURL, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__HIDE_LOGIN_URL", idx), 0)
|
||||
hideLoginURL, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__HIDE_LOGIN_URL", idx), 32)
|
||||
if ok {
|
||||
binding.HideLoginURL = int(hideLoginURL)
|
||||
isSet = true
|
||||
|
@ -1963,7 +1989,7 @@ func getCommandConfigsFromEnv(idx int) {
|
|||
cfg.Path = path
|
||||
}
|
||||
|
||||
timeout, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMAND__COMMANDS__%v__TIMEOUT", idx), 0)
|
||||
timeout, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMAND__COMMANDS__%v__TIMEOUT", idx), 32)
|
||||
if ok {
|
||||
cfg.Timeout = int(timeout)
|
||||
}
|
||||
|
@ -2023,6 +2049,7 @@ func setViperDefaults() {
|
|||
viper.SetDefault("common.defender.login_delay.password_failed", globalConf.Common.DefenderConfig.LoginDelay.PasswordFailed)
|
||||
viper.SetDefault("common.umask", globalConf.Common.Umask)
|
||||
viper.SetDefault("common.server_version", globalConf.Common.ServerVersion)
|
||||
viper.SetDefault("common.tz", globalConf.Common.TZ)
|
||||
viper.SetDefault("common.metadata.read", globalConf.Common.Metadata.Read)
|
||||
viper.SetDefault("common.event_manager.enabled_commands", globalConf.Common.EventManager.EnabledCommands)
|
||||
viper.SetDefault("acme.email", globalConf.ACME.Email)
|
||||
|
@ -2137,6 +2164,9 @@ func setViperDefaults() {
|
|||
viper.SetDefault("httpd.signing_passphrase", globalConf.HTTPDConfig.SigningPassphrase)
|
||||
viper.SetDefault("httpd.signing_passphrase_file", globalConf.HTTPDConfig.SigningPassphraseFile)
|
||||
viper.SetDefault("httpd.token_validation", globalConf.HTTPDConfig.TokenValidation)
|
||||
viper.SetDefault("httpd.cookie_lifetime", globalConf.HTTPDConfig.CookieLifetime)
|
||||
viper.SetDefault("httpd.share_cookie_lifetime", globalConf.HTTPDConfig.ShareCookieLifetime)
|
||||
viper.SetDefault("httpd.jwt_lifetime", globalConf.HTTPDConfig.JWTLifetime)
|
||||
viper.SetDefault("httpd.max_upload_file_size", globalConf.HTTPDConfig.MaxUploadFileSize)
|
||||
viper.SetDefault("httpd.cors.enabled", globalConf.HTTPDConfig.Cors.Enabled)
|
||||
viper.SetDefault("httpd.cors.allowed_origins", globalConf.HTTPDConfig.Cors.AllowedOrigins)
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/sftpgo/sdk/kms"
|
||||
|
@ -36,7 +37,6 @@ import (
|
|||
"github.com/drakkan/sftpgo/v2/internal/plugin"
|
||||
"github.com/drakkan/sftpgo/v2/internal/sftpd"
|
||||
"github.com/drakkan/sftpgo/v2/internal/smtp"
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
"github.com/drakkan/sftpgo/v2/internal/webdavd"
|
||||
)
|
||||
|
||||
|
@ -699,8 +699,8 @@ func TestPluginsFromEnv(t *testing.T) {
|
|||
pluginConf := pluginsConf[0]
|
||||
require.Equal(t, "notifier", pluginConf.Type)
|
||||
require.Len(t, pluginConf.NotifierOptions.FsEvents, 2)
|
||||
require.True(t, util.Contains(pluginConf.NotifierOptions.FsEvents, "upload"))
|
||||
require.True(t, util.Contains(pluginConf.NotifierOptions.FsEvents, "download"))
|
||||
require.True(t, slices.Contains(pluginConf.NotifierOptions.FsEvents, "upload"))
|
||||
require.True(t, slices.Contains(pluginConf.NotifierOptions.FsEvents, "download"))
|
||||
require.Len(t, pluginConf.NotifierOptions.ProviderEvents, 2)
|
||||
require.Equal(t, "add", pluginConf.NotifierOptions.ProviderEvents[0])
|
||||
require.Equal(t, "update", pluginConf.NotifierOptions.ProviderEvents[1])
|
||||
|
@ -749,8 +749,8 @@ func TestPluginsFromEnv(t *testing.T) {
|
|||
pluginConf = pluginsConf[0]
|
||||
require.Equal(t, "notifier", pluginConf.Type)
|
||||
require.Len(t, pluginConf.NotifierOptions.FsEvents, 2)
|
||||
require.True(t, util.Contains(pluginConf.NotifierOptions.FsEvents, "upload"))
|
||||
require.True(t, util.Contains(pluginConf.NotifierOptions.FsEvents, "download"))
|
||||
require.True(t, slices.Contains(pluginConf.NotifierOptions.FsEvents, "upload"))
|
||||
require.True(t, slices.Contains(pluginConf.NotifierOptions.FsEvents, "download"))
|
||||
require.Len(t, pluginConf.NotifierOptions.ProviderEvents, 2)
|
||||
require.Equal(t, "add", pluginConf.NotifierOptions.ProviderEvents[0])
|
||||
require.Equal(t, "update", pluginConf.NotifierOptions.ProviderEvents[1])
|
||||
|
@ -807,8 +807,8 @@ func TestRateLimitersFromEnv(t *testing.T) {
|
|||
require.Equal(t, 2, limiters[0].Type)
|
||||
protocols := limiters[0].Protocols
|
||||
require.Len(t, protocols, 2)
|
||||
require.True(t, util.Contains(protocols, common.ProtocolFTP))
|
||||
require.True(t, util.Contains(protocols, common.ProtocolSSH))
|
||||
require.True(t, slices.Contains(protocols, common.ProtocolFTP))
|
||||
require.True(t, slices.Contains(protocols, common.ProtocolSSH))
|
||||
require.True(t, limiters[0].GenerateDefenderEvents)
|
||||
require.Equal(t, 50, limiters[0].EntriesSoftLimit)
|
||||
require.Equal(t, 100, limiters[0].EntriesHardLimit)
|
||||
|
@ -819,10 +819,10 @@ func TestRateLimitersFromEnv(t *testing.T) {
|
|||
require.Equal(t, 2, limiters[1].Type)
|
||||
protocols = limiters[1].Protocols
|
||||
require.Len(t, protocols, 4)
|
||||
require.True(t, util.Contains(protocols, common.ProtocolFTP))
|
||||
require.True(t, util.Contains(protocols, common.ProtocolSSH))
|
||||
require.True(t, util.Contains(protocols, common.ProtocolWebDAV))
|
||||
require.True(t, util.Contains(protocols, common.ProtocolHTTP))
|
||||
require.True(t, slices.Contains(protocols, common.ProtocolFTP))
|
||||
require.True(t, slices.Contains(protocols, common.ProtocolSSH))
|
||||
require.True(t, slices.Contains(protocols, common.ProtocolWebDAV))
|
||||
require.True(t, slices.Contains(protocols, common.ProtocolHTTP))
|
||||
require.False(t, limiters[1].GenerateDefenderEvents)
|
||||
require.Equal(t, 100, limiters[1].EntriesSoftLimit)
|
||||
require.Equal(t, 150, limiters[1].EntriesHardLimit)
|
||||
|
@ -1094,6 +1094,7 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
|
|||
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__ENABLE_HTTPS", "0")
|
||||
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_CIPHER_SUITES", "TLS_RSA_WITH_AES_128_CBC_SHA ")
|
||||
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_PROTOCOLS", "http/1.1 ")
|
||||
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_MODE", "1")
|
||||
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_ALLOWED", "192.168.10.1")
|
||||
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__CLIENT_IP_PROXY_HEADER", "X-Forwarded-For")
|
||||
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__CLIENT_IP_HEADER_DEPTH", "2")
|
||||
|
@ -1113,6 +1114,7 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
|
|||
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__ENABLE_HTTPS")
|
||||
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_CIPHER_SUITES")
|
||||
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_PROTOCOLS")
|
||||
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_MODE")
|
||||
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_ALLOWED")
|
||||
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__CLIENT_IP_PROXY_HEADER")
|
||||
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__CLIENT_IP_HEADER_DEPTH")
|
||||
|
@ -1137,6 +1139,7 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
|
|||
require.Equal(t, 12, bindings[0].MinTLSVersion)
|
||||
require.Len(t, bindings[0].TLSCipherSuites, 0)
|
||||
require.Len(t, bindings[0].Protocols, 0)
|
||||
require.Equal(t, 0, bindings[0].ProxyMode)
|
||||
require.Empty(t, bindings[0].Prefix)
|
||||
require.Equal(t, 0, bindings[0].ClientIPHeaderDepth)
|
||||
require.False(t, bindings[0].DisableWWWAuthHeader)
|
||||
|
@ -1149,6 +1152,7 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
|
|||
require.Equal(t, "TLS_RSA_WITH_AES_128_CBC_SHA", bindings[1].TLSCipherSuites[0])
|
||||
require.Len(t, bindings[1].Protocols, 1)
|
||||
assert.Equal(t, "http/1.1", bindings[1].Protocols[0])
|
||||
require.Equal(t, 1, bindings[1].ProxyMode)
|
||||
require.Equal(t, "192.168.10.1", bindings[1].ProxyAllowed[0])
|
||||
require.Equal(t, "X-Forwarded-For", bindings[1].ClientIPProxyHeader)
|
||||
require.Equal(t, 2, bindings[1].ClientIPHeaderDepth)
|
||||
|
@ -1159,6 +1163,7 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
|
|||
require.True(t, bindings[2].EnableHTTPS)
|
||||
require.Equal(t, 13, bindings[2].MinTLSVersion)
|
||||
require.Equal(t, 1, bindings[2].ClientAuthType)
|
||||
require.Equal(t, 0, bindings[2].ProxyMode)
|
||||
require.Nil(t, bindings[2].TLSCipherSuites)
|
||||
require.Equal(t, "/dav2", bindings[2].Prefix)
|
||||
require.Equal(t, "webdav.crt", bindings[2].CertificateFile)
|
||||
|
@ -1188,11 +1193,13 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
|
|||
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_REST_API", "0")
|
||||
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLED_LOGIN_METHODS", "3")
|
||||
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__RENDER_OPENAPI", "0")
|
||||
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__LANGUAGES", "en,es")
|
||||
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_HTTPS", "1 ")
|
||||
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__MIN_TLS_VERSION", "13")
|
||||
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_AUTH_TYPE", "1")
|
||||
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__TLS_CIPHER_SUITES", " TLS_AES_256_GCM_SHA384 , TLS_CHACHA20_POLY1305_SHA256")
|
||||
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__TLS_PROTOCOLS", "h2, http/1.1")
|
||||
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_MODE", "1")
|
||||
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_ALLOWED", " 192.168.9.1 , 172.16.25.0/24")
|
||||
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_IP_PROXY_HEADER", "X-Real-IP")
|
||||
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_IP_HEADER_DEPTH", "2")
|
||||
|
@ -1255,9 +1262,11 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
|
|||
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_REST_API")
|
||||
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLED_LOGIN_METHODS")
|
||||
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__RENDER_OPENAPI")
|
||||
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__LANGUAGES")
|
||||
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_AUTH_TYPE")
|
||||
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__TLS_CIPHER_SUITES")
|
||||
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__TLS_PROTOCOLS")
|
||||
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_MODE")
|
||||
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_ALLOWED")
|
||||
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_IP_PROXY_HEADER")
|
||||
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_IP_HEADER_DEPTH")
|
||||
|
@ -1315,7 +1324,10 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
|
|||
require.True(t, bindings[0].EnableRESTAPI)
|
||||
require.Equal(t, 0, bindings[0].EnabledLoginMethods)
|
||||
require.True(t, bindings[0].RenderOpenAPI)
|
||||
require.Len(t, bindings[0].Languages, 1)
|
||||
assert.Contains(t, bindings[0].Languages, "en")
|
||||
require.Len(t, bindings[0].TLSCipherSuites, 1)
|
||||
require.Equal(t, 0, bindings[0].ProxyMode)
|
||||
require.Empty(t, bindings[0].OIDC.ConfigURL)
|
||||
require.Equal(t, "TLS_AES_128_GCM_SHA256", bindings[0].TLSCipherSuites[0])
|
||||
require.Equal(t, 0, bindings[0].HideLoginURL)
|
||||
|
@ -1333,6 +1345,8 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
|
|||
require.True(t, bindings[1].EnableRESTAPI)
|
||||
require.Equal(t, 0, bindings[1].EnabledLoginMethods)
|
||||
require.True(t, bindings[1].RenderOpenAPI)
|
||||
require.Len(t, bindings[1].Languages, 1)
|
||||
assert.Contains(t, bindings[1].Languages, "en")
|
||||
require.Nil(t, bindings[1].TLSCipherSuites)
|
||||
require.Equal(t, 1, bindings[1].HideLoginURL)
|
||||
require.Empty(t, bindings[1].OIDC.ClientID)
|
||||
|
@ -1342,6 +1356,7 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
|
|||
require.False(t, bindings[1].Security.Enabled)
|
||||
require.Equal(t, "Web Admin", bindings[1].Branding.WebAdmin.Name)
|
||||
require.Equal(t, "WebClient", bindings[1].Branding.WebClient.ShortName)
|
||||
require.Equal(t, 0, bindings[1].ProxyMode)
|
||||
require.Equal(t, 0, bindings[1].ClientIPHeaderDepth)
|
||||
require.Equal(t, 9000, bindings[2].Port)
|
||||
require.Equal(t, "127.0.1.1", bindings[2].Address)
|
||||
|
@ -1352,6 +1367,9 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
|
|||
require.False(t, bindings[2].EnableRESTAPI)
|
||||
require.Equal(t, 3, bindings[2].EnabledLoginMethods)
|
||||
require.False(t, bindings[2].RenderOpenAPI)
|
||||
require.Len(t, bindings[2].Languages, 2)
|
||||
assert.Contains(t, bindings[2].Languages, "en")
|
||||
assert.Contains(t, bindings[2].Languages, "es")
|
||||
require.Equal(t, 1, bindings[2].ClientAuthType)
|
||||
require.Len(t, bindings[2].TLSCipherSuites, 2)
|
||||
require.Equal(t, "TLS_AES_256_GCM_SHA384", bindings[2].TLSCipherSuites[0])
|
||||
|
@ -1359,6 +1377,7 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
|
|||
require.Len(t, bindings[2].Protocols, 2)
|
||||
require.Equal(t, "h2", bindings[2].Protocols[0])
|
||||
require.Equal(t, "http/1.1", bindings[2].Protocols[1])
|
||||
require.Equal(t, 1, bindings[2].ProxyMode)
|
||||
require.Len(t, bindings[2].ProxyAllowed, 2)
|
||||
require.Equal(t, "192.168.9.1", bindings[2].ProxyAllowed[0])
|
||||
require.Equal(t, "172.16.25.0/24", bindings[2].ProxyAllowed[1])
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"net/url"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -78,8 +79,8 @@ func executeAction(operation, executor, ip, objectType, objectName, role string,
|
|||
if config.Actions.Hook == "" {
|
||||
return
|
||||
}
|
||||
if !util.Contains(config.Actions.ExecuteOn, operation) ||
|
||||
!util.Contains(config.Actions.ExecuteFor, objectType) {
|
||||
if !slices.Contains(config.Actions.ExecuteOn, operation) ||
|
||||
!slices.Contains(config.Actions.ExecuteFor, objectType) {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
|
@ -86,7 +87,7 @@ func (c *AdminTOTPConfig) validate(username string) error {
|
|||
if c.ConfigName == "" {
|
||||
return util.NewValidationError("totp: config name is mandatory")
|
||||
}
|
||||
if !util.Contains(mfa.GetAvailableTOTPConfigNames(), c.ConfigName) {
|
||||
if !slices.Contains(mfa.GetAvailableTOTPConfigNames(), c.ConfigName) {
|
||||
return util.NewValidationError(fmt.Sprintf("totp: config name %q not found", c.ConfigName))
|
||||
}
|
||||
if c.Secret.IsEmpty() {
|
||||
|
@ -322,15 +323,15 @@ func (a *Admin) validatePermissions() error {
|
|||
util.I18nErrorPermissionsRequired,
|
||||
)
|
||||
}
|
||||
if util.Contains(a.Permissions, PermAdminAny) {
|
||||
if slices.Contains(a.Permissions, PermAdminAny) {
|
||||
a.Permissions = []string{PermAdminAny}
|
||||
}
|
||||
for _, perm := range a.Permissions {
|
||||
if !util.Contains(validAdminPerms, perm) {
|
||||
if !slices.Contains(validAdminPerms, perm) {
|
||||
return util.NewValidationError(fmt.Sprintf("invalid permission: %q", perm))
|
||||
}
|
||||
if a.Role != "" {
|
||||
if util.Contains(forbiddenPermsForRoleAdmins, perm) {
|
||||
if slices.Contains(forbiddenPermsForRoleAdmins, perm) {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError("a role admin cannot be a super admin"),
|
||||
util.I18nErrorRoleAdminPerms,
|
||||
|
@ -540,10 +541,20 @@ func (a *Admin) SetNilSecretsIfEmpty() {
|
|||
|
||||
// HasPermission returns true if the admin has the specified permission
|
||||
func (a *Admin) HasPermission(perm string) bool {
|
||||
if util.Contains(a.Permissions, PermAdminAny) {
|
||||
if slices.Contains(a.Permissions, PermAdminAny) {
|
||||
return true
|
||||
}
|
||||
return util.Contains(a.Permissions, perm)
|
||||
return slices.Contains(a.Permissions, perm)
|
||||
}
|
||||
|
||||
// HasPermissions returns true if the admin has all the specified permissions
|
||||
func (a *Admin) HasPermissions(perms ...string) bool {
|
||||
for _, perm := range perms {
|
||||
if !a.HasPermission(perm) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return len(perms) > 0
|
||||
}
|
||||
|
||||
// GetAllowedIPAsString returns the allowed IP as comma separated string
|
||||
|
|
|
@ -25,7 +25,9 @@ import (
|
|||
"fmt"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
bolt "go.etcd.io/bbolt"
|
||||
|
@ -181,6 +183,50 @@ func (p *BoltProvider) updateAPIKeyLastUse(keyID string) error {
|
|||
})
|
||||
}
|
||||
|
||||
func (p *BoltProvider) getAdminSignature(username string) (string, error) {
|
||||
var updatedAt int64
|
||||
err := p.dbHandle.View(func(tx *bolt.Tx) error {
|
||||
bucket, err := p.getAdminsBucket(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := bucket.Get([]byte(username))
|
||||
var admin Admin
|
||||
err = json.Unmarshal(u, &admin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updatedAt = admin.UpdatedAt
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strconv.FormatInt(updatedAt, 10), nil
|
||||
}
|
||||
|
||||
func (p *BoltProvider) getUserSignature(username string) (string, error) {
|
||||
var updatedAt int64
|
||||
err := p.dbHandle.View(func(tx *bolt.Tx) error {
|
||||
bucket, err := p.getUsersBucket(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := bucket.Get([]byte(username))
|
||||
var user User
|
||||
err = json.Unmarshal(u, &user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updatedAt = user.UpdatedAt
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strconv.FormatInt(updatedAt, 10), nil
|
||||
}
|
||||
|
||||
func (p *BoltProvider) setUpdatedAt(username string) {
|
||||
p.dbHandle.Update(func(tx *bolt.Tx) error { //nolint:errcheck
|
||||
bucket, err := p.getUsersBucket(tx)
|
||||
|
@ -3134,15 +3180,11 @@ func (p *BoltProvider) migrateDatabase() error {
|
|||
case version == boltDatabaseVersion:
|
||||
providerLog(logger.LevelDebug, "bolt database is up to date, current version: %d", version)
|
||||
return ErrNoInitRequired
|
||||
case version < 28:
|
||||
case version < 29:
|
||||
err = errSchemaVersionTooOld(version)
|
||||
providerLog(logger.LevelError, "%v", err)
|
||||
logger.ErrorToConsole("%v", err)
|
||||
return err
|
||||
case version == 28:
|
||||
logger.InfoToConsole("updating database schema version: %d -> 29", version)
|
||||
providerLog(logger.LevelInfo, "updating database schema version: %d -> 29", version)
|
||||
return updateBoltDatabaseVersion(p.dbHandle, 29)
|
||||
default:
|
||||
if version > boltDatabaseVersion {
|
||||
providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version,
|
||||
|
@ -3164,10 +3206,6 @@ func (p *BoltProvider) revertDatabase(targetVersion int) error { //nolint:gocycl
|
|||
return errors.New("current version match target version, nothing to do")
|
||||
}
|
||||
switch dbVersion.Version {
|
||||
case 29:
|
||||
logger.InfoToConsole("downgrading database schema version: %d -> 28", dbVersion.Version)
|
||||
providerLog(logger.LevelInfo, "downgrading database schema version: %d -> 28", dbVersion.Version)
|
||||
return updateBoltDatabaseVersion(p.dbHandle, 28)
|
||||
default:
|
||||
return fmt.Errorf("database schema version not handled: %v", dbVersion.Version)
|
||||
}
|
||||
|
@ -3328,7 +3366,7 @@ func (p *BoltProvider) addAdminToRole(username, roleName string, bucket *bolt.Bu
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !util.Contains(role.Admins, username) {
|
||||
if !slices.Contains(role.Admins, username) {
|
||||
role.Admins = append(role.Admins, username)
|
||||
buf, err := json.Marshal(role)
|
||||
if err != nil {
|
||||
|
@ -3353,7 +3391,7 @@ func (p *BoltProvider) removeAdminFromRole(username, roleName string, bucket *bo
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if util.Contains(role.Admins, username) {
|
||||
if slices.Contains(role.Admins, username) {
|
||||
var admins []string
|
||||
for _, admin := range role.Admins {
|
||||
if admin != username {
|
||||
|
@ -3383,7 +3421,7 @@ func (p *BoltProvider) addUserToRole(username, roleName string, bucket *bolt.Buc
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !util.Contains(role.Users, username) {
|
||||
if !slices.Contains(role.Users, username) {
|
||||
role.Users = append(role.Users, username)
|
||||
buf, err := json.Marshal(role)
|
||||
if err != nil {
|
||||
|
@ -3408,7 +3446,7 @@ func (p *BoltProvider) removeUserFromRole(username, roleName string, bucket *bol
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if util.Contains(role.Users, username) {
|
||||
if slices.Contains(role.Users, username) {
|
||||
var users []string
|
||||
for _, user := range role.Users {
|
||||
if user != username {
|
||||
|
@ -3436,7 +3474,7 @@ func (p *BoltProvider) addRuleToActionMapping(ruleName, actionName string, bucke
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !util.Contains(action.Rules, ruleName) {
|
||||
if !slices.Contains(action.Rules, ruleName) {
|
||||
action.Rules = append(action.Rules, ruleName)
|
||||
buf, err := json.Marshal(action)
|
||||
if err != nil {
|
||||
|
@ -3458,7 +3496,7 @@ func (p *BoltProvider) removeRuleFromActionMapping(ruleName, actionName string,
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if util.Contains(action.Rules, ruleName) {
|
||||
if slices.Contains(action.Rules, ruleName) {
|
||||
var rules []string
|
||||
for _, r := range action.Rules {
|
||||
if r != ruleName {
|
||||
|
@ -3485,7 +3523,7 @@ func (p *BoltProvider) addUserToGroupMapping(username, groupname string, bucket
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !util.Contains(group.Users, username) {
|
||||
if !slices.Contains(group.Users, username) {
|
||||
group.Users = append(group.Users, username)
|
||||
buf, err := json.Marshal(group)
|
||||
if err != nil {
|
||||
|
@ -3530,7 +3568,7 @@ func (p *BoltProvider) addAdminToGroupMapping(username, groupname string, bucket
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !util.Contains(group.Admins, username) {
|
||||
if !slices.Contains(group.Admins, username) {
|
||||
group.Admins = append(group.Admins, username)
|
||||
buf, err := json.Marshal(group)
|
||||
if err != nil {
|
||||
|
@ -3601,11 +3639,11 @@ func (p *BoltProvider) addRelationToFolderMapping(folderName string, user *User,
|
|||
return err
|
||||
}
|
||||
updated := false
|
||||
if user != nil && !util.Contains(folder.Users, user.Username) {
|
||||
if user != nil && !slices.Contains(folder.Users, user.Username) {
|
||||
folder.Users = append(folder.Users, user.Username)
|
||||
updated = true
|
||||
}
|
||||
if group != nil && !util.Contains(folder.Groups, group.Name) {
|
||||
if group != nil && !slices.Contains(folder.Groups, group.Name) {
|
||||
folder.Groups = append(folder.Groups, group.Name)
|
||||
updated = true
|
||||
}
|
||||
|
@ -3899,7 +3937,7 @@ func getBoltDatabaseVersion(dbHandle *bolt.DB) (schemaVersion, error) {
|
|||
v := bucket.Get(dbVersionKey)
|
||||
if v == nil {
|
||||
dbVersion = schemaVersion{
|
||||
Version: 28,
|
||||
Version: 29,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -3908,7 +3946,7 @@ func getBoltDatabaseVersion(dbHandle *bolt.DB) (schemaVersion, error) {
|
|||
return dbVersion, err
|
||||
}
|
||||
|
||||
func updateBoltDatabaseVersion(dbHandle *bolt.DB, version int) error {
|
||||
/*func updateBoltDatabaseVersion(dbHandle *bolt.DB, version int) error {
|
||||
err := dbHandle.Update(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket(dbVersionBucket)
|
||||
if bucket == nil {
|
||||
|
@ -3924,4 +3962,4 @@ func updateBoltDatabaseVersion(dbHandle *bolt.DB, version int) error {
|
|||
return bucket.Put(dbVersionKey, buf)
|
||||
})
|
||||
return err
|
||||
}
|
||||
}*/
|
||||
|
|
|
@ -15,8 +15,12 @@
|
|||
package dataprovider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"image/png"
|
||||
"net/url"
|
||||
"slices"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
|
@ -102,7 +106,7 @@ func (c *SFTPDConfigs) validate() error {
|
|||
if algo == ssh.CertAlgoRSAv01 {
|
||||
continue
|
||||
}
|
||||
if !util.Contains(supportedHostKeyAlgos, algo) {
|
||||
if !slices.Contains(supportedHostKeyAlgos, algo) {
|
||||
return util.NewValidationError(fmt.Sprintf("unsupported host key algorithm %q", algo))
|
||||
}
|
||||
hostKeyAlgos = append(hostKeyAlgos, algo)
|
||||
|
@ -113,24 +117,24 @@ func (c *SFTPDConfigs) validate() error {
|
|||
if algo == "diffie-hellman-group18-sha512" || algo == ssh.KeyExchangeDHGEXSHA256 {
|
||||
continue
|
||||
}
|
||||
if !util.Contains(supportedKexAlgos, algo) {
|
||||
if !slices.Contains(supportedKexAlgos, algo) {
|
||||
return util.NewValidationError(fmt.Sprintf("unsupported KEX algorithm %q", algo))
|
||||
}
|
||||
kexAlgos = append(kexAlgos, algo)
|
||||
}
|
||||
c.KexAlgorithms = kexAlgos
|
||||
for _, cipher := range c.Ciphers {
|
||||
if !util.Contains(supportedCiphers, cipher) {
|
||||
if !slices.Contains(supportedCiphers, cipher) {
|
||||
return util.NewValidationError(fmt.Sprintf("unsupported cipher %q", cipher))
|
||||
}
|
||||
}
|
||||
for _, mac := range c.MACs {
|
||||
if !util.Contains(supportedMACs, mac) {
|
||||
if !slices.Contains(supportedMACs, mac) {
|
||||
return util.NewValidationError(fmt.Sprintf("unsupported MAC algorithm %q", mac))
|
||||
}
|
||||
}
|
||||
for _, algo := range c.PublicKeyAlgos {
|
||||
if !util.Contains(supportedPublicKeyAlgos, algo) {
|
||||
if !slices.Contains(supportedPublicKeyAlgos, algo) {
|
||||
return util.NewValidationError(fmt.Sprintf("unsupported public key algorithm %q", algo))
|
||||
}
|
||||
}
|
||||
|
@ -193,19 +197,19 @@ func (c *SMTPOAuth2) validate() error {
|
|||
if c.ClientID == "" {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError("smtp oauth2: client id is required"),
|
||||
util.I18nErrorSMTPClientIDRequired,
|
||||
util.I18nErrorClientIDRequired,
|
||||
)
|
||||
}
|
||||
if c.ClientSecret == nil {
|
||||
if c.ClientSecret == nil || c.ClientSecret.IsEmpty() {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError("smtp oauth2: client secret is required"),
|
||||
util.I18nErrorSMTPClientSecretRequired,
|
||||
util.I18nErrorClientSecretRequired,
|
||||
)
|
||||
}
|
||||
if c.RefreshToken == nil {
|
||||
if c.RefreshToken == nil || c.RefreshToken.IsEmpty() {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError("smtp oauth2: refresh token is required"),
|
||||
util.I18nErrorSMTPRefreshTokenRequired,
|
||||
util.I18nErrorRefreshTokenRequired,
|
||||
)
|
||||
}
|
||||
if err := validateSMTPSecret(c.ClientSecret, "oauth2 client secret"); err != nil {
|
||||
|
@ -305,6 +309,27 @@ func (c *SMTPConfigs) TryDecrypt() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *SMTPConfigs) prepareForRendering() {
|
||||
if c.Password != nil {
|
||||
c.Password.Hide()
|
||||
if c.Password.IsEmpty() {
|
||||
c.Password = nil
|
||||
}
|
||||
}
|
||||
if c.OAuth2.ClientSecret != nil {
|
||||
c.OAuth2.ClientSecret.Hide()
|
||||
if c.OAuth2.ClientSecret.IsEmpty() {
|
||||
c.OAuth2.ClientSecret = nil
|
||||
}
|
||||
}
|
||||
if c.OAuth2.RefreshToken != nil {
|
||||
c.OAuth2.RefreshToken.Hide()
|
||||
if c.OAuth2.RefreshToken.IsEmpty() {
|
||||
c.OAuth2.RefreshToken = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SMTPConfigs) getACopy() *SMTPConfigs {
|
||||
var password *kms.Secret
|
||||
if c.Password != nil {
|
||||
|
@ -387,13 +412,137 @@ func (c *ACMEConfigs) getACopy() *ACMEConfigs {
|
|||
}
|
||||
}
|
||||
|
||||
// BrandingConfig defines the branding configuration
|
||||
type BrandingConfig struct {
|
||||
Name string `json:"name"`
|
||||
ShortName string `json:"short_name"`
|
||||
Logo []byte `json:"logo"`
|
||||
Favicon []byte `json:"favicon"`
|
||||
DisclaimerName string `json:"disclaimer_name"`
|
||||
DisclaimerURL string `json:"disclaimer_url"`
|
||||
}
|
||||
|
||||
func (c *BrandingConfig) isEmpty() bool {
|
||||
if c.Name != "" {
|
||||
return false
|
||||
}
|
||||
if c.ShortName != "" {
|
||||
return false
|
||||
}
|
||||
if len(c.Logo) > 0 {
|
||||
return false
|
||||
}
|
||||
if len(c.Favicon) > 0 {
|
||||
return false
|
||||
}
|
||||
if c.DisclaimerName != "" && c.DisclaimerURL != "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (*BrandingConfig) validatePNG(b []byte, maxWidth, maxHeight int) error {
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
// DecodeConfig is more efficient, but I'm not sure if this would lead to
|
||||
// accepting invalid images in some edge cases and performance does not
|
||||
// matter here.
|
||||
img, err := png.Decode(bytes.NewBuffer(b))
|
||||
if err != nil {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError("invalid PNG image"),
|
||||
util.I18nErrorInvalidPNG,
|
||||
)
|
||||
}
|
||||
bounds := img.Bounds()
|
||||
if bounds.Dx() > maxWidth || bounds.Dy() > maxHeight {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError("invalid PNG image size"),
|
||||
util.I18nErrorInvalidPNGSize,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BrandingConfig) validateDisclaimerURL() error {
|
||||
if c.DisclaimerURL == "" {
|
||||
return nil
|
||||
}
|
||||
u, err := url.Parse(c.DisclaimerURL)
|
||||
if err != nil {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError("invalid disclaimer URL"),
|
||||
util.I18nErrorInvalidDisclaimerURL,
|
||||
)
|
||||
}
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError("invalid disclaimer URL scheme"),
|
||||
util.I18nErrorInvalidDisclaimerURL,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BrandingConfig) validate() error {
|
||||
if err := c.validateDisclaimerURL(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.validatePNG(c.Logo, 512, 512); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.validatePNG(c.Favicon, 256, 256)
|
||||
}
|
||||
|
||||
func (c *BrandingConfig) getACopy() BrandingConfig {
|
||||
logo := make([]byte, len(c.Logo))
|
||||
copy(logo, c.Logo)
|
||||
favicon := make([]byte, len(c.Favicon))
|
||||
copy(favicon, c.Favicon)
|
||||
|
||||
return BrandingConfig{
|
||||
Name: c.Name,
|
||||
ShortName: c.ShortName,
|
||||
Logo: logo,
|
||||
Favicon: favicon,
|
||||
DisclaimerName: c.DisclaimerName,
|
||||
DisclaimerURL: c.DisclaimerURL,
|
||||
}
|
||||
}
|
||||
|
||||
// BrandingConfigs defines the branding configuration for WebAdmin and WebClient UI
|
||||
type BrandingConfigs struct {
|
||||
WebAdmin BrandingConfig
|
||||
WebClient BrandingConfig
|
||||
}
|
||||
|
||||
func (c *BrandingConfigs) isEmpty() bool {
|
||||
return c.WebAdmin.isEmpty() && c.WebClient.isEmpty()
|
||||
}
|
||||
|
||||
func (c *BrandingConfigs) validate() error {
|
||||
if err := c.WebAdmin.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.WebClient.validate()
|
||||
}
|
||||
|
||||
func (c *BrandingConfigs) getACopy() *BrandingConfigs {
|
||||
return &BrandingConfigs{
|
||||
WebAdmin: c.WebAdmin.getACopy(),
|
||||
WebClient: c.WebClient.getACopy(),
|
||||
}
|
||||
}
|
||||
|
||||
// Configs allows to set configuration keys disabled by default without
|
||||
// modifying the config file or setting env vars
|
||||
type Configs struct {
|
||||
SFTPD *SFTPDConfigs `json:"sftpd,omitempty"`
|
||||
SMTP *SMTPConfigs `json:"smtp,omitempty"`
|
||||
ACME *ACMEConfigs `json:"acme,omitempty"`
|
||||
UpdatedAt int64 `json:"updated_at,omitempty"`
|
||||
SFTPD *SFTPDConfigs `json:"sftpd,omitempty"`
|
||||
SMTP *SMTPConfigs `json:"smtp,omitempty"`
|
||||
ACME *ACMEConfigs `json:"acme,omitempty"`
|
||||
Branding *BrandingConfigs `json:"branding,omitempty"`
|
||||
UpdatedAt int64 `json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
func (c *Configs) validate() error {
|
||||
|
@ -412,6 +561,11 @@ func (c *Configs) validate() error {
|
|||
return err
|
||||
}
|
||||
}
|
||||
if c.Branding != nil {
|
||||
if err := c.Branding.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -428,25 +582,11 @@ func (c *Configs) PrepareForRendering() {
|
|||
if c.ACME != nil && c.ACME.isEmpty() {
|
||||
c.ACME = nil
|
||||
}
|
||||
if c.Branding != nil && c.Branding.isEmpty() {
|
||||
c.Branding = nil
|
||||
}
|
||||
if c.SMTP != nil {
|
||||
if c.SMTP.Password != nil {
|
||||
c.SMTP.Password.Hide()
|
||||
if c.SMTP.Password.IsEmpty() {
|
||||
c.SMTP.Password = nil
|
||||
}
|
||||
}
|
||||
if c.SMTP.OAuth2.ClientSecret != nil {
|
||||
c.SMTP.OAuth2.ClientSecret.Hide()
|
||||
if c.SMTP.OAuth2.ClientSecret.IsEmpty() {
|
||||
c.SMTP.OAuth2.ClientSecret = nil
|
||||
}
|
||||
}
|
||||
if c.SMTP.OAuth2.RefreshToken != nil {
|
||||
c.SMTP.OAuth2.RefreshToken.Hide()
|
||||
if c.SMTP.OAuth2.RefreshToken.IsEmpty() {
|
||||
c.SMTP.OAuth2.RefreshToken = nil
|
||||
}
|
||||
}
|
||||
c.SMTP.prepareForRendering()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -470,6 +610,9 @@ func (c *Configs) SetNilsToEmpty() {
|
|||
if c.ACME == nil {
|
||||
c.ACME = &ACMEConfigs{}
|
||||
}
|
||||
if c.Branding == nil {
|
||||
c.Branding = &BrandingConfigs{}
|
||||
}
|
||||
}
|
||||
|
||||
// RenderAsJSON implements the renderer interface used within plugins
|
||||
|
@ -498,6 +641,9 @@ func (c *Configs) getACopy() Configs {
|
|||
if c.ACME != nil {
|
||||
result.ACME = c.ACME.getACopy()
|
||||
}
|
||||
if c.Branding != nil {
|
||||
result.Branding = c.Branding.getACopy()
|
||||
}
|
||||
result.UpdatedAt = c.UpdatedAt
|
||||
return result
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@ import (
|
|||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -187,6 +188,7 @@ var (
|
|||
ErrDuplicatedKey = errors.New("duplicated key not allowed")
|
||||
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
|
||||
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
|
||||
tz = ""
|
||||
isAdminCreated atomic.Bool
|
||||
validTLSUsernames = []string{string(sdk.TLSUsernameNone), string(sdk.TLSUsernameCN)}
|
||||
config Config
|
||||
|
@ -518,7 +520,7 @@ type Config struct {
|
|||
// GetShared returns the provider share mode.
|
||||
// This method is called before the provider is initialized
|
||||
func (c *Config) GetShared() int {
|
||||
if !util.Contains(sharedProviders, c.Driver) {
|
||||
if !slices.Contains(sharedProviders, c.Driver) {
|
||||
return 0
|
||||
}
|
||||
return c.IsShared
|
||||
|
@ -590,6 +592,16 @@ func (c *Config) doBackup() (string, error) {
|
|||
return outputFile, nil
|
||||
}
|
||||
|
||||
// SetTZ sets the configured timezone.
|
||||
func SetTZ(val string) {
|
||||
tz = val
|
||||
}
|
||||
|
||||
// UseLocalTime returns true if local time should be used instead of UTC.
|
||||
func UseLocalTime() bool {
|
||||
return tz == "local"
|
||||
}
|
||||
|
||||
// ExecuteBackup executes a backup
|
||||
func ExecuteBackup() (string, error) {
|
||||
return config.doBackup()
|
||||
|
@ -759,6 +771,8 @@ type Provider interface {
|
|||
updateLastLogin(username string) error
|
||||
updateAdminLastLogin(username string) error
|
||||
setUpdatedAt(username string)
|
||||
getAdminSignature(username string) (string, error)
|
||||
getUserSignature(username string) (string, error)
|
||||
getFolders(limit, offset int, order string, minimal bool) ([]vfs.BaseVirtualFolder, error)
|
||||
getFolderByName(name string) (vfs.BaseVirtualFolder, error)
|
||||
addFolder(folder *vfs.BaseVirtualFolder) error
|
||||
|
@ -874,7 +888,7 @@ func SetTempPath(fsPath string) {
|
|||
}
|
||||
|
||||
func checkSharedMode() {
|
||||
if !util.Contains(sharedProviders, config.Driver) {
|
||||
if !slices.Contains(sharedProviders, config.Driver) {
|
||||
config.IsShared = 0
|
||||
}
|
||||
}
|
||||
|
@ -929,12 +943,13 @@ func checkDatabase(checkAdmins bool) error {
|
|||
if config.UpdateMode == 0 {
|
||||
err := provider.initializeDatabase()
|
||||
if err != nil && err != ErrNoInitRequired {
|
||||
logger.WarnToConsole("Unable to initialize data provider: %v", err)
|
||||
providerLog(logger.LevelError, "Unable to initialize data provider: %v", err)
|
||||
logger.WarnToConsole("unable to initialize data provider: %v", err)
|
||||
providerLog(logger.LevelError, "unable to initialize data provider: %v", err)
|
||||
return err
|
||||
}
|
||||
if err == nil {
|
||||
logger.DebugToConsole("Data provider successfully initialized")
|
||||
logger.DebugToConsole("data provider successfully initialized")
|
||||
providerLog(logger.LevelInfo, "data provider successfully initialized")
|
||||
}
|
||||
err = provider.migrateDatabase()
|
||||
if err != nil && err != ErrNoInitRequired {
|
||||
|
@ -1503,6 +1518,15 @@ func UpdateUserQuota(user *User, filesAdd int, sizeAdd int64, reset bool) error
|
|||
return nil
|
||||
}
|
||||
|
||||
// UpdateUserFolderQuota updates the quota for the given user and virtual folder.
|
||||
func UpdateUserFolderQuota(folder *vfs.VirtualFolder, user *User, filesAdd int, sizeAdd int64, reset bool) {
|
||||
if folder.IsIncludedInUserQuota() {
|
||||
UpdateUserQuota(user, filesAdd, sizeAdd, reset) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
UpdateVirtualFolderQuota(&folder.BaseVirtualFolder, filesAdd, sizeAdd, reset) //nolint:errcheck
|
||||
}
|
||||
|
||||
// UpdateVirtualFolderQuota updates the quota for the given virtual folder adding filesAdd and sizeAdd.
|
||||
// If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference.
|
||||
func UpdateVirtualFolderQuota(vfolder *vfs.BaseVirtualFolder, filesAdd int, sizeAdd int64, reset bool) error {
|
||||
|
@ -1693,7 +1717,7 @@ func IPListEntryExists(ipOrNet string, listType IPListType) (IPListEntry, error)
|
|||
|
||||
// GetIPListEntries returns the IP list entries applying the specified criteria and search limit
|
||||
func GetIPListEntries(listType IPListType, filter, from, order string, limit int) ([]IPListEntry, error) {
|
||||
if !util.Contains(supportedIPListType, listType) {
|
||||
if !slices.Contains(supportedIPListType, listType) {
|
||||
return nil, util.NewValidationError(fmt.Sprintf("invalid list type %d", listType))
|
||||
}
|
||||
return provider.getIPListEntries(listType, filter, from, order, limit)
|
||||
|
@ -2064,6 +2088,20 @@ func UserExists(username, role string) (User, error) {
|
|||
return provider.userExists(username, role)
|
||||
}
|
||||
|
||||
// GetAdminSignature returns the signature for the admin with the specified
|
||||
// username.
|
||||
func GetAdminSignature(username string) (string, error) {
|
||||
username = config.convertName(username)
|
||||
return provider.getAdminSignature(username)
|
||||
}
|
||||
|
||||
// GetUserSignature returns the signature for the user with the specified
|
||||
// username.
|
||||
func GetUserSignature(username string) (string, error) {
|
||||
username = config.convertName(username)
|
||||
return provider.getUserSignature(username)
|
||||
}
|
||||
|
||||
// GetUserWithGroupSettings tries to return the user with the specified username
|
||||
// loading also the group settings
|
||||
func GetUserWithGroupSettings(username, role string) (User, error) {
|
||||
|
@ -2352,7 +2390,7 @@ func GetFolders(limit, offset int, order string, minimal bool) ([]vfs.BaseVirtua
|
|||
}
|
||||
|
||||
func dumpUsers(data *BackupData, scopes []string) error {
|
||||
if len(scopes) == 0 || util.Contains(scopes, DumpScopeUsers) {
|
||||
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeUsers) {
|
||||
users, err := provider.dumpUsers()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -2363,7 +2401,7 @@ func dumpUsers(data *BackupData, scopes []string) error {
|
|||
}
|
||||
|
||||
func dumpFolders(data *BackupData, scopes []string) error {
|
||||
if len(scopes) == 0 || util.Contains(scopes, DumpScopeFolders) {
|
||||
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeFolders) {
|
||||
folders, err := provider.dumpFolders()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -2374,7 +2412,7 @@ func dumpFolders(data *BackupData, scopes []string) error {
|
|||
}
|
||||
|
||||
func dumpGroups(data *BackupData, scopes []string) error {
|
||||
if len(scopes) == 0 || util.Contains(scopes, DumpScopeGroups) {
|
||||
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeGroups) {
|
||||
groups, err := provider.dumpGroups()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -2385,7 +2423,7 @@ func dumpGroups(data *BackupData, scopes []string) error {
|
|||
}
|
||||
|
||||
func dumpAdmins(data *BackupData, scopes []string) error {
|
||||
if len(scopes) == 0 || util.Contains(scopes, DumpScopeAdmins) {
|
||||
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeAdmins) {
|
||||
admins, err := provider.dumpAdmins()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -2396,7 +2434,7 @@ func dumpAdmins(data *BackupData, scopes []string) error {
|
|||
}
|
||||
|
||||
func dumpAPIKeys(data *BackupData, scopes []string) error {
|
||||
if len(scopes) == 0 || util.Contains(scopes, DumpScopeAPIKeys) {
|
||||
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeAPIKeys) {
|
||||
apiKeys, err := provider.dumpAPIKeys()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -2407,7 +2445,7 @@ func dumpAPIKeys(data *BackupData, scopes []string) error {
|
|||
}
|
||||
|
||||
func dumpShares(data *BackupData, scopes []string) error {
|
||||
if len(scopes) == 0 || util.Contains(scopes, DumpScopeShares) {
|
||||
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeShares) {
|
||||
shares, err := provider.dumpShares()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -2418,7 +2456,7 @@ func dumpShares(data *BackupData, scopes []string) error {
|
|||
}
|
||||
|
||||
func dumpActions(data *BackupData, scopes []string) error {
|
||||
if len(scopes) == 0 || util.Contains(scopes, DumpScopeActions) {
|
||||
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeActions) {
|
||||
actions, err := provider.dumpEventActions()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -2429,7 +2467,7 @@ func dumpActions(data *BackupData, scopes []string) error {
|
|||
}
|
||||
|
||||
func dumpRules(data *BackupData, scopes []string) error {
|
||||
if len(scopes) == 0 || util.Contains(scopes, DumpScopeRules) {
|
||||
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeRules) {
|
||||
rules, err := provider.dumpEventRules()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -2440,7 +2478,7 @@ func dumpRules(data *BackupData, scopes []string) error {
|
|||
}
|
||||
|
||||
func dumpRoles(data *BackupData, scopes []string) error {
|
||||
if len(scopes) == 0 || util.Contains(scopes, DumpScopeRoles) {
|
||||
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeRoles) {
|
||||
roles, err := provider.dumpRoles()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -2451,7 +2489,7 @@ func dumpRoles(data *BackupData, scopes []string) error {
|
|||
}
|
||||
|
||||
func dumpIPLists(data *BackupData, scopes []string) error {
|
||||
if len(scopes) == 0 || util.Contains(scopes, DumpScopeIPLists) {
|
||||
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeIPLists) {
|
||||
ipLists, err := provider.dumpIPListEntries()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -2462,7 +2500,7 @@ func dumpIPLists(data *BackupData, scopes []string) error {
|
|||
}
|
||||
|
||||
func dumpConfigs(data *BackupData, scopes []string) error {
|
||||
if len(scopes) == 0 || util.Contains(scopes, DumpScopeConfigs) {
|
||||
if len(scopes) == 0 || slices.Contains(scopes, DumpScopeConfigs) {
|
||||
configs, err := provider.getConfigs()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -2765,7 +2803,7 @@ func validateUserTOTPConfig(c *UserTOTPConfig, username string) error {
|
|||
if c.ConfigName == "" {
|
||||
return util.NewValidationError("totp: config name is mandatory")
|
||||
}
|
||||
if !util.Contains(mfa.GetAvailableTOTPConfigNames(), c.ConfigName) {
|
||||
if !slices.Contains(mfa.GetAvailableTOTPConfigNames(), c.ConfigName) {
|
||||
return util.NewValidationError(fmt.Sprintf("totp: config name %q not found", c.ConfigName))
|
||||
}
|
||||
if c.Secret.IsEmpty() {
|
||||
|
@ -2781,7 +2819,7 @@ func validateUserTOTPConfig(c *UserTOTPConfig, username string) error {
|
|||
return util.NewValidationError("totp: specify at least one protocol")
|
||||
}
|
||||
for _, protocol := range c.Protocols {
|
||||
if !util.Contains(MFAProtocols, protocol) {
|
||||
if !slices.Contains(MFAProtocols, protocol) {
|
||||
return util.NewValidationError(fmt.Sprintf("totp: invalid protocol %q", protocol))
|
||||
}
|
||||
}
|
||||
|
@ -2814,7 +2852,7 @@ func validateUserPermissions(permsToCheck map[string][]string) (map[string][]str
|
|||
return permissions, util.NewValidationError("invalid permissions")
|
||||
}
|
||||
for _, p := range perms {
|
||||
if !util.Contains(ValidPerms, p) {
|
||||
if !slices.Contains(ValidPerms, p) {
|
||||
return permissions, util.NewValidationError(fmt.Sprintf("invalid permission: %q", p))
|
||||
}
|
||||
}
|
||||
|
@ -2828,7 +2866,7 @@ func validateUserPermissions(permsToCheck map[string][]string) (map[string][]str
|
|||
if dir != cleanedDir && cleanedDir == "/" {
|
||||
return permissions, util.NewValidationError(fmt.Sprintf("cannot set permissions for invalid subdirectory: %q is an alias for \"/\"", dir))
|
||||
}
|
||||
if util.Contains(perms, PermAny) {
|
||||
if slices.Contains(perms, PermAny) {
|
||||
permissions[cleanedDir] = []string{PermAny}
|
||||
} else {
|
||||
permissions[cleanedDir] = util.RemoveDuplicates(perms, false)
|
||||
|
@ -2911,7 +2949,7 @@ func validateFiltersPatternExtensions(baseFilters *sdk.BaseUserFilters) error {
|
|||
util.I18nErrorFilePatternPathInvalid,
|
||||
)
|
||||
}
|
||||
if util.Contains(filteredPaths, cleanedPath) {
|
||||
if slices.Contains(filteredPaths, cleanedPath) {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError(fmt.Sprintf("duplicate file patterns filter for path %q", f.Path)),
|
||||
util.I18nErrorFilePatternDuplicated,
|
||||
|
@ -3030,13 +3068,13 @@ func validateFilterProtocols(filters *sdk.BaseUserFilters) error {
|
|||
return util.NewValidationError("invalid denied_protocols")
|
||||
}
|
||||
for _, p := range filters.DeniedProtocols {
|
||||
if !util.Contains(ValidProtocols, p) {
|
||||
if !slices.Contains(ValidProtocols, p) {
|
||||
return util.NewValidationError(fmt.Sprintf("invalid denied protocol %q", p))
|
||||
}
|
||||
}
|
||||
|
||||
for _, p := range filters.TwoFactorAuthProtocols {
|
||||
if !util.Contains(MFAProtocols, p) {
|
||||
if !slices.Contains(MFAProtocols, p) {
|
||||
return util.NewValidationError(fmt.Sprintf("invalid two factor protocol %q", p))
|
||||
}
|
||||
}
|
||||
|
@ -3092,7 +3130,7 @@ func validateBaseFilters(filters *sdk.BaseUserFilters) error {
|
|||
return util.NewValidationError("invalid denied_login_methods")
|
||||
}
|
||||
for _, loginMethod := range filters.DeniedLoginMethods {
|
||||
if !util.Contains(ValidLoginMethods, loginMethod) {
|
||||
if !slices.Contains(ValidLoginMethods, loginMethod) {
|
||||
return util.NewValidationError(fmt.Sprintf("invalid login method: %q", loginMethod))
|
||||
}
|
||||
}
|
||||
|
@ -3100,7 +3138,7 @@ func validateBaseFilters(filters *sdk.BaseUserFilters) error {
|
|||
return err
|
||||
}
|
||||
if filters.TLSUsername != "" {
|
||||
if !util.Contains(validTLSUsernames, string(filters.TLSUsername)) {
|
||||
if !slices.Contains(validTLSUsernames, string(filters.TLSUsername)) {
|
||||
return util.NewValidationError(fmt.Sprintf("invalid TLS username: %q", filters.TLSUsername))
|
||||
}
|
||||
}
|
||||
|
@ -3110,7 +3148,7 @@ func validateBaseFilters(filters *sdk.BaseUserFilters) error {
|
|||
}
|
||||
filters.TLSCerts = certs
|
||||
for _, opts := range filters.WebClient {
|
||||
if !util.Contains(sdk.WebClientOptions, opts) {
|
||||
if !slices.Contains(sdk.WebClientOptions, opts) {
|
||||
return util.NewValidationError(fmt.Sprintf("invalid web client options %q", opts))
|
||||
}
|
||||
}
|
||||
|
@ -3178,19 +3216,19 @@ func validateAccessTimeFilters(filters *sdk.BaseUserFilters) error {
|
|||
}
|
||||
|
||||
func validateCombinedUserFilters(user *User) error {
|
||||
if user.Filters.TOTPConfig.Enabled && util.Contains(user.Filters.WebClient, sdk.WebClientMFADisabled) {
|
||||
if user.Filters.TOTPConfig.Enabled && slices.Contains(user.Filters.WebClient, sdk.WebClientMFADisabled) {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError("two-factor authentication cannot be disabled for a user with an active configuration"),
|
||||
util.I18nErrorDisableActive2FA,
|
||||
)
|
||||
}
|
||||
if user.Filters.RequirePasswordChange && util.Contains(user.Filters.WebClient, sdk.WebClientPasswordChangeDisabled) {
|
||||
if user.Filters.RequirePasswordChange && slices.Contains(user.Filters.WebClient, sdk.WebClientPasswordChangeDisabled) {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError("you cannot require password change and at the same time disallow it"),
|
||||
util.I18nErrorPwdChangeConflict,
|
||||
)
|
||||
}
|
||||
if len(user.Filters.TwoFactorAuthProtocols) > 0 && util.Contains(user.Filters.WebClient, sdk.WebClientMFADisabled) {
|
||||
if len(user.Filters.TwoFactorAuthProtocols) > 0 && slices.Contains(user.Filters.WebClient, sdk.WebClientMFADisabled) {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError("you cannot require two-factor authentication and at the same time disallow it"),
|
||||
util.I18nError2FAConflict,
|
||||
|
@ -3199,6 +3237,24 @@ func validateCombinedUserFilters(user *User) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func validateEmails(user *User) error {
|
||||
if user.Email != "" && !util.IsEmailValid(user.Email) {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError(fmt.Sprintf("email %q is not valid", user.Email)),
|
||||
util.I18nErrorInvalidEmail,
|
||||
)
|
||||
}
|
||||
for _, email := range user.Filters.AdditionalEmails {
|
||||
if !util.IsEmailValid(email) {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError(fmt.Sprintf("email %q is not valid", email)),
|
||||
util.I18nErrorInvalidEmail,
|
||||
)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateBaseParams(user *User) error {
|
||||
if user.Username == "" {
|
||||
return util.NewI18nError(util.NewValidationError("username is mandatory"), util.I18nErrorUsernameRequired)
|
||||
|
@ -3206,11 +3262,8 @@ func validateBaseParams(user *User) error {
|
|||
if err := checkReservedUsernames(user.Username); err != nil {
|
||||
return util.NewI18nError(err, util.I18nErrorReservedUsername)
|
||||
}
|
||||
if user.Email != "" && !util.IsEmailValid(user.Email) {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError(fmt.Sprintf("email %q is not valid", user.Email)),
|
||||
util.I18nErrorInvalidEmail,
|
||||
)
|
||||
if err := validateEmails(user); err != nil {
|
||||
return err
|
||||
}
|
||||
if config.NamingRules&1 == 0 && !usernameRegex.MatchString(user.Username) {
|
||||
return util.NewI18nError(
|
||||
|
@ -3511,7 +3564,7 @@ func checkUserPasscode(user *User, password, protocol string) (string, error) {
|
|||
if user.Filters.TOTPConfig.Enabled {
|
||||
switch protocol {
|
||||
case protocolFTP:
|
||||
if util.Contains(user.Filters.TOTPConfig.Protocols, protocol) {
|
||||
if slices.Contains(user.Filters.TOTPConfig.Protocols, protocol) {
|
||||
// the TOTP passcode has six digits
|
||||
pwdLen := len(password)
|
||||
if pwdLen < 7 {
|
||||
|
@ -3717,7 +3770,7 @@ func doBuiltinKeyboardInteractiveAuth(user *User, client ssh.KeyboardInteractive
|
|||
if err := user.LoadAndApplyGroupSettings(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
hasSecondFactor := user.Filters.TOTPConfig.Enabled && util.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH)
|
||||
hasSecondFactor := user.Filters.TOTPConfig.Enabled && slices.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH)
|
||||
if !isPartialAuth || !hasSecondFactor {
|
||||
answers, err := client("", "", []string{"Password: "}, []bool{false})
|
||||
if err != nil {
|
||||
|
@ -3735,7 +3788,7 @@ func doBuiltinKeyboardInteractiveAuth(user *User, client ssh.KeyboardInteractive
|
|||
}
|
||||
|
||||
func checkKeyboardInteractiveSecondFactor(user *User, client ssh.KeyboardInteractiveChallenge, protocol string) (int, error) {
|
||||
if !user.Filters.TOTPConfig.Enabled || !util.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH) {
|
||||
if !user.Filters.TOTPConfig.Enabled || !slices.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH) {
|
||||
return 1, nil
|
||||
}
|
||||
err := user.Filters.TOTPConfig.Secret.TryDecrypt()
|
||||
|
@ -3859,7 +3912,7 @@ func getKeyboardInteractiveAnswers(client ssh.KeyboardInteractiveChallenge, resp
|
|||
}
|
||||
if len(answers) == 1 && response.CheckPwd > 0 {
|
||||
if response.CheckPwd == 2 {
|
||||
if !user.Filters.TOTPConfig.Enabled || !util.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH) {
|
||||
if !user.Filters.TOTPConfig.Enabled || !slices.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH) {
|
||||
providerLog(logger.LevelInfo, "keyboard interactive auth error: unable to check TOTP passcode, TOTP is not enabled for user %q",
|
||||
user.Username)
|
||||
return answers, errors.New("TOTP not enabled for SSH protocol")
|
||||
|
@ -4454,6 +4507,7 @@ func doExternalAuth(username, password string, pubKey []byte, keyboardInteractiv
|
|||
webDAVUsersCache.swap(&user, password)
|
||||
}
|
||||
cachedUserPasswords.Add(user.Username, password, user.Password)
|
||||
executeAction(operationUpdate, ActionExecutorSelf, "", actionObjectUser, user.Username, "", &user)
|
||||
}
|
||||
return user, err
|
||||
}
|
||||
|
@ -4461,6 +4515,7 @@ func doExternalAuth(username, password string, pubKey []byte, keyboardInteractiv
|
|||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
executeAction(operationAdd, ActionExecutorSelf, "", actionObjectUser, user.Username, "", &user)
|
||||
return provider.userExists(user.Username, "")
|
||||
}
|
||||
|
||||
|
@ -4526,6 +4581,7 @@ func doPluginAuth(username, password string, pubKey []byte, ip, protocol string,
|
|||
webDAVUsersCache.swap(&user, password)
|
||||
}
|
||||
cachedUserPasswords.Add(user.Username, password, user.Password)
|
||||
executeAction(operationUpdate, ActionExecutorSelf, "", actionObjectUser, user.Username, "", &user)
|
||||
}
|
||||
return user, err
|
||||
}
|
||||
|
@ -4533,6 +4589,7 @@ func doPluginAuth(username, password string, pubKey []byte, ip, protocol string,
|
|||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
executeAction(operationAdd, ActionExecutorSelf, "", actionObjectUser, user.Username, "", &user)
|
||||
return provider.userExists(user.Username, "")
|
||||
}
|
||||
|
||||
|
@ -4625,7 +4682,7 @@ func getConfigPath(name, configDir string) string {
|
|||
}
|
||||
|
||||
func checkReservedUsernames(username string) error {
|
||||
if util.Contains(reservedUsers, username) {
|
||||
if slices.Contains(reservedUsers, username) {
|
||||
return util.NewValidationError("this username is reserved")
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -64,7 +64,7 @@ var (
|
|||
)
|
||||
|
||||
func isActionTypeValid(action int) bool {
|
||||
return util.Contains(supportedEventActions, action)
|
||||
return slices.Contains(supportedEventActions, action)
|
||||
}
|
||||
|
||||
func getActionTypeAsString(action int) string {
|
||||
|
@ -119,7 +119,7 @@ var (
|
|||
)
|
||||
|
||||
func isEventTriggerValid(trigger int) bool {
|
||||
return util.Contains(supportedEventTriggers, trigger)
|
||||
return slices.Contains(supportedEventTriggers, trigger)
|
||||
}
|
||||
|
||||
func getTriggerTypeAsString(trigger int) string {
|
||||
|
@ -173,7 +173,7 @@ var (
|
|||
)
|
||||
|
||||
func isFilesystemActionValid(value int) bool {
|
||||
return util.Contains(supportedFsActions, value)
|
||||
return slices.Contains(supportedFsActions, value)
|
||||
}
|
||||
|
||||
func getFsActionTypeAsString(value int) string {
|
||||
|
@ -342,7 +342,7 @@ func (c *EventActionHTTPConfig) validateMultiparts() error {
|
|||
)
|
||||
}
|
||||
for _, k := range c.Headers {
|
||||
if strings.ToLower(k.Key) == "content-type" {
|
||||
if strings.EqualFold(k.Key, "content-type") {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError("content type is automatically set for multipart requests"),
|
||||
util.I18nErrorMultipartCType,
|
||||
|
@ -384,7 +384,7 @@ func (c *EventActionHTTPConfig) validate(additionalData string) error {
|
|||
return util.NewValidationError(fmt.Sprintf("could not encrypt HTTP password: %v", err))
|
||||
}
|
||||
}
|
||||
if !util.Contains(SupportedHTTPActionMethods, c.Method) {
|
||||
if !slices.Contains(SupportedHTTPActionMethods, c.Method) {
|
||||
return util.NewValidationError(fmt.Sprintf("unsupported HTTP method: %s", c.Method))
|
||||
}
|
||||
for _, kv := range c.QueryParameters {
|
||||
|
@ -671,12 +671,21 @@ func (c *EventActionFsCompress) validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// RenameConfig defines the configuration for a filesystem rename
|
||||
type RenameConfig struct {
|
||||
// key is the source and target the value
|
||||
KeyValue
|
||||
// This setting only applies to storage providers that support
|
||||
// changing modification times.
|
||||
UpdateModTime bool `json:"update_modtime,omitempty"`
|
||||
}
|
||||
|
||||
// EventActionFilesystemConfig defines the configuration for filesystem actions
|
||||
type EventActionFilesystemConfig struct {
|
||||
// Filesystem actions, see the above enum
|
||||
Type int `json:"type,omitempty"`
|
||||
// files/dirs to rename, key is the source and target the value
|
||||
Renames []KeyValue `json:"renames,omitempty"`
|
||||
// files/dirs to rename
|
||||
Renames []RenameConfig `json:"renames,omitempty"`
|
||||
// directories to create
|
||||
MkDirs []string `json:"mkdirs,omitempty"`
|
||||
// files/dirs to delete
|
||||
|
@ -717,9 +726,9 @@ func (c *EventActionFilesystemConfig) validateRenames() error {
|
|||
if len(c.Renames) == 0 {
|
||||
return util.NewI18nError(util.NewValidationError("no path to rename specified"), util.I18nErrorPathRequired)
|
||||
}
|
||||
for idx, kv := range c.Renames {
|
||||
key := strings.TrimSpace(kv.Key)
|
||||
value := strings.TrimSpace(kv.Value)
|
||||
for idx, cfg := range c.Renames {
|
||||
key := strings.TrimSpace(cfg.Key)
|
||||
value := strings.TrimSpace(cfg.Value)
|
||||
if key == "" || value == "" {
|
||||
return util.NewValidationError("invalid paths to rename")
|
||||
}
|
||||
|
@ -737,9 +746,12 @@ func (c *EventActionFilesystemConfig) validateRenames() error {
|
|||
util.I18nErrorRootNotAllowed,
|
||||
)
|
||||
}
|
||||
c.Renames[idx] = KeyValue{
|
||||
Key: key,
|
||||
Value: value,
|
||||
c.Renames[idx] = RenameConfig{
|
||||
KeyValue: KeyValue{
|
||||
Key: key,
|
||||
Value: value,
|
||||
},
|
||||
UpdateModTime: cfg.UpdateModTime,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
@ -903,7 +915,7 @@ func (c *EventActionFilesystemConfig) getACopy() EventActionFilesystemConfig {
|
|||
|
||||
return EventActionFilesystemConfig{
|
||||
Type: c.Type,
|
||||
Renames: cloneKeyValues(c.Renames),
|
||||
Renames: cloneRenameConfigs(c.Renames),
|
||||
MkDirs: mkdirs,
|
||||
Deletes: deletes,
|
||||
Exist: exist,
|
||||
|
@ -1292,7 +1304,7 @@ func (a *EventAction) validateAssociation(trigger int, fsEvents []string) error
|
|||
}
|
||||
if trigger == EventTriggerFsEvent {
|
||||
for _, ev := range fsEvents {
|
||||
if !util.Contains(allowedSyncFsEvents, ev) {
|
||||
if !slices.Contains(allowedSyncFsEvents, ev) {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError("sync execution is only supported for upload and pre-* events"),
|
||||
util.I18nErrorEvSyncUnsupportedFs,
|
||||
|
@ -1386,12 +1398,12 @@ func (f *ConditionOptions) validate() error {
|
|||
}
|
||||
|
||||
for _, p := range f.Protocols {
|
||||
if !util.Contains(SupportedRuleConditionProtocols, p) {
|
||||
if !slices.Contains(SupportedRuleConditionProtocols, p) {
|
||||
return util.NewValidationError(fmt.Sprintf("unsupported rule condition protocol: %q", p))
|
||||
}
|
||||
}
|
||||
for _, p := range f.ProviderObjects {
|
||||
if !util.Contains(SupporteRuleConditionProviderObjects, p) {
|
||||
if !slices.Contains(SupporteRuleConditionProviderObjects, p) {
|
||||
return util.NewValidationError(fmt.Sprintf("unsupported provider object: %q", p))
|
||||
}
|
||||
}
|
||||
|
@ -1496,7 +1508,7 @@ func (c *EventConditions) validate(trigger int) error {
|
|||
)
|
||||
}
|
||||
for _, ev := range c.FsEvents {
|
||||
if !util.Contains(SupportedFsEvents, ev) {
|
||||
if !slices.Contains(SupportedFsEvents, ev) {
|
||||
return util.NewValidationError(fmt.Sprintf("unsupported fs event: %q", ev))
|
||||
}
|
||||
}
|
||||
|
@ -1516,7 +1528,7 @@ func (c *EventConditions) validate(trigger int) error {
|
|||
)
|
||||
}
|
||||
for _, ev := range c.ProviderEvents {
|
||||
if !util.Contains(SupportedProviderEvents, ev) {
|
||||
if !slices.Contains(SupportedProviderEvents, ev) {
|
||||
return util.NewValidationError(fmt.Sprintf("unsupported provider event: %q", ev))
|
||||
}
|
||||
}
|
||||
|
@ -1569,7 +1581,7 @@ func (c *EventConditions) validate(trigger int) error {
|
|||
c.Options.MinFileSize = 0
|
||||
c.Options.MaxFileSize = 0
|
||||
c.Schedules = nil
|
||||
if !util.Contains(supportedIDPLoginEvents, c.IDPLoginEvent) {
|
||||
if !slices.Contains(supportedIDPLoginEvents, c.IDPLoginEvent) {
|
||||
return util.NewValidationError(fmt.Sprintf("invalid Identity Provider login event %d", c.IDPLoginEvent))
|
||||
}
|
||||
default:
|
||||
|
@ -1723,7 +1735,7 @@ func (r *EventRule) validateMandatorySyncActions() error {
|
|||
return nil
|
||||
}
|
||||
for _, ev := range r.Conditions.FsEvents {
|
||||
if util.Contains(mandatorySyncFsEvents, ev) {
|
||||
if slices.Contains(mandatorySyncFsEvents, ev) {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError(fmt.Sprintf("event %q requires at least a sync action", ev)),
|
||||
util.I18nErrorRuleSyncActionRequired,
|
||||
|
@ -1741,7 +1753,7 @@ func (r *EventRule) checkIPBlockedAndCertificateActions() error {
|
|||
ActionTypeDataRetentionCheck, ActionTypeFilesystem, ActionTypePasswordExpirationCheck,
|
||||
ActionTypeUserExpirationCheck}
|
||||
for _, action := range r.Actions {
|
||||
if util.Contains(unavailableActions, action.Type) {
|
||||
if slices.Contains(unavailableActions, action.Type) {
|
||||
return fmt.Errorf("action %q, type %q is not supported for event trigger %q",
|
||||
action.Name, getActionTypeAsString(action.Type), getTriggerTypeAsString(r.Trigger))
|
||||
}
|
||||
|
@ -1757,7 +1769,7 @@ func (r *EventRule) checkProviderEventActions(providerObjectType string) error {
|
|||
ActionTypeDataRetentionCheck, ActionTypeFilesystem,
|
||||
ActionTypePasswordExpirationCheck, ActionTypeUserExpirationCheck}
|
||||
for _, action := range r.Actions {
|
||||
if util.Contains(userSpecificActions, action.Type) && providerObjectType != actionObjectUser {
|
||||
if slices.Contains(userSpecificActions, action.Type) && providerObjectType != actionObjectUser {
|
||||
return fmt.Errorf("action %q, type %q is only supported for provider user events",
|
||||
action.Name, getActionTypeAsString(action.Type))
|
||||
}
|
||||
|
@ -1865,6 +1877,20 @@ func (r *EventRule) RenderAsJSON(reload bool) ([]byte, error) {
|
|||
return json.Marshal(r)
|
||||
}
|
||||
|
||||
func cloneRenameConfigs(renames []RenameConfig) []RenameConfig {
|
||||
res := make([]RenameConfig, 0, len(renames))
|
||||
for _, c := range renames {
|
||||
res = append(res, RenameConfig{
|
||||
KeyValue: KeyValue{
|
||||
Key: c.Key,
|
||||
Value: c.Value,
|
||||
},
|
||||
UpdateModTime: c.UpdateModTime,
|
||||
})
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func cloneKeyValues(keyVals []KeyValue) []KeyValue {
|
||||
res := make([]KeyValue, 0, len(keyVals))
|
||||
for _, kv := range keyVals {
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -85,7 +86,7 @@ var (
|
|||
|
||||
// CheckIPListType returns an error if the provided IP list type is not valid
|
||||
func CheckIPListType(t IPListType) error {
|
||||
if !util.Contains(supportedIPListType, t) {
|
||||
if !slices.Contains(supportedIPListType, t) {
|
||||
return util.NewValidationError(fmt.Sprintf("invalid list type %d", t))
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -22,7 +22,9 @@ import (
|
|||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -206,6 +208,32 @@ func (p *MemoryProvider) updateAPIKeyLastUse(keyID string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) getAdminSignature(username string) (string, error) {
|
||||
p.dbHandle.Lock()
|
||||
defer p.dbHandle.Unlock()
|
||||
if p.dbHandle.isClosed {
|
||||
return "", errMemoryProviderClosed
|
||||
}
|
||||
admin, err := p.adminExistsInternal(username)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strconv.FormatInt(admin.UpdatedAt, 10), nil
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) getUserSignature(username string) (string, error) {
|
||||
p.dbHandle.Lock()
|
||||
defer p.dbHandle.Unlock()
|
||||
if p.dbHandle.isClosed {
|
||||
return "", errMemoryProviderClosed
|
||||
}
|
||||
user, err := p.userExistsInternal(username)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strconv.FormatInt(user.UpdatedAt, 10), nil
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) setUpdatedAt(username string) {
|
||||
p.dbHandle.Lock()
|
||||
defer p.dbHandle.Unlock()
|
||||
|
@ -1210,7 +1238,7 @@ func (p *MemoryProvider) addRuleToActionMapping(ruleName, actionName string) err
|
|||
if err != nil {
|
||||
return util.NewGenericError(fmt.Sprintf("action %q does not exist", actionName))
|
||||
}
|
||||
if !util.Contains(a.Rules, ruleName) {
|
||||
if !slices.Contains(a.Rules, ruleName) {
|
||||
a.Rules = append(a.Rules, ruleName)
|
||||
p.dbHandle.actions[actionName] = a
|
||||
}
|
||||
|
@ -1223,7 +1251,7 @@ func (p *MemoryProvider) removeRuleFromActionMapping(ruleName, actionName string
|
|||
providerLog(logger.LevelWarn, "action %q does not exist, cannot remove from mapping", actionName)
|
||||
return
|
||||
}
|
||||
if util.Contains(a.Rules, ruleName) {
|
||||
if slices.Contains(a.Rules, ruleName) {
|
||||
var rules []string
|
||||
for _, r := range a.Rules {
|
||||
if r != ruleName {
|
||||
|
@ -1240,7 +1268,7 @@ func (p *MemoryProvider) addAdminToGroupMapping(username, groupname string) erro
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !util.Contains(g.Admins, username) {
|
||||
if !slices.Contains(g.Admins, username) {
|
||||
g.Admins = append(g.Admins, username)
|
||||
p.dbHandle.groups[groupname] = g
|
||||
}
|
||||
|
@ -1283,7 +1311,7 @@ func (p *MemoryProvider) addUserToGroupMapping(username, groupname string) error
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !util.Contains(g.Users, username) {
|
||||
if !slices.Contains(g.Users, username) {
|
||||
g.Users = append(g.Users, username)
|
||||
p.dbHandle.groups[groupname] = g
|
||||
}
|
||||
|
@ -1313,7 +1341,7 @@ func (p *MemoryProvider) addAdminToRole(username, role string) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("%w: role %q does not exist", ErrForeignKeyViolated, role)
|
||||
}
|
||||
if !util.Contains(r.Admins, username) {
|
||||
if !slices.Contains(r.Admins, username) {
|
||||
r.Admins = append(r.Admins, username)
|
||||
p.dbHandle.roles[role] = r
|
||||
}
|
||||
|
@ -1347,7 +1375,7 @@ func (p *MemoryProvider) addUserToRole(username, role string) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("%w: role %q does not exist", ErrForeignKeyViolated, role)
|
||||
}
|
||||
if !util.Contains(r.Users, username) {
|
||||
if !slices.Contains(r.Users, username) {
|
||||
r.Users = append(r.Users, username)
|
||||
p.dbHandle.roles[role] = r
|
||||
}
|
||||
|
@ -1378,7 +1406,7 @@ func (p *MemoryProvider) addUserToFolderMapping(username, foldername string) err
|
|||
if err != nil {
|
||||
return util.NewGenericError(fmt.Sprintf("unable to get folder %q: %v", foldername, err))
|
||||
}
|
||||
if !util.Contains(f.Users, username) {
|
||||
if !slices.Contains(f.Users, username) {
|
||||
f.Users = append(f.Users, username)
|
||||
p.dbHandle.vfolders[foldername] = f
|
||||
}
|
||||
|
@ -1390,7 +1418,7 @@ func (p *MemoryProvider) addGroupToFolderMapping(name, foldername string) error
|
|||
if err != nil {
|
||||
return util.NewGenericError(fmt.Sprintf("unable to get folder %q: %v", foldername, err))
|
||||
}
|
||||
if !util.Contains(f.Groups, name) {
|
||||
if !slices.Contains(f.Groups, name) {
|
||||
f.Groups = append(f.Groups, name)
|
||||
p.dbHandle.vfolders[foldername] = f
|
||||
}
|
||||
|
|
|
@ -95,8 +95,8 @@ const (
|
|||
"`last_login` bigint NOT NULL, `filters` longtext NULL, `filesystem` longtext NULL, `additional_info` longtext NULL, " +
|
||||
"`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL, `email` varchar(255) NULL, " +
|
||||
"`upload_data_transfer` integer NOT NULL, `download_data_transfer` integer NOT NULL, " +
|
||||
"`total_data_transfer` integer NOT NULL, `used_upload_data_transfer` integer NOT NULL, " +
|
||||
"`used_download_data_transfer` integer NOT NULL, `deleted_at` bigint NOT NULL, `first_download` bigint NOT NULL, " +
|
||||
"`total_data_transfer` integer NOT NULL, `used_upload_data_transfer` bigint NOT NULL, " +
|
||||
"`used_download_data_transfer` bigint NOT NULL, `deleted_at` bigint NOT NULL, `first_download` bigint NOT NULL, " +
|
||||
"`first_upload` bigint NOT NULL, `last_password_change` bigint NOT NULL, `role_id` integer NULL);" +
|
||||
"CREATE TABLE `{{groups_folders_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
|
||||
"`group_id` integer NOT NULL, `folder_id` integer NOT NULL, " +
|
||||
|
@ -193,11 +193,7 @@ const (
|
|||
"CREATE INDEX `{{prefix}}ip_lists_updated_at_idx` ON `{{ip_lists}}` (`updated_at`);" +
|
||||
"CREATE INDEX `{{prefix}}ip_lists_deleted_at_idx` ON `{{ip_lists}}` (`deleted_at`);" +
|
||||
"CREATE INDEX `{{prefix}}ip_lists_first_last_idx` ON `{{ip_lists}}` (`first`, `last`);" +
|
||||
"INSERT INTO {{schema_version}} (version) VALUES (28);"
|
||||
mysqlV29SQL = "ALTER TABLE `{{users}}` MODIFY `used_download_data_transfer` bigint NOT NULL;" +
|
||||
"ALTER TABLE `{{users}}` MODIFY `used_upload_data_transfer` bigint NOT NULL;"
|
||||
mysqlV29DownSQL = "ALTER TABLE `{{users}}` MODIFY `used_upload_data_transfer` integer NOT NULL;" +
|
||||
"ALTER TABLE `{{users}}` MODIFY `used_download_data_transfer` integer NOT NULL;"
|
||||
"INSERT INTO {{schema_version}} (version) VALUES (29);"
|
||||
)
|
||||
|
||||
// MySQLProvider defines the auth provider for MySQL/MariaDB database
|
||||
|
@ -329,6 +325,14 @@ func (p *MySQLProvider) getUsedQuota(username string) (int, int64, int64, int64,
|
|||
return sqlCommonGetUsedQuota(username, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) getAdminSignature(username string) (string, error) {
|
||||
return sqlCommonGetAdminSignature(username, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) getUserSignature(username string) (string, error) {
|
||||
return sqlCommonGetUserSignature(username, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) setUpdatedAt(username string) {
|
||||
sqlCommonSetUpdatedAt(username, p.dbHandle)
|
||||
}
|
||||
|
@ -776,11 +780,11 @@ func (p *MySQLProvider) initializeDatabase() error {
|
|||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return errSchemaVersionEmpty
|
||||
}
|
||||
logger.InfoToConsole("creating initial database schema, version 28")
|
||||
providerLog(logger.LevelInfo, "creating initial database schema, version 28")
|
||||
logger.InfoToConsole("creating initial database schema, version 29")
|
||||
providerLog(logger.LevelInfo, "creating initial database schema, version 29")
|
||||
initialSQL := sqlReplaceAll(mysqlInitialSQL)
|
||||
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(initialSQL, ";"), 28, true)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(initialSQL, ";"), 29, true)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) migrateDatabase() error {
|
||||
|
@ -793,13 +797,11 @@ func (p *MySQLProvider) migrateDatabase() error {
|
|||
case version == sqlDatabaseVersion:
|
||||
providerLog(logger.LevelDebug, "sql database is up to date, current version: %d", version)
|
||||
return ErrNoInitRequired
|
||||
case version < 28:
|
||||
case version < 29:
|
||||
err = errSchemaVersionTooOld(version)
|
||||
providerLog(logger.LevelError, "%v", err)
|
||||
logger.ErrorToConsole("%v", err)
|
||||
return err
|
||||
case version == 28:
|
||||
return updateMySQLDatabaseFrom28To29(p.dbHandle)
|
||||
default:
|
||||
if version > sqlDatabaseVersion {
|
||||
providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version,
|
||||
|
@ -822,8 +824,6 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error {
|
|||
}
|
||||
|
||||
switch dbVersion.Version {
|
||||
case 29:
|
||||
return downgradeMySQLDatabaseFrom29To28(p.dbHandle)
|
||||
default:
|
||||
return fmt.Errorf("database schema version not handled: %d", dbVersion.Version)
|
||||
}
|
||||
|
@ -861,19 +861,3 @@ func (p *MySQLProvider) normalizeError(err error, fieldType int) error {
|
|||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func updateMySQLDatabaseFrom28To29(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("updating database schema version: 28 -> 29")
|
||||
providerLog(logger.LevelInfo, "updating database schema version: 28 -> 29")
|
||||
|
||||
sql := strings.ReplaceAll(mysqlV29SQL, "{{users}}", sqlTableUsers)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 29, true)
|
||||
}
|
||||
|
||||
func downgradeMySQLDatabaseFrom29To28(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("downgrading database schema version: 29 -> 28")
|
||||
providerLog(logger.LevelInfo, "downgrading database schema version: 29 -> 28")
|
||||
|
||||
sql := strings.ReplaceAll(mysqlV29DownSQL, "{{users}}", sqlTableUsers)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 28, false)
|
||||
}
|
||||
|
|
|
@ -184,6 +184,7 @@ func (n *Node) generateAuthToken(username, role string) (string, error) {
|
|||
t := jwt.New()
|
||||
t.Set("admin", username) //nolint:errcheck
|
||||
t.Set("role", role) //nolint:errcheck
|
||||
t.Set(jwt.IssuedAtKey, now) //nolint:errcheck
|
||||
t.Set(jwt.JwtIDKey, xid.New().String()) //nolint:errcheck
|
||||
t.Set(jwt.NotBeforeKey, now.Add(-30*time.Second)) //nolint:errcheck
|
||||
t.Set(jwt.ExpirationKey, now.Add(1*time.Minute)) //nolint:errcheck
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -95,7 +96,7 @@ CREATE TABLE "{{users}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS
|
|||
"download_bandwidth" integer NOT NULL, "last_login" bigint NOT NULL, "filters" text NULL, "filesystem" text NULL,
|
||||
"additional_info" text NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "email" varchar(255) NULL,
|
||||
"upload_data_transfer" integer NOT NULL, "download_data_transfer" integer NOT NULL, "total_data_transfer" integer NOT NULL,
|
||||
"used_upload_data_transfer" integer NOT NULL, "used_download_data_transfer" integer NOT NULL, "deleted_at" bigint NOT NULL,
|
||||
"used_upload_data_transfer" bigint NOT NULL, "used_download_data_transfer" bigint NOT NULL, "deleted_at" bigint NOT NULL,
|
||||
"first_download" bigint NOT NULL, "first_upload" bigint NOT NULL, "last_password_change" bigint NOT NULL, "role_id" integer NULL);
|
||||
CREATE TABLE "{{groups_folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "group_id" integer NOT NULL,
|
||||
"folder_id" integer NOT NULL, "virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL);
|
||||
|
@ -205,16 +206,10 @@ CREATE INDEX "{{prefix}}ip_lists_ipornet_idx" ON "{{ip_lists}}" ("ipornet");
|
|||
CREATE INDEX "{{prefix}}ip_lists_updated_at_idx" ON "{{ip_lists}}" ("updated_at");
|
||||
CREATE INDEX "{{prefix}}ip_lists_deleted_at_idx" ON "{{ip_lists}}" ("deleted_at");
|
||||
CREATE INDEX "{{prefix}}ip_lists_first_last_idx" ON "{{ip_lists}}" ("first", "last");
|
||||
INSERT INTO {{schema_version}} (version) VALUES (28);
|
||||
INSERT INTO {{schema_version}} (version) VALUES (29);
|
||||
`
|
||||
// not supported in CockroachDB
|
||||
ipListsLikeIndex = `CREATE INDEX "{{prefix}}ip_lists_ipornet_like_idx" ON "{{ip_lists}}" ("ipornet" varchar_pattern_ops);`
|
||||
pgsqlV29SQL = `ALTER TABLE "{{users}}" ALTER COLUMN "used_download_data_transfer" TYPE bigint;
|
||||
ALTER TABLE "{{users}}" ALTER COLUMN "used_upload_data_transfer" TYPE bigint;
|
||||
`
|
||||
pgsqlV29DownSQL = `ALTER TABLE "{{users}}" ALTER COLUMN "used_upload_data_transfer" TYPE integer;
|
||||
ALTER TABLE "{{users}}" ALTER COLUMN "used_download_data_transfer" TYPE integer;
|
||||
`
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -311,7 +306,7 @@ func getPGSQLConnectionString(redactedPwd bool) string {
|
|||
if config.DisableSNI {
|
||||
connectionString += " sslsni=0"
|
||||
}
|
||||
if util.Contains(pgSQLTargetSessionAttrs, config.TargetSessionAttrs) {
|
||||
if slices.Contains(pgSQLTargetSessionAttrs, config.TargetSessionAttrs) {
|
||||
connectionString += fmt.Sprintf(" target_session_attrs='%s'", config.TargetSessionAttrs)
|
||||
}
|
||||
} else {
|
||||
|
@ -348,6 +343,14 @@ func (p *PGSQLProvider) getUsedQuota(username string) (int, int64, int64, int64,
|
|||
return sqlCommonGetUsedQuota(username, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) getAdminSignature(username string) (string, error) {
|
||||
return sqlCommonGetAdminSignature(username, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) getUserSignature(username string) (string, error) {
|
||||
return sqlCommonGetUserSignature(username, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) setUpdatedAt(username string) {
|
||||
sqlCommonSetUpdatedAt(username, p.dbHandle)
|
||||
}
|
||||
|
@ -795,8 +798,8 @@ func (p *PGSQLProvider) initializeDatabase() error {
|
|||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return errSchemaVersionEmpty
|
||||
}
|
||||
logger.InfoToConsole("creating initial database schema, version 28")
|
||||
providerLog(logger.LevelInfo, "creating initial database schema, version 28")
|
||||
logger.InfoToConsole("creating initial database schema, version 29")
|
||||
providerLog(logger.LevelInfo, "creating initial database schema, version 29")
|
||||
var initialSQL string
|
||||
if config.Driver == CockroachDataProviderName {
|
||||
initialSQL = sqlReplaceAll(pgsqlInitial)
|
||||
|
@ -805,7 +808,7 @@ func (p *PGSQLProvider) initializeDatabase() error {
|
|||
initialSQL = sqlReplaceAll(pgsqlInitial + ipListsLikeIndex)
|
||||
}
|
||||
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{initialSQL}, 28, true)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{initialSQL}, 29, true)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) migrateDatabase() error { //nolint:dupl
|
||||
|
@ -818,13 +821,11 @@ func (p *PGSQLProvider) migrateDatabase() error { //nolint:dupl
|
|||
case version == sqlDatabaseVersion:
|
||||
providerLog(logger.LevelDebug, "sql database is up to date, current version: %d", version)
|
||||
return ErrNoInitRequired
|
||||
case version < 28:
|
||||
case version < 29:
|
||||
err = errSchemaVersionTooOld(version)
|
||||
providerLog(logger.LevelError, "%v", err)
|
||||
logger.ErrorToConsole("%v", err)
|
||||
return err
|
||||
case version == 28:
|
||||
return updatePGSQLDatabaseFrom28To29(p.dbHandle)
|
||||
default:
|
||||
if version > sqlDatabaseVersion {
|
||||
providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version,
|
||||
|
@ -847,8 +848,6 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error {
|
|||
}
|
||||
|
||||
switch dbVersion.Version {
|
||||
case 29:
|
||||
return downgradePGSQLDatabaseFrom29To28(p.dbHandle)
|
||||
default:
|
||||
return fmt.Errorf("database schema version not handled: %d", dbVersion.Version)
|
||||
}
|
||||
|
@ -886,19 +885,3 @@ func (p *PGSQLProvider) normalizeError(err error, fieldType int) error {
|
|||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func updatePGSQLDatabaseFrom28To29(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("updating database schema version: 28 -> 29")
|
||||
providerLog(logger.LevelInfo, "updating database schema version: 28 -> 29")
|
||||
|
||||
sql := strings.ReplaceAll(pgsqlV29SQL, "{{users}}", sqlTableUsers)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 29, true)
|
||||
}
|
||||
|
||||
func downgradePGSQLDatabaseFrom29To28(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("downgrading database schema version: 29 -> 28")
|
||||
providerLog(logger.LevelInfo, "downgrading database schema version: 29 -> 28")
|
||||
|
||||
sql := strings.ReplaceAll(pgsqlV29DownSQL, "{{users}}", sqlTableUsers)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 28, false)
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"fmt"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -1248,6 +1249,32 @@ func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bo
|
|||
return err
|
||||
}
|
||||
|
||||
func sqlCommonGetAdminSignature(username string, dbHandle *sql.DB) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getAdminSignatureQuery()
|
||||
var updatedAt int64
|
||||
err := dbHandle.QueryRowContext(ctx, q, username).Scan(&updatedAt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strconv.FormatInt(updatedAt, 10), nil
|
||||
}
|
||||
|
||||
func sqlCommonGetUserSignature(username string, dbHandle *sql.DB) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getUserSignatureQuery()
|
||||
var updatedAt int64
|
||||
err := dbHandle.QueryRowContext(ctx, q, username).Scan(&updatedAt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strconv.FormatInt(updatedAt, 10), nil
|
||||
}
|
||||
|
||||
func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, int64, int64, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
|
|
@ -12,8 +12,8 @@
|
|||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
//go:build !nosqlite
|
||||
// +build !nosqlite
|
||||
//go:build !nosqlite && cgo
|
||||
// +build !nosqlite,cgo
|
||||
|
||||
package dataprovider
|
||||
|
||||
|
@ -178,7 +178,7 @@ CREATE INDEX "{{prefix}}ip_lists_ip_type_idx" ON "{{ip_lists}}" ("ip_type");
|
|||
CREATE INDEX "{{prefix}}ip_lists_ip_updated_at_idx" ON "{{ip_lists}}" ("updated_at");
|
||||
CREATE INDEX "{{prefix}}ip_lists_ip_deleted_at_idx" ON "{{ip_lists}}" ("deleted_at");
|
||||
CREATE INDEX "{{prefix}}ip_lists_first_last_idx" ON "{{ip_lists}}" ("first", "last");
|
||||
INSERT INTO {{schema_version}} (version) VALUES (28);
|
||||
INSERT INTO {{schema_version}} (version) VALUES (29);
|
||||
`
|
||||
)
|
||||
|
||||
|
@ -215,7 +215,7 @@ func initializeSQLiteProvider(basePath string) error {
|
|||
providerLog(logger.LevelDebug, "sqlite database handle created, connection string: %q", connectionString)
|
||||
dbHandle.SetMaxOpenConns(1)
|
||||
provider = &SQLiteProvider{dbHandle: dbHandle}
|
||||
return nil
|
||||
return executePragmaOptimize(dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) checkAvailability() error {
|
||||
|
@ -246,6 +246,14 @@ func (p *SQLiteProvider) getUsedQuota(username string) (int, int64, int64, int64
|
|||
return sqlCommonGetUsedQuota(username, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) getAdminSignature(username string) (string, error) {
|
||||
return sqlCommonGetAdminSignature(username, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) getUserSignature(username string) (string, error) {
|
||||
return sqlCommonGetUserSignature(username, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) setUpdatedAt(username string) {
|
||||
sqlCommonSetUpdatedAt(username, p.dbHandle)
|
||||
}
|
||||
|
@ -693,10 +701,10 @@ func (p *SQLiteProvider) initializeDatabase() error {
|
|||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return errSchemaVersionEmpty
|
||||
}
|
||||
logger.InfoToConsole("creating initial database schema, version 28")
|
||||
providerLog(logger.LevelInfo, "creating initial database schema, version 28")
|
||||
logger.InfoToConsole("creating initial database schema, version 29")
|
||||
providerLog(logger.LevelInfo, "creating initial database schema, version 29")
|
||||
sql := sqlReplaceAll(sqliteInitialSQL)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 28, true)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 29, true)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) migrateDatabase() error { //nolint:dupl
|
||||
|
@ -709,13 +717,11 @@ func (p *SQLiteProvider) migrateDatabase() error { //nolint:dupl
|
|||
case version == sqlDatabaseVersion:
|
||||
providerLog(logger.LevelDebug, "sql database is up to date, current version: %d", version)
|
||||
return ErrNoInitRequired
|
||||
case version < 28:
|
||||
case version < 29:
|
||||
err = errSchemaVersionTooOld(version)
|
||||
providerLog(logger.LevelError, "%v", err)
|
||||
logger.ErrorToConsole("%v", err)
|
||||
return err
|
||||
case version == 28:
|
||||
return updateSQLiteDatabaseFrom28To29(p.dbHandle)
|
||||
default:
|
||||
if version > sqlDatabaseVersion {
|
||||
providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version,
|
||||
|
@ -738,8 +744,6 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error {
|
|||
}
|
||||
|
||||
switch dbVersion.Version {
|
||||
case 29:
|
||||
return downgradeSQLiteDatabaseFrom29To28(p.dbHandle)
|
||||
default:
|
||||
return fmt.Errorf("database schema version not handled: %d", dbVersion.Version)
|
||||
}
|
||||
|
@ -777,24 +781,12 @@ func (p *SQLiteProvider) normalizeError(err error, fieldType int) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func updateSQLiteDatabaseFrom28To29(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("updating database schema version: 28 -> 29")
|
||||
providerLog(logger.LevelInfo, "updating database schema version: 28 -> 29")
|
||||
|
||||
func executePragmaOptimize(dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
return sqlCommonUpdateDatabaseVersion(ctx, dbHandle, 29)
|
||||
}
|
||||
|
||||
func downgradeSQLiteDatabaseFrom29To28(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("downgrading database schema version: 29 -> 28")
|
||||
providerLog(logger.LevelInfo, "downgrading database schema version: 29 -> 28")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
return sqlCommonUpdateDatabaseVersion(ctx, dbHandle, 28)
|
||||
_, err := dbHandle.ExecContext(ctx, "PRAGMA optimize;")
|
||||
return err
|
||||
}
|
||||
|
||||
/*func setPragmaFK(dbHandle *sql.DB, value string) error {
|
||||
|
|
|
@ -12,8 +12,8 @@
|
|||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
//go:build nosqlite
|
||||
// +build nosqlite
|
||||
//go:build nosqlite || !cgo
|
||||
// +build nosqlite !cgo
|
||||
|
||||
package dataprovider
|
||||
|
||||
|
|
|
@ -650,6 +650,14 @@ func getUpdateQuotaQuery(reset bool) string {
|
|||
WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
|
||||
}
|
||||
|
||||
func getAdminSignatureQuery() string {
|
||||
return fmt.Sprintf(`SELECT updated_at FROM %s WHERE username = %s`, sqlTableAdmins, sqlPlaceholders[0])
|
||||
}
|
||||
|
||||
func getUserSignatureQuery() string {
|
||||
return fmt.Sprintf(`SELECT updated_at FROM %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0])
|
||||
}
|
||||
|
||||
func getSetUpdateAtQuery() string {
|
||||
return fmt.Sprintf(`UPDATE %s SET updated_at = %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1])
|
||||
}
|
||||
|
|
|
@ -12,8 +12,8 @@
|
|||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
//go:build unixcrypt
|
||||
// +build unixcrypt
|
||||
//go:build unixcrypt && cgo
|
||||
// +build unixcrypt,cgo
|
||||
|
||||
package dataprovider
|
||||
|
||||
|
|
|
@ -12,8 +12,8 @@
|
|||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
//go:build !unixcrypt
|
||||
// +build !unixcrypt
|
||||
//go:build !unixcrypt || !cgo
|
||||
// +build !unixcrypt !cgo
|
||||
|
||||
package dataprovider
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"path"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -123,6 +124,8 @@ type UserFilters struct {
|
|||
sdk.BaseUserFilters
|
||||
// User must change password from WebClient/REST API at next login.
|
||||
RequirePasswordChange bool `json:"require_password_change,omitempty"`
|
||||
// AdditionalEmails defines additional email addresses
|
||||
AdditionalEmails []string `json:"additional_emails,omitempty"`
|
||||
// Time-based one time passwords configuration
|
||||
TOTPConfig UserTOTPConfig `json:"totp_config,omitempty"`
|
||||
// Recovery codes to use if the user loses access to their second factor auth device.
|
||||
|
@ -342,7 +345,11 @@ func (u *User) isTimeBasedAccessAllowed(when time.Time) bool {
|
|||
if when.IsZero() {
|
||||
when = time.Now()
|
||||
}
|
||||
when = when.UTC()
|
||||
if UseLocalTime() {
|
||||
when = when.Local()
|
||||
} else {
|
||||
when = when.UTC()
|
||||
}
|
||||
weekDay := when.Weekday()
|
||||
hhMM := when.Format("15:04")
|
||||
for _, p := range u.Filters.AccessTime {
|
||||
|
@ -399,6 +406,15 @@ func (u *User) CheckMaxShareExpiration(expiresAt time.Time) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// GetEmailAddresses returns all the email addresses.
|
||||
func (u *User) GetEmailAddresses() []string {
|
||||
var res []string
|
||||
if u.Email != "" {
|
||||
res = append(res, u.Email)
|
||||
}
|
||||
return slices.Concat(res, u.Filters.AdditionalEmails)
|
||||
}
|
||||
|
||||
// GetSubDirPermissions returns permissions for sub directories
|
||||
func (u *User) GetSubDirPermissions() []sdk.DirectoryPermissions {
|
||||
var result []sdk.DirectoryPermissions
|
||||
|
@ -840,20 +856,20 @@ func (u *User) HasPermissionsInside(virtualPath string) bool {
|
|||
// HasPerm returns true if the user has the given permission or any permission
|
||||
func (u *User) HasPerm(permission, path string) bool {
|
||||
perms := u.GetPermissionsForPath(path)
|
||||
if util.Contains(perms, PermAny) {
|
||||
if slices.Contains(perms, PermAny) {
|
||||
return true
|
||||
}
|
||||
return util.Contains(perms, permission)
|
||||
return slices.Contains(perms, permission)
|
||||
}
|
||||
|
||||
// HasAnyPerm returns true if the user has at least one of the given permissions
|
||||
func (u *User) HasAnyPerm(permissions []string, path string) bool {
|
||||
perms := u.GetPermissionsForPath(path)
|
||||
if util.Contains(perms, PermAny) {
|
||||
if slices.Contains(perms, PermAny) {
|
||||
return true
|
||||
}
|
||||
for _, permission := range permissions {
|
||||
if util.Contains(perms, permission) {
|
||||
if slices.Contains(perms, permission) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -863,11 +879,11 @@ func (u *User) HasAnyPerm(permissions []string, path string) bool {
|
|||
// HasPerms returns true if the user has all the given permissions
|
||||
func (u *User) HasPerms(permissions []string, path string) bool {
|
||||
perms := u.GetPermissionsForPath(path)
|
||||
if util.Contains(perms, PermAny) {
|
||||
if slices.Contains(perms, PermAny) {
|
||||
return true
|
||||
}
|
||||
for _, permission := range permissions {
|
||||
if !util.Contains(perms, permission) {
|
||||
if !slices.Contains(perms, permission) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
@ -927,11 +943,11 @@ func (u *User) IsLoginMethodAllowed(loginMethod, protocol string) bool {
|
|||
if len(u.Filters.DeniedLoginMethods) == 0 {
|
||||
return true
|
||||
}
|
||||
if util.Contains(u.Filters.DeniedLoginMethods, loginMethod) {
|
||||
if slices.Contains(u.Filters.DeniedLoginMethods, loginMethod) {
|
||||
return false
|
||||
}
|
||||
if protocol == protocolSSH && loginMethod == LoginMethodPassword {
|
||||
if util.Contains(u.Filters.DeniedLoginMethods, SSHLoginMethodPassword) {
|
||||
if slices.Contains(u.Filters.DeniedLoginMethods, SSHLoginMethodPassword) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
@ -965,10 +981,10 @@ func (u *User) IsPartialAuth() bool {
|
|||
method == SSHLoginMethodPassword {
|
||||
continue
|
||||
}
|
||||
if method == LoginMethodPassword && util.Contains(u.Filters.DeniedLoginMethods, SSHLoginMethodPassword) {
|
||||
if method == LoginMethodPassword && slices.Contains(u.Filters.DeniedLoginMethods, SSHLoginMethodPassword) {
|
||||
continue
|
||||
}
|
||||
if !util.Contains(SSHMultiStepsLoginMethods, method) {
|
||||
if !slices.Contains(SSHMultiStepsLoginMethods, method) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
@ -982,7 +998,7 @@ func (u *User) GetAllowedLoginMethods() []string {
|
|||
if method == SSHLoginMethodPassword {
|
||||
continue
|
||||
}
|
||||
if !util.Contains(u.Filters.DeniedLoginMethods, method) {
|
||||
if !slices.Contains(u.Filters.DeniedLoginMethods, method) {
|
||||
allowedMethods = append(allowedMethods, method)
|
||||
}
|
||||
}
|
||||
|
@ -1052,7 +1068,7 @@ func (u *User) IsFileAllowed(virtualPath string) (bool, int) {
|
|||
|
||||
// CanManageMFA returns true if the user can add a multi-factor authentication configuration
|
||||
func (u *User) CanManageMFA() bool {
|
||||
if util.Contains(u.Filters.WebClient, sdk.WebClientMFADisabled) {
|
||||
if slices.Contains(u.Filters.WebClient, sdk.WebClientMFADisabled) {
|
||||
return false
|
||||
}
|
||||
return len(mfa.GetAvailableTOTPConfigs()) > 0
|
||||
|
@ -1073,39 +1089,39 @@ func (u *User) skipExternalAuth() bool {
|
|||
|
||||
// CanManageShares returns true if the user can add, update and list shares
|
||||
func (u *User) CanManageShares() bool {
|
||||
return !util.Contains(u.Filters.WebClient, sdk.WebClientSharesDisabled)
|
||||
return !slices.Contains(u.Filters.WebClient, sdk.WebClientSharesDisabled)
|
||||
}
|
||||
|
||||
// CanResetPassword returns true if this user is allowed to reset its password
|
||||
func (u *User) CanResetPassword() bool {
|
||||
return !util.Contains(u.Filters.WebClient, sdk.WebClientPasswordResetDisabled)
|
||||
return !slices.Contains(u.Filters.WebClient, sdk.WebClientPasswordResetDisabled)
|
||||
}
|
||||
|
||||
// CanChangePassword returns true if this user is allowed to change its password
|
||||
func (u *User) CanChangePassword() bool {
|
||||
return !util.Contains(u.Filters.WebClient, sdk.WebClientPasswordChangeDisabled)
|
||||
return !slices.Contains(u.Filters.WebClient, sdk.WebClientPasswordChangeDisabled)
|
||||
}
|
||||
|
||||
// CanChangeAPIKeyAuth returns true if this user is allowed to enable/disable API key authentication
|
||||
func (u *User) CanChangeAPIKeyAuth() bool {
|
||||
return !util.Contains(u.Filters.WebClient, sdk.WebClientAPIKeyAuthChangeDisabled)
|
||||
return !slices.Contains(u.Filters.WebClient, sdk.WebClientAPIKeyAuthChangeDisabled)
|
||||
}
|
||||
|
||||
// CanChangeInfo returns true if this user is allowed to change its info such as email and description
|
||||
func (u *User) CanChangeInfo() bool {
|
||||
return !util.Contains(u.Filters.WebClient, sdk.WebClientInfoChangeDisabled)
|
||||
return !slices.Contains(u.Filters.WebClient, sdk.WebClientInfoChangeDisabled)
|
||||
}
|
||||
|
||||
// CanManagePublicKeys returns true if this user is allowed to manage public keys
|
||||
// from the WebClient. Used in WebClient UI
|
||||
func (u *User) CanManagePublicKeys() bool {
|
||||
return !util.Contains(u.Filters.WebClient, sdk.WebClientPubKeyChangeDisabled)
|
||||
return !slices.Contains(u.Filters.WebClient, sdk.WebClientPubKeyChangeDisabled)
|
||||
}
|
||||
|
||||
// CanManageTLSCerts returns true if this user is allowed to manage TLS certificates
|
||||
// from the WebClient. Used in WebClient UI
|
||||
func (u *User) CanManageTLSCerts() bool {
|
||||
return !util.Contains(u.Filters.WebClient, sdk.WebClientTLSCertChangeDisabled)
|
||||
return !slices.Contains(u.Filters.WebClient, sdk.WebClientTLSCertChangeDisabled)
|
||||
}
|
||||
|
||||
// CanUpdateProfile returns true if the user is allowed to update the profile.
|
||||
|
@ -1117,7 +1133,7 @@ func (u *User) CanUpdateProfile() bool {
|
|||
// CanAddFilesFromWeb returns true if the client can add files from the web UI.
|
||||
// The specified target is the directory where the files must be uploaded
|
||||
func (u *User) CanAddFilesFromWeb(target string) bool {
|
||||
if util.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
|
||||
if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
|
||||
return false
|
||||
}
|
||||
return u.HasPerm(PermUpload, target) || u.HasPerm(PermOverwrite, target)
|
||||
|
@ -1126,7 +1142,7 @@ func (u *User) CanAddFilesFromWeb(target string) bool {
|
|||
// CanAddDirsFromWeb returns true if the client can add directories from the web UI.
|
||||
// The specified target is the directory where the new directory must be created
|
||||
func (u *User) CanAddDirsFromWeb(target string) bool {
|
||||
if util.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
|
||||
if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
|
||||
return false
|
||||
}
|
||||
return u.HasPerm(PermCreateDirs, target)
|
||||
|
@ -1135,7 +1151,7 @@ func (u *User) CanAddDirsFromWeb(target string) bool {
|
|||
// CanRenameFromWeb returns true if the client can rename objects from the web UI.
|
||||
// The specified src and dest are the source and target directories for the rename.
|
||||
func (u *User) CanRenameFromWeb(src, dest string) bool {
|
||||
if util.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
|
||||
if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
|
||||
return false
|
||||
}
|
||||
return u.HasAnyPerm(permsRenameAny, src) && u.HasAnyPerm(permsRenameAny, dest)
|
||||
|
@ -1144,7 +1160,7 @@ func (u *User) CanRenameFromWeb(src, dest string) bool {
|
|||
// CanDeleteFromWeb returns true if the client can delete objects from the web UI.
|
||||
// The specified target is the parent directory for the object to delete
|
||||
func (u *User) CanDeleteFromWeb(target string) bool {
|
||||
if util.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
|
||||
if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
|
||||
return false
|
||||
}
|
||||
return u.HasAnyPerm(permsDeleteAny, target)
|
||||
|
@ -1153,7 +1169,7 @@ func (u *User) CanDeleteFromWeb(target string) bool {
|
|||
// CanCopyFromWeb returns true if the client can copy objects from the web UI.
|
||||
// The specified src and dest are the source and target directories for the copy.
|
||||
func (u *User) CanCopyFromWeb(src, dest string) bool {
|
||||
if util.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
|
||||
if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) {
|
||||
return false
|
||||
}
|
||||
if !u.HasPerm(PermListItems, src) {
|
||||
|
@ -1213,7 +1229,7 @@ func (u *User) MustSetSecondFactor() bool {
|
|||
return true
|
||||
}
|
||||
for _, p := range u.Filters.TwoFactorAuthProtocols {
|
||||
if !util.Contains(u.Filters.TOTPConfig.Protocols, p) {
|
||||
if !slices.Contains(u.Filters.TOTPConfig.Protocols, p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -1224,11 +1240,11 @@ func (u *User) MustSetSecondFactor() bool {
|
|||
// MustSetSecondFactorForProtocol returns true if the user must set a second factor authentication
|
||||
// for the specified protocol
|
||||
func (u *User) MustSetSecondFactorForProtocol(protocol string) bool {
|
||||
if util.Contains(u.Filters.TwoFactorAuthProtocols, protocol) {
|
||||
if slices.Contains(u.Filters.TwoFactorAuthProtocols, protocol) {
|
||||
if !u.Filters.TOTPConfig.Enabled {
|
||||
return true
|
||||
}
|
||||
if !util.Contains(u.Filters.TOTPConfig.Protocols, protocol) {
|
||||
if !slices.Contains(u.Filters.TOTPConfig.Protocols, protocol) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -1779,6 +1795,8 @@ func (u *User) getACopy() User {
|
|||
filters.TOTPConfig.Secret = u.Filters.TOTPConfig.Secret.Clone()
|
||||
filters.TOTPConfig.Protocols = make([]string, len(u.Filters.TOTPConfig.Protocols))
|
||||
copy(filters.TOTPConfig.Protocols, u.Filters.TOTPConfig.Protocols)
|
||||
filters.AdditionalEmails = make([]string, len(u.Filters.AdditionalEmails))
|
||||
copy(filters.AdditionalEmails, u.Filters.AdditionalEmails)
|
||||
filters.RecoveryCodes = make([]RecoveryCode, 0, len(u.Filters.RecoveryCodes))
|
||||
for _, code := range u.Filters.RecoveryCodes {
|
||||
if code.Secret == nil {
|
||||
|
|
|
@ -134,6 +134,7 @@ func TestBasicFTPHandlingCryptFs(t *testing.T) {
|
|||
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond)
|
||||
assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond,
|
||||
50*time.Millisecond)
|
||||
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
|
||||
}
|
||||
|
||||
func TestBufferedCryptFs(t *testing.T) {
|
||||
|
@ -179,6 +180,7 @@ func TestBufferedCryptFs(t *testing.T) {
|
|||
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond)
|
||||
assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond,
|
||||
50*time.Millisecond)
|
||||
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
|
||||
}
|
||||
|
||||
func TestZeroBytesTransfersCryptFs(t *testing.T) {
|
||||
|
|
|
@ -37,6 +37,7 @@ import (
|
|||
|
||||
ftpserver "github.com/fclairamb/ftpserverlib"
|
||||
"github.com/jlaffaye/ftp"
|
||||
"github.com/pkg/sftp"
|
||||
"github.com/pquerna/otp"
|
||||
"github.com/pquerna/otp/totp"
|
||||
"github.com/rs/zerolog"
|
||||
|
@ -44,6 +45,7 @@ import (
|
|||
sdkkms "github.com/sftpgo/sdk/kms"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/common"
|
||||
"github.com/drakkan/sftpgo/v2/internal/config"
|
||||
|
@ -671,6 +673,7 @@ func TestBasicFTPHandling(t *testing.T) {
|
|||
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond)
|
||||
assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond,
|
||||
50*time.Millisecond)
|
||||
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
|
||||
}
|
||||
|
||||
func TestHTTPFs(t *testing.T) {
|
||||
|
@ -715,6 +718,7 @@ func TestHTTPFs(t *testing.T) {
|
|||
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond)
|
||||
assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond,
|
||||
50*time.Millisecond)
|
||||
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
|
||||
}
|
||||
|
||||
func TestListDirWithWildcards(t *testing.T) {
|
||||
|
@ -1735,6 +1739,66 @@ func TestMaxPerHostConnections(t *testing.T) {
|
|||
common.Config.MaxPerHostConnections = oldValue
|
||||
}
|
||||
|
||||
func TestMaxTransfers(t *testing.T) {
|
||||
oldValue := common.Config.MaxPerHostConnections
|
||||
common.Config.MaxPerHostConnections = 2
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return common.Connections.GetClientConnections() == 0
|
||||
}, 1000*time.Millisecond, 50*time.Millisecond)
|
||||
|
||||
user := getTestUser()
|
||||
err := dataprovider.AddUser(&user, "", "", "")
|
||||
assert.NoError(t, err)
|
||||
user.Password = ""
|
||||
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(65535)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
assert.NoError(t, err)
|
||||
|
||||
conn, sftpClient, err := getSftpClient(user)
|
||||
assert.NoError(t, err)
|
||||
defer conn.Close()
|
||||
defer sftpClient.Close()
|
||||
|
||||
f1, err := sftpClient.Create("file1")
|
||||
assert.NoError(t, err)
|
||||
f2, err := sftpClient.Create("file2")
|
||||
assert.NoError(t, err)
|
||||
_, err = f1.Write([]byte(" "))
|
||||
assert.NoError(t, err)
|
||||
_, err = f2.Write([]byte(" "))
|
||||
assert.NoError(t, err)
|
||||
|
||||
client, err := getFTPClient(user, true, nil)
|
||||
if assert.NoError(t, err) {
|
||||
err = checkBasicFTP(client)
|
||||
assert.NoError(t, err)
|
||||
err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0)
|
||||
assert.Error(t, err)
|
||||
localDownloadPath := filepath.Join(homeBasePath, testDLFileName)
|
||||
err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0)
|
||||
assert.Error(t, err)
|
||||
err := client.Quit()
|
||||
assert.NoError(t, err)
|
||||
err = os.Remove(localDownloadPath)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
err = f1.Close()
|
||||
assert.NoError(t, err)
|
||||
err = f2.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = dataprovider.DeleteUser(user.Username, "", "", "")
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
|
||||
common.Config.MaxPerHostConnections = oldValue
|
||||
}
|
||||
|
||||
func TestRateLimiter(t *testing.T) {
|
||||
oldConfig := config.GetCommonConfig()
|
||||
|
||||
|
@ -2718,6 +2782,7 @@ func TestStat(t *testing.T) {
|
|||
|
||||
func TestUploadOverwriteVfolder(t *testing.T) {
|
||||
u := getTestUser()
|
||||
u.QuotaFiles = 1000
|
||||
vdir := "/vdir"
|
||||
mappedPath := filepath.Join(os.TempDir(), "vdir")
|
||||
folderName := filepath.Base(mappedPath)
|
||||
|
@ -2749,14 +2814,24 @@ func TestUploadOverwriteVfolder(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, folder.UsedQuotaSize)
|
||||
assert.Equal(t, 1, folder.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), folder.UsedQuotaSize)
|
||||
assert.Equal(t, 0, folder.UsedQuotaFiles)
|
||||
user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, user.UsedQuotaSize)
|
||||
assert.Equal(t, 1, user.UsedQuotaFiles)
|
||||
|
||||
err = ftpUploadFile(testFilePath, path.Join(vdir, testFileName), testFileSize, client, 0)
|
||||
assert.NoError(t, err)
|
||||
folder, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, folder.UsedQuotaSize)
|
||||
assert.Equal(t, 1, folder.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), folder.UsedQuotaSize)
|
||||
assert.Equal(t, 0, folder.UsedQuotaFiles)
|
||||
user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, user.UsedQuotaSize)
|
||||
assert.Equal(t, 1, user.UsedQuotaFiles)
|
||||
|
||||
err = client.Quit()
|
||||
assert.NoError(t, err)
|
||||
err = os.Remove(testFilePath)
|
||||
|
@ -3951,6 +4026,7 @@ func TestNestedVirtualFolders(t *testing.T) {
|
|||
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond)
|
||||
assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond,
|
||||
50*time.Millisecond)
|
||||
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
|
||||
}
|
||||
|
||||
func checkBasicFTP(client *ftp.ServerConn) error {
|
||||
|
@ -4202,6 +4278,30 @@ func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []by
|
|||
return content
|
||||
}
|
||||
|
||||
func getSftpClient(user dataprovider.User) (*ssh.Client, *sftp.Client, error) {
|
||||
var sftpClient *sftp.Client
|
||||
config := &ssh.ClientConfig{
|
||||
User: user.Username,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
if user.Password != "" {
|
||||
config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)}
|
||||
} else {
|
||||
config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)}
|
||||
}
|
||||
|
||||
conn, err := ssh.Dial("tcp", sftpServerAddr, config)
|
||||
if err != nil {
|
||||
return conn, sftpClient, err
|
||||
}
|
||||
sftpClient, err = sftp.NewClient(conn)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
}
|
||||
return conn, sftpClient, err
|
||||
}
|
||||
|
||||
func getExitCodeScriptContent(exitCode int) []byte {
|
||||
content := []byte("#!/bin/sh\n\n")
|
||||
content = append(content, []byte(fmt.Sprintf("exit %v", exitCode))...)
|
||||
|
|
|
@ -331,6 +331,11 @@ func (c *Connection) GetHandle(name string, flags int, offset int64) (ftpserver.
|
|||
return nil, errCOMBNotSupported
|
||||
}
|
||||
|
||||
if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil {
|
||||
c.Log(logger.LevelInfo, "denying transfer due to count limits")
|
||||
return nil, c.GetPermissionDeniedError()
|
||||
}
|
||||
|
||||
if flags&os.O_WRONLY != 0 {
|
||||
return c.uploadFile(fs, p, name, flags)
|
||||
}
|
||||
|
@ -465,7 +470,7 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve
|
|||
}
|
||||
|
||||
if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() {
|
||||
_, _, err = fs.Rename(resolvedPath, filePath)
|
||||
_, _, err = fs.Rename(resolvedPath, filePath, 0)
|
||||
if err != nil {
|
||||
c.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %q, dest: %q, err: %+v",
|
||||
resolvedPath, filePath, err)
|
||||
|
@ -493,10 +498,7 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve
|
|||
if vfs.HasTruncateSupport(fs) {
|
||||
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
|
||||
if err == nil {
|
||||
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
|
||||
if vfolder.IsIncludedInUserQuota() {
|
||||
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
|
||||
}
|
||||
dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false)
|
||||
} else {
|
||||
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
|
||||
}
|
||||
|
|
|
@ -404,7 +404,7 @@ func (fs MockOsFs) Remove(name string, _ bool) error {
|
|||
}
|
||||
|
||||
// Rename renames (moves) source to target
|
||||
func (fs MockOsFs) Rename(source, target string) (int, int64, error) {
|
||||
func (fs MockOsFs) Rename(source, target string, _ int) (int, int64, error) {
|
||||
if fs.err != nil {
|
||||
return -1, -1, fs.err
|
||||
}
|
||||
|
@ -664,6 +664,7 @@ func TestClientVersion(t *testing.T) {
|
|||
common.Connections.Remove(connection.GetID())
|
||||
}
|
||||
assert.Len(t, common.Connections.GetStats(""), 0)
|
||||
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
|
||||
}
|
||||
|
||||
func TestDriverMethodsNotImplemented(t *testing.T) {
|
||||
|
@ -918,6 +919,7 @@ func TestTransferErrors(t *testing.T) {
|
|||
pipeWriter := vfs.NewPipeWriter(w)
|
||||
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile,
|
||||
common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{})
|
||||
tr.Connection.RemoveTransfer(tr)
|
||||
tr = newTransfer(baseTransfer, pipeWriter, nil, 0)
|
||||
|
||||
err = r.Close()
|
||||
|
@ -933,6 +935,7 @@ func TestTransferErrors(t *testing.T) {
|
|||
if assert.Error(t, err) {
|
||||
assert.EqualError(t, err, common.ErrOpUnsupported.Error())
|
||||
}
|
||||
tr.Connection.RemoveTransfer(tr)
|
||||
err = os.Remove(testfile)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
ftpserver "github.com/fclairamb/ftpserverlib"
|
||||
"github.com/sftpgo/sdk/plugin/notifier"
|
||||
|
@ -187,20 +188,18 @@ func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string)
|
|||
user, err := dataprovider.CheckUserAndPass(username, password, ipAddr, common.ProtocolFTP)
|
||||
if err != nil {
|
||||
user.Username = username
|
||||
updateLoginMetrics(&user, ipAddr, loginMethod, err)
|
||||
updateLoginMetrics(&user, ipAddr, loginMethod, err, nil)
|
||||
return nil, dataprovider.ErrInvalidCredentials
|
||||
}
|
||||
|
||||
connection, err := s.validateUser(user, cc, loginMethod)
|
||||
|
||||
defer updateLoginMetrics(&user, ipAddr, loginMethod, err)
|
||||
defer updateLoginMetrics(&user, ipAddr, loginMethod, err, connection)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
setStartDirectory(user.Filters.StartDirectory, cc)
|
||||
connection.Log(logger.LevelInfo, "User %q logged in with %q from ip %q, TLS enabled? %t",
|
||||
user.Username, loginMethod, ipAddr, cc.HasTLSForControl())
|
||||
dataprovider.UpdateLastLogin(&user)
|
||||
return connection, nil
|
||||
}
|
||||
|
@ -245,7 +244,7 @@ func (s *Server) VerifyConnection(cc ftpserver.ClientContext, user string, tlsCo
|
|||
dbUser, err := dataprovider.CheckUserBeforeTLSAuth(user, ipAddr, common.ProtocolFTP, state.PeerCertificates[0])
|
||||
if err != nil {
|
||||
dbUser.Username = user
|
||||
updateLoginMetrics(&dbUser, ipAddr, dataprovider.LoginMethodTLSCertificate, err)
|
||||
updateLoginMetrics(&dbUser, ipAddr, dataprovider.LoginMethodTLSCertificate, err, nil)
|
||||
return nil, dataprovider.ErrInvalidCredentials
|
||||
}
|
||||
if dbUser.IsTLSVerificationEnabled() {
|
||||
|
@ -259,14 +258,12 @@ func (s *Server) VerifyConnection(cc ftpserver.ClientContext, user string, tlsCo
|
|||
if dbUser.IsLoginMethodAllowed(dataprovider.LoginMethodTLSCertificate, common.ProtocolFTP) {
|
||||
connection, err := s.validateUser(dbUser, cc, dataprovider.LoginMethodTLSCertificate)
|
||||
|
||||
defer updateLoginMetrics(&dbUser, ipAddr, dataprovider.LoginMethodTLSCertificate, err)
|
||||
defer updateLoginMetrics(&dbUser, ipAddr, dataprovider.LoginMethodTLSCertificate, err, connection)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
setStartDirectory(dbUser.Filters.StartDirectory, cc)
|
||||
connection.Log(logger.LevelInfo, "User id: %d, logged in with FTP using a TLS certificate, username: %q, home_dir: %q remote addr: %q",
|
||||
dbUser.ID, dbUser.Username, dbUser.HomeDir, ipAddr)
|
||||
dataprovider.UpdateLastLogin(&dbUser)
|
||||
return connection, nil
|
||||
}
|
||||
|
@ -361,7 +358,7 @@ func (s *Server) validateUser(user dataprovider.User, cc ftpserver.ClientContext
|
|||
user.Username, user.HomeDir)
|
||||
return nil, fmt.Errorf("cannot login user with invalid home dir: %q", user.HomeDir)
|
||||
}
|
||||
if util.Contains(user.Filters.DeniedProtocols, common.ProtocolFTP) {
|
||||
if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolFTP) {
|
||||
logger.Info(logSender, connectionID, "cannot login user %q, protocol FTP is not allowed", user.Username)
|
||||
return nil, fmt.Errorf("protocol FTP is not allowed for user %q", user.Username)
|
||||
}
|
||||
|
@ -416,9 +413,11 @@ func setStartDirectory(startDirectory string, cc ftpserver.ClientContext) {
|
|||
cc.SetPath(startDirectory)
|
||||
}
|
||||
|
||||
func updateLoginMetrics(user *dataprovider.User, ip, loginMethod string, err error) {
|
||||
func updateLoginMetrics(user *dataprovider.User, ip, loginMethod string, err error, c *Connection) {
|
||||
metric.AddLoginAttempt(loginMethod)
|
||||
if err == nil {
|
||||
logger.LoginLog(user.Username, ip, loginMethod, common.ProtocolFTP, c.ID, c.GetClientVersion(),
|
||||
c.clientContext.HasTLSForControl(), "")
|
||||
plugin.Handler.NotifyLogEvent(notifier.LogEventTypeLoginOK, common.ProtocolFTP, user.Username, ip, "", nil)
|
||||
common.DelayLogin(nil)
|
||||
} else if err != common.ErrInternalFailure {
|
||||
|
|
|
@ -85,7 +85,7 @@ type oauth2TokenRequest struct {
|
|||
BaseRedirectURL string `json:"base_redirect_url"`
|
||||
}
|
||||
|
||||
func handleSMTPOAuth2TokenRequestPost(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *httpdServer) handleSMTPOAuth2TokenRequestPost(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
|
||||
var req oauth2TokenRequest
|
||||
|
@ -115,7 +115,7 @@ func handleSMTPOAuth2TokenRequestPost(w http.ResponseWriter, r *http.Request) {
|
|||
clientSecret.SetAdditionalData(xid.New().String())
|
||||
pendingAuth := newOAuth2PendingAuth(req.Provider, cfg.RedirectURL, cfg.ClientID, clientSecret)
|
||||
oauth2Mgr.addPendingAuth(pendingAuth)
|
||||
stateToken := createOAuth2Token(pendingAuth.State, util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
stateToken := createOAuth2Token(s.csrfTokenAuth, pendingAuth.State, util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
if stateToken == "" {
|
||||
sendAPIResponse(w, r, nil, "unable to create state token", http.StatusInternalServerError)
|
||||
return
|
||||
|
|
|
@ -317,6 +317,13 @@ func uploadUserFiles(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
if err := common.Connections.IsNewTransferAllowed(connection.User.Username); err != nil {
|
||||
connection.Log(logger.LevelInfo, "denying file write due to number of transfer limits")
|
||||
sendAPIResponse(w, r, err, "Denying file write due to transfer count limits",
|
||||
http.StatusConflict)
|
||||
return
|
||||
}
|
||||
|
||||
transferQuota := connection.GetTransferQuota()
|
||||
if !transferQuota.HasUploadSpace() {
|
||||
connection.Log(logger.LevelInfo, "denying file write due to transfer quota limits")
|
||||
|
@ -469,8 +476,9 @@ func getUserProfile(w http.ResponseWriter, r *http.Request) {
|
|||
Description: user.Description,
|
||||
AllowAPIKeyAuth: user.Filters.AllowAPIKeyAuth,
|
||||
},
|
||||
PublicKeys: user.PublicKeys,
|
||||
TLSCerts: user.Filters.TLSCerts,
|
||||
AdditionalEmails: user.Filters.AdditionalEmails,
|
||||
PublicKeys: user.PublicKeys,
|
||||
TLSCerts: user.Filters.TLSCerts,
|
||||
}
|
||||
render.JSON(w, r, resp)
|
||||
}
|
||||
|
@ -508,6 +516,7 @@ func updateUserProfile(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
if userMerged.CanChangeInfo() {
|
||||
user.Email = req.Email
|
||||
user.Filters.AdditionalEmails = req.AdditionalEmails
|
||||
user.Description = req.Description
|
||||
}
|
||||
if err := dataprovider.UpdateUser(&user, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), user.Role); err != nil {
|
||||
|
|
|
@ -551,9 +551,10 @@ func RestoreUsers(users []dataprovider.User, inputFile string, mode, scanQuota i
|
|||
return fmt.Errorf("unable to restore user %q: %w", user.Username, err)
|
||||
}
|
||||
if scanQuota == 1 || (scanQuota == 2 && user.HasQuotaRestrictions()) {
|
||||
if common.QuotaScans.AddUserQuotaScan(user.Username, user.Role) {
|
||||
user, err = dataprovider.GetUserWithGroupSettings(user.Username, "")
|
||||
if err == nil && common.QuotaScans.AddUserQuotaScan(user.Username, user.Role) {
|
||||
logger.Debug(logSender, "", "starting quota scan for restored user: %q", user.Username)
|
||||
go doUserQuotaScan(user) //nolint:errcheck
|
||||
go doUserQuotaScan(&user) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
|
@ -138,8 +139,7 @@ func saveTOTPConfig(w http.ResponseWriter, r *http.Request) {
|
|||
if claims.MustSetTwoFactorAuth {
|
||||
// force logout
|
||||
defer func() {
|
||||
c := jwtTokenClaims{}
|
||||
c.removeCookie(w, r, baseURL)
|
||||
removeCookie(w, r, baseURL)
|
||||
}()
|
||||
}
|
||||
|
||||
|
@ -276,7 +276,7 @@ func saveUserTOTPConfig(username string, r *http.Request, recoveryCodes []datapr
|
|||
return util.NewValidationError("two-factor authentication must be enabled")
|
||||
}
|
||||
for _, p := range userMerged.Filters.TwoFactorAuthProtocols {
|
||||
if !util.Contains(user.Filters.TOTPConfig.Protocols, p) {
|
||||
if !slices.Contains(user.Filters.TOTPConfig.Protocols, p) {
|
||||
return util.NewValidationError(fmt.Sprintf("totp: the following protocols are required: %q",
|
||||
strings.Join(userMerged.Filters.TwoFactorAuthProtocols, ", ")))
|
||||
}
|
||||
|
|
|
@ -219,7 +219,7 @@ func doStartUserQuotaScan(w http.ResponseWriter, r *http.Request, username strin
|
|||
http.StatusConflict)
|
||||
return
|
||||
}
|
||||
go doUserQuotaScan(user) //nolint:errcheck
|
||||
go doUserQuotaScan(&user) //nolint:errcheck
|
||||
sendAPIResponse(w, r, err, "Scan started", http.StatusAccepted)
|
||||
}
|
||||
|
||||
|
@ -242,14 +242,14 @@ func doStartFolderQuotaScan(w http.ResponseWriter, r *http.Request, name string)
|
|||
sendAPIResponse(w, r, err, "Scan started", http.StatusAccepted)
|
||||
}
|
||||
|
||||
func doUserQuotaScan(user dataprovider.User) error {
|
||||
func doUserQuotaScan(user *dataprovider.User) error {
|
||||
defer common.QuotaScans.RemoveUserQuotaScan(user.Username)
|
||||
numFiles, size, err := user.ScanQuota()
|
||||
if err != nil {
|
||||
logger.Warn(logSender, "", "error scanning user quota %q: %v", user.Username, err)
|
||||
return err
|
||||
}
|
||||
err = dataprovider.UpdateUserQuota(&user, numFiles, size, true)
|
||||
err = dataprovider.UpdateUserQuota(user, numFiles, size, true)
|
||||
logger.Debug(logSender, "", "user quota scanned, user: %q, error: %v", user.Username, err)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -107,7 +108,7 @@ func addShare(w http.ResponseWriter, r *http.Request) {
|
|||
share.Name = share.ShareID
|
||||
}
|
||||
if share.Password == "" {
|
||||
if util.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) {
|
||||
if slices.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) {
|
||||
sendAPIResponse(w, r, nil, "You are not authorized to share files/folders without a password",
|
||||
http.StatusForbidden)
|
||||
return
|
||||
|
@ -155,7 +156,7 @@ func updateShare(w http.ResponseWriter, r *http.Request) {
|
|||
updatedShare.Password = share.Password
|
||||
}
|
||||
if updatedShare.Password == "" {
|
||||
if util.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) {
|
||||
if slices.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) {
|
||||
sendAPIResponse(w, r, nil, "You are not authorized to share files/folders without a password",
|
||||
http.StatusForbidden)
|
||||
return
|
||||
|
@ -379,6 +380,12 @@ func (s *httpdServer) uploadFilesToShare(w http.ResponseWriter, r *http.Request)
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err := common.Connections.IsNewTransferAllowed(connection.User.Username); err != nil {
|
||||
connection.Log(logger.LevelInfo, "denying file write due to number of transfer limits")
|
||||
sendAPIResponse(w, r, err, "Denying file write due to transfer count limits",
|
||||
http.StatusConflict)
|
||||
return
|
||||
}
|
||||
|
||||
transferQuota := connection.GetTransferQuota()
|
||||
if !transferQuota.HasUploadSpace() {
|
||||
|
@ -425,36 +432,42 @@ func (s *httpdServer) uploadFilesToShare(w http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
}
|
||||
|
||||
func (s *httpdServer) getShareClaims(r *http.Request, shareID string) (context.Context, *jwtTokenClaims, error) {
|
||||
token, err := jwtauth.VerifyRequest(s.tokenAuth, r, jwtauth.TokenFromCookie)
|
||||
if err != nil || token == nil {
|
||||
return nil, nil, errInvalidToken
|
||||
}
|
||||
tokenString := jwtauth.TokenFromCookie(r)
|
||||
if tokenString == "" || invalidatedJWTTokens.Get(tokenString) {
|
||||
return nil, nil, errInvalidToken
|
||||
}
|
||||
if !slices.Contains(token.Audience(), tokenAudienceWebShare) {
|
||||
logger.Debug(logSender, "", "invalid token audience for share %q", shareID)
|
||||
return nil, nil, errInvalidToken
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := validateIPForToken(token, ipAddr); err != nil {
|
||||
logger.Debug(logSender, "", "token for share %q is not valid for the ip address %q", shareID, ipAddr)
|
||||
return nil, nil, err
|
||||
}
|
||||
ctx := jwtauth.NewContext(r.Context(), token, nil)
|
||||
claims, err := getTokenClaims(r.WithContext(ctx))
|
||||
if err != nil || claims.Username != shareID {
|
||||
logger.Debug(logSender, "", "token not valid for share %q", shareID)
|
||||
return nil, nil, errInvalidToken
|
||||
}
|
||||
return ctx, &claims, nil
|
||||
}
|
||||
|
||||
func (s *httpdServer) checkWebClientShareCredentials(w http.ResponseWriter, r *http.Request, share *dataprovider.Share) error {
|
||||
doRedirect := func() {
|
||||
redirectURL := path.Join(webClientPubSharesPath, share.ShareID, fmt.Sprintf("login?next=%s", url.QueryEscape(r.RequestURI)))
|
||||
http.Redirect(w, r, redirectURL, http.StatusFound)
|
||||
}
|
||||
|
||||
token, err := jwtauth.VerifyRequest(s.tokenAuth, r, jwtauth.TokenFromCookie)
|
||||
if err != nil || token == nil {
|
||||
if _, _, err := s.getShareClaims(r, share.ShareID); err != nil {
|
||||
doRedirect()
|
||||
return errInvalidToken
|
||||
}
|
||||
if !util.Contains(token.Audience(), tokenAudienceWebShare) {
|
||||
logger.Debug(logSender, "", "invalid token audience for share %q", share.ShareID)
|
||||
doRedirect()
|
||||
return errInvalidToken
|
||||
}
|
||||
if tokenValidationMode != tokenValidationNoIPMatch {
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if !util.Contains(token.Audience(), ipAddr) {
|
||||
logger.Debug(logSender, "", "token for share %q is not valid for the ip address %q", share.ShareID, ipAddr)
|
||||
doRedirect()
|
||||
return errInvalidToken
|
||||
}
|
||||
}
|
||||
ctx := jwtauth.NewContext(r.Context(), token, nil)
|
||||
claims, err := getTokenClaims(r.WithContext(ctx))
|
||||
if err != nil || claims.Username != share.ShareID {
|
||||
logger.Debug(logSender, "", "token not valid for share %q", share.ShareID)
|
||||
doRedirect()
|
||||
return errInvalidToken
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -480,7 +493,7 @@ func (s *httpdServer) checkPublicShare(w http.ResponseWriter, r *http.Request, v
|
|||
renderError(err, "", statusCode)
|
||||
return share, nil, err
|
||||
}
|
||||
if !util.Contains(validScopes, share.Scope) {
|
||||
if !slices.Contains(validScopes, share.Scope) {
|
||||
err := errors.New("invalid share scope")
|
||||
renderError(util.NewI18nError(err, util.I18nErrorShareScope), "", http.StatusForbidden)
|
||||
return share, nil, err
|
||||
|
@ -537,7 +550,7 @@ func getUserForShare(share dataprovider.Share) (dataprovider.User, error) {
|
|||
if !user.CanManageShares() {
|
||||
return user, util.NewI18nError(util.NewRecordNotFoundError("this share does not exist"), util.I18nError404Message)
|
||||
}
|
||||
if share.Password == "" && util.Contains(user.Filters.WebClient, sdk.WebClientShareNoPasswordDisabled) {
|
||||
if share.Password == "" && slices.Contains(user.Filters.WebClient, sdk.WebClientShareNoPasswordDisabled) {
|
||||
return user, util.NewI18nError(
|
||||
fmt.Errorf("sharing without a password was disabled: %w", os.ErrPermission),
|
||||
util.I18nError403Message,
|
||||
|
|
|
@ -278,6 +278,9 @@ func updateEncryptedSecrets(fsConfig *vfs.Filesystem, currentFsConfig *vfs.Files
|
|||
if fsConfig.S3Config.AccessSecret.IsNotPlainAndNotEmpty() {
|
||||
fsConfig.S3Config.AccessSecret = currentFsConfig.S3Config.AccessSecret
|
||||
}
|
||||
if fsConfig.S3Config.SSECustomerKey.IsNotPlainAndNotEmpty() {
|
||||
fsConfig.S3Config.SSECustomerKey = currentFsConfig.S3Config.SSECustomerKey
|
||||
}
|
||||
case sdk.AzureBlobFilesystemProvider:
|
||||
if fsConfig.AzBlobConfig.AccountKey.IsNotPlainAndNotEmpty() {
|
||||
fsConfig.AzBlobConfig.AccountKey = currentFsConfig.AzBlobConfig.AccountKey
|
||||
|
|
|
@ -27,6 +27,7 @@ import (
|
|||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -71,8 +72,9 @@ type adminProfile struct {
|
|||
|
||||
type userProfile struct {
|
||||
baseProfile
|
||||
PublicKeys []string `json:"public_keys,omitempty"`
|
||||
TLSCerts []string `json:"tls_certs,omitempty"`
|
||||
AdditionalEmails []string `json:"additional_emails,omitempty"`
|
||||
PublicKeys []string `json:"public_keys,omitempty"`
|
||||
TLSCerts []string `json:"tls_certs,omitempty"`
|
||||
}
|
||||
|
||||
func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) {
|
||||
|
@ -363,6 +365,16 @@ func streamJSONArray(w http.ResponseWriter, chunkSize int, dataGetter func(limit
|
|||
streamData(w, []byte("]"))
|
||||
}
|
||||
|
||||
func renderPNGImage(w http.ResponseWriter, r *http.Request, b []byte) {
|
||||
if len(b) == 0 {
|
||||
ctx := context.WithValue(r.Context(), render.StatusCtxKey, http.StatusNotFound)
|
||||
render.PlainText(w, r.WithContext(ctx), http.StatusText(http.StatusNotFound))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "image/png")
|
||||
streamData(w, b)
|
||||
}
|
||||
|
||||
func getCompressedFileName(username string, files []string) string {
|
||||
if len(files) == 1 {
|
||||
name := path.Base(files[0])
|
||||
|
@ -691,7 +703,7 @@ func handleDefenderEventLoginFailed(ipAddr string, err error) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func updateLoginMetrics(user *dataprovider.User, loginMethod, ip string, err error) {
|
||||
func updateLoginMetrics(user *dataprovider.User, loginMethod, ip string, err error, r *http.Request) {
|
||||
metric.AddLoginAttempt(loginMethod)
|
||||
var protocol string
|
||||
switch loginMethod {
|
||||
|
@ -701,6 +713,7 @@ func updateLoginMetrics(user *dataprovider.User, loginMethod, ip string, err err
|
|||
protocol = common.ProtocolHTTP
|
||||
}
|
||||
if err == nil {
|
||||
logger.LoginLog(user.Username, ip, loginMethod, protocol, "", r.UserAgent(), r.TLS != nil, "")
|
||||
plugin.Handler.NotifyLogEvent(notifier.LogEventTypeLoginOK, protocol, user.Username, ip, "", nil)
|
||||
common.DelayLogin(nil)
|
||||
} else if err != common.ErrInternalFailure && err != common.ErrNoCredentials {
|
||||
|
@ -717,7 +730,7 @@ func updateLoginMetrics(user *dataprovider.User, loginMethod, ip string, err err
|
|||
}
|
||||
|
||||
func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string, checkSessions bool) error {
|
||||
if util.Contains(user.Filters.DeniedProtocols, common.ProtocolHTTP) {
|
||||
if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolHTTP) {
|
||||
logger.Info(logSender, connectionID, "cannot login user %q, protocol HTTP is not allowed", user.Username)
|
||||
return util.NewI18nError(
|
||||
fmt.Errorf("protocol HTTP is not allowed for user %q", user.Username),
|
||||
|
@ -775,7 +788,8 @@ func getActiveUser(username string, r *http.Request) (dataprovider.User, error)
|
|||
}
|
||||
|
||||
func handleForgotPassword(r *http.Request, username string, isAdmin bool) error {
|
||||
var email, subject string
|
||||
var emails []string
|
||||
var subject string
|
||||
var err error
|
||||
var admin dataprovider.Admin
|
||||
var user dataprovider.User
|
||||
|
@ -785,11 +799,13 @@ func handleForgotPassword(r *http.Request, username string, isAdmin bool) error
|
|||
}
|
||||
if isAdmin {
|
||||
admin, err = getActiveAdmin(username, util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
email = admin.Email
|
||||
if admin.Email != "" {
|
||||
emails = []string{admin.Email}
|
||||
}
|
||||
subject = fmt.Sprintf("Email Verification Code for admin %q", username)
|
||||
} else {
|
||||
user, err = getActiveUser(username, r)
|
||||
email = user.Email
|
||||
emails = user.GetEmailAddresses()
|
||||
subject = fmt.Sprintf("Email Verification Code for user %q", username)
|
||||
if err == nil {
|
||||
if !isUserAllowedToResetPassword(r, &user) {
|
||||
|
@ -810,7 +826,7 @@ func handleForgotPassword(r *http.Request, username string, isAdmin bool) error
|
|||
}
|
||||
return util.NewI18nError(util.NewGenericError("Error retrieving your account, please try again later"), util.I18nErrorGetUser)
|
||||
}
|
||||
if email == "" {
|
||||
if len(emails) == 0 {
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError("Your account does not have an email address, it is not possible to reset your password by sending an email verification code"),
|
||||
util.I18nErrorPwdResetNoEmail,
|
||||
|
@ -825,7 +841,7 @@ func handleForgotPassword(r *http.Request, username string, isAdmin bool) error
|
|||
return util.NewGenericError("Unable to render password reset template")
|
||||
}
|
||||
startTime := time.Now()
|
||||
if err := smtp.SendEmail([]string{email}, nil, subject, body.String(), smtp.EmailContentTypeTextHTML); err != nil {
|
||||
if err := smtp.SendEmail(emails, nil, subject, body.String(), smtp.EmailContentTypeTextHTML); err != nil {
|
||||
logger.Warn(logSender, middleware.GetReqID(r.Context()), "unable to send password reset code via email: %v, elapsed: %v",
|
||||
err, time.Since(startTime))
|
||||
return util.NewI18nError(
|
||||
|
@ -833,8 +849,8 @@ func handleForgotPassword(r *http.Request, username string, isAdmin bool) error
|
|||
util.I18nErrorPwdResetSendEmail,
|
||||
)
|
||||
}
|
||||
logger.Debug(logSender, middleware.GetReqID(r.Context()), "reset code sent via email to %q, email: %q, is admin? %v, elapsed: %v",
|
||||
username, email, isAdmin, time.Since(startTime))
|
||||
logger.Debug(logSender, middleware.GetReqID(r.Context()), "reset code sent via email to %q, emails: %+v, is admin? %v, elapsed: %v",
|
||||
username, emails, isAdmin, time.Since(startTime))
|
||||
return resetCodesMgr.Add(c)
|
||||
}
|
||||
|
||||
|
@ -902,7 +918,7 @@ func isUserAllowedToResetPassword(r *http.Request, user *dataprovider.User) bool
|
|||
if !user.CanResetPassword() {
|
||||
return false
|
||||
}
|
||||
if util.Contains(user.Filters.DeniedProtocols, common.ProtocolHTTP) {
|
||||
if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolHTTP) {
|
||||
return false
|
||||
}
|
||||
if !user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolHTTP) {
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/jwtauth/v5"
|
||||
|
@ -41,13 +42,13 @@ const (
|
|||
tokenAudienceAPIUser tokenAudience = "APIUser"
|
||||
tokenAudienceCSRF tokenAudience = "CSRF"
|
||||
tokenAudienceOAuth2 tokenAudience = "OAuth2"
|
||||
tokenAudienceWebLogin tokenAudience = "WebLogin"
|
||||
)
|
||||
|
||||
type tokenValidation = int
|
||||
|
||||
const (
|
||||
tokenValidationFull = iota
|
||||
tokenValidationNoIPMatch tokenValidation = iota
|
||||
tokenValidationModeDefault = 0
|
||||
tokenValidationModeNoIPMatch = 1
|
||||
tokenValidationModeUserSignature = 2
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -60,20 +61,74 @@ const (
|
|||
claimMustSetSecondFactorKey = "2fa_required"
|
||||
claimRequiredTwoFactorProtocols = "2fa_protos"
|
||||
claimHideUserPageSection = "hus"
|
||||
claimRef = "ref"
|
||||
basicRealm = "Basic realm=\"SFTPGo\""
|
||||
jwtCookieKey = "jwt"
|
||||
)
|
||||
|
||||
var (
|
||||
tokenDuration = 20 * time.Minute
|
||||
shareTokenDuration = 2 * time.Hour
|
||||
apiTokenDuration = 20 * time.Minute
|
||||
cookieTokenDuration = 20 * time.Minute
|
||||
shareTokenDuration = 2 * time.Hour
|
||||
// csrf token duration is greater than normal token duration to reduce issues
|
||||
// with the login form
|
||||
csrfTokenDuration = 6 * time.Hour
|
||||
tokenRefreshThreshold = 10 * time.Minute
|
||||
tokenValidationMode = tokenValidationFull
|
||||
csrfTokenDuration = 4 * time.Hour
|
||||
cookieRefreshThreshold = 10 * time.Minute
|
||||
maxTokenDuration = 12 * time.Hour
|
||||
tokenValidationMode = tokenValidationModeDefault
|
||||
)
|
||||
|
||||
func isTokenDurationValid(minutes int) bool {
|
||||
return minutes >= 1 && minutes <= 720
|
||||
}
|
||||
|
||||
func updateTokensDuration(api, cookie, share int) {
|
||||
if isTokenDurationValid(api) {
|
||||
apiTokenDuration = time.Duration(api) * time.Minute
|
||||
}
|
||||
if isTokenDurationValid(cookie) {
|
||||
cookieTokenDuration = time.Duration(cookie) * time.Minute
|
||||
cookieRefreshThreshold = cookieTokenDuration / 2
|
||||
if cookieTokenDuration > csrfTokenDuration {
|
||||
csrfTokenDuration = cookieTokenDuration
|
||||
}
|
||||
}
|
||||
if isTokenDurationValid(share) {
|
||||
shareTokenDuration = time.Duration(share) * time.Minute
|
||||
}
|
||||
logger.Debug(logSender, "", "API token duration %s, cookie token duration %s, cookie refresh threshold %s, share token duration %s",
|
||||
apiTokenDuration, cookieTokenDuration, cookieRefreshThreshold, shareTokenDuration)
|
||||
}
|
||||
|
||||
func getTokenDuration(audience tokenAudience) time.Duration {
|
||||
switch audience {
|
||||
case tokenAudienceWebShare:
|
||||
return shareTokenDuration
|
||||
case tokenAudienceWebLogin, tokenAudienceCSRF:
|
||||
return csrfTokenDuration
|
||||
case tokenAudienceAPI, tokenAudienceAPIUser:
|
||||
return apiTokenDuration
|
||||
case tokenAudienceWebAdmin, tokenAudienceWebClient:
|
||||
return cookieTokenDuration
|
||||
case tokenAudienceWebAdminPartial, tokenAudienceWebClientPartial, tokenAudienceOAuth2:
|
||||
return 5 * time.Minute
|
||||
default:
|
||||
logger.Error(logSender, "", "token duration not handled for audience: %q", audience)
|
||||
return 20 * time.Minute
|
||||
}
|
||||
}
|
||||
|
||||
func getMaxCookieDuration() time.Duration {
|
||||
result := csrfTokenDuration
|
||||
if shareTokenDuration > result {
|
||||
result = shareTokenDuration
|
||||
}
|
||||
if cookieTokenDuration > result {
|
||||
result = cookieTokenDuration
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type jwtTokenClaims struct {
|
||||
Username string
|
||||
Permissions []string
|
||||
|
@ -86,6 +141,9 @@ type jwtTokenClaims struct {
|
|||
MustChangePassword bool
|
||||
RequiredTwoFactorProtocols []string
|
||||
HideUserPageSections int
|
||||
JwtID string
|
||||
JwtIssuedAt time.Time
|
||||
Ref string
|
||||
}
|
||||
|
||||
func (c *jwtTokenClaims) hasUserAudience() bool {
|
||||
|
@ -103,6 +161,15 @@ func (c *jwtTokenClaims) asMap() map[string]any {
|
|||
|
||||
claims[claimUsernameKey] = c.Username
|
||||
claims[claimPermissionsKey] = c.Permissions
|
||||
if c.JwtID != "" {
|
||||
claims[jwt.JwtIDKey] = c.JwtID
|
||||
}
|
||||
if !c.JwtIssuedAt.IsZero() {
|
||||
claims[jwt.IssuedAtKey] = c.JwtIssuedAt
|
||||
}
|
||||
if c.Ref != "" {
|
||||
claims[claimRef] = c.Ref
|
||||
}
|
||||
if c.Role != "" {
|
||||
claims[claimRole] = c.Role
|
||||
}
|
||||
|
@ -169,6 +236,7 @@ func (c *jwtTokenClaims) Decode(token map[string]any) {
|
|||
c.Permissions = nil
|
||||
c.Username = c.decodeString(token[claimUsernameKey])
|
||||
c.Signature = c.decodeString(token[jwt.SubjectKey])
|
||||
c.JwtID = c.decodeString(token[jwt.JwtIDKey])
|
||||
|
||||
audience := token[jwt.AudienceKey]
|
||||
switch v := audience.(type) {
|
||||
|
@ -176,6 +244,10 @@ func (c *jwtTokenClaims) Decode(token map[string]any) {
|
|||
c.Audience = v
|
||||
}
|
||||
|
||||
if val, ok := token[claimRef]; ok {
|
||||
c.Ref = c.decodeString(val)
|
||||
}
|
||||
|
||||
if val, ok := token[claimAPIKey]; ok {
|
||||
c.APIKeyID = c.decodeString(val)
|
||||
}
|
||||
|
@ -212,20 +284,25 @@ func (c *jwtTokenClaims) Decode(token map[string]any) {
|
|||
}
|
||||
|
||||
func (c *jwtTokenClaims) hasPerm(perm string) bool {
|
||||
if util.Contains(c.Permissions, dataprovider.PermAdminAny) {
|
||||
if slices.Contains(c.Permissions, dataprovider.PermAdminAny) {
|
||||
return true
|
||||
}
|
||||
|
||||
return util.Contains(c.Permissions, perm)
|
||||
return slices.Contains(c.Permissions, perm)
|
||||
}
|
||||
|
||||
func (c *jwtTokenClaims) createToken(tokenAuth *jwtauth.JWTAuth, audience tokenAudience, ip string) (jwt.Token, string, error) {
|
||||
claims := c.asMap()
|
||||
now := time.Now().UTC()
|
||||
|
||||
claims[jwt.JwtIDKey] = xid.New().String()
|
||||
if _, ok := claims[jwt.JwtIDKey]; !ok {
|
||||
claims[jwt.JwtIDKey] = xid.New().String()
|
||||
}
|
||||
if _, ok := claims[jwt.IssuedAtKey]; !ok {
|
||||
claims[jwt.IssuedAtKey] = now
|
||||
}
|
||||
claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second)
|
||||
claims[jwt.ExpirationKey] = now.Add(tokenDuration)
|
||||
claims[jwt.ExpirationKey] = now.Add(getTokenDuration(audience))
|
||||
claims[jwt.AudienceKey] = []string{audience, ip}
|
||||
|
||||
return tokenAuth.Encode(claims)
|
||||
|
@ -257,25 +334,26 @@ func (c *jwtTokenClaims) createAndSetCookie(w http.ResponseWriter, r *http.Reque
|
|||
} else {
|
||||
basePath = webBaseClientPath
|
||||
}
|
||||
duration := tokenDuration
|
||||
if audience == tokenAudienceWebShare {
|
||||
duration = shareTokenDuration
|
||||
}
|
||||
setCookie(w, r, basePath, resp["access_token"].(string), getTokenDuration(audience))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setCookie(w http.ResponseWriter, r *http.Request, cookiePath, cookieValue string, duration time.Duration) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: jwtCookieKey,
|
||||
Value: resp["access_token"].(string),
|
||||
Path: basePath,
|
||||
Value: cookieValue,
|
||||
Path: cookiePath,
|
||||
Expires: time.Now().Add(duration),
|
||||
MaxAge: int(duration / time.Second),
|
||||
HttpOnly: true,
|
||||
Secure: isTLS(r),
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *jwtTokenClaims) removeCookie(w http.ResponseWriter, r *http.Request, cookiePath string) {
|
||||
func removeCookie(w http.ResponseWriter, r *http.Request, cookiePath string) {
|
||||
invalidateToken(r)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: jwtCookieKey,
|
||||
Value: "",
|
||||
|
@ -287,10 +365,9 @@ func (c *jwtTokenClaims) removeCookie(w http.ResponseWriter, r *http.Request, co
|
|||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`)
|
||||
invalidateToken(r)
|
||||
}
|
||||
|
||||
func tokenFromContext(r *http.Request) string {
|
||||
func oidcTokenFromContext(r *http.Request) string {
|
||||
if token, ok := r.Context().Value(oidcGeneratedToken).(string); ok {
|
||||
return token
|
||||
}
|
||||
|
@ -311,7 +388,7 @@ func isTokenInvalidated(r *http.Request) bool {
|
|||
var findTokenFns []func(r *http.Request) string
|
||||
findTokenFns = append(findTokenFns, jwtauth.TokenFromHeader)
|
||||
findTokenFns = append(findTokenFns, jwtauth.TokenFromCookie)
|
||||
findTokenFns = append(findTokenFns, tokenFromContext)
|
||||
findTokenFns = append(findTokenFns, oidcTokenFromContext)
|
||||
|
||||
isTokenFound := false
|
||||
for _, fn := range findTokenFns {
|
||||
|
@ -330,14 +407,23 @@ func isTokenInvalidated(r *http.Request) bool {
|
|||
func invalidateToken(r *http.Request) {
|
||||
tokenString := jwtauth.TokenFromHeader(r)
|
||||
if tokenString != "" {
|
||||
invalidatedJWTTokens.Add(tokenString, time.Now().Add(tokenDuration).UTC())
|
||||
invalidateTokenString(r, tokenString, apiTokenDuration)
|
||||
}
|
||||
tokenString = jwtauth.TokenFromCookie(r)
|
||||
if tokenString != "" {
|
||||
invalidatedJWTTokens.Add(tokenString, time.Now().Add(tokenDuration).UTC())
|
||||
invalidateTokenString(r, tokenString, getMaxCookieDuration())
|
||||
}
|
||||
}
|
||||
|
||||
func invalidateTokenString(r *http.Request, tokenString string, fallbackDuration time.Duration) {
|
||||
token, _, err := jwtauth.FromContext(r.Context())
|
||||
if err != nil || token == nil {
|
||||
invalidatedJWTTokens.Add(tokenString, time.Now().Add(fallbackDuration).UTC())
|
||||
return
|
||||
}
|
||||
invalidatedJWTTokens.Add(tokenString, token.Expiration().Add(1*time.Minute).UTC())
|
||||
}
|
||||
|
||||
func getUserFromToken(r *http.Request) *dataprovider.User {
|
||||
user := &dataprovider.User{}
|
||||
_, claims, err := jwtauth.FromContext(r.Context())
|
||||
|
@ -367,15 +453,40 @@ func getAdminFromToken(r *http.Request) *dataprovider.Admin {
|
|||
return admin
|
||||
}
|
||||
|
||||
func createCSRFToken(ip string) string {
|
||||
func createLoginCookie(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwtauth.JWTAuth, tokenID, basePath, ip string,
|
||||
) {
|
||||
c := jwtTokenClaims{
|
||||
JwtID: tokenID,
|
||||
}
|
||||
resp, err := c.createTokenResponse(csrfTokenAuth, tokenAudienceWebLogin, ip)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
setCookie(w, r, basePath, resp["access_token"].(string), csrfTokenDuration)
|
||||
}
|
||||
|
||||
func createCSRFToken(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwtauth.JWTAuth, tokenID,
|
||||
basePath string,
|
||||
) string {
|
||||
ip := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
claims := make(map[string]any)
|
||||
now := time.Now().UTC()
|
||||
|
||||
claims[jwt.JwtIDKey] = xid.New().String()
|
||||
claims[jwt.IssuedAtKey] = now
|
||||
claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second)
|
||||
claims[jwt.ExpirationKey] = now.Add(csrfTokenDuration)
|
||||
claims[jwt.AudienceKey] = []string{tokenAudienceCSRF, ip}
|
||||
|
||||
if tokenID != "" {
|
||||
createLoginCookie(w, r, csrfTokenAuth, tokenID, basePath, ip)
|
||||
claims[claimRef] = tokenID
|
||||
} else {
|
||||
if c, err := getTokenClaims(r); err == nil {
|
||||
claims[claimRef] = c.JwtID
|
||||
} else {
|
||||
logger.Error(logSender, "", "unable to add reference to CSRF token: %v", err)
|
||||
}
|
||||
}
|
||||
_, tokenString, err := csrfTokenAuth.Encode(claims)
|
||||
if err != nil {
|
||||
logger.Debug(logSender, "", "unable to create CSRF token: %v", err)
|
||||
|
@ -384,35 +495,84 @@ func createCSRFToken(ip string) string {
|
|||
return tokenString
|
||||
}
|
||||
|
||||
func verifyCSRFToken(tokenString, ip string) error {
|
||||
func verifyCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAuth) error {
|
||||
tokenString := r.Form.Get(csrfFormToken)
|
||||
token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString)
|
||||
if err != nil || token == nil {
|
||||
logger.Debug(logSender, "", "error validating CSRF token %q: %v", tokenString, err)
|
||||
return fmt.Errorf("unable to verify form token: %v", err)
|
||||
}
|
||||
|
||||
if !util.Contains(token.Audience(), tokenAudienceCSRF) {
|
||||
if !slices.Contains(token.Audience(), tokenAudienceCSRF) {
|
||||
logger.Debug(logSender, "", "error validating CSRF token audience")
|
||||
return errors.New("the form token is not valid")
|
||||
}
|
||||
|
||||
if tokenValidationMode != tokenValidationNoIPMatch {
|
||||
if !util.Contains(token.Audience(), ip) {
|
||||
logger.Debug(logSender, "", "error validating CSRF token IP audience")
|
||||
return errors.New("the form token is not valid")
|
||||
}
|
||||
if err := validateIPForToken(token, util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil {
|
||||
logger.Debug(logSender, "", "error validating CSRF token IP audience")
|
||||
return errors.New("the form token is not valid")
|
||||
}
|
||||
return checkCSRFTokenRef(r, token)
|
||||
}
|
||||
|
||||
func checkCSRFTokenRef(r *http.Request, token jwt.Token) error {
|
||||
claims, err := getTokenClaims(r)
|
||||
if err != nil {
|
||||
logger.Debug(logSender, "", "error getting token claims for CSRF validation: %v", err)
|
||||
return err
|
||||
}
|
||||
ref, ok := token.Get(claimRef)
|
||||
if !ok {
|
||||
logger.Debug(logSender, "", "error validating CSRF token, missing reference")
|
||||
return errors.New("the form token is not valid")
|
||||
}
|
||||
if claims.JwtID == "" || claims.JwtID != ref.(string) {
|
||||
logger.Debug(logSender, "", "error validating CSRF reference, id %q, reference %q", claims.JwtID, ref)
|
||||
return errors.New("unexpected form token")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createOAuth2Token(state, ip string) string {
|
||||
func verifyLoginCookie(r *http.Request) error {
|
||||
token, _, err := jwtauth.FromContext(r.Context())
|
||||
if err != nil || token == nil {
|
||||
logger.Debug(logSender, "", "error getting login token: %v", err)
|
||||
return errInvalidToken
|
||||
}
|
||||
if isTokenInvalidated(r) {
|
||||
logger.Debug(logSender, "", "the login token has been invalidated")
|
||||
return errInvalidToken
|
||||
}
|
||||
if !slices.Contains(token.Audience(), tokenAudienceWebLogin) {
|
||||
logger.Debug(logSender, "", "the token with id %q is not valid for audience %q", token.JwtID(), tokenAudienceWebLogin)
|
||||
return errInvalidToken
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := validateIPForToken(token, ipAddr); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyLoginCookieAndCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAuth) error {
|
||||
if err := verifyLoginCookie(r); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := verifyCSRFToken(r, csrfTokenAuth); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func createOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, state, ip string) string {
|
||||
claims := make(map[string]any)
|
||||
now := time.Now().UTC()
|
||||
|
||||
claims[jwt.JwtIDKey] = state
|
||||
claims[jwt.IssuedAtKey] = now
|
||||
claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second)
|
||||
claims[jwt.ExpirationKey] = now.Add(3 * time.Minute)
|
||||
claims[jwt.ExpirationKey] = now.Add(getTokenDuration(tokenAudienceOAuth2))
|
||||
claims[jwt.AudienceKey] = []string{tokenAudienceOAuth2, ip}
|
||||
|
||||
_, tokenString, err := csrfTokenAuth.Encode(claims)
|
||||
|
@ -423,7 +583,7 @@ func createOAuth2Token(state, ip string) string {
|
|||
return tokenString
|
||||
}
|
||||
|
||||
func verifyOAuth2Token(tokenString, ip string) (string, error) {
|
||||
func verifyOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, tokenString, ip string) (string, error) {
|
||||
token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString)
|
||||
if err != nil || token == nil {
|
||||
logger.Debug(logSender, "", "error validating OAuth2 token %q: %v", tokenString, err)
|
||||
|
@ -433,16 +593,14 @@ func verifyOAuth2Token(tokenString, ip string) (string, error) {
|
|||
)
|
||||
}
|
||||
|
||||
if !util.Contains(token.Audience(), tokenAudienceOAuth2) {
|
||||
if !slices.Contains(token.Audience(), tokenAudienceOAuth2) {
|
||||
logger.Debug(logSender, "", "error validating OAuth2 token audience")
|
||||
return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState)
|
||||
}
|
||||
|
||||
if tokenValidationMode != tokenValidationNoIPMatch {
|
||||
if !util.Contains(token.Audience(), ip) {
|
||||
logger.Debug(logSender, "", "error validating OAuth2 token IP audience")
|
||||
return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState)
|
||||
}
|
||||
if err := validateIPForToken(token, ip); err != nil {
|
||||
logger.Debug(logSender, "", "error validating OAuth2 token IP audience")
|
||||
return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState)
|
||||
}
|
||||
if val, ok := token.Get(jwt.JwtIDKey); ok {
|
||||
if state, ok := val.(string); ok {
|
||||
|
@ -452,3 +610,53 @@ func verifyOAuth2Token(tokenString, ip string) (string, error) {
|
|||
logger.Debug(logSender, "", "jti not found in OAuth2 token")
|
||||
return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState)
|
||||
}
|
||||
|
||||
func validateIPForToken(token jwt.Token, ip string) error {
|
||||
if tokenValidationMode&tokenValidationModeNoIPMatch == 0 {
|
||||
if !slices.Contains(token.Audience(), ip) {
|
||||
return errInvalidToken
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkTokenSignature(r *http.Request, token jwt.Token) error {
|
||||
if _, ok := r.Context().Value(oidcTokenKey).(string); ok {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
if tokenValidationMode&tokenValidationModeUserSignature != 0 {
|
||||
for _, audience := range token.Audience() {
|
||||
switch audience {
|
||||
case tokenAudienceAPI, tokenAudienceWebAdmin:
|
||||
err = validateSignatureForToken(token, dataprovider.GetAdminSignature)
|
||||
case tokenAudienceAPIUser, tokenAudienceWebClient:
|
||||
err = validateSignatureForToken(token, dataprovider.GetUserSignature)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
invalidateToken(r)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func validateSignatureForToken(token jwt.Token, getter func(string) (string, error)) error {
|
||||
username := ""
|
||||
if u, ok := token.Get(claimUsernameKey); ok {
|
||||
c := jwtTokenClaims{}
|
||||
username = c.decodeString(u)
|
||||
}
|
||||
|
||||
signature, err := getter(username)
|
||||
if err != nil {
|
||||
logger.Debug(logSender, "", "unable to get signature for username %q: %v", username, err)
|
||||
return errInvalidToken
|
||||
}
|
||||
if signature != "" && signature == token.Subject() {
|
||||
return nil
|
||||
}
|
||||
logger.Debug(logSender, "", "signature mismatch for username %q, signature %q, token signature %q",
|
||||
username, signature, token.Subject())
|
||||
return errInvalidToken
|
||||
}
|
||||
|
|
|
@ -97,6 +97,11 @@ func (c *Connection) ReadDir(name string) (vfs.DirLister, error) {
|
|||
func (c *Connection) getFileReader(name string, offset int64, method string) (io.ReadCloser, error) {
|
||||
c.UpdateLastActivity()
|
||||
|
||||
if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil {
|
||||
c.Log(logger.LevelInfo, "denying file read due to transfer count limits")
|
||||
return nil, util.NewI18nError(c.GetPermissionDeniedError(), util.I18nError403Message)
|
||||
}
|
||||
|
||||
transferQuota := c.GetTransferQuota()
|
||||
if !transferQuota.HasDownloadSpace() {
|
||||
c.Log(logger.LevelInfo, "denying file read due to quota limits")
|
||||
|
@ -176,7 +181,7 @@ func (c *Connection) getFileWriter(name string) (io.WriteCloser, error) {
|
|||
}
|
||||
|
||||
if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() {
|
||||
_, _, err = fs.Rename(p, filePath)
|
||||
_, _, err = fs.Rename(p, filePath, 0)
|
||||
if err != nil {
|
||||
c.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %q, dest: %q, err: %+v",
|
||||
p, filePath, err)
|
||||
|
@ -188,6 +193,10 @@ func (c *Connection) getFileWriter(name string) (io.WriteCloser, error) {
|
|||
}
|
||||
|
||||
func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, requestPath string, isNewFile bool, fileSize int64) (io.WriteCloser, error) {
|
||||
if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil {
|
||||
c.Log(logger.LevelInfo, "denying file write due to transfer count limits")
|
||||
return nil, util.NewI18nError(c.GetPermissionDeniedError(), util.I18nError403Message)
|
||||
}
|
||||
diskQuota, transferQuota := c.HasSpace(isNewFile, false, requestPath)
|
||||
if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() {
|
||||
c.Log(logger.LevelInfo, "denying file write due to quota limits")
|
||||
|
@ -213,10 +222,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
|
|||
if vfs.HasTruncateSupport(fs) {
|
||||
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
|
||||
if err == nil {
|
||||
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
|
||||
if vfolder.IsIncludedInUserQuota() {
|
||||
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
|
||||
}
|
||||
dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false)
|
||||
} else {
|
||||
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
|
||||
}
|
||||
|
|
|
@ -28,11 +28,10 @@ import (
|
|||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/jwtauth/v5"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/acme"
|
||||
"github.com/drakkan/sftpgo/v2/internal/common"
|
||||
|
@ -196,7 +195,6 @@ var (
|
|||
cleanupTicker *time.Ticker
|
||||
cleanupDone chan bool
|
||||
invalidatedJWTTokens tokenManager
|
||||
csrfTokenAuth *jwtauth.JWTAuth
|
||||
webRootPath string
|
||||
webBasePath string
|
||||
webBaseAdminPath string
|
||||
|
@ -288,12 +286,88 @@ var (
|
|||
installationCodeHint string
|
||||
fnInstallationCodeResolver FnInstallationCodeResolver
|
||||
configurationDir string
|
||||
dbBrandingConfig brandingCache
|
||||
)
|
||||
|
||||
func init() {
|
||||
updateWebAdminURLs("")
|
||||
updateWebClientURLs("")
|
||||
acme.SetReloadHTTPDCertsFn(ReloadCertificateMgr)
|
||||
common.SetUpdateBrandingFn(dbBrandingConfig.Set)
|
||||
}
|
||||
|
||||
type brandingCache struct {
|
||||
mu sync.RWMutex
|
||||
configs *dataprovider.BrandingConfigs
|
||||
}
|
||||
|
||||
func (b *brandingCache) Set(configs *dataprovider.BrandingConfigs) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
b.configs = configs
|
||||
}
|
||||
|
||||
func (b *brandingCache) getWebAdminLogo() []byte {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
return b.configs.WebAdmin.Logo
|
||||
}
|
||||
|
||||
func (b *brandingCache) getWebAdminFavicon() []byte {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
return b.configs.WebAdmin.Favicon
|
||||
}
|
||||
|
||||
func (b *brandingCache) getWebClientLogo() []byte {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
return b.configs.WebClient.Logo
|
||||
}
|
||||
|
||||
func (b *brandingCache) getWebClientFavicon() []byte {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
return b.configs.WebClient.Favicon
|
||||
}
|
||||
|
||||
func (b *brandingCache) mergeBrandingConfig(branding UIBranding, isWebClient bool) UIBranding {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
var urlPrefix string
|
||||
var cfg dataprovider.BrandingConfig
|
||||
if isWebClient {
|
||||
cfg = b.configs.WebClient
|
||||
urlPrefix = "webclient"
|
||||
} else {
|
||||
cfg = b.configs.WebAdmin
|
||||
urlPrefix = "webadmin"
|
||||
}
|
||||
if cfg.Name != "" {
|
||||
branding.Name = cfg.Name
|
||||
}
|
||||
if cfg.ShortName != "" {
|
||||
branding.ShortName = cfg.ShortName
|
||||
}
|
||||
if cfg.DisclaimerName != "" {
|
||||
branding.DisclaimerName = cfg.DisclaimerName
|
||||
}
|
||||
if cfg.DisclaimerURL != "" {
|
||||
branding.DisclaimerPath = cfg.DisclaimerURL
|
||||
}
|
||||
if len(cfg.Logo) > 0 {
|
||||
branding.LogoPath = path.Join("/", "branding", urlPrefix, "logo.png")
|
||||
}
|
||||
if len(cfg.Favicon) > 0 {
|
||||
branding.FaviconPath = path.Join("/", "branding", urlPrefix, "favicon.png")
|
||||
}
|
||||
return branding
|
||||
}
|
||||
|
||||
// FnInstallationCodeResolver defines a method to get the installation code.
|
||||
|
@ -411,19 +485,23 @@ type UIBranding struct {
|
|||
// the default CSS files
|
||||
DefaultCSS []string `json:"default_css" mapstructure:"default_css"`
|
||||
// Additional CSS file paths, relative to "static_files_path", to include
|
||||
ExtraCSS []string `json:"extra_css" mapstructure:"extra_css"`
|
||||
ExtraCSS []string `json:"extra_css" mapstructure:"extra_css"`
|
||||
DefaultLogoPath string `json:"-" mapstructure:"-"`
|
||||
DefaultFaviconPath string `json:"-" mapstructure:"-"`
|
||||
}
|
||||
|
||||
func (b *UIBranding) check() {
|
||||
b.DefaultLogoPath = "/img/logo.png"
|
||||
b.DefaultFaviconPath = "/favicon.png"
|
||||
if b.LogoPath != "" {
|
||||
b.LogoPath = util.CleanPath(b.LogoPath)
|
||||
} else {
|
||||
b.LogoPath = "/img/logo.png"
|
||||
b.LogoPath = b.DefaultLogoPath
|
||||
}
|
||||
if b.FaviconPath != "" {
|
||||
b.FaviconPath = util.CleanPath(b.FaviconPath)
|
||||
} else {
|
||||
b.FaviconPath = "/favicon.ico"
|
||||
b.FaviconPath = b.DefaultFaviconPath
|
||||
}
|
||||
if b.DisclaimerPath != "" {
|
||||
if !strings.HasPrefix(b.DisclaimerPath, "https://") && !strings.HasPrefix(b.DisclaimerPath, "http://") {
|
||||
|
@ -508,6 +586,9 @@ type Binding struct {
|
|||
TLSCipherSuites []string `json:"tls_cipher_suites" mapstructure:"tls_cipher_suites"`
|
||||
// HTTP protocols in preference order. Supported values: http/1.1, h2
|
||||
Protocols []string `json:"tls_protocols" mapstructure:"tls_protocols"`
|
||||
// Defines whether to use the common proxy protocol configuration or the
|
||||
// binding-specific proxy header configuration.
|
||||
ProxyMode int `json:"proxy_mode" mapstructure:"proxy_mode"`
|
||||
// List of IP addresses and IP ranges allowed to set client IP proxy headers and
|
||||
// X-Forwarded-Proto header.
|
||||
ProxyAllowed []string `json:"proxy_allowed" mapstructure:"proxy_allowed"`
|
||||
|
@ -527,6 +608,8 @@ type Binding struct {
|
|||
HideLoginURL int `json:"hide_login_url" mapstructure:"hide_login_url"`
|
||||
// Enable the built-in OpenAPI renderer
|
||||
RenderOpenAPI bool `json:"render_openapi" mapstructure:"render_openapi"`
|
||||
// Languages defines the list of enabled translations for the WebAdmin and WebClient UI.
|
||||
Languages []string `json:"languages" mapstructure:"languages"`
|
||||
// Defining an OIDC configuration the web admin and web client UI will use OpenID to authenticate users.
|
||||
OIDC OIDC `json:"oidc" mapstructure:"oidc"`
|
||||
// Security defines security headers to add to HTTP responses and allows to restrict allowed hosts
|
||||
|
@ -553,6 +636,18 @@ func (b *Binding) checkBranding() {
|
|||
}
|
||||
}
|
||||
|
||||
func (b *Binding) webAdminBranding() UIBranding {
|
||||
return dbBrandingConfig.mergeBrandingConfig(b.Branding.WebAdmin, false)
|
||||
}
|
||||
|
||||
func (b *Binding) webClientBranding() UIBranding {
|
||||
return dbBrandingConfig.mergeBrandingConfig(b.Branding.WebClient, true)
|
||||
}
|
||||
|
||||
func (b *Binding) languages() []string {
|
||||
return b.Languages
|
||||
}
|
||||
|
||||
func (b *Binding) parseAllowedProxy() error {
|
||||
if filepath.IsAbs(b.Address) && len(b.ProxyAllowed) > 0 {
|
||||
// unix domain socket
|
||||
|
@ -670,6 +765,13 @@ func (b *Binding) isMutualTLSEnabled() bool {
|
|||
return b.ClientAuthType == 1
|
||||
}
|
||||
|
||||
func (b *Binding) listenerWrapper() func(net.Listener) (net.Listener, error) {
|
||||
if b.ProxyMode == 1 {
|
||||
return common.Config.GetProxyListener
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type defenderStatus struct {
|
||||
IsActive bool `json:"is_active"`
|
||||
}
|
||||
|
@ -763,6 +865,12 @@ type Conf struct {
|
|||
// By default all the available security checks are enabled. Set to 1 to disable the requirement
|
||||
// that a token must be used by the same IP for which it was issued.
|
||||
TokenValidation int `json:"token_validation" mapstructure:"token_validation"`
|
||||
// CookieLifetime defines the duration of cookies for WebAdmin and WebClient
|
||||
CookieLifetime int `json:"cookie_lifetime" mapstructure:"cookie_lifetime"`
|
||||
// ShareCookieLifetime defines the duration of cookies for public shares
|
||||
ShareCookieLifetime int `json:"share_cookie_lifetime" mapstructure:"share_cookie_lifetime"`
|
||||
// JWTLifetime defines the duration of JWT tokens used in REST API
|
||||
JWTLifetime int `json:"jwt_lifetime" mapstructure:"jwt_lifetime"`
|
||||
// MaxUploadFileSize Defines the maximum request body size, in bytes, for Web Client/API HTTP upload requests.
|
||||
// 0 means no limit
|
||||
MaxUploadFileSize int64 `json:"max_upload_file_size" mapstructure:"max_upload_file_size"`
|
||||
|
@ -871,11 +979,7 @@ func (c *Conf) getKeyPairs(configDir string) []common.TLSKeyPair {
|
|||
}
|
||||
|
||||
func (c *Conf) setTokenValidationMode() {
|
||||
if c.TokenValidation == 1 {
|
||||
tokenValidationMode = tokenValidationNoIPMatch
|
||||
} else {
|
||||
tokenValidationMode = tokenValidationFull
|
||||
}
|
||||
tokenValidationMode = c.TokenValidation
|
||||
}
|
||||
|
||||
func (c *Conf) loadFromProvider() error {
|
||||
|
@ -884,6 +988,7 @@ func (c *Conf) loadFromProvider() error {
|
|||
return fmt.Errorf("unable to load config from provider: %w", err)
|
||||
}
|
||||
configs.SetNilsToEmpty()
|
||||
dbBrandingConfig.Set(configs.Branding)
|
||||
if configs.ACME.Domain == "" || !configs.ACME.HasProtocol(common.ProtocolHTTP) {
|
||||
return nil
|
||||
}
|
||||
|
@ -969,7 +1074,6 @@ func (c *Conf) Initialize(configDir string, isShared int) error {
|
|||
c.SigningPassphrase = passphrase
|
||||
}
|
||||
|
||||
csrfTokenAuth = jwtauth.New(jwa.HS256.String(), getSigningKey(c.SigningPassphrase), nil)
|
||||
hideSupportLink = c.HideSupportLink
|
||||
|
||||
exitChannel := make(chan error, 1)
|
||||
|
@ -1003,7 +1107,8 @@ func (c *Conf) Initialize(configDir string, isShared int) error {
|
|||
maxUploadFileSize = c.MaxUploadFileSize
|
||||
installationCode = c.Setup.InstallationCode
|
||||
installationCodeHint = c.Setup.InstallationCodeHint
|
||||
startCleanupTicker(tokenDuration / 2)
|
||||
updateTokensDuration(c.JWTLifetime, c.CookieLifetime, c.ShareCookieLifetime)
|
||||
startCleanupTicker(10 * time.Minute)
|
||||
c.setTokenValidationMode()
|
||||
return <-exitChannel
|
||||
}
|
||||
|
@ -1217,11 +1322,14 @@ func stopCleanupTicker() {
|
|||
}
|
||||
|
||||
func getSigningKey(signingPassphrase string) []byte {
|
||||
var key []byte
|
||||
if signingPassphrase != "" {
|
||||
sk := sha256.Sum256([]byte(signingPassphrase))
|
||||
return sk[:]
|
||||
key = []byte(signingPassphrase)
|
||||
} else {
|
||||
key = util.GenerateRandomBytes(32)
|
||||
}
|
||||
return util.GenerateRandomBytes(32)
|
||||
sk := sha256.Sum256(key)
|
||||
return sk[:]
|
||||
}
|
||||
|
||||
// SetInstallationCodeResolver sets a function to call to resolve the installation code
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load diff
|
@ -20,10 +20,10 @@ import (
|
|||
"io/fs"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/jwtauth/v5"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/rs/xid"
|
||||
"github.com/sftpgo/sdk"
|
||||
|
||||
|
@ -75,12 +75,6 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi
|
|||
return errInvalidToken
|
||||
}
|
||||
|
||||
err = jwt.Validate(token)
|
||||
if err != nil {
|
||||
logger.Debug(logSender, "", "error validating jwt token: %v", err)
|
||||
doRedirect(http.StatusText(http.StatusUnauthorized), err)
|
||||
return errInvalidToken
|
||||
}
|
||||
if isTokenInvalidated(r) {
|
||||
logger.Debug(logSender, "", "the token has been invalidated")
|
||||
doRedirect("Your token is no longer valid", nil)
|
||||
|
@ -90,18 +84,20 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi
|
|||
if err := checkPartialAuth(w, r, audience, token.Audience()); err != nil {
|
||||
return err
|
||||
}
|
||||
if !util.Contains(token.Audience(), audience) {
|
||||
if !slices.Contains(token.Audience(), audience) {
|
||||
logger.Debug(logSender, "", "the token is not valid for audience %q", audience)
|
||||
doRedirect("Your token audience is not valid", nil)
|
||||
return errInvalidToken
|
||||
}
|
||||
if tokenValidationMode != tokenValidationNoIPMatch {
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if !util.Contains(token.Audience(), ipAddr) {
|
||||
logger.Debug(logSender, "", "the token with id %q is not valid for the ip address %q", token.JwtID(), ipAddr)
|
||||
doRedirect("Your token is not valid", nil)
|
||||
return errInvalidToken
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := validateIPForToken(token, ipAddr); err != nil {
|
||||
logger.Debug(logSender, "", "the token with id %q is not valid for the ip address %q", token.JwtID(), ipAddr)
|
||||
doRedirect("Your token is not valid", nil)
|
||||
return err
|
||||
}
|
||||
if err := checkTokenSignature(r, token); err != nil {
|
||||
doRedirect("Your token is no longer valid", nil)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -114,7 +110,7 @@ func (s *httpdServer) validateJWTPartialToken(w http.ResponseWriter, r *http.Req
|
|||
} else {
|
||||
notFoundFunc = s.renderClientNotFoundPage
|
||||
}
|
||||
if err != nil || token == nil || jwt.Validate(token) != nil {
|
||||
if err != nil || token == nil {
|
||||
notFoundFunc(w, r, nil)
|
||||
return errInvalidToken
|
||||
}
|
||||
|
@ -122,11 +118,17 @@ func (s *httpdServer) validateJWTPartialToken(w http.ResponseWriter, r *http.Req
|
|||
notFoundFunc(w, r, nil)
|
||||
return errInvalidToken
|
||||
}
|
||||
if !util.Contains(token.Audience(), audience) {
|
||||
logger.Debug(logSender, "", "the token is not valid for audience %q", audience)
|
||||
if !slices.Contains(token.Audience(), audience) {
|
||||
logger.Debug(logSender, "", "the partial token with id %q is not valid for audience %q", token.JwtID(), audience)
|
||||
notFoundFunc(w, r, nil)
|
||||
return errInvalidToken
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := validateIPForToken(token, ipAddr); err != nil {
|
||||
logger.Debug(logSender, "", "the partial token with id %q is not valid for the ip address %q", token.JwtID(), ipAddr)
|
||||
notFoundFunc(w, r, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -295,7 +297,7 @@ func (s *httpdServer) requireBuiltinLogin(next http.Handler) http.Handler {
|
|||
})
|
||||
}
|
||||
|
||||
func (s *httpdServer) checkPerm(perm string) func(next http.Handler) http.Handler {
|
||||
func (s *httpdServer) checkPerms(perms ...string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, claims, err := jwtauth.FromContext(r.Context())
|
||||
|
@ -310,13 +312,15 @@ func (s *httpdServer) checkPerm(perm string) func(next http.Handler) http.Handle
|
|||
tokenClaims := jwtTokenClaims{}
|
||||
tokenClaims.Decode(claims)
|
||||
|
||||
if !tokenClaims.hasPerm(perm) {
|
||||
if isWebRequest(r) {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(fs.ErrPermission, util.I18nError403Message))
|
||||
} else {
|
||||
sendAPIResponse(w, r, nil, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
for _, perm := range perms {
|
||||
if !tokenClaims.hasPerm(perm) {
|
||||
if isWebRequest(r) {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(fs.ErrPermission, util.I18nError403Message))
|
||||
} else {
|
||||
sendAPIResponse(w, r, nil, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
}
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
|
@ -324,28 +328,30 @@ func (s *httpdServer) checkPerm(perm string) func(next http.Handler) http.Handle
|
|||
}
|
||||
}
|
||||
|
||||
func verifyCSRFHeader(next http.Handler) http.Handler {
|
||||
func (s *httpdServer) verifyCSRFHeader(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tokenString := r.Header.Get(csrfHeaderToken)
|
||||
token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString)
|
||||
token, err := jwtauth.VerifyToken(s.csrfTokenAuth, tokenString)
|
||||
if err != nil || token == nil {
|
||||
logger.Debug(logSender, "", "error validating CSRF header: %v", err)
|
||||
sendAPIResponse(w, r, err, "Invalid token", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if !util.Contains(token.Audience(), tokenAudienceCSRF) {
|
||||
if !slices.Contains(token.Audience(), tokenAudienceCSRF) {
|
||||
logger.Debug(logSender, "", "error validating CSRF header token audience")
|
||||
sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if tokenValidationMode != tokenValidationNoIPMatch {
|
||||
if !util.Contains(token.Audience(), util.GetIPFromRemoteAddress(r.RemoteAddr)) {
|
||||
logger.Debug(logSender, "", "error validating CSRF header IP audience")
|
||||
sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
if err := validateIPForToken(token, util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil {
|
||||
logger.Debug(logSender, "", "error validating CSRF header IP audience")
|
||||
sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
if err := checkCSRFTokenRef(r, token); err != nil {
|
||||
sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
|
@ -449,7 +455,7 @@ func checkAPIKeyAuth(tokenAuth *jwtauth.JWTAuth, scope dataprovider.APIKeyScope)
|
|||
logger.Debug(logSender, "", "unable to authenticate user %q associated with api key %q: %v",
|
||||
apiUser, apiKey, err)
|
||||
updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: apiUser}},
|
||||
dataprovider.LoginMethodPassword, util.GetIPFromRemoteAddress(r.RemoteAddr), err)
|
||||
dataprovider.LoginMethodPassword, util.GetIPFromRemoteAddress(r.RemoteAddr), err, r)
|
||||
code := http.StatusUnauthorized
|
||||
if errors.Is(err, common.ErrInternalFailure) {
|
||||
code = http.StatusInternalServerError
|
||||
|
@ -459,7 +465,7 @@ func checkAPIKeyAuth(tokenAuth *jwtauth.JWTAuth, scope dataprovider.APIKeyScope)
|
|||
return
|
||||
}
|
||||
updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: apiUser}},
|
||||
dataprovider.LoginMethodPassword, util.GetIPFromRemoteAddress(r.RemoteAddr), nil)
|
||||
dataprovider.LoginMethodPassword, util.GetIPFromRemoteAddress(r.RemoteAddr), nil, r)
|
||||
}
|
||||
dataprovider.UpdateAPIKeyLastUse(&k) //nolint:errcheck
|
||||
|
||||
|
@ -523,7 +529,7 @@ func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAu
|
|||
if username == "" {
|
||||
err := errors.New("the provided key is not associated with any user and no username was provided")
|
||||
updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
|
||||
dataprovider.LoginMethodPassword, ipAddr, err)
|
||||
dataprovider.LoginMethodPassword, ipAddr, err, r)
|
||||
return err
|
||||
}
|
||||
if err := common.Config.ExecutePostConnectHook(ipAddr, protocol); err != nil {
|
||||
|
@ -532,27 +538,27 @@ func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAu
|
|||
user, err := dataprovider.GetUserWithGroupSettings(username, "")
|
||||
if err != nil {
|
||||
updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
|
||||
dataprovider.LoginMethodPassword, ipAddr, err)
|
||||
dataprovider.LoginMethodPassword, ipAddr, err, r)
|
||||
return err
|
||||
}
|
||||
if !user.Filters.AllowAPIKeyAuth {
|
||||
err := fmt.Errorf("API key authentication disabled for user %q", user.Username)
|
||||
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
|
||||
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r)
|
||||
return err
|
||||
}
|
||||
if err := user.CheckLoginConditions(); err != nil {
|
||||
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
|
||||
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r)
|
||||
return err
|
||||
}
|
||||
connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
|
||||
if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
|
||||
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
|
||||
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r)
|
||||
return err
|
||||
}
|
||||
defer user.CloseFs() //nolint:errcheck
|
||||
err = user.CheckFsRoot(connectionID)
|
||||
if err != nil {
|
||||
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
|
||||
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r)
|
||||
return common.ErrInternalFailure
|
||||
}
|
||||
c := jwtTokenClaims{
|
||||
|
@ -565,22 +571,22 @@ func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAu
|
|||
|
||||
resp, err := c.createTokenResponse(tokenAuth, tokenAudienceAPIUser, ipAddr)
|
||||
if err != nil {
|
||||
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
|
||||
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r)
|
||||
return err
|
||||
}
|
||||
r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", resp["access_token"]))
|
||||
dataprovider.UpdateLastLogin(&user)
|
||||
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, nil)
|
||||
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, nil, r)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkPartialAuth(w http.ResponseWriter, r *http.Request, audience string, tokenAudience []string) error {
|
||||
if audience == tokenAudienceWebAdmin && util.Contains(tokenAudience, tokenAudienceWebAdminPartial) {
|
||||
if audience == tokenAudienceWebAdmin && slices.Contains(tokenAudience, tokenAudienceWebAdminPartial) {
|
||||
http.Redirect(w, r, webAdminTwoFactorPath, http.StatusFound)
|
||||
return errInvalidToken
|
||||
}
|
||||
if audience == tokenAudienceWebClient && util.Contains(tokenAudience, tokenAudienceWebClientPartial) {
|
||||
if audience == tokenAudienceWebClient && slices.Contains(tokenAudience, tokenAudienceWebClientPartial) {
|
||||
http.Redirect(w, r, webClientTwoFactorPath, http.StatusFound)
|
||||
return errInvalidToken
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -142,7 +143,7 @@ func (o *OIDC) initialize() error {
|
|||
if o.RedirectBaseURL == "" {
|
||||
return errors.New("oidc: redirect base URL cannot be empty")
|
||||
}
|
||||
if !util.Contains(o.Scopes, oidc.ScopeOpenID) {
|
||||
if !slices.Contains(o.Scopes, oidc.ScopeOpenID) {
|
||||
return fmt.Errorf("oidc: required scope %q is not set", oidc.ScopeOpenID)
|
||||
}
|
||||
if o.ClientSecretFile != "" {
|
||||
|
@ -210,21 +211,24 @@ func newOIDCPendingAuth(audience tokenAudience) oidcPendingAuth {
|
|||
}
|
||||
|
||||
type oidcToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ExpiresAt int64 `json:"expires_at,omitempty"`
|
||||
SessionID string `json:"session_id"`
|
||||
IDToken string `json:"id_token"`
|
||||
Nonce string `json:"nonce"`
|
||||
Username string `json:"username"`
|
||||
Permissions []string `json:"permissions"`
|
||||
HideUserPageSections int `json:"hide_user_page_sections,omitempty"`
|
||||
TokenRole string `json:"token_role,omitempty"` // SFTPGo role name
|
||||
Role any `json:"role"` // oidc user role: SFTPGo user or admin
|
||||
CustomFields *map[string]any `json:"custom_fields,omitempty"`
|
||||
Cookie string `json:"cookie"`
|
||||
UsedAt int64 `json:"used_at"`
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ExpiresAt int64 `json:"expires_at,omitempty"`
|
||||
SessionID string `json:"session_id"`
|
||||
IDToken string `json:"id_token"`
|
||||
Nonce string `json:"nonce"`
|
||||
Username string `json:"username"`
|
||||
Permissions []string `json:"permissions"`
|
||||
HideUserPageSections int `json:"hide_user_page_sections,omitempty"`
|
||||
MustSetTwoFactorAuth bool `json:"must_set_2fa,omitempty"`
|
||||
MustChangePassword bool `json:"must_change_password,omitempty"`
|
||||
RequiredTwoFactorProtocols []string `json:"required_two_factor_protocols,omitempty"`
|
||||
TokenRole string `json:"token_role,omitempty"` // SFTPGo role name
|
||||
Role any `json:"role"` // oidc user role: SFTPGo user or admin
|
||||
CustomFields *map[string]any `json:"custom_fields,omitempty"`
|
||||
Cookie string `json:"cookie"`
|
||||
UsedAt int64 `json:"used_at"`
|
||||
}
|
||||
|
||||
func (t *oidcToken) parseClaims(claims map[string]any, usernameField, roleField string, customFields []string,
|
||||
|
@ -397,6 +401,9 @@ func (t *oidcToken) refreshUser(r *http.Request) error {
|
|||
}
|
||||
t.Permissions = user.Filters.WebClient
|
||||
t.TokenRole = user.Role
|
||||
t.MustSetTwoFactorAuth = user.MustSetSecondFactor()
|
||||
t.MustChangePassword = user.MustChangePassword()
|
||||
t.RequiredTwoFactorProtocols = user.Filters.TwoFactorAuthProtocols
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -445,29 +452,32 @@ func (t *oidcToken) getUser(r *http.Request) error {
|
|||
user = &u
|
||||
}
|
||||
if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolOIDC); err != nil {
|
||||
updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, err)
|
||||
updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, err, r)
|
||||
return fmt.Errorf("access denied: %w", err)
|
||||
}
|
||||
if err := user.CheckLoginConditions(); err != nil {
|
||||
updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, err)
|
||||
updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, err, r)
|
||||
return err
|
||||
}
|
||||
connectionID := fmt.Sprintf("%s_%s", common.ProtocolOIDC, xid.New().String())
|
||||
if err := checkHTTPClientUser(user, r, connectionID, true); err != nil {
|
||||
updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, err)
|
||||
updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, err, r)
|
||||
return err
|
||||
}
|
||||
defer user.CloseFs() //nolint:errcheck
|
||||
err = user.CheckFsRoot(connectionID)
|
||||
if err != nil {
|
||||
logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
|
||||
updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, common.ErrInternalFailure)
|
||||
updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, common.ErrInternalFailure, r)
|
||||
return err
|
||||
}
|
||||
updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, nil)
|
||||
updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, nil, r)
|
||||
dataprovider.UpdateLastLogin(user)
|
||||
t.Permissions = user.Filters.WebClient
|
||||
t.TokenRole = user.Role
|
||||
t.MustSetTwoFactorAuth = user.MustSetSecondFactor()
|
||||
t.MustChangePassword = user.MustChangePassword()
|
||||
t.RequiredTwoFactorProtocols = user.Filters.TwoFactorAuthProtocols
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -542,11 +552,17 @@ func (s *httpdServer) oidcTokenAuthenticator(audience tokenAudience) func(next h
|
|||
return
|
||||
}
|
||||
jwtTokenClaims := jwtTokenClaims{
|
||||
JwtID: token.Cookie,
|
||||
Username: token.Username,
|
||||
Permissions: token.Permissions,
|
||||
Role: token.TokenRole,
|
||||
HideUserPageSections: token.HideUserPageSections,
|
||||
}
|
||||
if audience == tokenAudienceWebClient {
|
||||
jwtTokenClaims.MustSetTwoFactorAuth = token.MustSetTwoFactorAuth
|
||||
jwtTokenClaims.MustChangePassword = token.MustChangePassword
|
||||
jwtTokenClaims.RequiredTwoFactorProtocols = token.RequiredTwoFactorProtocols
|
||||
}
|
||||
_, tokenString, err := jwtTokenClaims.createToken(s.tokenAuth, audience, util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
if err != nil {
|
||||
setFlashMessage(w, r, newFlashMessage("Unable to create cookie", util.I18nError500Message))
|
||||
|
|
|
@ -33,7 +33,6 @@ import (
|
|||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/go-chi/jwtauth/v5"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/rs/xid"
|
||||
"github.com/sftpgo/sdk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -136,6 +135,8 @@ func TestOIDCInitialization(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestOIDCLoginLogout(t *testing.T) {
|
||||
tokenValidationMode = 2
|
||||
|
||||
oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
|
||||
require.True(t, ok)
|
||||
server := getTestOIDCServer()
|
||||
|
@ -553,6 +554,8 @@ func TestOIDCLoginLogout(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
err = dataprovider.DeleteUser(username, "", "", "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
tokenValidationMode = 0
|
||||
}
|
||||
|
||||
func TestOIDCRefreshToken(t *testing.T) {
|
||||
|
@ -1586,12 +1589,9 @@ func TestOIDCWithLoginFormsDisabled(t *testing.T) {
|
|||
tokenCookie = k
|
||||
}
|
||||
// we should be able to create admins without setting a password
|
||||
if csrfTokenAuth == nil {
|
||||
csrfTokenAuth = jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil)
|
||||
}
|
||||
adminUsername := "testAdmin"
|
||||
form := make(url.Values)
|
||||
form.Set(csrfFormToken, createCSRFToken(""))
|
||||
form.Set(csrfFormToken, createCSRFToken(rr, r, server.csrfTokenAuth, tokenCookie, webBaseAdminPath))
|
||||
form.Set("username", adminUsername)
|
||||
form.Set("password", "")
|
||||
form.Set("status", "1")
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -67,6 +67,7 @@ type loginPage struct {
|
|||
OpenIDLoginURL string
|
||||
Title string
|
||||
Branding UIBranding
|
||||
Languages []string
|
||||
FormDisabled bool
|
||||
CheckRedirect bool
|
||||
}
|
||||
|
@ -79,6 +80,7 @@ type twoFactorPage struct {
|
|||
RecoveryURL string
|
||||
Title string
|
||||
Branding UIBranding
|
||||
Languages []string
|
||||
CheckRedirect bool
|
||||
}
|
||||
|
||||
|
@ -90,6 +92,7 @@ type forgotPwdPage struct {
|
|||
LoginURL string
|
||||
Title string
|
||||
Branding UIBranding
|
||||
Languages []string
|
||||
CheckRedirect bool
|
||||
}
|
||||
|
||||
|
@ -101,6 +104,7 @@ type resetPwdPage struct {
|
|||
LoginURL string
|
||||
Title string
|
||||
Branding UIBranding
|
||||
Languages []string
|
||||
CheckRedirect bool
|
||||
}
|
||||
|
||||
|
|
|
@ -25,18 +25,20 @@ import (
|
|||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/render"
|
||||
"github.com/rs/xid"
|
||||
"github.com/sftpgo/sdk"
|
||||
sdkkms "github.com/sftpgo/sdk/kms"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/acme"
|
||||
"github.com/drakkan/sftpgo/v2/internal/common"
|
||||
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
|
||||
"github.com/drakkan/sftpgo/v2/internal/ftpd"
|
||||
"github.com/drakkan/sftpgo/v2/internal/kms"
|
||||
"github.com/drakkan/sftpgo/v2/internal/logger"
|
||||
"github.com/drakkan/sftpgo/v2/internal/mfa"
|
||||
|
@ -44,6 +46,7 @@ import (
|
|||
"github.com/drakkan/sftpgo/v2/internal/smtp"
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
"github.com/drakkan/sftpgo/v2/internal/vfs"
|
||||
"github.com/drakkan/sftpgo/v2/internal/webdavd"
|
||||
)
|
||||
|
||||
type userPageMode int
|
||||
|
@ -150,7 +153,9 @@ type basePage struct {
|
|||
HasSearcher bool
|
||||
HasExternalLogin bool
|
||||
LoggedUser *dataprovider.Admin
|
||||
IsLoggedToShare bool
|
||||
Branding UIBranding
|
||||
Languages []string
|
||||
}
|
||||
|
||||
type statusPage struct {
|
||||
|
@ -184,6 +189,7 @@ type userPage struct {
|
|||
Roles []dataprovider.Role
|
||||
CanImpersonate bool
|
||||
FsWrapper fsWrapper
|
||||
CanUseTLSCerts bool
|
||||
}
|
||||
|
||||
type adminPage struct {
|
||||
|
@ -257,6 +263,7 @@ type setupPage struct {
|
|||
HideSupportLink bool
|
||||
Title string
|
||||
Branding UIBranding
|
||||
Languages []string
|
||||
CheckRedirect bool
|
||||
}
|
||||
|
||||
|
@ -329,6 +336,7 @@ type configsPage struct {
|
|||
RedactedSecret string
|
||||
OAuth2TokenURL string
|
||||
OAuth2RedirectURL string
|
||||
WebClientBranding UIBranding
|
||||
Error *util.I18nError
|
||||
}
|
||||
|
||||
|
@ -614,10 +622,10 @@ func isServerManagerResource(currentURL string) bool {
|
|||
currentURL == webConfigsPath
|
||||
}
|
||||
|
||||
func (s *httpdServer) getBasePageData(title, currentURL string, r *http.Request) basePage {
|
||||
func (s *httpdServer) getBasePageData(title, currentURL string, w http.ResponseWriter, r *http.Request) basePage {
|
||||
var csrfToken string
|
||||
if currentURL != "" {
|
||||
csrfToken = createCSRFToken(util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
csrfToken = createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath)
|
||||
}
|
||||
return basePage{
|
||||
commonBasePage: getCommonBasePage(r),
|
||||
|
@ -662,7 +670,8 @@ func (s *httpdServer) getBasePageData(title, currentURL string, r *http.Request)
|
|||
HasSearcher: plugin.Handler.HasSearcher(),
|
||||
HasExternalLogin: isLoggedInWithOIDC(r),
|
||||
CSRFToken: csrfToken,
|
||||
Branding: s.binding.Branding.WebAdmin,
|
||||
Branding: s.binding.webAdminBranding(),
|
||||
Languages: s.binding.languages(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -677,7 +686,7 @@ func (s *httpdServer) renderMessagePageWithString(w http.ResponseWriter, r *http
|
|||
err error, message, text string,
|
||||
) {
|
||||
data := messagePage{
|
||||
basePage: s.getBasePageData(title, "", r),
|
||||
basePage: s.getBasePageData(title, "", w, r),
|
||||
Error: getI18nError(err),
|
||||
Success: message,
|
||||
Text: text,
|
||||
|
@ -712,60 +721,64 @@ func (s *httpdServer) renderNotFoundPage(w http.ResponseWriter, r *http.Request,
|
|||
util.NewI18nError(err, util.I18nError404Message), "")
|
||||
}
|
||||
|
||||
func (s *httpdServer) renderForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
|
||||
func (s *httpdServer) renderForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
|
||||
data := forgotPwdPage{
|
||||
commonBasePage: getCommonBasePage(r),
|
||||
CurrentURL: webAdminForgotPwdPath,
|
||||
Error: err,
|
||||
CSRFToken: createCSRFToken(ip),
|
||||
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath),
|
||||
LoginURL: webAdminLoginPath,
|
||||
Title: util.I18nForgotPwdTitle,
|
||||
Branding: s.binding.Branding.WebAdmin,
|
||||
Branding: s.binding.webAdminBranding(),
|
||||
Languages: s.binding.languages(),
|
||||
}
|
||||
renderAdminTemplate(w, templateForgotPassword, data)
|
||||
}
|
||||
|
||||
func (s *httpdServer) renderResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
|
||||
func (s *httpdServer) renderResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
|
||||
data := resetPwdPage{
|
||||
commonBasePage: getCommonBasePage(r),
|
||||
CurrentURL: webAdminResetPwdPath,
|
||||
Error: err,
|
||||
CSRFToken: createCSRFToken(ip),
|
||||
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath),
|
||||
LoginURL: webAdminLoginPath,
|
||||
Title: util.I18nResetPwdTitle,
|
||||
Branding: s.binding.Branding.WebAdmin,
|
||||
Branding: s.binding.webAdminBranding(),
|
||||
Languages: s.binding.languages(),
|
||||
}
|
||||
renderAdminTemplate(w, templateResetPassword, data)
|
||||
}
|
||||
|
||||
func (s *httpdServer) renderTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
|
||||
func (s *httpdServer) renderTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
|
||||
data := twoFactorPage{
|
||||
commonBasePage: getCommonBasePage(r),
|
||||
Title: pageTwoFactorTitle,
|
||||
CurrentURL: webAdminTwoFactorPath,
|
||||
Error: err,
|
||||
CSRFToken: createCSRFToken(ip),
|
||||
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath),
|
||||
RecoveryURL: webAdminTwoFactorRecoveryPath,
|
||||
Branding: s.binding.Branding.WebAdmin,
|
||||
Branding: s.binding.webAdminBranding(),
|
||||
Languages: s.binding.languages(),
|
||||
}
|
||||
renderAdminTemplate(w, templateTwoFactor, data)
|
||||
}
|
||||
|
||||
func (s *httpdServer) renderTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
|
||||
func (s *httpdServer) renderTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
|
||||
data := twoFactorPage{
|
||||
commonBasePage: getCommonBasePage(r),
|
||||
Title: pageTwoFactorRecoveryTitle,
|
||||
CurrentURL: webAdminTwoFactorRecoveryPath,
|
||||
Error: err,
|
||||
CSRFToken: createCSRFToken(ip),
|
||||
Branding: s.binding.Branding.WebAdmin,
|
||||
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath),
|
||||
Branding: s.binding.webAdminBranding(),
|
||||
Languages: s.binding.languages(),
|
||||
}
|
||||
renderAdminTemplate(w, templateTwoFactorRecovery, data)
|
||||
}
|
||||
|
||||
func (s *httpdServer) renderMFAPage(w http.ResponseWriter, r *http.Request) {
|
||||
data := mfaPage{
|
||||
basePage: s.getBasePageData(pageMFATitle, webAdminMFAPath, r),
|
||||
basePage: s.getBasePageData(pageMFATitle, webAdminMFAPath, w, r),
|
||||
TOTPConfigs: mfa.GetAvailableTOTPConfigNames(),
|
||||
GenerateTOTPURL: webAdminTOTPGeneratePath,
|
||||
ValidateTOTPURL: webAdminTOTPValidatePath,
|
||||
|
@ -784,7 +797,7 @@ func (s *httpdServer) renderMFAPage(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
func (s *httpdServer) renderProfilePage(w http.ResponseWriter, r *http.Request, err error) {
|
||||
data := profilePage{
|
||||
basePage: s.getBasePageData(util.I18nProfileTitle, webAdminProfilePath, r),
|
||||
basePage: s.getBasePageData(util.I18nProfileTitle, webAdminProfilePath, w, r),
|
||||
Error: getI18nError(err),
|
||||
}
|
||||
admin, err := dataprovider.AdminExists(data.LoggedUser.Username)
|
||||
|
@ -801,7 +814,7 @@ func (s *httpdServer) renderProfilePage(w http.ResponseWriter, r *http.Request,
|
|||
|
||||
func (s *httpdServer) renderChangePasswordPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
|
||||
data := changePasswordPage{
|
||||
basePage: s.getBasePageData(util.I18nChangePwdTitle, webChangeAdminPwdPath, r),
|
||||
basePage: s.getBasePageData(util.I18nChangePwdTitle, webChangeAdminPwdPath, w, r),
|
||||
Error: err,
|
||||
}
|
||||
|
||||
|
@ -810,7 +823,7 @@ func (s *httpdServer) renderChangePasswordPage(w http.ResponseWriter, r *http.Re
|
|||
|
||||
func (s *httpdServer) renderMaintenancePage(w http.ResponseWriter, r *http.Request, err error) {
|
||||
data := maintenancePage{
|
||||
basePage: s.getBasePageData(util.I18nMaintenanceTitle, webMaintenancePath, r),
|
||||
basePage: s.getBasePageData(util.I18nMaintenanceTitle, webMaintenancePath, w, r),
|
||||
BackupPath: webBackupPath,
|
||||
RestorePath: webRestorePath,
|
||||
Error: getI18nError(err),
|
||||
|
@ -832,30 +845,32 @@ func (s *httpdServer) renderConfigsPage(w http.ResponseWriter, r *http.Request,
|
|||
configs.ACME.HTTP01Challenge.Port = 80
|
||||
}
|
||||
data := configsPage{
|
||||
basePage: s.getBasePageData(util.I18nConfigsTitle, webConfigsPath, r),
|
||||
basePage: s.getBasePageData(util.I18nConfigsTitle, webConfigsPath, w, r),
|
||||
Configs: configs,
|
||||
ConfigSection: section,
|
||||
RedactedSecret: redactedSecret,
|
||||
OAuth2TokenURL: webOAuth2TokenPath,
|
||||
OAuth2RedirectURL: webOAuth2RedirectPath,
|
||||
WebClientBranding: s.binding.webClientBranding(),
|
||||
Error: getI18nError(err),
|
||||
}
|
||||
|
||||
renderAdminTemplate(w, templateConfigs, data)
|
||||
}
|
||||
|
||||
func (s *httpdServer) renderAdminSetupPage(w http.ResponseWriter, r *http.Request, username, ip string, err *util.I18nError) {
|
||||
func (s *httpdServer) renderAdminSetupPage(w http.ResponseWriter, r *http.Request, username string, err *util.I18nError) {
|
||||
data := setupPage{
|
||||
commonBasePage: getCommonBasePage(r),
|
||||
Title: util.I18nSetupTitle,
|
||||
CurrentURL: webAdminSetupPath,
|
||||
CSRFToken: createCSRFToken(ip),
|
||||
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath),
|
||||
Username: username,
|
||||
HasInstallationCode: installationCode != "",
|
||||
InstallationCodeHint: installationCodeHint,
|
||||
HideSupportLink: hideSupportLink,
|
||||
Error: err,
|
||||
Branding: s.binding.Branding.WebAdmin,
|
||||
Branding: s.binding.webAdminBranding(),
|
||||
Languages: s.binding.languages(),
|
||||
}
|
||||
|
||||
renderAdminTemplate(w, templateSetup, data)
|
||||
|
@ -878,7 +893,7 @@ func (s *httpdServer) renderAddUpdateAdminPage(w http.ResponseWriter, r *http.Re
|
|||
title = util.I18nUpdateAdminTitle
|
||||
}
|
||||
data := adminPage{
|
||||
basePage: s.getBasePageData(title, currentURL, r),
|
||||
basePage: s.getBasePageData(title, currentURL, w, r),
|
||||
Admin: admin,
|
||||
Groups: groups,
|
||||
Roles: roles,
|
||||
|
@ -919,7 +934,7 @@ func (s *httpdServer) renderUserPage(w http.ResponseWriter, r *http.Request, use
|
|||
}
|
||||
}
|
||||
user.FsConfig.RedactedSecret = redactedSecret
|
||||
basePage := s.getBasePageData(title, currentURL, r)
|
||||
basePage := s.getBasePageData(title, currentURL, w, r)
|
||||
if (mode == userPageModeAdd || mode == userPageModeTemplate) && len(user.Groups) == 0 && admin != nil {
|
||||
for _, group := range admin.Groups {
|
||||
user.Groups = append(user.Groups, sdk.GroupMapping{
|
||||
|
@ -959,6 +974,7 @@ func (s *httpdServer) renderUserPage(w http.ResponseWriter, r *http.Request, use
|
|||
Groups: groups,
|
||||
Roles: roles,
|
||||
CanImpersonate: os.Getuid() == 0,
|
||||
CanUseTLSCerts: ftpd.GetStatus().IsActive || webdavd.GetStatus().IsActive,
|
||||
FsWrapper: fsWrapper{
|
||||
Filesystem: user.FsConfig,
|
||||
IsUserPage: true,
|
||||
|
@ -984,7 +1000,7 @@ func (s *httpdServer) renderIPListPage(w http.ResponseWriter, r *http.Request, e
|
|||
currentURL = fmt.Sprintf("%s/%d/%s", webIPListPath, entry.Type, url.PathEscape(entry.IPOrNet))
|
||||
}
|
||||
data := ipListPage{
|
||||
basePage: s.getBasePageData(title, currentURL, r),
|
||||
basePage: s.getBasePageData(title, currentURL, w, r),
|
||||
Error: getI18nError(err),
|
||||
Entry: &entry,
|
||||
Mode: mode,
|
||||
|
@ -1005,7 +1021,7 @@ func (s *httpdServer) renderRolePage(w http.ResponseWriter, r *http.Request, rol
|
|||
currentURL = fmt.Sprintf("%s/%s", webAdminRolePath, url.PathEscape(role.Name))
|
||||
}
|
||||
data := rolePage{
|
||||
basePage: s.getBasePageData(title, currentURL, r),
|
||||
basePage: s.getBasePageData(title, currentURL, w, r),
|
||||
Error: getI18nError(err),
|
||||
Role: &role,
|
||||
Mode: mode,
|
||||
|
@ -1035,7 +1051,7 @@ func (s *httpdServer) renderGroupPage(w http.ResponseWriter, r *http.Request, gr
|
|||
group.UserSettings.FsConfig.SetEmptySecretsIfNil()
|
||||
|
||||
data := groupPage{
|
||||
basePage: s.getBasePageData(title, currentURL, r),
|
||||
basePage: s.getBasePageData(title, currentURL, w, r),
|
||||
Error: getI18nError(err),
|
||||
Group: &group,
|
||||
Mode: mode,
|
||||
|
@ -1080,7 +1096,7 @@ func (s *httpdServer) renderEventActionPage(w http.ResponseWriter, r *http.Reque
|
|||
}
|
||||
|
||||
data := eventActionPage{
|
||||
basePage: s.getBasePageData(title, currentURL, r),
|
||||
basePage: s.getBasePageData(title, currentURL, w, r),
|
||||
Action: action,
|
||||
ActionTypes: dataprovider.EventActionTypes,
|
||||
FsActions: dataprovider.FsActionTypes,
|
||||
|
@ -1111,7 +1127,7 @@ func (s *httpdServer) renderEventRulePage(w http.ResponseWriter, r *http.Request
|
|||
}
|
||||
|
||||
data := eventRulePage{
|
||||
basePage: s.getBasePageData(title, currentURL, r),
|
||||
basePage: s.getBasePageData(title, currentURL, w, r),
|
||||
Rule: rule,
|
||||
TriggerTypes: dataprovider.EventTriggerTypes,
|
||||
Actions: actions,
|
||||
|
@ -1145,7 +1161,7 @@ func (s *httpdServer) renderFolderPage(w http.ResponseWriter, r *http.Request, f
|
|||
folder.FsConfig.SetEmptySecretsIfNil()
|
||||
|
||||
data := folderPage{
|
||||
basePage: s.getBasePageData(title, currentURL, r),
|
||||
basePage: s.getBasePageData(title, currentURL, w, r),
|
||||
Error: getI18nError(err),
|
||||
Folder: folder,
|
||||
Mode: mode,
|
||||
|
@ -1487,13 +1503,13 @@ func getFiltersFromUserPostFields(r *http.Request) (sdk.BaseUserFilters, error)
|
|||
filters.PasswordStrength = passwordStrength
|
||||
filters.AccessTime = getAccessTimeRestrictionsFromPostFields(r)
|
||||
hooks := r.Form["hooks"]
|
||||
if util.Contains(hooks, "external_auth_disabled") {
|
||||
if slices.Contains(hooks, "external_auth_disabled") {
|
||||
filters.Hooks.ExternalAuthDisabled = true
|
||||
}
|
||||
if util.Contains(hooks, "pre_login_disabled") {
|
||||
if slices.Contains(hooks, "pre_login_disabled") {
|
||||
filters.Hooks.PreLoginDisabled = true
|
||||
}
|
||||
if util.Contains(hooks, "check_password_disabled") {
|
||||
if slices.Contains(hooks, "check_password_disabled") {
|
||||
filters.Hooks.CheckPasswordDisabled = true
|
||||
}
|
||||
filters.IsAnonymous = r.Form.Get("is_anonymous") != ""
|
||||
|
@ -1527,6 +1543,7 @@ func getS3Config(r *http.Request) (vfs.S3FsConfig, error) {
|
|||
config.AccessKey = strings.TrimSpace(r.Form.Get("s3_access_key"))
|
||||
config.RoleARN = strings.TrimSpace(r.Form.Get("s3_role_arn"))
|
||||
config.AccessSecret = getSecretFromFormField(r, "s3_access_secret")
|
||||
config.SSECustomerKey = getSecretFromFormField(r, "s3_sse_customer_key")
|
||||
config.Endpoint = strings.TrimSpace(r.Form.Get("s3_endpoint"))
|
||||
config.StorageClass = strings.TrimSpace(r.Form.Get("s3_storage_class"))
|
||||
config.ACL = strings.TrimSpace(r.Form.Get("s3_acl"))
|
||||
|
@ -1583,7 +1600,7 @@ func getGCSConfig(r *http.Request) (vfs.GCSFsConfig, error) {
|
|||
config.AutomaticCredentials = 0
|
||||
}
|
||||
credentials, _, err := r.FormFile("gcs_credential_file")
|
||||
if err == http.ErrMissingFile {
|
||||
if errors.Is(err, http.ErrMissingFile) {
|
||||
return config, nil
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -1852,6 +1869,10 @@ func getS3FsFromTemplate(fsConfig vfs.S3FsConfig, replacements map[string]string
|
|||
payload := replacePlaceholders(fsConfig.AccessSecret.GetPayload(), replacements)
|
||||
fsConfig.AccessSecret = kms.NewPlainSecret(payload)
|
||||
}
|
||||
if fsConfig.SSECustomerKey != nil && fsConfig.SSECustomerKey.IsPlain() {
|
||||
payload := replacePlaceholders(fsConfig.SSECustomerKey.GetPayload(), replacements)
|
||||
fsConfig.SSECustomerKey = kms.NewPlainSecret(payload)
|
||||
}
|
||||
return fsConfig
|
||||
}
|
||||
|
||||
|
@ -1970,6 +1991,13 @@ func updateRepeaterFormFields(r *http.Request) {
|
|||
}
|
||||
continue
|
||||
}
|
||||
if hasPrefixAndSuffix(k, "additional_emails[", "][additional_email]") {
|
||||
email := strings.TrimSpace(r.Form.Get(k))
|
||||
if email != "" {
|
||||
r.Form.Add("additional_emails", email)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if hasPrefixAndSuffix(k, "virtual_folders[", "][vfolder_path]") {
|
||||
base, _ := strings.CutSuffix(k, "[vfolder_path]")
|
||||
r.Form.Add("vfolder_path", strings.TrimSpace(r.Form.Get(k)))
|
||||
|
@ -2103,6 +2131,7 @@ func getUserFromPostFields(r *http.Request) (dataprovider.User, error) {
|
|||
Filters: dataprovider.UserFilters{
|
||||
BaseUserFilters: filters,
|
||||
RequirePasswordChange: r.Form.Get("require_password_change") != "",
|
||||
AdditionalEmails: r.Form["additional_emails"],
|
||||
},
|
||||
VirtualFolders: getVirtualFoldersFromPostFields(r),
|
||||
FsConfig: fsConfig,
|
||||
|
@ -2199,6 +2228,28 @@ func getKeyValsFromPostFields(r *http.Request, key, val string) []dataprovider.K
|
|||
return res
|
||||
}
|
||||
|
||||
func getRenameConfigsFromPostFields(r *http.Request) []dataprovider.RenameConfig {
|
||||
var res []dataprovider.RenameConfig
|
||||
keys := r.Form["fs_rename_source"]
|
||||
values := r.Form["fs_rename_target"]
|
||||
|
||||
for idx, k := range keys {
|
||||
v := values[idx]
|
||||
if k != "" && v != "" {
|
||||
opts := r.Form["fs_rename_options"+strconv.Itoa(idx)]
|
||||
res = append(res, dataprovider.RenameConfig{
|
||||
KeyValue: dataprovider.KeyValue{
|
||||
Key: k,
|
||||
Value: v,
|
||||
},
|
||||
UpdateModTime: slices.Contains(opts, "1"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func getFoldersRetentionFromPostFields(r *http.Request) ([]dataprovider.FolderRetention, error) {
|
||||
var res []dataprovider.FolderRetention
|
||||
paths := r.Form["folder_retention_path"]
|
||||
|
@ -2214,7 +2265,7 @@ func getFoldersRetentionFromPostFields(r *http.Request) ([]dataprovider.FolderRe
|
|||
res = append(res, dataprovider.FolderRetention{
|
||||
Path: p,
|
||||
Retention: retention,
|
||||
DeleteEmptyDirs: util.Contains(opts, "1"),
|
||||
DeleteEmptyDirs: slices.Contains(opts, "1"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -2308,6 +2359,8 @@ func updateRepeaterFormActionFields(r *http.Request) {
|
|||
base, _ := strings.CutSuffix(k, "[fs_rename_source]")
|
||||
r.Form.Add("fs_rename_source", strings.TrimSpace(r.Form.Get(k)))
|
||||
r.Form.Add("fs_rename_target", strings.TrimSpace(r.Form.Get(base+"[fs_rename_target]")))
|
||||
r.Form["fs_rename_options"+strconv.Itoa(len(r.Form["fs_rename_source"])-1)] =
|
||||
r.Form[base+"[fs_rename_options][]"]
|
||||
continue
|
||||
}
|
||||
if hasPrefixAndSuffix(k, "fs_copy[", "][fs_copy_source]") {
|
||||
|
@ -2396,7 +2449,7 @@ func getEventActionOptionsFromPostFields(r *http.Request) (dataprovider.BaseEven
|
|||
},
|
||||
FsConfig: dataprovider.EventActionFilesystemConfig{
|
||||
Type: fsActionType,
|
||||
Renames: getKeyValsFromPostFields(r, "fs_rename_source", "fs_rename_target"),
|
||||
Renames: getRenameConfigsFromPostFields(r),
|
||||
Deletes: getSliceFromDelimitedValues(r.Form.Get("fs_delete_paths"), ","),
|
||||
MkDirs: getSliceFromDelimitedValues(r.Form.Get("fs_mkdir_paths"), ","),
|
||||
Exist: getSliceFromDelimitedValues(r.Form.Get("fs_exist_paths"), ","),
|
||||
|
@ -2564,9 +2617,9 @@ func getEventRuleActionsFromPostFields(r *http.Request) []dataprovider.EventActi
|
|||
},
|
||||
Order: order + 1,
|
||||
Options: dataprovider.EventActionOptions{
|
||||
IsFailureAction: util.Contains(options, "1"),
|
||||
StopOnFailure: util.Contains(options, "2"),
|
||||
ExecuteSync: util.Contains(options, "3"),
|
||||
IsFailureAction: slices.Contains(options, "1"),
|
||||
StopOnFailure: slices.Contains(options, "2"),
|
||||
ExecuteSync: slices.Contains(options, "3"),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
@ -2769,31 +2822,95 @@ func getSMTPConfigsFromPostFields(r *http.Request) *dataprovider.SMTPConfigs {
|
|||
}
|
||||
}
|
||||
|
||||
func getImageInputBytes(r *http.Request, fieldName, removeFieldName string, defaultVal []byte) ([]byte, error) {
|
||||
var result []byte
|
||||
remove := r.Form.Get(removeFieldName)
|
||||
if remove == "" || remove == "0" {
|
||||
result = defaultVal
|
||||
}
|
||||
f, _, err := r.FormFile(fieldName)
|
||||
if err != nil {
|
||||
if errors.Is(err, http.ErrMissingFile) {
|
||||
return result, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return io.ReadAll(f)
|
||||
}
|
||||
|
||||
func getBrandingConfigFromPostFields(r *http.Request, config *dataprovider.BrandingConfigs) (
|
||||
*dataprovider.BrandingConfigs, error,
|
||||
) {
|
||||
if config == nil {
|
||||
config = &dataprovider.BrandingConfigs{}
|
||||
}
|
||||
adminLogo, err := getImageInputBytes(r, "branding_webadmin_logo", "branding_webadmin_logo_remove", config.WebAdmin.Logo)
|
||||
if err != nil {
|
||||
return nil, util.NewI18nError(err, util.I18nErrorInvalidForm)
|
||||
}
|
||||
adminFavicon, err := getImageInputBytes(r, "branding_webadmin_favicon", "branding_webadmin_favicon_remove",
|
||||
config.WebAdmin.Favicon)
|
||||
if err != nil {
|
||||
return nil, util.NewI18nError(err, util.I18nErrorInvalidForm)
|
||||
}
|
||||
clientLogo, err := getImageInputBytes(r, "branding_webclient_logo", "branding_webclient_logo_remove",
|
||||
config.WebClient.Logo)
|
||||
if err != nil {
|
||||
return nil, util.NewI18nError(err, util.I18nErrorInvalidForm)
|
||||
}
|
||||
clientFavicon, err := getImageInputBytes(r, "branding_webclient_favicon", "branding_webclient_favicon_remove",
|
||||
config.WebClient.Favicon)
|
||||
if err != nil {
|
||||
return nil, util.NewI18nError(err, util.I18nErrorInvalidForm)
|
||||
}
|
||||
|
||||
branding := &dataprovider.BrandingConfigs{
|
||||
WebAdmin: dataprovider.BrandingConfig{
|
||||
Name: strings.TrimSpace(r.Form.Get("branding_webadmin_name")),
|
||||
ShortName: strings.TrimSpace(r.Form.Get("branding_webadmin_short_name")),
|
||||
Logo: adminLogo,
|
||||
Favicon: adminFavicon,
|
||||
DisclaimerName: strings.TrimSpace(r.Form.Get("branding_webadmin_disclaimer_name")),
|
||||
DisclaimerURL: strings.TrimSpace(r.Form.Get("branding_webadmin_disclaimer_url")),
|
||||
},
|
||||
WebClient: dataprovider.BrandingConfig{
|
||||
Name: strings.TrimSpace(r.Form.Get("branding_webclient_name")),
|
||||
ShortName: strings.TrimSpace(r.Form.Get("branding_webclient_short_name")),
|
||||
Logo: clientLogo,
|
||||
Favicon: clientFavicon,
|
||||
DisclaimerName: strings.TrimSpace(r.Form.Get("branding_webclient_disclaimer_name")),
|
||||
DisclaimerURL: strings.TrimSpace(r.Form.Get("branding_webclient_disclaimer_url")),
|
||||
},
|
||||
}
|
||||
return branding, nil
|
||||
}
|
||||
|
||||
func (s *httpdServer) handleWebAdminForgotPwd(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
if !smtp.IsEnabled() {
|
||||
s.renderNotFoundPage(w, r, errors.New("this page does not exist"))
|
||||
return
|
||||
}
|
||||
s.renderForgotPwdPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
s.renderForgotPwdPage(w, r, nil)
|
||||
}
|
||||
|
||||
func (s *httpdServer) handleWebAdminForgotPwdPost(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr)
|
||||
s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
|
||||
return
|
||||
}
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
err = handleForgotPassword(r, r.Form.Get("username"), true)
|
||||
if err != nil {
|
||||
s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric), ipAddr)
|
||||
s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric))
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, webAdminResetPwdPath, http.StatusFound)
|
||||
|
@ -2805,17 +2922,17 @@ func (s *httpdServer) handleWebAdminPasswordReset(w http.ResponseWriter, r *http
|
|||
s.renderNotFoundPage(w, r, errors.New("this page does not exist"))
|
||||
return
|
||||
}
|
||||
s.renderResetPwdPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
s.renderResetPwdPage(w, r, nil)
|
||||
}
|
||||
|
||||
func (s *httpdServer) handleWebAdminTwoFactor(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
s.renderTwoFactorPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
s.renderTwoFactorPage(w, r, nil)
|
||||
}
|
||||
|
||||
func (s *httpdServer) handleWebAdminTwoFactorRecovery(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
s.renderTwoFactorRecoveryPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
s.renderTwoFactorRecoveryPage(w, r, nil)
|
||||
}
|
||||
|
||||
func (s *httpdServer) handleWebAdminMFA(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -2841,7 +2958,7 @@ func (s *httpdServer) handleWebAdminProfilePost(w http.ResponseWriter, r *http.R
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -2886,7 +3003,7 @@ func (s *httpdServer) handleWebRestore(w http.ResponseWriter, r *http.Request) {
|
|||
defer r.MultipartForm.RemoveAll() //nolint:errcheck
|
||||
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -2947,7 +3064,7 @@ func getAllAdmins(w http.ResponseWriter, r *http.Request) {
|
|||
func (s *httpdServer) handleGetWebAdmins(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
|
||||
data := s.getBasePageData(util.I18nAdminsTitle, webAdminsPath, r)
|
||||
data := s.getBasePageData(util.I18nAdminsTitle, webAdminsPath, w, r)
|
||||
renderAdminTemplate(w, templateAdmins, data)
|
||||
}
|
||||
|
||||
|
@ -2957,7 +3074,7 @@ func (s *httpdServer) handleWebAdminSetupGet(w http.ResponseWriter, r *http.Requ
|
|||
http.Redirect(w, r, webAdminLoginPath, http.StatusFound)
|
||||
return
|
||||
}
|
||||
s.renderAdminSetupPage(w, r, "", util.GetIPFromRemoteAddress(r.RemoteAddr), nil)
|
||||
s.renderAdminSetupPage(w, r, "", nil)
|
||||
}
|
||||
|
||||
func (s *httpdServer) handleWebAddAdminGet(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -2998,7 +3115,7 @@ func (s *httpdServer) handleWebAddAdminPost(w http.ResponseWriter, r *http.Reque
|
|||
admin.Password = util.GenerateUniqueID()
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -3029,7 +3146,7 @@ func (s *httpdServer) handleWebUpdateAdminPost(w http.ResponseWriter, r *http.Re
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -3082,7 +3199,7 @@ func (s *httpdServer) handleWebUpdateAdminPost(w http.ResponseWriter, r *http.Re
|
|||
func (s *httpdServer) handleWebDefenderPage(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
data := defenderHostsPage{
|
||||
basePage: s.getBasePageData(util.I18nDefenderTitle, webDefenderPath, r),
|
||||
basePage: s.getBasePageData(util.I18nDefenderTitle, webDefenderPath, w, r),
|
||||
DefenderHostsURL: webDefenderHostsPath,
|
||||
}
|
||||
|
||||
|
@ -3116,7 +3233,7 @@ func (s *httpdServer) handleGetWebUsers(w http.ResponseWriter, r *http.Request)
|
|||
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
|
||||
return
|
||||
}
|
||||
data := s.getBasePageData(util.I18nUsersTitle, webUsersPath, r)
|
||||
data := s.getBasePageData(util.I18nUsersTitle, webUsersPath, w, r)
|
||||
renderAdminTemplate(w, templateUsers, data)
|
||||
}
|
||||
|
||||
|
@ -3155,7 +3272,7 @@ func (s *httpdServer) handleWebTemplateFolderPost(w http.ResponseWriter, r *http
|
|||
defer r.MultipartForm.RemoveAll() //nolint:errcheck
|
||||
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -3170,7 +3287,6 @@ func (s *httpdServer) handleWebTemplateFolderPost(w http.ResponseWriter, r *http
|
|||
templateFolder.FsConfig = fsConfig
|
||||
|
||||
var dump dataprovider.BackupData
|
||||
dump.Version = dataprovider.DumpVersion
|
||||
|
||||
foldersFields := getFoldersForTemplate(r)
|
||||
for _, tmpl := range foldersFields {
|
||||
|
@ -3190,12 +3306,6 @@ func (s *httpdServer) handleWebTemplateFolderPost(w http.ResponseWriter, r *http
|
|||
), "")
|
||||
return
|
||||
}
|
||||
if r.Form.Get("form_action") == "export_from_template" {
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"sftpgo-%v-folders-from-template.json\"",
|
||||
len(dump.Folders)))
|
||||
render.JSON(w, r, dump)
|
||||
return
|
||||
}
|
||||
if err = RestoreFolders(dump.Folders, "", 1, 0, claims.Username, ipAddr, claims.Role); err != nil {
|
||||
s.renderMessagePage(w, r, util.I18nTemplateFolderTitle, getRespStatus(err), err, "")
|
||||
return
|
||||
|
@ -3218,6 +3328,7 @@ func (s *httpdServer) handleWebTemplateUserGet(w http.ResponseWriter, r *http.Re
|
|||
user.SetEmptySecrets()
|
||||
user.PublicKeys = nil
|
||||
user.Email = ""
|
||||
user.Filters.AdditionalEmails = nil
|
||||
user.Description = ""
|
||||
if user.ExpirationDate == 0 && admin.Filters.Preferences.DefaultUsersExpiration > 0 {
|
||||
user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(admin.Filters.Preferences.DefaultUsersExpiration)))
|
||||
|
@ -3255,13 +3366,12 @@ func (s *httpdServer) handleWebTemplateUserPost(w http.ResponseWriter, r *http.R
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
||||
var dump dataprovider.BackupData
|
||||
dump.Version = dataprovider.DumpVersion
|
||||
|
||||
userTmplFields := getUsersForTemplate(r)
|
||||
for _, tmpl := range userTmplFields {
|
||||
|
@ -3270,14 +3380,10 @@ func (s *httpdServer) handleWebTemplateUserPost(w http.ResponseWriter, r *http.R
|
|||
s.renderMessagePage(w, r, util.I18nTemplateUserTitle, http.StatusBadRequest, err, "")
|
||||
return
|
||||
}
|
||||
// to create a template the "*" permission is required, so role admins cannot use
|
||||
// this method, we don't need to force the role
|
||||
dump.Users = append(dump.Users, u)
|
||||
for _, folder := range u.VirtualFolders {
|
||||
if !dump.HasFolder(folder.Name) {
|
||||
dump.Folders = append(dump.Folders, folder.BaseVirtualFolder)
|
||||
}
|
||||
if claims.Role != "" {
|
||||
u.Role = claims.Role
|
||||
}
|
||||
dump.Users = append(dump.Users, u)
|
||||
}
|
||||
|
||||
if len(dump.Users) == 0 {
|
||||
|
@ -3288,12 +3394,6 @@ func (s *httpdServer) handleWebTemplateUserPost(w http.ResponseWriter, r *http.R
|
|||
), "")
|
||||
return
|
||||
}
|
||||
if r.Form.Get("form_action") == "export_from_template" {
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"sftpgo-%v-users-from-template.json\"",
|
||||
len(dump.Users)))
|
||||
render.JSON(w, r, dump)
|
||||
return
|
||||
}
|
||||
if err = RestoreUsers(dump.Users, "", 1, 0, claims.Username, ipAddr, claims.Role); err != nil {
|
||||
s.renderMessagePage(w, r, util.I18nTemplateUserTitle, getRespStatus(err), err, "")
|
||||
return
|
||||
|
@ -3352,7 +3452,7 @@ func (s *httpdServer) handleWebAddUserPost(w http.ResponseWriter, r *http.Reques
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -3398,7 +3498,7 @@ func (s *httpdServer) handleWebUpdateUserPost(w http.ResponseWriter, r *http.Req
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -3436,7 +3536,7 @@ func (s *httpdServer) handleWebUpdateUserPost(w http.ResponseWriter, r *http.Req
|
|||
func (s *httpdServer) handleWebGetStatus(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
data := statusPage{
|
||||
basePage: s.getBasePageData(util.I18nStatusTitle, webStatusPath, r),
|
||||
basePage: s.getBasePageData(util.I18nStatusTitle, webStatusPath, w, r),
|
||||
Status: getServicesStatus(),
|
||||
}
|
||||
renderAdminTemplate(w, templateStatus, data)
|
||||
|
@ -3450,7 +3550,7 @@ func (s *httpdServer) handleWebGetConnections(w http.ResponseWriter, r *http.Req
|
|||
return
|
||||
}
|
||||
|
||||
data := s.getBasePageData(util.I18nSessionsTitle, webConnectionsPath, r)
|
||||
data := s.getBasePageData(util.I18nSessionsTitle, webConnectionsPath, w, r)
|
||||
renderAdminTemplate(w, templateConnections, data)
|
||||
}
|
||||
|
||||
|
@ -3475,7 +3575,7 @@ func (s *httpdServer) handleWebAddFolderPost(w http.ResponseWriter, r *http.Requ
|
|||
defer r.MultipartForm.RemoveAll() //nolint:errcheck
|
||||
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -3536,7 +3636,7 @@ func (s *httpdServer) handleWebUpdateFolderPost(w http.ResponseWriter, r *http.R
|
|||
defer r.MultipartForm.RemoveAll() //nolint:errcheck
|
||||
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -3599,7 +3699,7 @@ func getAllFolders(w http.ResponseWriter, r *http.Request) {
|
|||
func (s *httpdServer) handleWebGetFolders(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
|
||||
data := s.getBasePageData(util.I18nFoldersTitle, webFoldersPath, r)
|
||||
data := s.getBasePageData(util.I18nFoldersTitle, webFoldersPath, w, r)
|
||||
renderAdminTemplate(w, templateFolders, data)
|
||||
}
|
||||
|
||||
|
@ -3637,7 +3737,7 @@ func getAllGroups(w http.ResponseWriter, r *http.Request) {
|
|||
func (s *httpdServer) handleWebGetGroups(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
|
||||
data := s.getBasePageData(util.I18nGroupsTitle, webGroupsPath, r)
|
||||
data := s.getBasePageData(util.I18nGroupsTitle, webGroupsPath, w, r)
|
||||
renderAdminTemplate(w, templateGroups, data)
|
||||
}
|
||||
|
||||
|
@ -3659,7 +3759,7 @@ func (s *httpdServer) handleWebAddGroupPost(w http.ResponseWriter, r *http.Reque
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -3706,7 +3806,7 @@ func (s *httpdServer) handleWebUpdateGroupPost(w http.ResponseWriter, r *http.Re
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -3759,7 +3859,7 @@ func getAllActions(w http.ResponseWriter, r *http.Request) {
|
|||
func (s *httpdServer) handleWebGetEventActions(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
|
||||
data := s.getBasePageData(util.I18nActionsTitle, webAdminEventActionsPath, r)
|
||||
data := s.getBasePageData(util.I18nActionsTitle, webAdminEventActionsPath, w, r)
|
||||
renderAdminTemplate(w, templateEventActions, data)
|
||||
}
|
||||
|
||||
|
@ -3784,7 +3884,7 @@ func (s *httpdServer) handleWebAddEventActionPost(w http.ResponseWriter, r *http
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -3830,7 +3930,7 @@ func (s *httpdServer) handleWebUpdateEventActionPost(w http.ResponseWriter, r *h
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -3869,7 +3969,7 @@ func getAllRules(w http.ResponseWriter, r *http.Request) {
|
|||
func (s *httpdServer) handleWebGetEventRules(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
|
||||
data := s.getBasePageData(util.I18nRulesTitle, webAdminEventRulesPath, r)
|
||||
data := s.getBasePageData(util.I18nRulesTitle, webAdminEventRulesPath, w, r)
|
||||
renderAdminTemplate(w, templateEventRules, data)
|
||||
}
|
||||
|
||||
|
@ -3895,7 +3995,7 @@ func (s *httpdServer) handleWebAddEventRulePost(w http.ResponseWriter, r *http.R
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
err = verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr)
|
||||
err = verifyCSRFToken(r, s.csrfTokenAuth)
|
||||
if err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
|
@ -3942,7 +4042,7 @@ func (s *httpdServer) handleWebUpdateEventRulePost(w http.ResponseWriter, r *htt
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -3989,7 +4089,7 @@ func getAllRoles(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
func (s *httpdServer) handleWebGetRoles(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
data := s.getBasePageData(util.I18nRolesTitle, webAdminRolesPath, r)
|
||||
data := s.getBasePageData(util.I18nRolesTitle, webAdminRolesPath, w, r)
|
||||
|
||||
renderAdminTemplate(w, templateRoles, data)
|
||||
}
|
||||
|
@ -4012,7 +4112,7 @@ func (s *httpdServer) handleWebAddRolePost(w http.ResponseWriter, r *http.Reques
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -4058,7 +4158,7 @@ func (s *httpdServer) handleWebUpdateRolePost(w http.ResponseWriter, r *http.Req
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -4076,7 +4176,7 @@ func (s *httpdServer) handleWebGetEvents(w http.ResponseWriter, r *http.Request)
|
|||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
|
||||
data := eventsPage{
|
||||
basePage: s.getBasePageData(util.I18nEventsTitle, webEventsPath, r),
|
||||
basePage: s.getBasePageData(util.I18nEventsTitle, webEventsPath, w, r),
|
||||
FsEventsSearchURL: webEventsFsSearchPath,
|
||||
ProviderEventsSearchURL: webEventsProviderSearchPath,
|
||||
LogEventsSearchURL: webEventsLogSearchPath,
|
||||
|
@ -4088,7 +4188,7 @@ func (s *httpdServer) handleWebIPListsPage(w http.ResponseWriter, r *http.Reques
|
|||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
rtlStatus, rtlProtocols := common.Config.GetRateLimitersStatus()
|
||||
data := ipListsPage{
|
||||
basePage: s.getBasePageData(util.I18nIPListsTitle, webIPListsPath, r),
|
||||
basePage: s.getBasePageData(util.I18nIPListsTitle, webIPListsPath, w, r),
|
||||
RateLimitersStatus: rtlStatus,
|
||||
RateLimitersProtocols: strings.Join(rtlProtocols, ", "),
|
||||
IsAllowListEnabled: common.Config.IsAllowListEnabled(),
|
||||
|
@ -4126,7 +4226,7 @@ func (s *httpdServer) handleWebAddIPListEntryPost(w http.ResponseWriter, r *http
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -4181,7 +4281,7 @@ func (s *httpdServer) handleWebUpdateIPListEntryPost(w http.ResponseWriter, r *h
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -4217,13 +4317,15 @@ func (s *httpdServer) handleWebConfigsPost(w http.ResponseWriter, r *http.Reques
|
|||
s.renderInternalServerErrorPage(w, r, err)
|
||||
return
|
||||
}
|
||||
err = r.ParseForm()
|
||||
err = r.ParseMultipartForm(maxRequestSize)
|
||||
if err != nil {
|
||||
s.renderBadRequestPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
|
||||
return
|
||||
}
|
||||
defer r.MultipartForm.RemoveAll() //nolint:errcheck
|
||||
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -4247,6 +4349,15 @@ func (s *httpdServer) handleWebConfigsPost(w http.ResponseWriter, r *http.Reques
|
|||
smtpConfigs := getSMTPConfigsFromPostFields(r)
|
||||
updateSMTPSecrets(smtpConfigs, configs.SMTP)
|
||||
configs.SMTP = smtpConfigs
|
||||
case "branding_submit":
|
||||
configSection = 4
|
||||
brandingConfigs, err := getBrandingConfigFromPostFields(r, configs.Branding)
|
||||
configs.Branding = brandingConfigs
|
||||
if err != nil {
|
||||
logger.Info(logSender, "", "unable to get branding config: %v", err)
|
||||
s.renderConfigsPage(w, r, configs, err, configSection)
|
||||
return
|
||||
}
|
||||
default:
|
||||
s.renderBadRequestPage(w, r, errors.New("unsupported form action"))
|
||||
return
|
||||
|
@ -4257,15 +4368,22 @@ func (s *httpdServer) handleWebConfigsPost(w http.ResponseWriter, r *http.Reques
|
|||
s.renderConfigsPage(w, r, configs, err, configSection)
|
||||
return
|
||||
}
|
||||
if configSection == 3 {
|
||||
postConfigsUpdate(configSection, configs)
|
||||
s.renderMessagePage(w, r, util.I18nConfigsTitle, http.StatusOK, nil, util.I18nConfigsOK)
|
||||
}
|
||||
|
||||
func postConfigsUpdate(section int, configs dataprovider.Configs) {
|
||||
switch section {
|
||||
case 3:
|
||||
err := configs.SMTP.TryDecrypt()
|
||||
if err == nil {
|
||||
smtp.Activate(configs.SMTP)
|
||||
} else {
|
||||
logger.Error(logSender, "", "unable to decrypt SMTP configuration, cannot activate configuration: %v", err)
|
||||
}
|
||||
case 4:
|
||||
dbBrandingConfig.Set(configs.Branding)
|
||||
}
|
||||
s.renderMessagePage(w, r, util.I18nConfigsTitle, http.StatusOK, nil, util.I18nConfigsOK)
|
||||
}
|
||||
|
||||
func (s *httpdServer) handleOAuth2TokenRedirect(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -4273,7 +4391,7 @@ func (s *httpdServer) handleOAuth2TokenRedirect(w http.ResponseWriter, r *http.R
|
|||
|
||||
stateToken := r.URL.Query().Get("state")
|
||||
|
||||
state, err := verifyOAuth2Token(stateToken, util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
state, err := verifyOAuth2Token(s.csrfTokenAuth, stateToken, util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
if err != nil {
|
||||
s.renderMessagePage(w, r, util.I18nOAuth2ErrorTitle, http.StatusBadRequest, err, "")
|
||||
return
|
||||
|
|
|
@ -27,6 +27,7 @@ import (
|
|||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -81,21 +82,23 @@ func isZeroTime(t time.Time) bool {
|
|||
|
||||
type baseClientPage struct {
|
||||
commonBasePage
|
||||
Title string
|
||||
CurrentURL string
|
||||
FilesURL string
|
||||
SharesURL string
|
||||
ShareURL string
|
||||
ProfileURL string
|
||||
PingURL string
|
||||
ChangePwdURL string
|
||||
LogoutURL string
|
||||
LoginURL string
|
||||
EditURL string
|
||||
MFAURL string
|
||||
CSRFToken string
|
||||
LoggedUser *dataprovider.User
|
||||
Branding UIBranding
|
||||
Title string
|
||||
CurrentURL string
|
||||
FilesURL string
|
||||
SharesURL string
|
||||
ShareURL string
|
||||
ProfileURL string
|
||||
PingURL string
|
||||
ChangePwdURL string
|
||||
LogoutURL string
|
||||
LoginURL string
|
||||
EditURL string
|
||||
MFAURL string
|
||||
CSRFToken string
|
||||
LoggedUser *dataprovider.User
|
||||
IsLoggedToShare bool
|
||||
Branding UIBranding
|
||||
Languages []string
|
||||
}
|
||||
|
||||
type dirMapping struct {
|
||||
|
@ -105,9 +108,10 @@ type dirMapping struct {
|
|||
|
||||
type viewPDFPage struct {
|
||||
commonBasePage
|
||||
Title string
|
||||
URL string
|
||||
Branding UIBranding
|
||||
Title string
|
||||
URL string
|
||||
Branding UIBranding
|
||||
Languages []string
|
||||
}
|
||||
|
||||
type editFilePage struct {
|
||||
|
@ -150,6 +154,7 @@ type shareLoginPage struct {
|
|||
CSRFToken string
|
||||
Title string
|
||||
Branding UIBranding
|
||||
Languages []string
|
||||
}
|
||||
|
||||
type shareDownloadPage struct {
|
||||
|
@ -172,13 +177,15 @@ type clientMessagePage struct {
|
|||
|
||||
type clientProfilePage struct {
|
||||
baseClientPage
|
||||
PublicKeys []string
|
||||
TLSCerts []string
|
||||
CanSubmit bool
|
||||
AllowAPIKeyAuth bool
|
||||
Email string
|
||||
Description string
|
||||
Error *util.I18nError
|
||||
PublicKeys []string
|
||||
TLSCerts []string
|
||||
CanSubmit bool
|
||||
AllowAPIKeyAuth bool
|
||||
Email string
|
||||
AdditionalEmails []string
|
||||
AdditionalEmailsString string
|
||||
Description string
|
||||
Error *util.I18nError
|
||||
}
|
||||
|
||||
type changeClientPasswordPage struct {
|
||||
|
@ -523,28 +530,30 @@ func loadClientTemplates(templatesPath string) {
|
|||
clientTemplates[templateShareDownload] = shareDownloadTmpl
|
||||
}
|
||||
|
||||
func (s *httpdServer) getBaseClientPageData(title, currentURL string, r *http.Request) baseClientPage {
|
||||
func (s *httpdServer) getBaseClientPageData(title, currentURL string, w http.ResponseWriter, r *http.Request) baseClientPage {
|
||||
var csrfToken string
|
||||
if currentURL != "" {
|
||||
csrfToken = createCSRFToken(util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
csrfToken = createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath)
|
||||
}
|
||||
|
||||
data := baseClientPage{
|
||||
commonBasePage: getCommonBasePage(r),
|
||||
Title: title,
|
||||
CurrentURL: currentURL,
|
||||
FilesURL: webClientFilesPath,
|
||||
SharesURL: webClientSharesPath,
|
||||
ShareURL: webClientSharePath,
|
||||
ProfileURL: webClientProfilePath,
|
||||
PingURL: webClientPingPath,
|
||||
ChangePwdURL: webChangeClientPwdPath,
|
||||
LogoutURL: webClientLogoutPath,
|
||||
EditURL: webClientEditFilePath,
|
||||
MFAURL: webClientMFAPath,
|
||||
CSRFToken: csrfToken,
|
||||
LoggedUser: getUserFromToken(r),
|
||||
Branding: s.binding.Branding.WebClient,
|
||||
commonBasePage: getCommonBasePage(r),
|
||||
Title: title,
|
||||
CurrentURL: currentURL,
|
||||
FilesURL: webClientFilesPath,
|
||||
SharesURL: webClientSharesPath,
|
||||
ShareURL: webClientSharePath,
|
||||
ProfileURL: webClientProfilePath,
|
||||
PingURL: webClientPingPath,
|
||||
ChangePwdURL: webChangeClientPwdPath,
|
||||
LogoutURL: webClientLogoutPath,
|
||||
EditURL: webClientEditFilePath,
|
||||
MFAURL: webClientMFAPath,
|
||||
CSRFToken: csrfToken,
|
||||
LoggedUser: getUserFromToken(r),
|
||||
IsLoggedToShare: false,
|
||||
Branding: s.binding.webClientBranding(),
|
||||
Languages: s.binding.languages(),
|
||||
}
|
||||
if !strings.HasPrefix(r.RequestURI, webClientPubSharesPath) {
|
||||
data.LoginURL = webClientLoginPath
|
||||
|
@ -552,40 +561,43 @@ func (s *httpdServer) getBaseClientPageData(title, currentURL string, r *http.Re
|
|||
return data
|
||||
}
|
||||
|
||||
func (s *httpdServer) renderClientForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
|
||||
func (s *httpdServer) renderClientForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
|
||||
data := forgotPwdPage{
|
||||
commonBasePage: getCommonBasePage(r),
|
||||
CurrentURL: webClientForgotPwdPath,
|
||||
Error: err,
|
||||
CSRFToken: createCSRFToken(ip),
|
||||
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath),
|
||||
LoginURL: webClientLoginPath,
|
||||
Title: util.I18nForgotPwdTitle,
|
||||
Branding: s.binding.Branding.WebClient,
|
||||
Branding: s.binding.webClientBranding(),
|
||||
Languages: s.binding.languages(),
|
||||
}
|
||||
renderClientTemplate(w, templateForgotPassword, data)
|
||||
}
|
||||
|
||||
func (s *httpdServer) renderClientResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
|
||||
func (s *httpdServer) renderClientResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
|
||||
data := resetPwdPage{
|
||||
commonBasePage: getCommonBasePage(r),
|
||||
CurrentURL: webClientResetPwdPath,
|
||||
Error: err,
|
||||
CSRFToken: createCSRFToken(ip),
|
||||
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath),
|
||||
LoginURL: webClientLoginPath,
|
||||
Title: util.I18nResetPwdTitle,
|
||||
Branding: s.binding.Branding.WebClient,
|
||||
Branding: s.binding.webClientBranding(),
|
||||
Languages: s.binding.languages(),
|
||||
}
|
||||
renderClientTemplate(w, templateResetPassword, data)
|
||||
}
|
||||
|
||||
func (s *httpdServer) renderShareLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
|
||||
func (s *httpdServer) renderShareLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
|
||||
data := shareLoginPage{
|
||||
commonBasePage: getCommonBasePage(r),
|
||||
Title: util.I18nShareLoginTitle,
|
||||
CurrentURL: r.RequestURI,
|
||||
Error: err,
|
||||
CSRFToken: createCSRFToken(ip),
|
||||
Branding: s.binding.Branding.WebClient,
|
||||
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath),
|
||||
Branding: s.binding.webClientBranding(),
|
||||
Languages: s.binding.languages(),
|
||||
}
|
||||
renderClientTemplate(w, templateShareLogin, data)
|
||||
}
|
||||
|
@ -599,7 +611,7 @@ func renderClientTemplate(w http.ResponseWriter, tmplName string, data any) {
|
|||
|
||||
func (s *httpdServer) renderClientMessagePage(w http.ResponseWriter, r *http.Request, title string, statusCode int, err error, message string) {
|
||||
data := clientMessagePage{
|
||||
baseClientPage: s.getBaseClientPageData(title, "", r),
|
||||
baseClientPage: s.getBaseClientPageData(title, "", w, r),
|
||||
Error: getI18nError(err),
|
||||
Success: message,
|
||||
}
|
||||
|
@ -627,15 +639,16 @@ func (s *httpdServer) renderClientNotFoundPage(w http.ResponseWriter, r *http.Re
|
|||
util.NewI18nError(err, util.I18nError404Message), "")
|
||||
}
|
||||
|
||||
func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
|
||||
func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
|
||||
data := twoFactorPage{
|
||||
commonBasePage: getCommonBasePage(r),
|
||||
Title: pageTwoFactorTitle,
|
||||
CurrentURL: webClientTwoFactorPath,
|
||||
Error: err,
|
||||
CSRFToken: createCSRFToken(ip),
|
||||
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath),
|
||||
RecoveryURL: webClientTwoFactorRecoveryPath,
|
||||
Branding: s.binding.Branding.WebClient,
|
||||
Branding: s.binding.webClientBranding(),
|
||||
Languages: s.binding.languages(),
|
||||
}
|
||||
if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) {
|
||||
data.CurrentURL += "?next=" + url.QueryEscape(next)
|
||||
|
@ -643,21 +656,22 @@ func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, r *http.R
|
|||
renderClientTemplate(w, templateTwoFactor, data)
|
||||
}
|
||||
|
||||
func (s *httpdServer) renderClientTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) {
|
||||
func (s *httpdServer) renderClientTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
|
||||
data := twoFactorPage{
|
||||
commonBasePage: getCommonBasePage(r),
|
||||
Title: pageTwoFactorRecoveryTitle,
|
||||
CurrentURL: webClientTwoFactorRecoveryPath,
|
||||
Error: err,
|
||||
CSRFToken: createCSRFToken(ip),
|
||||
Branding: s.binding.Branding.WebClient,
|
||||
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath),
|
||||
Branding: s.binding.webClientBranding(),
|
||||
Languages: s.binding.languages(),
|
||||
}
|
||||
renderClientTemplate(w, templateTwoFactorRecovery, data)
|
||||
}
|
||||
|
||||
func (s *httpdServer) renderClientMFAPage(w http.ResponseWriter, r *http.Request) {
|
||||
data := clientMFAPage{
|
||||
baseClientPage: s.getBaseClientPageData(util.I18n2FATitle, webClientMFAPath, r),
|
||||
baseClientPage: s.getBaseClientPageData(util.I18n2FATitle, webClientMFAPath, w, r),
|
||||
TOTPConfigs: mfa.GetAvailableTOTPConfigNames(),
|
||||
GenerateTOTPURL: webClientTOTPGeneratePath,
|
||||
ValidateTOTPURL: webClientTOTPValidatePath,
|
||||
|
@ -681,7 +695,7 @@ func (s *httpdServer) renderEditFilePage(w http.ResponseWriter, r *http.Request,
|
|||
title = util.I18nEditFileTitle
|
||||
}
|
||||
data := editFilePage{
|
||||
baseClientPage: s.getBaseClientPageData(title, webClientEditFilePath, r),
|
||||
baseClientPage: s.getBaseClientPageData(title, webClientEditFilePath, w, r),
|
||||
Path: fileName,
|
||||
Name: path.Base(fileName),
|
||||
CurrentDir: path.Dir(fileName),
|
||||
|
@ -705,7 +719,7 @@ func (s *httpdServer) renderAddUpdateSharePage(w http.ResponseWriter, r *http.Re
|
|||
share.Password = redactedSecret
|
||||
}
|
||||
data := clientSharePage{
|
||||
baseClientPage: s.getBaseClientPageData(title, currentURL, r),
|
||||
baseClientPage: s.getBaseClientPageData(title, currentURL, w, r),
|
||||
Share: share,
|
||||
Error: err,
|
||||
IsAdd: isAdd,
|
||||
|
@ -739,9 +753,11 @@ func (s *httpdServer) renderSharedFilesPage(w http.ResponseWriter, r *http.Reque
|
|||
err *util.I18nError, share dataprovider.Share,
|
||||
) {
|
||||
currentURL := path.Join(webClientPubSharesPath, share.ShareID, "browse")
|
||||
baseData := s.getBaseClientPageData(util.I18nSharedFilesTitle, currentURL, r)
|
||||
baseData := s.getBaseClientPageData(util.I18nSharedFilesTitle, currentURL, w, r)
|
||||
baseData.FilesURL = currentURL
|
||||
baseSharePath := path.Join(webClientPubSharesPath, share.ShareID)
|
||||
baseData.LogoutURL = path.Join(webClientPubSharesPath, share.ShareID, "logout")
|
||||
baseData.IsLoggedToShare = share.Password != ""
|
||||
|
||||
data := filesPage{
|
||||
baseClientPage: baseData,
|
||||
|
@ -769,28 +785,39 @@ func (s *httpdServer) renderSharedFilesPage(w http.ResponseWriter, r *http.Reque
|
|||
renderClientTemplate(w, templateClientFiles, data)
|
||||
}
|
||||
|
||||
func (s *httpdServer) renderShareDownloadPage(w http.ResponseWriter, r *http.Request, downloadLink string) {
|
||||
func (s *httpdServer) renderShareDownloadPage(w http.ResponseWriter, r *http.Request, share *dataprovider.Share,
|
||||
downloadLink string,
|
||||
) {
|
||||
data := shareDownloadPage{
|
||||
baseClientPage: s.getBaseClientPageData(util.I18nShareDownloadTitle, "", r),
|
||||
baseClientPage: s.getBaseClientPageData(util.I18nShareDownloadTitle, "", w, r),
|
||||
DownloadLink: downloadLink,
|
||||
}
|
||||
data.LogoutURL = ""
|
||||
if share.Password != "" {
|
||||
data.LogoutURL = path.Join(webClientPubSharesPath, share.ShareID, "logout")
|
||||
}
|
||||
|
||||
renderClientTemplate(w, templateShareDownload, data)
|
||||
}
|
||||
|
||||
func (s *httpdServer) renderUploadToSharePage(w http.ResponseWriter, r *http.Request, share dataprovider.Share) {
|
||||
func (s *httpdServer) renderUploadToSharePage(w http.ResponseWriter, r *http.Request, share *dataprovider.Share) {
|
||||
currentURL := path.Join(webClientPubSharesPath, share.ShareID, "upload")
|
||||
data := shareUploadPage{
|
||||
baseClientPage: s.getBaseClientPageData(util.I18nShareUploadTitle, currentURL, r),
|
||||
Share: &share,
|
||||
baseClientPage: s.getBaseClientPageData(util.I18nShareUploadTitle, currentURL, w, r),
|
||||
Share: share,
|
||||
UploadBasePath: path.Join(webClientPubSharesPath, share.ShareID),
|
||||
}
|
||||
data.LogoutURL = ""
|
||||
if share.Password != "" {
|
||||
data.LogoutURL = path.Join(webClientPubSharesPath, share.ShareID, "logout")
|
||||
}
|
||||
renderClientTemplate(w, templateUploadToShare, data)
|
||||
}
|
||||
|
||||
func (s *httpdServer) renderFilesPage(w http.ResponseWriter, r *http.Request, dirName string,
|
||||
err *util.I18nError, user *dataprovider.User) {
|
||||
data := filesPage{
|
||||
baseClientPage: s.getBaseClientPageData(util.I18nFilesTitle, webClientFilesPath, r),
|
||||
baseClientPage: s.getBaseClientPageData(util.I18nFilesTitle, webClientFilesPath, w, r),
|
||||
Error: err,
|
||||
CurrentDir: url.QueryEscape(dirName),
|
||||
DownloadURL: webClientDownloadZipPath,
|
||||
|
@ -816,7 +843,7 @@ func (s *httpdServer) renderFilesPage(w http.ResponseWriter, r *http.Request, di
|
|||
|
||||
func (s *httpdServer) renderClientProfilePage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
|
||||
data := clientProfilePage{
|
||||
baseClientPage: s.getBaseClientPageData(util.I18nProfileTitle, webClientProfilePath, r),
|
||||
baseClientPage: s.getBaseClientPageData(util.I18nProfileTitle, webClientProfilePath, w, r),
|
||||
Error: err,
|
||||
}
|
||||
user, userMerged, errUser := dataprovider.GetUserVariants(data.LoggedUser.Username, "")
|
||||
|
@ -828,6 +855,8 @@ func (s *httpdServer) renderClientProfilePage(w http.ResponseWriter, r *http.Req
|
|||
data.TLSCerts = user.Filters.TLSCerts
|
||||
data.AllowAPIKeyAuth = user.Filters.AllowAPIKeyAuth
|
||||
data.Email = user.Email
|
||||
data.AdditionalEmails = user.Filters.AdditionalEmails
|
||||
data.AdditionalEmailsString = strings.Join(data.AdditionalEmails, ", ")
|
||||
data.Description = user.Description
|
||||
data.CanSubmit = userMerged.CanUpdateProfile()
|
||||
renderClientTemplate(w, templateClientProfile, data)
|
||||
|
@ -835,7 +864,7 @@ func (s *httpdServer) renderClientProfilePage(w http.ResponseWriter, r *http.Req
|
|||
|
||||
func (s *httpdServer) renderClientChangePasswordPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
|
||||
data := changeClientPasswordPage{
|
||||
baseClientPage: s.getBaseClientPageData(util.I18nChangePwdTitle, webChangeClientPwdPath, r),
|
||||
baseClientPage: s.getBaseClientPageData(util.I18nChangePwdTitle, webChangeClientPwdPath, w, r),
|
||||
Error: err,
|
||||
}
|
||||
|
||||
|
@ -853,8 +882,7 @@ func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http.
|
|||
s.renderClientBadRequestPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
|
||||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -1026,7 +1054,7 @@ func (s *httpdServer) handleClientUploadToShare(w http.ResponseWriter, r *http.R
|
|||
http.Redirect(w, r, path.Join(webClientPubSharesPath, share.ShareID, "browse"), http.StatusFound)
|
||||
return
|
||||
}
|
||||
s.renderUploadToSharePage(w, r, share)
|
||||
s.renderUploadToSharePage(w, r, &share)
|
||||
}
|
||||
|
||||
func (s *httpdServer) handleShareGetFiles(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -1091,7 +1119,8 @@ func (s *httpdServer) handleShareViewPDF(w http.ResponseWriter, r *http.Request)
|
|||
Title: path.Base(name),
|
||||
URL: fmt.Sprintf("%s?path=%s&_=%d", path.Join(webClientPubSharesPath, share.ShareID, "getpdf"),
|
||||
url.QueryEscape(name), time.Now().UTC().Unix()),
|
||||
Branding: s.binding.Branding.WebClient,
|
||||
Branding: s.binding.webClientBranding(),
|
||||
Languages: s.binding.languages(),
|
||||
}
|
||||
renderClientTemplate(w, templateClientViewPDF, data)
|
||||
}
|
||||
|
@ -1442,7 +1471,7 @@ func (s *httpdServer) handleClientAddSharePost(w http.ResponseWriter, r *http.Re
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -1451,7 +1480,7 @@ func (s *httpdServer) handleClientAddSharePost(w http.ResponseWriter, r *http.Re
|
|||
share.LastUseAt = 0
|
||||
share.Username = claims.Username
|
||||
if share.Password == "" {
|
||||
if util.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) {
|
||||
if slices.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) {
|
||||
s.renderAddUpdateSharePage(w, r, share,
|
||||
util.NewI18nError(util.NewValidationError("You are not allowed to share files/folders without password"), util.I18nErrorShareNoPwd),
|
||||
true)
|
||||
|
@ -1510,7 +1539,7 @@ func (s *httpdServer) handleClientUpdateSharePost(w http.ResponseWriter, r *http
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -1520,7 +1549,7 @@ func (s *httpdServer) handleClientUpdateSharePost(w http.ResponseWriter, r *http
|
|||
updatedShare.Password = share.Password
|
||||
}
|
||||
if updatedShare.Password == "" {
|
||||
if util.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) {
|
||||
if slices.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) {
|
||||
s.renderAddUpdateSharePage(w, r, updatedShare,
|
||||
util.NewI18nError(util.NewValidationError("You are not allowed to share files/folders without password"), util.I18nErrorShareNoPwd),
|
||||
false)
|
||||
|
@ -1581,7 +1610,7 @@ func (s *httpdServer) handleClientGetShares(w http.ResponseWriter, r *http.Reque
|
|||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
|
||||
data := clientSharesPage{
|
||||
baseClientPage: s.getBaseClientPageData(util.I18nSharesTitle, webClientSharesPath, r),
|
||||
baseClientPage: s.getBaseClientPageData(util.I18nSharesTitle, webClientSharesPath, w, r),
|
||||
BasePublicSharesURL: webClientPubSharesPath,
|
||||
}
|
||||
renderClientTemplate(w, templateClientShares, data)
|
||||
|
@ -1605,7 +1634,7 @@ func (s *httpdServer) handleWebClientProfilePost(w http.ResponseWriter, r *http.
|
|||
return
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
|
@ -1648,6 +1677,15 @@ func (s *httpdServer) handleWebClientProfilePost(w http.ResponseWriter, r *http.
|
|||
if userMerged.CanChangeInfo() {
|
||||
user.Email = strings.TrimSpace(r.Form.Get("email"))
|
||||
user.Description = r.Form.Get("description")
|
||||
for k := range r.Form {
|
||||
if hasPrefixAndSuffix(k, "additional_emails[", "][additional_email]") {
|
||||
email := strings.TrimSpace(r.Form.Get(k))
|
||||
if email != "" {
|
||||
r.Form.Add("additional_emails", email)
|
||||
}
|
||||
}
|
||||
}
|
||||
user.Filters.AdditionalEmails = r.Form["additional_emails"]
|
||||
}
|
||||
err = dataprovider.UpdateUser(&user, dataprovider.ActionExecutorSelf, ipAddr, user.Role)
|
||||
if err != nil {
|
||||
|
@ -1664,12 +1702,12 @@ func (s *httpdServer) handleWebClientMFA(w http.ResponseWriter, r *http.Request)
|
|||
|
||||
func (s *httpdServer) handleWebClientTwoFactor(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
s.renderClientTwoFactorPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
s.renderClientTwoFactorPage(w, r, nil)
|
||||
}
|
||||
|
||||
func (s *httpdServer) handleWebClientTwoFactorRecovery(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
s.renderClientTwoFactorRecoveryPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
s.renderClientTwoFactorRecoveryPage(w, r, nil)
|
||||
}
|
||||
|
||||
func getShareFromPostFields(r *http.Request) (*dataprovider.Share, error) {
|
||||
|
@ -1721,26 +1759,25 @@ func (s *httpdServer) handleWebClientForgotPwd(w http.ResponseWriter, r *http.Re
|
|||
s.renderClientNotFoundPage(w, r, errors.New("this page does not exist"))
|
||||
return
|
||||
}
|
||||
s.renderClientForgotPwdPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
s.renderClientForgotPwdPage(w, r, nil)
|
||||
}
|
||||
|
||||
func (s *httpdServer) handleWebClientForgotPwdPost(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr)
|
||||
s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
|
||||
return
|
||||
}
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
username := strings.TrimSpace(r.Form.Get("username"))
|
||||
err = handleForgotPassword(r, username, false)
|
||||
if err != nil {
|
||||
s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric), ipAddr)
|
||||
s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric))
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, webClientResetPwdPath, http.StatusFound)
|
||||
|
@ -1752,7 +1789,7 @@ func (s *httpdServer) handleWebClientPasswordReset(w http.ResponseWriter, r *htt
|
|||
s.renderClientNotFoundPage(w, r, errors.New("this page does not exist"))
|
||||
return
|
||||
}
|
||||
s.renderClientResetPwdPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
s.renderClientResetPwdPage(w, r, nil)
|
||||
}
|
||||
|
||||
func (s *httpdServer) handleClientViewPDF(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -1767,7 +1804,8 @@ func (s *httpdServer) handleClientViewPDF(w http.ResponseWriter, r *http.Request
|
|||
commonBasePage: getCommonBasePage(r),
|
||||
Title: path.Base(name),
|
||||
URL: fmt.Sprintf("%s?path=%s&_=%d", webClientGetPDFPath, url.QueryEscape(name), time.Now().UTC().Unix()),
|
||||
Branding: s.binding.Branding.WebClient,
|
||||
Branding: s.binding.webClientBranding(),
|
||||
Languages: s.binding.languages(),
|
||||
}
|
||||
renderClientTemplate(w, templateClientViewPDF, data)
|
||||
}
|
||||
|
@ -1855,43 +1893,46 @@ func (s *httpdServer) ensurePDF(w http.ResponseWriter, r *http.Request, name str
|
|||
|
||||
func (s *httpdServer) handleClientShareLoginGet(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
|
||||
s.renderShareLoginPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr))
|
||||
s.renderShareLoginPage(w, r, nil)
|
||||
}
|
||||
|
||||
func (s *httpdServer) handleClientShareLoginPost(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
|
||||
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if err := r.ParseForm(); err != nil {
|
||||
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr)
|
||||
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
|
||||
return
|
||||
}
|
||||
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
|
||||
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr)
|
||||
if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
|
||||
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
|
||||
return
|
||||
}
|
||||
invalidateToken(r)
|
||||
shareID := getURLParam(r, "id")
|
||||
share, err := dataprovider.ShareExists(shareID, "")
|
||||
if err != nil {
|
||||
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials), ipAddr)
|
||||
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials))
|
||||
return
|
||||
}
|
||||
match, err := share.CheckCredentials(strings.TrimSpace(r.Form.Get("share_password")))
|
||||
if !match || err != nil {
|
||||
s.renderShareLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials),
|
||||
ipAddr)
|
||||
return
|
||||
}
|
||||
c := jwtTokenClaims{
|
||||
Username: shareID,
|
||||
}
|
||||
err = c.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebShare, ipAddr)
|
||||
if err != nil {
|
||||
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nError500Message), ipAddr)
|
||||
s.renderShareLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
|
||||
return
|
||||
}
|
||||
next := path.Clean(r.URL.Query().Get("next"))
|
||||
baseShareURL := path.Join(webClientPubSharesPath, share.ShareID)
|
||||
isRedirect, redirectTo := checkShareRedirectURL(next, baseShareURL)
|
||||
c := jwtTokenClaims{
|
||||
Username: shareID,
|
||||
}
|
||||
if isRedirect {
|
||||
c.Ref = next
|
||||
}
|
||||
err = c.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebShare, ipAddr)
|
||||
if err != nil {
|
||||
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nError500Message))
|
||||
return
|
||||
}
|
||||
if isRedirect {
|
||||
http.Redirect(w, r, redirectTo, http.StatusFound)
|
||||
return
|
||||
|
@ -1899,6 +1940,22 @@ func (s *httpdServer) handleClientShareLoginPost(w http.ResponseWriter, r *http.
|
|||
s.renderClientMessagePage(w, r, util.I18nSharedFilesTitle, http.StatusOK, nil, util.I18nShareLoginOK)
|
||||
}
|
||||
|
||||
func (s *httpdServer) handleClientShareLogout(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
|
||||
|
||||
shareID := getURLParam(r, "id")
|
||||
ctx, claims, err := s.getShareClaims(r, shareID)
|
||||
if err != nil {
|
||||
s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, http.StatusForbidden,
|
||||
util.NewI18nError(err, util.I18nErrorInvalidToken), "")
|
||||
return
|
||||
}
|
||||
removeCookie(w, r.WithContext(ctx), webBaseClientPath)
|
||||
|
||||
redirectURL := path.Join(webClientPubSharesPath, shareID, fmt.Sprintf("login?next=%s", url.QueryEscape(claims.Ref)))
|
||||
http.Redirect(w, r, redirectURL, http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *httpdServer) handleClientSharedFile(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead}
|
||||
|
@ -1910,7 +1967,7 @@ func (s *httpdServer) handleClientSharedFile(w http.ResponseWriter, r *http.Requ
|
|||
if r.URL.RawQuery != "" {
|
||||
query = "?" + r.URL.RawQuery
|
||||
}
|
||||
s.renderShareDownloadPage(w, r, path.Join(webClientPubSharesPath, share.ShareID)+query)
|
||||
s.renderShareDownloadPage(w, r, &share, path.Join(webClientPubSharesPath, share.ShareID)+query)
|
||||
}
|
||||
|
||||
func (s *httpdServer) handleClientCheckExist(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -1985,7 +2042,7 @@ func doCheckExist(w http.ResponseWriter, r *http.Request, connection *Connection
|
|||
}
|
||||
existing := make([]map[string]any, 0)
|
||||
for _, info := range contents {
|
||||
if util.Contains(filesList.Files, info.Name()) {
|
||||
if slices.Contains(filesList.Files, info.Name()) {
|
||||
res := make(map[string]any)
|
||||
res["name"] = info.Name()
|
||||
if info.IsDir() {
|
||||
|
|
|
@ -37,7 +37,6 @@ import (
|
|||
"github.com/drakkan/sftpgo/v2/internal/httpclient"
|
||||
"github.com/drakkan/sftpgo/v2/internal/httpd"
|
||||
"github.com/drakkan/sftpgo/v2/internal/kms"
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
"github.com/drakkan/sftpgo/v2/internal/version"
|
||||
"github.com/drakkan/sftpgo/v2/internal/vfs"
|
||||
)
|
||||
|
@ -1680,7 +1679,7 @@ func checkEventConditionOptions(expected, actual dataprovider.ConditionOptions)
|
|||
return errors.New("condition protocols mismatch")
|
||||
}
|
||||
for _, v := range expected.Protocols {
|
||||
if !util.Contains(actual.Protocols, v) {
|
||||
if !slices.Contains(actual.Protocols, v) {
|
||||
return errors.New("condition protocols content mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -1696,7 +1695,7 @@ func checkEventConditionOptions(expected, actual dataprovider.ConditionOptions)
|
|||
return errors.New("condition provider objects mismatch")
|
||||
}
|
||||
for _, v := range expected.ProviderObjects {
|
||||
if !util.Contains(actual.ProviderObjects, v) {
|
||||
if !slices.Contains(actual.ProviderObjects, v) {
|
||||
return errors.New("condition provider objects content mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -1714,7 +1713,7 @@ func checkEventConditions(expected, actual dataprovider.EventConditions) error {
|
|||
return errors.New("fs events mismatch")
|
||||
}
|
||||
for _, v := range expected.FsEvents {
|
||||
if !util.Contains(actual.FsEvents, v) {
|
||||
if !slices.Contains(actual.FsEvents, v) {
|
||||
return errors.New("fs events content mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -1722,7 +1721,7 @@ func checkEventConditions(expected, actual dataprovider.EventConditions) error {
|
|||
return errors.New("provider events mismatch")
|
||||
}
|
||||
for _, v := range expected.ProviderEvents {
|
||||
if !util.Contains(actual.ProviderEvents, v) {
|
||||
if !slices.Contains(actual.ProviderEvents, v) {
|
||||
return errors.New("provider events content mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -1957,7 +1956,7 @@ func checkAdmin(expected, actual *dataprovider.Admin) error {
|
|||
return errors.New("permissions mismatch")
|
||||
}
|
||||
for _, p := range expected.Permissions {
|
||||
if !util.Contains(actual.Permissions, p) {
|
||||
if !slices.Contains(actual.Permissions, p) {
|
||||
return errors.New("permissions content mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -1975,7 +1974,7 @@ func compareAdminFilters(expected, actual dataprovider.AdminFilters) error {
|
|||
return errors.New("allow list mismatch")
|
||||
}
|
||||
for _, v := range expected.AllowList {
|
||||
if !util.Contains(actual.AllowList, v) {
|
||||
if !slices.Contains(actual.AllowList, v) {
|
||||
return errors.New("allow list content mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -2038,6 +2037,9 @@ func checkUser(expected *dataprovider.User, actual *dataprovider.User) error {
|
|||
if expected.Email != actual.Email {
|
||||
return errors.New("email mismatch")
|
||||
}
|
||||
if !slices.Equal(expected.Filters.AdditionalEmails, actual.Filters.AdditionalEmails) {
|
||||
return errors.New("additional emails mismatch")
|
||||
}
|
||||
if expected.Filters.RequirePasswordChange != actual.Filters.RequirePasswordChange {
|
||||
return errors.New("require_password_change mismatch")
|
||||
}
|
||||
|
@ -2066,7 +2068,7 @@ func compareUserPermissions(expected map[string][]string, actual map[string][]st
|
|||
for dir, perms := range expected {
|
||||
if actualPerms, ok := actual[dir]; ok {
|
||||
for _, v := range actualPerms {
|
||||
if !util.Contains(perms, v) {
|
||||
if !slices.Contains(perms, v) {
|
||||
return errors.New("permissions contents mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -2197,6 +2199,9 @@ func compareS3Config(expected *vfs.Filesystem, actual *vfs.Filesystem) error { /
|
|||
if err := checkEncryptedSecret(expected.S3Config.AccessSecret, actual.S3Config.AccessSecret); err != nil {
|
||||
return fmt.Errorf("fs S3 access secret mismatch: %v", err)
|
||||
}
|
||||
if err := checkEncryptedSecret(expected.S3Config.SSECustomerKey, actual.S3Config.SSECustomerKey); err != nil {
|
||||
return fmt.Errorf("fs S3 SSE customer key mismatch: %v", err)
|
||||
}
|
||||
if expected.S3Config.Endpoint != actual.S3Config.Endpoint {
|
||||
return errors.New("fs S3 endpoint mismatch")
|
||||
}
|
||||
|
@ -2319,7 +2324,7 @@ func compareSFTPFsConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error
|
|||
return errors.New("SFTPFs fingerprints mismatch")
|
||||
}
|
||||
for _, value := range actual.SFTPConfig.Fingerprints {
|
||||
if !util.Contains(expected.SFTPConfig.Fingerprints, value) {
|
||||
if !slices.Contains(expected.SFTPConfig.Fingerprints, value) {
|
||||
return errors.New("SFTPFs fingerprints mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -2410,27 +2415,27 @@ func checkEncryptedSecret(expected, actual *kms.Secret) error {
|
|||
|
||||
func compareUserFilterSubStructs(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error {
|
||||
for _, IPMask := range expected.AllowedIP {
|
||||
if !util.Contains(actual.AllowedIP, IPMask) {
|
||||
if !slices.Contains(actual.AllowedIP, IPMask) {
|
||||
return errors.New("allowed IP contents mismatch")
|
||||
}
|
||||
}
|
||||
for _, IPMask := range expected.DeniedIP {
|
||||
if !util.Contains(actual.DeniedIP, IPMask) {
|
||||
if !slices.Contains(actual.DeniedIP, IPMask) {
|
||||
return errors.New("denied IP contents mismatch")
|
||||
}
|
||||
}
|
||||
for _, method := range expected.DeniedLoginMethods {
|
||||
if !util.Contains(actual.DeniedLoginMethods, method) {
|
||||
if !slices.Contains(actual.DeniedLoginMethods, method) {
|
||||
return errors.New("denied login methods contents mismatch")
|
||||
}
|
||||
}
|
||||
for _, protocol := range expected.DeniedProtocols {
|
||||
if !util.Contains(actual.DeniedProtocols, protocol) {
|
||||
if !slices.Contains(actual.DeniedProtocols, protocol) {
|
||||
return errors.New("denied protocols contents mismatch")
|
||||
}
|
||||
}
|
||||
for _, options := range expected.WebClient {
|
||||
if !util.Contains(actual.WebClient, options) {
|
||||
if !slices.Contains(actual.WebClient, options) {
|
||||
return errors.New("web client options contents mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -2439,7 +2444,7 @@ func compareUserFilterSubStructs(expected sdk.BaseUserFilters, actual sdk.BaseUs
|
|||
return errors.New("TLS certs mismatch")
|
||||
}
|
||||
for _, cert := range expected.TLSCerts {
|
||||
if !util.Contains(actual.TLSCerts, cert) {
|
||||
if !slices.Contains(actual.TLSCerts, cert) {
|
||||
return errors.New("TLS certs content mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -2536,7 +2541,7 @@ func checkFilterMatch(expected []string, actual []string) bool {
|
|||
return false
|
||||
}
|
||||
for _, e := range expected {
|
||||
if !util.Contains(actual, strings.ToLower(e)) {
|
||||
if !slices.Contains(actual, strings.ToLower(e)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
@ -2579,7 +2584,7 @@ func compareUserBandwidthLimitFilters(expected sdk.BaseUserFilters, actual sdk.B
|
|||
return errors.New("bandwidth filters sources mismatch")
|
||||
}
|
||||
for _, source := range actual.BandwidthLimits[idx].Sources {
|
||||
if !util.Contains(l.Sources, source) {
|
||||
if !slices.Contains(l.Sources, source) {
|
||||
return errors.New("bandwidth filters source mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -2610,9 +2615,28 @@ func compareUserFilePatternsFilters(expected sdk.BaseUserFilters, actual sdk.Bas
|
|||
return nil
|
||||
}
|
||||
|
||||
func compareRenameConfigs(expected, actual []dataprovider.RenameConfig) error {
|
||||
if len(expected) != len(actual) {
|
||||
return errors.New("rename configs mismatch")
|
||||
}
|
||||
for _, ex := range expected {
|
||||
found := false
|
||||
for _, ac := range actual {
|
||||
if ac.Key == ex.Key && ac.Value == ex.Value && ac.UpdateModTime == ex.UpdateModTime {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return errors.New("rename configs mismatch")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func compareKeyValues(expected, actual []dataprovider.KeyValue) error {
|
||||
if len(expected) != len(actual) {
|
||||
return errors.New("kay values mismatch")
|
||||
return errors.New("key values mismatch")
|
||||
}
|
||||
for _, ex := range expected {
|
||||
found := false
|
||||
|
@ -2623,7 +2647,7 @@ func compareKeyValues(expected, actual []dataprovider.KeyValue) error {
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return errors.New("kay values mismatch")
|
||||
return errors.New("key values mismatch")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
@ -2689,7 +2713,7 @@ func compareEventActionEmailConfigFields(expected, actual dataprovider.EventActi
|
|||
return errors.New("email recipients mismatch")
|
||||
}
|
||||
for _, v := range expected.Recipients {
|
||||
if !util.Contains(actual.Recipients, v) {
|
||||
if !slices.Contains(actual.Recipients, v) {
|
||||
return errors.New("email recipients content mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -2697,7 +2721,7 @@ func compareEventActionEmailConfigFields(expected, actual dataprovider.EventActi
|
|||
return errors.New("email bcc mismatch")
|
||||
}
|
||||
for _, v := range expected.Bcc {
|
||||
if !util.Contains(actual.Bcc, v) {
|
||||
if !slices.Contains(actual.Bcc, v) {
|
||||
return errors.New("email bcc content mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -2714,7 +2738,7 @@ func compareEventActionEmailConfigFields(expected, actual dataprovider.EventActi
|
|||
return errors.New("email attachments mismatch")
|
||||
}
|
||||
for _, v := range expected.Attachments {
|
||||
if !util.Contains(actual.Attachments, v) {
|
||||
if !slices.Contains(actual.Attachments, v) {
|
||||
return errors.New("email attachments content mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -2729,7 +2753,7 @@ func compareEventActionFsCompressFields(expected, actual dataprovider.EventActio
|
|||
return errors.New("fs compress paths mismatch")
|
||||
}
|
||||
for _, v := range expected.Paths {
|
||||
if !util.Contains(actual.Paths, v) {
|
||||
if !slices.Contains(actual.Paths, v) {
|
||||
return errors.New("fs compress paths content mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -2740,7 +2764,7 @@ func compareEventActionFsConfigFields(expected, actual dataprovider.EventActionF
|
|||
if expected.Type != actual.Type {
|
||||
return errors.New("fs type mismatch")
|
||||
}
|
||||
if err := compareKeyValues(expected.Renames, actual.Renames); err != nil {
|
||||
if err := compareRenameConfigs(expected.Renames, actual.Renames); err != nil {
|
||||
return errors.New("fs renames mismatch")
|
||||
}
|
||||
if err := compareKeyValues(expected.Copy, actual.Copy); err != nil {
|
||||
|
@ -2750,7 +2774,7 @@ func compareEventActionFsConfigFields(expected, actual dataprovider.EventActionF
|
|||
return errors.New("fs deletes mismatch")
|
||||
}
|
||||
for _, v := range expected.Deletes {
|
||||
if !util.Contains(actual.Deletes, v) {
|
||||
if !slices.Contains(actual.Deletes, v) {
|
||||
return errors.New("fs deletes content mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -2758,7 +2782,7 @@ func compareEventActionFsConfigFields(expected, actual dataprovider.EventActionF
|
|||
return errors.New("fs mkdirs mismatch")
|
||||
}
|
||||
for _, v := range expected.MkDirs {
|
||||
if !util.Contains(actual.MkDirs, v) {
|
||||
if !slices.Contains(actual.MkDirs, v) {
|
||||
return errors.New("fs mkdir content mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -2766,7 +2790,7 @@ func compareEventActionFsConfigFields(expected, actual dataprovider.EventActionF
|
|||
return errors.New("fs exist mismatch")
|
||||
}
|
||||
for _, v := range expected.Exist {
|
||||
if !util.Contains(actual.Exist, v) {
|
||||
if !slices.Contains(actual.Exist, v) {
|
||||
return errors.New("fs exist content mismatch")
|
||||
}
|
||||
}
|
||||
|
@ -2797,7 +2821,7 @@ func compareEventActionCmdConfigFields(expected, actual dataprovider.EventAction
|
|||
return errors.New("cmd args mismatch")
|
||||
}
|
||||
for _, v := range expected.Args {
|
||||
if !util.Contains(actual.Args, v) {
|
||||
if !slices.Contains(actual.Args, v) {
|
||||
return errors.New("cmd args content mismatch")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -74,7 +74,7 @@ var (
|
|||
ErrInvalidSecret = errors.New("invalid secret")
|
||||
validSecretStatuses = []string{sdkkms.SecretStatusPlain, sdkkms.SecretStatusAES256GCM, sdkkms.SecretStatusSecretBox,
|
||||
sdkkms.SecretStatusVaultTransit, sdkkms.SecretStatusAWS, sdkkms.SecretStatusGCP, sdkkms.SecretStatusAzureKeyVault,
|
||||
"OracleKeyVault", sdkkms.SecretStatusRedacted}
|
||||
sdkkms.SecretStatusOracleKeyVault, sdkkms.SecretStatusRedacted}
|
||||
config Configuration
|
||||
secretProviders = make(map[string]registeredSecretProvider)
|
||||
)
|
||||
|
|
|
@ -268,6 +268,26 @@ func ConnectionFailedLog(user, ip, loginType, protocol, errorString string) {
|
|||
Send()
|
||||
}
|
||||
|
||||
// LoginLog logs successful logins.
|
||||
func LoginLog(user, ip, loginMethod, protocol, connectionID, clientVersion string, encrypted bool, info string) {
|
||||
ev := logger.Info()
|
||||
ev.Timestamp().
|
||||
Str("sender", "login").
|
||||
Str("ip", ip).
|
||||
Str("username", user).
|
||||
Str("method", loginMethod).
|
||||
Str("protocol", protocol)
|
||||
if connectionID != "" {
|
||||
ev.Str("connection_id", connectionID)
|
||||
}
|
||||
ev.Str("client", clientVersion).
|
||||
Bool("encrypted", encrypted)
|
||||
if info != "" {
|
||||
ev.Str("info", info)
|
||||
}
|
||||
ev.Send()
|
||||
}
|
||||
|
||||
func isLogFilePathValid(logFilePath string) bool {
|
||||
cleanInput := filepath.Clean(logFilePath)
|
||||
if cleanInput == "." || cleanInput == ".." {
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -50,17 +51,21 @@ func NewStructuredLogger(logger *zerolog.Logger) func(next http.Handler) http.Ha
|
|||
// NewLogEntry creates a new log entry for an HTTP request
|
||||
func (l *StructuredLogger) NewLogEntry(r *http.Request) middleware.LogEntry {
|
||||
scheme := "http"
|
||||
cipherSuite := ""
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
cipherSuite = tls.CipherSuiteName(r.TLS.CipherSuite)
|
||||
}
|
||||
|
||||
fields := map[string]any{
|
||||
"local_addr": getLocalAddress(r),
|
||||
"remote_addr": r.RemoteAddr,
|
||||
"proto": r.Proto,
|
||||
"method": r.Method,
|
||||
"user_agent": r.UserAgent(),
|
||||
"uri": fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI)}
|
||||
"local_addr": getLocalAddress(r),
|
||||
"remote_addr": r.RemoteAddr,
|
||||
"proto": r.Proto,
|
||||
"method": r.Method,
|
||||
"user_agent": r.UserAgent(),
|
||||
"uri": fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI),
|
||||
"cipher_suite": cipherSuite,
|
||||
}
|
||||
|
||||
reqID := middleware.GetReqID(r.Context())
|
||||
if reqID != "" {
|
||||
|
|
|
@ -17,6 +17,7 @@ package plugin
|
|||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-plugin"
|
||||
|
@ -25,14 +26,13 @@ import (
|
|||
|
||||
"github.com/drakkan/sftpgo/v2/internal/kms"
|
||||
"github.com/drakkan/sftpgo/v2/internal/logger"
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
)
|
||||
|
||||
var (
|
||||
validKMSSchemes = []string{sdkkms.SchemeAWS, sdkkms.SchemeGCP, sdkkms.SchemeVaultTransit,
|
||||
sdkkms.SchemeAzureKeyVault, "ocikeyvault"}
|
||||
sdkkms.SchemeAzureKeyVault, sdkkms.SchemeOracleKeyVault}
|
||||
validKMSEncryptedStatuses = []string{sdkkms.SecretStatusVaultTransit, sdkkms.SecretStatusAWS, sdkkms.SecretStatusGCP,
|
||||
sdkkms.SecretStatusAzureKeyVault, "OracleKeyVault"}
|
||||
sdkkms.SecretStatusAzureKeyVault, sdkkms.SecretStatusOracleKeyVault}
|
||||
)
|
||||
|
||||
// KMSConfig defines configuration parameters for kms plugins
|
||||
|
@ -42,10 +42,10 @@ type KMSConfig struct {
|
|||
}
|
||||
|
||||
func (c *KMSConfig) validate() error {
|
||||
if !util.Contains(validKMSSchemes, c.Scheme) {
|
||||
if !slices.Contains(validKMSSchemes, c.Scheme) {
|
||||
return fmt.Errorf("invalid kms scheme: %v", c.Scheme)
|
||||
}
|
||||
if !util.Contains(validKMSEncryptedStatuses, c.EncryptedStatus) {
|
||||
if !slices.Contains(validKMSEncryptedStatuses, c.EncryptedStatus) {
|
||||
return fmt.Errorf("invalid kms encrypted status: %v", c.EncryptedStatus)
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -16,6 +16,7 @@ package plugin
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -24,7 +25,6 @@ import (
|
|||
"github.com/sftpgo/sdk/plugin/notifier"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/logger"
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
)
|
||||
|
||||
// NotifierConfig defines configuration parameters for notifiers plugins
|
||||
|
@ -50,97 +50,19 @@ func (c *NotifierConfig) hasActions() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
type eventsQueue struct {
|
||||
sync.RWMutex
|
||||
type notifierPlugin struct {
|
||||
config Config
|
||||
notifier notifier.Notifier
|
||||
client *plugin.Client
|
||||
mu sync.RWMutex
|
||||
fsEvents []*notifier.FsEvent
|
||||
providerEvents []*notifier.ProviderEvent
|
||||
logEvents []*notifier.LogEvent
|
||||
}
|
||||
|
||||
func (q *eventsQueue) addFsEvent(event *notifier.FsEvent) {
|
||||
q.Lock()
|
||||
defer q.Unlock()
|
||||
|
||||
q.fsEvents = append(q.fsEvents, event)
|
||||
}
|
||||
|
||||
func (q *eventsQueue) addProviderEvent(event *notifier.ProviderEvent) {
|
||||
q.Lock()
|
||||
defer q.Unlock()
|
||||
|
||||
q.providerEvents = append(q.providerEvents, event)
|
||||
}
|
||||
|
||||
func (q *eventsQueue) addLogEvent(event *notifier.LogEvent) {
|
||||
q.Lock()
|
||||
defer q.Unlock()
|
||||
|
||||
q.logEvents = append(q.logEvents, event)
|
||||
}
|
||||
|
||||
func (q *eventsQueue) popFsEvent() *notifier.FsEvent {
|
||||
q.Lock()
|
||||
defer q.Unlock()
|
||||
|
||||
if len(q.fsEvents) == 0 {
|
||||
return nil
|
||||
}
|
||||
truncLen := len(q.fsEvents) - 1
|
||||
ev := q.fsEvents[truncLen]
|
||||
q.fsEvents[truncLen] = nil
|
||||
q.fsEvents = q.fsEvents[:truncLen]
|
||||
|
||||
return ev
|
||||
}
|
||||
|
||||
func (q *eventsQueue) popProviderEvent() *notifier.ProviderEvent {
|
||||
q.Lock()
|
||||
defer q.Unlock()
|
||||
|
||||
if len(q.providerEvents) == 0 {
|
||||
return nil
|
||||
}
|
||||
truncLen := len(q.providerEvents) - 1
|
||||
ev := q.providerEvents[truncLen]
|
||||
q.providerEvents[truncLen] = nil
|
||||
q.providerEvents = q.providerEvents[:truncLen]
|
||||
|
||||
return ev
|
||||
}
|
||||
|
||||
func (q *eventsQueue) popLogEvent() *notifier.LogEvent {
|
||||
q.Lock()
|
||||
defer q.Unlock()
|
||||
|
||||
if len(q.logEvents) == 0 {
|
||||
return nil
|
||||
}
|
||||
truncLen := len(q.logEvents) - 1
|
||||
ev := q.logEvents[truncLen]
|
||||
q.logEvents[truncLen] = nil
|
||||
q.logEvents = q.logEvents[:truncLen]
|
||||
|
||||
return ev
|
||||
}
|
||||
|
||||
func (q *eventsQueue) getSize() int {
|
||||
q.RLock()
|
||||
defer q.RUnlock()
|
||||
|
||||
return len(q.providerEvents) + len(q.fsEvents) + len(q.logEvents)
|
||||
}
|
||||
|
||||
type notifierPlugin struct {
|
||||
config Config
|
||||
notifier notifier.Notifier
|
||||
client *plugin.Client
|
||||
queue *eventsQueue
|
||||
}
|
||||
|
||||
func newNotifierPlugin(config Config) (*notifierPlugin, error) {
|
||||
p := ¬ifierPlugin{
|
||||
config: config,
|
||||
queue: &eventsQueue{},
|
||||
}
|
||||
if err := p.initialize(); err != nil {
|
||||
logger.Warn(logSender, "", "unable to create notifier plugin: %v, config %+v", err, config)
|
||||
|
@ -180,7 +102,7 @@ func (p *notifierPlugin) initialize() error {
|
|||
Managed: false,
|
||||
Logger: &logger.HCLogAdapter{
|
||||
Logger: hclog.New(&hclog.LoggerOptions{
|
||||
Name: fmt.Sprintf("%v.%v", logSender, notifier.PluginName),
|
||||
Name: fmt.Sprintf("%s.%s", logSender, notifier.PluginName),
|
||||
Level: pluginsLogLevel,
|
||||
DisableTime: true,
|
||||
}),
|
||||
|
@ -204,6 +126,34 @@ func (p *notifierPlugin) initialize() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *notifierPlugin) queueSize() int {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
return len(p.providerEvents) + len(p.fsEvents) + len(p.logEvents)
|
||||
}
|
||||
|
||||
func (p *notifierPlugin) queueFsEvent(ev *notifier.FsEvent) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.fsEvents = append(p.fsEvents, ev)
|
||||
}
|
||||
|
||||
func (p *notifierPlugin) queueProviderEvent(ev *notifier.ProviderEvent) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.providerEvents = append(p.providerEvents, ev)
|
||||
}
|
||||
|
||||
func (p *notifierPlugin) queueLogEvent(ev *notifier.LogEvent) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.logEvents = append(p.logEvents, ev)
|
||||
}
|
||||
|
||||
func (p *notifierPlugin) canQueueEvent(timestamp int64) bool {
|
||||
if p.config.NotifierOptions.RetryMaxTime == 0 {
|
||||
return false
|
||||
|
@ -214,107 +164,105 @@ func (p *notifierPlugin) canQueueEvent(timestamp int64) bool {
|
|||
return false
|
||||
}
|
||||
if p.config.NotifierOptions.RetryQueueMaxSize > 0 {
|
||||
return p.queue.getSize() < p.config.NotifierOptions.RetryQueueMaxSize
|
||||
return p.queueSize() < p.config.NotifierOptions.RetryQueueMaxSize
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *notifierPlugin) notifyFsAction(event *notifier.FsEvent) {
|
||||
if !util.Contains(p.config.NotifierOptions.FsEvents, event.Action) {
|
||||
if !slices.Contains(p.config.NotifierOptions.FsEvents, event.Action) {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
Handler.addTask()
|
||||
defer Handler.removeTask()
|
||||
|
||||
p.sendFsEvent(event)
|
||||
}()
|
||||
p.sendFsEvent(event)
|
||||
}
|
||||
|
||||
func (p *notifierPlugin) notifyProviderAction(event *notifier.ProviderEvent, object Renderer) {
|
||||
if !util.Contains(p.config.NotifierOptions.ProviderEvents, event.Action) ||
|
||||
!util.Contains(p.config.NotifierOptions.ProviderObjects, event.ObjectType) {
|
||||
if !slices.Contains(p.config.NotifierOptions.ProviderEvents, event.Action) ||
|
||||
!slices.Contains(p.config.NotifierOptions.ProviderObjects, event.ObjectType) {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
Handler.addTask()
|
||||
defer Handler.removeTask()
|
||||
|
||||
objectAsJSON, err := object.RenderAsJSON(event.Action != "delete")
|
||||
if err != nil {
|
||||
logger.Warn(logSender, "", "unable to render user as json for action %v: %v", event.Action, err)
|
||||
return
|
||||
}
|
||||
event.ObjectData = objectAsJSON
|
||||
p.sendProviderEvent(event)
|
||||
}()
|
||||
p.sendProviderEvent(event, object)
|
||||
}
|
||||
|
||||
func (p *notifierPlugin) notifyLogEvent(event *notifier.LogEvent) {
|
||||
go func() {
|
||||
p.sendLogEvent(event)
|
||||
}
|
||||
|
||||
func (p *notifierPlugin) sendFsEvent(ev *notifier.FsEvent) {
|
||||
go func(event *notifier.FsEvent) {
|
||||
Handler.addTask()
|
||||
defer Handler.removeTask()
|
||||
|
||||
p.sendLogEvent(event)
|
||||
}()
|
||||
if err := p.notifier.NotifyFsEvent(event); err != nil {
|
||||
logger.Warn(logSender, "", "unable to send fs action notification to plugin %v: %v", p.config.Cmd, err)
|
||||
if p.canQueueEvent(event.Timestamp) {
|
||||
p.queueFsEvent(event)
|
||||
}
|
||||
}
|
||||
}(ev)
|
||||
}
|
||||
|
||||
func (p *notifierPlugin) sendFsEvent(event *notifier.FsEvent) {
|
||||
if err := p.notifier.NotifyFsEvent(event); err != nil {
|
||||
logger.Warn(logSender, "", "unable to send fs action notification to plugin %v: %v", p.config.Cmd, err)
|
||||
if p.canQueueEvent(event.Timestamp) {
|
||||
p.queue.addFsEvent(event)
|
||||
func (p *notifierPlugin) sendProviderEvent(ev *notifier.ProviderEvent, object Renderer) {
|
||||
go func(event *notifier.ProviderEvent) {
|
||||
Handler.addTask()
|
||||
defer Handler.removeTask()
|
||||
|
||||
if object != nil {
|
||||
objectAsJSON, err := object.RenderAsJSON(event.Action != "delete")
|
||||
if err != nil {
|
||||
logger.Error(logSender, "", "unable to render user as json for action %q: %v", event.Action, err)
|
||||
} else {
|
||||
event.ObjectData = objectAsJSON
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.notifier.NotifyProviderEvent(event); err != nil {
|
||||
logger.Warn(logSender, "", "unable to send user action notification to plugin %v: %v", p.config.Cmd, err)
|
||||
if p.canQueueEvent(event.Timestamp) {
|
||||
p.queueProviderEvent(event)
|
||||
}
|
||||
}
|
||||
}(ev)
|
||||
}
|
||||
|
||||
func (p *notifierPlugin) sendProviderEvent(event *notifier.ProviderEvent) {
|
||||
if err := p.notifier.NotifyProviderEvent(event); err != nil {
|
||||
logger.Warn(logSender, "", "unable to send user action notification to plugin %v: %v", p.config.Cmd, err)
|
||||
if p.canQueueEvent(event.Timestamp) {
|
||||
p.queue.addProviderEvent(event)
|
||||
}
|
||||
}
|
||||
}
|
||||
func (p *notifierPlugin) sendLogEvent(ev *notifier.LogEvent) {
|
||||
go func(event *notifier.LogEvent) {
|
||||
Handler.addTask()
|
||||
defer Handler.removeTask()
|
||||
|
||||
func (p *notifierPlugin) sendLogEvent(event *notifier.LogEvent) {
|
||||
if err := p.notifier.NotifyLogEvent(event); err != nil {
|
||||
logger.Warn(logSender, "", "unable to send log event to plugin %v: %v", p.config.Cmd, err)
|
||||
if p.canQueueEvent(event.Timestamp) {
|
||||
p.queue.addLogEvent(event)
|
||||
if err := p.notifier.NotifyLogEvent(event); err != nil {
|
||||
logger.Warn(logSender, "", "unable to send log event to plugin %v: %v", p.config.Cmd, err)
|
||||
if p.canQueueEvent(event.Timestamp) {
|
||||
p.queueLogEvent(event)
|
||||
}
|
||||
}
|
||||
}
|
||||
}(ev)
|
||||
}
|
||||
|
||||
func (p *notifierPlugin) sendQueuedEvents() {
|
||||
queueSize := p.queue.getSize()
|
||||
queueSize := p.queueSize()
|
||||
if queueSize == 0 {
|
||||
return
|
||||
}
|
||||
logger.Debug(logSender, "", "check queued events for notifier %q, events size: %v", p.config.Cmd, queueSize)
|
||||
fsEv := p.queue.popFsEvent()
|
||||
for fsEv != nil {
|
||||
go func(ev *notifier.FsEvent) {
|
||||
p.sendFsEvent(ev)
|
||||
}(fsEv)
|
||||
fsEv = p.queue.popFsEvent()
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
providerEv := p.queue.popProviderEvent()
|
||||
for providerEv != nil {
|
||||
go func(ev *notifier.ProviderEvent) {
|
||||
p.sendProviderEvent(ev)
|
||||
}(providerEv)
|
||||
providerEv = p.queue.popProviderEvent()
|
||||
logger.Debug(logSender, "", "send queued events for notifier %q, events size: %v", p.config.Cmd, queueSize)
|
||||
|
||||
for _, ev := range p.fsEvents {
|
||||
p.sendFsEvent(ev)
|
||||
}
|
||||
logEv := p.queue.popLogEvent()
|
||||
for logEv != nil {
|
||||
go func(ev *notifier.LogEvent) {
|
||||
p.sendLogEvent(ev)
|
||||
}(logEv)
|
||||
logEv = p.queue.popLogEvent()
|
||||
p.fsEvents = nil
|
||||
|
||||
for _, ev := range p.providerEvents {
|
||||
p.sendProviderEvent(ev, nil)
|
||||
}
|
||||
logger.Debug(logSender, "", "queued events sent for notifier %q, new events size: %v", p.config.Cmd, p.queue.getSize())
|
||||
p.providerEvents = nil
|
||||
|
||||
for _, ev := range p.logEvents {
|
||||
p.sendLogEvent(ev)
|
||||
}
|
||||
p.logEvents = nil
|
||||
|
||||
logger.Debug(logSender, "", "%d queued events sent for notifier %q,", queueSize, p.config.Cmd)
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -336,7 +337,7 @@ func (m *Manager) NotifyLogEvent(event notifier.LogEventType, protocol, username
|
|||
var e *notifier.LogEvent
|
||||
|
||||
for _, n := range m.notifiers {
|
||||
if util.Contains(n.config.NotifierOptions.LogEvents, int(event)) {
|
||||
if slices.Contains(n.config.NotifierOptions.LogEvents, int(event)) {
|
||||
if e == nil {
|
||||
message := ""
|
||||
if err != nil {
|
||||
|
@ -641,7 +642,9 @@ func (m *Manager) restartNotifierPlugin(config Config, idx int) {
|
|||
}
|
||||
|
||||
m.notifLock.Lock()
|
||||
plugin.queue = m.notifiers[idx].queue
|
||||
plugin.fsEvents = m.notifiers[idx].fsEvents
|
||||
plugin.providerEvents = m.notifiers[idx].providerEvents
|
||||
plugin.logEvents = m.notifiers[idx].logEvents
|
||||
m.notifiers[idx] = plugin
|
||||
m.notifLock.Unlock()
|
||||
plugin.sendQueuedEvents()
|
||||
|
|
|
@ -20,6 +20,7 @@ package service
|
|||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/sftpgo/sdk"
|
||||
|
@ -211,7 +212,7 @@ func configurePortableSFTPService(port int, enabledSSHCommands []string) {
|
|||
} else {
|
||||
sftpdConf.Bindings[0].Port = 0
|
||||
}
|
||||
if util.Contains(enabledSSHCommands, "*") {
|
||||
if slices.Contains(enabledSSHCommands, "*") {
|
||||
sftpdConf.EnabledSSHCommands = sftpd.GetSupportedSSHCommands()
|
||||
} else {
|
||||
sftpdConf.EnabledSSHCommands = enabledSSHCommands
|
||||
|
|
|
@ -76,6 +76,10 @@ func (c *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
|
|||
if !c.User.HasPerm(dataprovider.PermDownload, path.Dir(request.Filepath)) {
|
||||
return nil, sftp.ErrSSHFxPermissionDenied
|
||||
}
|
||||
if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil {
|
||||
c.Log(logger.LevelInfo, "denying file read due to transfer count limits")
|
||||
return nil, c.GetPermissionDeniedError()
|
||||
}
|
||||
transferQuota := c.GetTransferQuota()
|
||||
if !transferQuota.HasDownloadSpace() {
|
||||
c.Log(logger.LevelInfo, "denying file read due to quota limits")
|
||||
|
@ -120,9 +124,14 @@ func (c *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) {
|
|||
return c.handleFilewrite(request)
|
||||
}
|
||||
|
||||
func (c *Connection) handleFilewrite(request *sftp.Request) (sftp.WriterAtReaderAt, error) {
|
||||
func (c *Connection) handleFilewrite(request *sftp.Request) (sftp.WriterAtReaderAt, error) { //nolint:gocyclo
|
||||
c.UpdateLastActivity()
|
||||
|
||||
if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil {
|
||||
c.Log(logger.LevelInfo, "denying file write due to transfer count limits")
|
||||
return nil, c.GetPermissionDeniedError()
|
||||
}
|
||||
|
||||
if ok, _ := c.User.IsFileAllowed(request.Filepath); !ok {
|
||||
c.Log(logger.LevelWarn, "writing file %q is not allowed", request.Filepath)
|
||||
return nil, c.GetPermissionDeniedError()
|
||||
|
@ -221,9 +230,9 @@ func (c *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
|
|||
}
|
||||
modTime := time.Unix(0, 0)
|
||||
if request.Filepath != "/" {
|
||||
lister.Add(vfs.NewFileInfo("..", true, 0, modTime, false))
|
||||
lister.Prepend(vfs.NewFileInfo("..", true, 0, modTime, false))
|
||||
}
|
||||
lister.Add(vfs.NewFileInfo(".", true, 0, modTime, false))
|
||||
lister.Prepend(vfs.NewFileInfo(".", true, 0, modTime, false))
|
||||
return lister, nil
|
||||
case "Stat":
|
||||
if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(request.Filepath)) {
|
||||
|
@ -457,7 +466,7 @@ func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileO
|
|||
}
|
||||
|
||||
if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() {
|
||||
_, _, err = fs.Rename(resolvedPath, filePath)
|
||||
_, _, err = fs.Rename(resolvedPath, filePath, 0)
|
||||
if err != nil {
|
||||
c.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %q, dest: %q, err: %+v",
|
||||
resolvedPath, filePath, err)
|
||||
|
@ -559,13 +568,10 @@ func (c *Connection) getStatVFSFromQuotaResult(fs vfs.Fs, name string, quotaResu
|
|||
func (c *Connection) updateQuotaAfterTruncate(requestPath string, fileSize int64) {
|
||||
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
|
||||
if err == nil {
|
||||
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
|
||||
if vfolder.IsIncludedInUserQuota() {
|
||||
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
|
||||
}
|
||||
} else {
|
||||
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
|
||||
dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false)
|
||||
return
|
||||
}
|
||||
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
|
||||
}
|
||||
|
||||
func getOSOpenFlags(requestFlags sftp.FileOpenFlags) (flags int) {
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -145,7 +146,7 @@ func (fs MockOsFs) Remove(name string, _ bool) error {
|
|||
}
|
||||
|
||||
// Rename renames (moves) source to target
|
||||
func (fs MockOsFs) Rename(source, target string) (int, int64, error) {
|
||||
func (fs MockOsFs) Rename(source, target string, _ int) (int, int64, error) {
|
||||
if fs.err != nil {
|
||||
return -1, -1, fs.err
|
||||
}
|
||||
|
@ -269,6 +270,7 @@ func TestReadWriteErrors(t *testing.T) {
|
|||
err = os.Remove(testfile)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, conn.GetTransfers(), 0)
|
||||
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
|
||||
}
|
||||
|
||||
func TestUnsupportedListOP(t *testing.T) {
|
||||
|
@ -418,7 +420,7 @@ func TestSupportedSSHCommands(t *testing.T) {
|
|||
assert.Equal(t, len(supportedSSHCommands), len(cmds))
|
||||
|
||||
for _, c := range cmds {
|
||||
assert.True(t, util.Contains(supportedSSHCommands, c))
|
||||
assert.True(t, slices.Contains(supportedSSHCommands, c))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -842,7 +844,7 @@ func TestRsyncOptions(t *testing.T) {
|
|||
}
|
||||
cmd, err := sshCmd.getSystemCommand()
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, util.Contains(cmd.cmd.Args, "--safe-links"),
|
||||
assert.True(t, slices.Contains(cmd.cmd.Args, "--safe-links"),
|
||||
"--safe-links must be added if the user has the create symlinks permission")
|
||||
|
||||
permissions["/"] = []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermCreateDirs,
|
||||
|
@ -859,7 +861,7 @@ func TestRsyncOptions(t *testing.T) {
|
|||
}
|
||||
cmd, err = sshCmd.getSystemCommand()
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, util.Contains(cmd.cmd.Args, "--munge-links"),
|
||||
assert.True(t, slices.Contains(cmd.cmd.Args, "--munge-links"),
|
||||
"--munge-links must be added if the user has the create symlinks permission")
|
||||
|
||||
sshCmd.connection.User.VirtualFolders = append(sshCmd.connection.User.VirtualFolders, vfs.VirtualFolder{
|
||||
|
@ -1013,6 +1015,8 @@ func TestSystemCommandErrors(t *testing.T) {
|
|||
transfer.MaxWriteSize = -1
|
||||
_, err = transfer.copyFromReaderToWriter(sshCmd.connection.channel, dst)
|
||||
assert.True(t, transfer.Connection.IsQuotaExceededError(err))
|
||||
err = transfer.Close()
|
||||
assert.Error(t, err)
|
||||
|
||||
baseTransfer = common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", "",
|
||||
common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{
|
||||
|
@ -1030,9 +1034,13 @@ func TestSystemCommandErrors(t *testing.T) {
|
|||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error())
|
||||
}
|
||||
err = transfer.Close()
|
||||
assert.Error(t, err)
|
||||
|
||||
err = os.RemoveAll(homeDir)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
|
||||
}
|
||||
|
||||
func TestCommandGetFsError(t *testing.T) {
|
||||
|
@ -1716,6 +1724,7 @@ func TestSCPUploadFiledata(t *testing.T) {
|
|||
if assert.Error(t, err) {
|
||||
assert.EqualError(t, err, common.ErrTransferClosed.Error())
|
||||
}
|
||||
transfer.Connection.RemoveTransfer(transfer)
|
||||
|
||||
mockSSHChannel = MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
|
@ -1727,9 +1736,12 @@ func TestSCPUploadFiledata(t *testing.T) {
|
|||
transfer.Connection.AddTransfer(transfer)
|
||||
err = scpCommand.getUploadFileData(2, transfer)
|
||||
assert.ErrorContains(t, err, os.ErrClosed.Error())
|
||||
transfer.Connection.RemoveTransfer(transfer)
|
||||
|
||||
err = os.Remove(testfile)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
|
||||
}
|
||||
|
||||
func TestUploadError(t *testing.T) {
|
||||
|
@ -2039,6 +2051,7 @@ func TestRecoverer(t *testing.T) {
|
|||
err = scpCmd.handle()
|
||||
assert.EqualError(t, err, common.ErrGenericFailure.Error())
|
||||
assert.Len(t, common.Connections.GetStats(""), 0)
|
||||
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
|
||||
}
|
||||
|
||||
func TestListernerAcceptErrors(t *testing.T) {
|
||||
|
@ -2169,6 +2182,7 @@ func TestMaxUserSessions(t *testing.T) {
|
|||
}
|
||||
common.Connections.Remove(connection.GetID())
|
||||
assert.Len(t, common.Connections.GetStats(""), 0)
|
||||
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
|
||||
}
|
||||
|
||||
func TestCanReadSymlink(t *testing.T) {
|
||||
|
|
|
@ -227,6 +227,12 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *transfer) err
|
|||
}
|
||||
|
||||
func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error {
|
||||
if err := common.Connections.IsNewTransferAllowed(c.connection.User.Username); err != nil {
|
||||
err := fmt.Errorf("denying file write due to transfer count limits")
|
||||
c.connection.Log(logger.LevelInfo, "denying file write due to transfer count limits")
|
||||
c.sendErrorMessage(nil, err)
|
||||
return err
|
||||
}
|
||||
diskQuota, transferQuota := c.connection.HasSpace(isNewFile, false, requestPath)
|
||||
if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() {
|
||||
err := fmt.Errorf("denying file write due to quota limits")
|
||||
|
@ -258,10 +264,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string,
|
|||
if vfs.HasTruncateSupport(fs) {
|
||||
vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath))
|
||||
if err == nil {
|
||||
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
|
||||
if vfolder.IsIncludedInUserQuota() {
|
||||
dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck
|
||||
}
|
||||
dataprovider.UpdateUserFolderQuota(&vfolder, &c.connection.User, 0, -fileSize, false)
|
||||
} else {
|
||||
dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck
|
||||
}
|
||||
|
@ -333,7 +336,7 @@ func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error
|
|||
}
|
||||
|
||||
if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() {
|
||||
_, _, err = fs.Rename(p, filePath)
|
||||
_, _, err = fs.Rename(p, filePath, 0)
|
||||
if err != nil {
|
||||
c.connection.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %q, dest: %q, err: %v",
|
||||
p, filePath, err)
|
||||
|
@ -504,6 +507,13 @@ func (c *scpCommand) sendDownloadFileData(fs vfs.Fs, filePath string, stat os.Fi
|
|||
|
||||
func (c *scpCommand) handleDownload(filePath string) error {
|
||||
c.connection.UpdateLastActivity()
|
||||
|
||||
if err := common.Connections.IsNewTransferAllowed(c.connection.User.Username); err != nil {
|
||||
err := fmt.Errorf("denying file read due to transfer count limits")
|
||||
c.connection.Log(logger.LevelInfo, "denying file read due to transfer count limits")
|
||||
c.sendErrorMessage(nil, err)
|
||||
return err
|
||||
}
|
||||
transferQuota := c.connection.GetTransferQuota()
|
||||
if !transferQuota.HasDownloadSpace() {
|
||||
c.connection.Log(logger.LevelInfo, "denying file read due to quota limits")
|
||||
|
|
|
@ -26,6 +26,7 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
"runtime/debug"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -263,13 +264,13 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
|
|||
func (c *Configuration) updateSupportedAuthentications() {
|
||||
serviceStatus.Authentications = util.RemoveDuplicates(serviceStatus.Authentications, false)
|
||||
|
||||
if util.Contains(serviceStatus.Authentications, dataprovider.LoginMethodPassword) &&
|
||||
util.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) {
|
||||
if slices.Contains(serviceStatus.Authentications, dataprovider.LoginMethodPassword) &&
|
||||
slices.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) {
|
||||
serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyAndPassword)
|
||||
}
|
||||
|
||||
if util.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyboardInteractive) &&
|
||||
util.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) {
|
||||
if slices.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyboardInteractive) &&
|
||||
slices.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) {
|
||||
serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyAndKeyboardInt)
|
||||
}
|
||||
}
|
||||
|
@ -422,7 +423,7 @@ func (c *Configuration) configureKeyAlgos(serverConfig *ssh.ServerConfig) error
|
|||
c.HostKeyAlgorithms = util.RemoveDuplicates(c.HostKeyAlgorithms, true)
|
||||
}
|
||||
for _, hostKeyAlgo := range c.HostKeyAlgorithms {
|
||||
if !util.Contains(supportedHostKeyAlgos, hostKeyAlgo) {
|
||||
if !slices.Contains(supportedHostKeyAlgos, hostKeyAlgo) {
|
||||
return fmt.Errorf("unsupported host key algorithm %q", hostKeyAlgo)
|
||||
}
|
||||
}
|
||||
|
@ -430,7 +431,7 @@ func (c *Configuration) configureKeyAlgos(serverConfig *ssh.ServerConfig) error
|
|||
if len(c.PublicKeyAlgorithms) > 0 {
|
||||
c.PublicKeyAlgorithms = util.RemoveDuplicates(c.PublicKeyAlgorithms, true)
|
||||
for _, algo := range c.PublicKeyAlgorithms {
|
||||
if !util.Contains(supportedPublicKeyAlgos, algo) {
|
||||
if !slices.Contains(supportedPublicKeyAlgos, algo) {
|
||||
return fmt.Errorf("unsupported public key authentication algorithm %q", algo)
|
||||
}
|
||||
}
|
||||
|
@ -472,7 +473,7 @@ func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig)
|
|||
if kex == keyExchangeCurve25519SHA256LibSSH {
|
||||
continue
|
||||
}
|
||||
if !util.Contains(supportedKexAlgos, kex) {
|
||||
if !slices.Contains(supportedKexAlgos, kex) {
|
||||
return fmt.Errorf("unsupported key-exchange algorithm %q", kex)
|
||||
}
|
||||
}
|
||||
|
@ -486,7 +487,7 @@ func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig)
|
|||
if len(c.Ciphers) > 0 {
|
||||
c.Ciphers = util.RemoveDuplicates(c.Ciphers, true)
|
||||
for _, cipher := range c.Ciphers {
|
||||
if !util.Contains(supportedCiphers, cipher) {
|
||||
if !slices.Contains(supportedCiphers, cipher) {
|
||||
return fmt.Errorf("unsupported cipher %q", cipher)
|
||||
}
|
||||
}
|
||||
|
@ -499,7 +500,7 @@ func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig)
|
|||
if len(c.MACs) > 0 {
|
||||
c.MACs = util.RemoveDuplicates(c.MACs, true)
|
||||
for _, mac := range c.MACs {
|
||||
if !util.Contains(supportedMACs, mac) {
|
||||
if !slices.Contains(supportedMACs, mac) {
|
||||
return fmt.Errorf("unsupported MAC algorithm %q", mac)
|
||||
}
|
||||
}
|
||||
|
@ -608,10 +609,10 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
|
|||
return
|
||||
}
|
||||
|
||||
logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
|
||||
"User %q logged in with %q, from ip %q, client version %q, negotiated algorithms: %+v",
|
||||
user.Username, loginType, ipAddr, util.BytesToString(sconn.ClientVersion()),
|
||||
sconn.Conn.(ssh.AlgorithmsConnMetadata).Algorithms())
|
||||
logger.LoginLog(user.Username, ipAddr, loginType, common.ProtocolSSH, connectionID,
|
||||
util.BytesToString(sconn.ClientVersion()), true,
|
||||
fmt.Sprintf("negotiated algorithms: %+v", sconn.Conn.(ssh.AlgorithmsConnMetadata).Algorithms()))
|
||||
|
||||
dataprovider.UpdateLastLogin(&user)
|
||||
|
||||
sshConnection := common.NewSSHConnection(connectionID, conn)
|
||||
|
@ -785,7 +786,7 @@ func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.
|
|||
user.Username, user.HomeDir)
|
||||
return nil, fmt.Errorf("cannot login user with invalid home dir: %q", user.HomeDir)
|
||||
}
|
||||
if util.Contains(user.Filters.DeniedProtocols, common.ProtocolSSH) {
|
||||
if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolSSH) {
|
||||
logger.Info(logSender, connectionID, "cannot login user %q, protocol SSH is not allowed", user.Username)
|
||||
return nil, fmt.Errorf("protocol SSH is not allowed for user %q", user.Username)
|
||||
}
|
||||
|
@ -830,14 +831,14 @@ func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.
|
|||
}
|
||||
|
||||
func (c *Configuration) checkSSHCommands() {
|
||||
if util.Contains(c.EnabledSSHCommands, "*") {
|
||||
if slices.Contains(c.EnabledSSHCommands, "*") {
|
||||
c.EnabledSSHCommands = GetSupportedSSHCommands()
|
||||
return
|
||||
}
|
||||
sshCommands := []string{}
|
||||
for _, command := range c.EnabledSSHCommands {
|
||||
command = strings.TrimSpace(command)
|
||||
if util.Contains(supportedSSHCommands, command) {
|
||||
if slices.Contains(supportedSSHCommands, command) {
|
||||
sshCommands = append(sshCommands, command)
|
||||
} else {
|
||||
logger.Warn(logSender, "", "unsupported ssh command: %q ignored", command)
|
||||
|
@ -927,7 +928,7 @@ func (c *Configuration) checkHostKeyAutoGeneration(configDir string) error {
|
|||
func (c *Configuration) getHostKeyAlgorithms(keyFormat string) []string {
|
||||
var algos []string
|
||||
for _, algo := range algorithmsForKeyFormat(keyFormat) {
|
||||
if util.Contains(c.HostKeyAlgorithms, algo) {
|
||||
if slices.Contains(c.HostKeyAlgorithms, algo) {
|
||||
algos = append(algos, algo)
|
||||
}
|
||||
}
|
||||
|
@ -986,7 +987,7 @@ func (c *Configuration) checkAndLoadHostKeys(configDir string, serverConfig *ssh
|
|||
var algos []string
|
||||
for _, algo := range algorithmsForKeyFormat(signer.PublicKey().Type()) {
|
||||
if underlyingAlgo, ok := certKeyAlgoNames[algo]; ok {
|
||||
if util.Contains(mas.Algorithms(), underlyingAlgo) {
|
||||
if slices.Contains(mas.Algorithms(), underlyingAlgo) {
|
||||
algos = append(algos, algo)
|
||||
}
|
||||
}
|
||||
|
@ -1098,12 +1099,12 @@ func (c *Configuration) initializeCertChecker(configDir string) error {
|
|||
|
||||
func (c *Configuration) getPartialSuccessError(nextAuthMethods []string) error {
|
||||
err := &ssh.PartialSuccessError{}
|
||||
if c.PasswordAuthentication && util.Contains(nextAuthMethods, dataprovider.LoginMethodPassword) {
|
||||
if c.PasswordAuthentication && slices.Contains(nextAuthMethods, dataprovider.LoginMethodPassword) {
|
||||
err.Next.PasswordCallback = func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
return c.validatePasswordCredentials(conn, password, dataprovider.SSHLoginMethodKeyAndPassword)
|
||||
}
|
||||
}
|
||||
if c.KeyboardInteractiveAuthentication && util.Contains(nextAuthMethods, dataprovider.SSHLoginMethodKeyboardInteractive) {
|
||||
if c.KeyboardInteractiveAuthentication && slices.Contains(nextAuthMethods, dataprovider.SSHLoginMethodKeyboardInteractive) {
|
||||
err.Next.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
||||
return c.validateKeyboardInteractiveCredentials(conn, client, dataprovider.SSHLoginMethodKeyAndKeyboardInt, true)
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -37,6 +38,7 @@ import (
|
|||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -782,6 +784,34 @@ func TestSFTPFsEscapeHomeDir(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestReadDirLongNames(t *testing.T) {
|
||||
usePubKey := true
|
||||
user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
|
||||
conn, client, err := getSftpClient(user, usePubKey)
|
||||
if assert.NoError(t, err) {
|
||||
defer conn.Close()
|
||||
defer client.Close()
|
||||
|
||||
numFiles := 1000
|
||||
for i := 0; i < 1000; i++ {
|
||||
fPath := filepath.Join(user.GetHomeDir(), hex.EncodeToString(util.GenerateRandomBytes(127)))
|
||||
err = os.WriteFile(fPath, util.GenerateRandomBytes(30), 0666)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
entries, err := client.ReadDir("/")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, entries, numFiles)
|
||||
}
|
||||
|
||||
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestGroupSettingsOverride(t *testing.T) {
|
||||
usePubKey := true
|
||||
g := getTestGroup()
|
||||
|
@ -1172,6 +1202,7 @@ func TestConcurrency(t *testing.T) {
|
|||
assert.Eventually(t, func() bool {
|
||||
return len(common.Connections.GetStats("")) == 0
|
||||
}, 1*time.Second, 50*time.Millisecond)
|
||||
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
|
||||
|
||||
err = os.Remove(testFilePath)
|
||||
assert.NoError(t, err)
|
||||
|
@ -4361,6 +4392,78 @@ func TestMaxPerHostConnections(t *testing.T) {
|
|||
common.Config.MaxPerHostConnections = oldValue
|
||||
}
|
||||
|
||||
func TestMaxTransfers(t *testing.T) {
|
||||
oldValue := common.Config.MaxPerHostConnections
|
||||
common.Config.MaxPerHostConnections = 2
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return common.Connections.GetClientConnections() == 0
|
||||
}, 1000*time.Millisecond, 50*time.Millisecond)
|
||||
|
||||
usePubKey := true
|
||||
user := getTestUser(usePubKey)
|
||||
err := dataprovider.AddUser(&user, "", "", "")
|
||||
assert.NoError(t, err)
|
||||
user.Password = ""
|
||||
conn, client, err := getSftpClient(user, usePubKey)
|
||||
if assert.NoError(t, err) {
|
||||
assert.NoError(t, checkBasicSFTP(client))
|
||||
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(65535)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
assert.NoError(t, err)
|
||||
err = sftpUploadFile(testFilePath, testFileName, testFileSize, client)
|
||||
assert.NoError(t, err)
|
||||
|
||||
f1, err := client.Create("file1")
|
||||
assert.NoError(t, err)
|
||||
f2, err := client.Create("file2")
|
||||
assert.NoError(t, err)
|
||||
_, err = f1.Write([]byte(" "))
|
||||
assert.NoError(t, err)
|
||||
_, err = f2.Write([]byte(" "))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = sftpUploadFile(testFilePath, testFileName, testFileSize, client)
|
||||
assert.ErrorContains(t, err, sftp.ErrSSHFxPermissionDenied.Error())
|
||||
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/")
|
||||
err = scpUpload(testFilePath, remoteUpPath, false, false)
|
||||
assert.Error(t, err)
|
||||
|
||||
localDownloadPath := filepath.Join(homeBasePath, testDLFileName)
|
||||
err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client)
|
||||
assert.ErrorContains(t, err, sftp.ErrSSHFxPermissionDenied.Error())
|
||||
|
||||
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName))
|
||||
err = scpDownload(localDownloadPath, remoteDownPath, false, false)
|
||||
assert.Error(t, err)
|
||||
|
||||
err = f1.Close()
|
||||
assert.NoError(t, err)
|
||||
err = f2.Close()
|
||||
assert.NoError(t, err)
|
||||
err = os.Remove(testFilePath)
|
||||
assert.NoError(t, err)
|
||||
err = os.Remove(localDownloadPath)
|
||||
assert.NoError(t, err)
|
||||
err = client.Close()
|
||||
assert.NoError(t, err)
|
||||
err = conn.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
err = dataprovider.DeleteUser(user.Username, "", "", "")
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
assert.Eventually(t, func() bool {
|
||||
return common.Connections.GetTotalTransfers() == 0
|
||||
}, 1000*time.Millisecond, 50*time.Millisecond)
|
||||
|
||||
common.Config.MaxPerHostConnections = oldValue
|
||||
}
|
||||
|
||||
func TestMaxSessions(t *testing.T) {
|
||||
usePubKey := false
|
||||
u := getTestUser(usePubKey)
|
||||
|
@ -4910,6 +5013,7 @@ func TestBandwidthAndConnections(t *testing.T) {
|
|||
assert.Eventually(t, func() bool {
|
||||
return len(common.Connections.GetStats("")) == 0
|
||||
}, 10*time.Second, 200*time.Millisecond)
|
||||
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
|
||||
err = os.Remove(testFilePath)
|
||||
assert.NoError(t, err)
|
||||
err = os.Remove(localDownloadPath)
|
||||
|
@ -5493,13 +5597,13 @@ func TestNestedVirtualFolders(t *testing.T) {
|
|||
|
||||
folderGet, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(18769), folderGet.UsedQuotaSize)
|
||||
assert.Equal(t, 1, folderGet.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), folderGet.UsedQuotaSize)
|
||||
assert.Equal(t, 0, folderGet.UsedQuotaFiles)
|
||||
|
||||
folderGet, _, err = httpdtest.GetFolderByName(folderNameNested, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(27658), folderGet.UsedQuotaSize)
|
||||
assert.Equal(t, 1, folderGet.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), folderGet.UsedQuotaSize)
|
||||
assert.Equal(t, 0, folderGet.UsedQuotaFiles)
|
||||
|
||||
files, err := client.ReadDir("/")
|
||||
if assert.NoError(t, err) {
|
||||
|
@ -6166,8 +6270,8 @@ func TestVirtualFoldersQuotaValues(t *testing.T) {
|
|||
assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize)
|
||||
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
|
@ -6286,8 +6390,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 2, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -6311,8 +6415,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 2, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
// rename a file inside vdir2, it isn't included inside user quota, so we have:
|
||||
// - vdir1/dir1/testFileName.rename
|
||||
// - vdir1/dir2/testFileName1
|
||||
|
@ -6330,8 +6434,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, 2, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 2, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
// rename a file inside vdir2 overwriting an existing, we now have:
|
||||
// - vdir1/dir1/testFileName.rename
|
||||
// - vdir1/dir2/testFileName1
|
||||
|
@ -6348,8 +6452,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 2, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
// rename a file inside vdir1 overwriting an existing, we now have:
|
||||
// - vdir1/dir1/testFileName.rename (initial testFileName1)
|
||||
// - vdir2/dir1/testFileName.rename (initial testFileName1)
|
||||
|
@ -6361,8 +6465,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -6382,8 +6486,8 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -6504,8 +6608,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize, user.UsedQuotaSize)
|
||||
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1+testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -6523,8 +6627,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize*2, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize*2, f.UsedQuotaSize)
|
||||
assert.Equal(t, 2, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1*2, f.UsedQuotaSize)
|
||||
|
@ -6541,8 +6645,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1+testFileSize, f.UsedQuotaSize)
|
||||
|
@ -6558,8 +6662,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
|
@ -6589,8 +6693,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize1*3+testFileSize*2, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1*3+testFileSize*2, f.UsedQuotaSize)
|
||||
assert.Equal(t, 5, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
|
@ -6604,8 +6708,8 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize1*2+testFileSize, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1*2+testFileSize, f.UsedQuotaSize)
|
||||
assert.Equal(t, 3, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -6726,8 +6830,8 @@ func TestQuotaRenameFromVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -6745,8 +6849,8 @@ func TestQuotaRenameFromVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize+testFileSize1+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
|
@ -6809,8 +6913,8 @@ func TestQuotaRenameFromVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize*3+testFileSize1*3, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 2, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
|
@ -6943,8 +7047,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
// rename a file from user home dir to vdir2, vdir2 is not included in user quota so we have:
|
||||
// - /vdir2/dir1/testFileName
|
||||
// - /vdir1/dir1/testFileName1
|
||||
|
@ -6983,8 +7087,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
|
@ -7000,8 +7104,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -7023,8 +7127,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -7041,8 +7145,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize*2+testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 3, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -7067,8 +7171,8 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
|
|||
assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize*2+testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 3, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize1*2+testFileSize, f.UsedQuotaSize)
|
||||
|
@ -7336,8 +7440,8 @@ func TestVFolderQuotaSize(t *testing.T) {
|
|||
|
||||
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
assert.Equal(t, 1, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize, f.UsedQuotaSize)
|
||||
|
@ -8607,8 +8711,8 @@ func TestUserAllowedLoginMethods(t *testing.T) {
|
|||
allowedMethods = user.GetAllowedLoginMethods()
|
||||
assert.Equal(t, 4, len(allowedMethods))
|
||||
|
||||
assert.True(t, util.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndKeyboardInt))
|
||||
assert.True(t, util.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndPassword))
|
||||
assert.True(t, slices.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndKeyboardInt))
|
||||
assert.True(t, slices.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndPassword))
|
||||
}
|
||||
|
||||
func TestUserPartialAuth(t *testing.T) {
|
||||
|
@ -9115,8 +9219,8 @@ func TestSSHCopy(t *testing.T) {
|
|||
assert.Equal(t, 2*testFileSize+2*testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 2, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize)
|
||||
|
@ -9194,8 +9298,8 @@ func TestSSHCopy(t *testing.T) {
|
|||
assert.Equal(t, 5*testFileSize+4*testFileSize1, user.UsedQuotaSize)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2*testFileSize+2*testFileSize1, f.UsedQuotaSize)
|
||||
assert.Equal(t, 4, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
}
|
||||
// cross folder copy
|
||||
newDir := "newdir"
|
||||
|
@ -9829,6 +9933,62 @@ func TestBasicGitCommands(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSSHCommandMaxTransfers(t *testing.T) {
|
||||
if len(gitPath) == 0 || len(sshPath) == 0 || runtime.GOOS == osWindows {
|
||||
t.Skip("git and/or ssh command not found or OS is windows, unable to execute this test")
|
||||
}
|
||||
oldValue := common.Config.MaxPerHostConnections
|
||||
common.Config.MaxPerHostConnections = 2
|
||||
|
||||
usePubKey := true
|
||||
u := getTestUser(usePubKey)
|
||||
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
|
||||
repoName := "testrepo" //nolint:goconst
|
||||
clonePath := filepath.Join(homeBasePath, repoName)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(filepath.Join(homeBasePath, repoName))
|
||||
assert.NoError(t, err)
|
||||
out, err := initGitRepo(filepath.Join(user.HomeDir, repoName))
|
||||
assert.NoError(t, err, "unexpected error, out: %v", string(out))
|
||||
conn, client, err := getSftpClient(user, usePubKey)
|
||||
if assert.NoError(t, err) {
|
||||
f1, err := client.Create("file1")
|
||||
assert.NoError(t, err)
|
||||
f2, err := client.Create("file2")
|
||||
assert.NoError(t, err)
|
||||
_, err = f1.Write([]byte(" "))
|
||||
assert.NoError(t, err)
|
||||
_, err = f2.Write([]byte(" "))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = cloneGitRepo(homeBasePath, "/"+repoName, user.Username)
|
||||
assert.Error(t, err)
|
||||
|
||||
err = f1.Close()
|
||||
assert.NoError(t, err)
|
||||
err = f2.Close()
|
||||
assert.NoError(t, err)
|
||||
err = client.Close()
|
||||
assert.NoError(t, err)
|
||||
err = conn.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(clonePath)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
|
||||
|
||||
common.Config.MaxPerHostConnections = oldValue
|
||||
}
|
||||
|
||||
func TestGitIncludedVirtualFolders(t *testing.T) {
|
||||
if len(gitPath) == 0 || len(sshPath) == 0 || runtime.GOOS == osWindows {
|
||||
t.Skip("git and/or ssh command not found or OS is windows, unable to execute this test")
|
||||
|
@ -9891,8 +10051,8 @@ func TestGitIncludedVirtualFolders(t *testing.T) {
|
|||
|
||||
folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, user.UsedQuotaFiles, folder.UsedQuotaFiles)
|
||||
assert.Equal(t, user.UsedQuotaSize, folder.UsedQuotaSize)
|
||||
assert.Equal(t, 0, folder.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), folder.UsedQuotaSize)
|
||||
|
||||
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
|
@ -10677,8 +10837,8 @@ func TestSCPVirtualFoldersQuota(t *testing.T) {
|
|||
assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize)
|
||||
f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedQuotaSize, f.UsedQuotaSize)
|
||||
assert.Equal(t, expectedQuotaFiles, f.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(0), f.UsedQuotaSize)
|
||||
assert.Equal(t, 0, f.UsedQuotaFiles)
|
||||
f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedQuotaSize, f.UsedQuotaSize)
|
||||
|
@ -11074,6 +11234,7 @@ func TestSCPErrors(t *testing.T) {
|
|||
err = cmd.Process.Kill()
|
||||
assert.NoError(t, err)
|
||||
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 2*time.Second, 100*time.Millisecond)
|
||||
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
|
||||
cmd = getScpUploadCommand(testFilePath, remoteUpPath, false, false)
|
||||
go func() {
|
||||
err := cmd.Run()
|
||||
|
@ -11086,6 +11247,7 @@ func TestSCPErrors(t *testing.T) {
|
|||
err = cmd.Process.Kill()
|
||||
assert.NoError(t, err)
|
||||
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 2*time.Second, 100*time.Millisecond)
|
||||
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
|
||||
err = os.Remove(testFilePath)
|
||||
assert.NoError(t, err)
|
||||
os.Remove(localPath)
|
||||
|
|
|
@ -27,6 +27,7 @@ import (
|
|||
"os/exec"
|
||||
"path"
|
||||
"runtime/debug"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -91,7 +92,7 @@ func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommand
|
|||
name, args, err := parseCommandPayload(msg.Command)
|
||||
connection.Log(logger.LevelDebug, "new ssh command: %q args: %v num args: %d user: %s, error: %v",
|
||||
name, args, len(args), connection.User.Username, err)
|
||||
if err == nil && util.Contains(enabledSSHCommands, name) {
|
||||
if err == nil && slices.Contains(enabledSSHCommands, name) {
|
||||
connection.command = msg.Command
|
||||
if name == scpCmdName && len(args) >= 2 {
|
||||
connection.SetProtocol(common.ProtocolSCP)
|
||||
|
@ -139,9 +140,9 @@ func (c *sshCommand) handle() (err error) {
|
|||
defer common.Connections.Remove(c.connection.GetID())
|
||||
|
||||
c.connection.UpdateLastActivity()
|
||||
if util.Contains(sshHashCommands, c.command) {
|
||||
if slices.Contains(sshHashCommands, c.command) {
|
||||
return c.handleHashCommands()
|
||||
} else if util.Contains(systemCommands, c.command) {
|
||||
} else if slices.Contains(systemCommands, c.command) {
|
||||
command, err := c.getSystemCommand()
|
||||
if err != nil {
|
||||
return c.sendErrorResponse(err)
|
||||
|
@ -192,13 +193,10 @@ func (c *sshCommand) handleSFTPGoRemove() error {
|
|||
func (c *sshCommand) updateQuota(sshDestPath string, filesNum int, filesSize int64) {
|
||||
vfolder, err := c.connection.User.GetVirtualFolderForPath(sshDestPath)
|
||||
if err == nil {
|
||||
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, filesNum, filesSize, false) //nolint:errcheck
|
||||
if vfolder.IsIncludedInUserQuota() {
|
||||
dataprovider.UpdateUserQuota(&c.connection.User, filesNum, filesSize, false) //nolint:errcheck
|
||||
}
|
||||
} else {
|
||||
dataprovider.UpdateUserQuota(&c.connection.User, filesNum, filesSize, false) //nolint:errcheck
|
||||
dataprovider.UpdateUserFolderQuota(&vfolder, &c.connection.User, filesNum, filesSize, false)
|
||||
return
|
||||
}
|
||||
dataprovider.UpdateUserQuota(&c.connection.User, filesNum, filesSize, false) //nolint:errcheck
|
||||
}
|
||||
|
||||
func (c *sshCommand) handleHashCommands() error {
|
||||
|
@ -248,11 +246,15 @@ func (c *sshCommand) handleHashCommands() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *sshCommand) executeSystemCommand(command systemCommand) error {
|
||||
func (c *sshCommand) executeSystemCommand(command systemCommand) error { //nolint:gocyclo
|
||||
sshDestPath := c.getDestPath()
|
||||
if !c.isLocalPath(sshDestPath) {
|
||||
return c.sendErrorResponse(errUnsupportedConfig)
|
||||
}
|
||||
if err := common.Connections.IsNewTransferAllowed(c.connection.User.Username); err != nil {
|
||||
err := fmt.Errorf("denying command due to transfer count limits")
|
||||
return c.sendErrorResponse(err)
|
||||
}
|
||||
diskQuota, transferQuota := c.connection.HasSpace(true, false, command.quotaCheckPath)
|
||||
if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() || !transferQuota.HasDownloadSpace() {
|
||||
return c.sendErrorResponse(common.ErrQuotaExceeded)
|
||||
|
@ -432,11 +434,11 @@ func (c *sshCommand) getSystemCommand() (systemCommand, error) {
|
|||
// If the user cannot create symlinks we add the option --munge-links, if it is not
|
||||
// already set. This should make symlinks unusable (but manually recoverable)
|
||||
if c.connection.User.HasPerm(dataprovider.PermCreateSymlinks, c.getDestPath()) {
|
||||
if !util.Contains(args, "--safe-links") {
|
||||
if !slices.Contains(args, "--safe-links") {
|
||||
args = append([]string{"--safe-links"}, args...)
|
||||
}
|
||||
} else {
|
||||
if !util.Contains(args, "--munge-links") {
|
||||
if !slices.Contains(args, "--munge-links") {
|
||||
args = append([]string{"--munge-links"}, args...)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -27,7 +28,6 @@ import (
|
|||
"golang.org/x/oauth2/microsoft"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/logger"
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
)
|
||||
|
||||
// Supported OAuth2 providers
|
||||
|
@ -56,7 +56,7 @@ type OAuth2Config struct {
|
|||
|
||||
// Validate validates and initializes the configuration
|
||||
func (c *OAuth2Config) Validate() error {
|
||||
if !util.Contains(supportedOAuth2Providers, c.Provider) {
|
||||
if !slices.Contains(supportedOAuth2Providers, c.Provider) {
|
||||
return fmt.Errorf("smtp oauth2: unsupported provider %d", c.Provider)
|
||||
}
|
||||
if c.ClientID == "" {
|
||||
|
|
|
@ -279,15 +279,15 @@ func (c *Config) Initialize(configDir string, isService bool) error {
|
|||
}
|
||||
|
||||
func (c *Config) getMailClientOptions() []mail.Option {
|
||||
options := []mail.Option{mail.WithoutNoop()}
|
||||
options := []mail.Option{mail.WithPort(c.Port), mail.WithoutNoop()}
|
||||
|
||||
switch c.Encryption {
|
||||
case 1:
|
||||
options = append(options, mail.WithSSLPort(false))
|
||||
options = append(options, mail.WithSSL())
|
||||
case 2:
|
||||
options = append(options, mail.WithTLSPortPolicy(mail.TLSMandatory))
|
||||
options = append(options, mail.WithTLSPolicy(mail.TLSMandatory))
|
||||
default:
|
||||
options = append(options, mail.WithTLSPortPolicy(mail.NoTLS))
|
||||
options = append(options, mail.WithTLSPolicy(mail.NoTLS))
|
||||
}
|
||||
if c.User != "" {
|
||||
options = append(options, mail.WithUsername(c.User))
|
||||
|
@ -317,7 +317,6 @@ func (c *Config) getMailClientOptions() []mail.Option {
|
|||
}),
|
||||
mail.WithDebugLog())
|
||||
}
|
||||
options = append(options, mail.WithPort(c.Port))
|
||||
return options
|
||||
}
|
||||
|
||||
|
@ -416,12 +415,6 @@ func SendEmail(to, bcc []string, subject, body string, contentType EmailContentT
|
|||
return config.sendEmail(to, bcc, subject, body, contentType, attachments...)
|
||||
}
|
||||
|
||||
// ReloadProviderConf reloads the configuration from the provider
|
||||
// and apply it if different from the active one
|
||||
func ReloadProviderConf() {
|
||||
loadConfigFromProvider() //nolint:errcheck
|
||||
}
|
||||
|
||||
func loadConfigFromProvider() error {
|
||||
configs, err := dataprovider.GetConfigs()
|
||||
if err != nil {
|
||||
|
|
|
@ -135,9 +135,9 @@ func (c Conf) Initialize(configDir string) error {
|
|||
}
|
||||
logger.Debug(logSender, "", "configured TLS cipher suites: %v", config.CipherSuites)
|
||||
httpServer.TLSConfig = config
|
||||
return util.HTTPListenAndServe(httpServer, c.BindAddress, c.BindPort, true, logSender)
|
||||
return util.HTTPListenAndServe(httpServer, c.BindAddress, c.BindPort, true, nil, logSender)
|
||||
}
|
||||
return util.HTTPListenAndServe(httpServer, c.BindAddress, c.BindPort, false, logSender)
|
||||
return util.HTTPListenAndServe(httpServer, c.BindAddress, c.BindPort, false, nil, logSender)
|
||||
}
|
||||
|
||||
// ReloadCertificateMgr reloads the certificate manager
|
||||
|
|
|
@ -241,9 +241,9 @@ const (
|
|||
I18nErrorRestore = "maintenance.restore_error"
|
||||
I18nErrorACMEGeneric = "acme.generic_error"
|
||||
I18nErrorSMTPRequiredFields = "smtp.err_required_fields"
|
||||
I18nErrorSMTPClientIDRequired = "smtp.client_id_required"
|
||||
I18nErrorSMTPClientSecretRequired = "smtp.client_secret_required"
|
||||
I18nErrorSMTPRefreshTokenRequired = "smtp.refresh_token_required"
|
||||
I18nErrorClientIDRequired = "oauth2.client_id_required"
|
||||
I18nErrorClientSecretRequired = "oauth2.client_secret_required"
|
||||
I18nErrorRefreshTokenRequired = "oauth2.refresh_token_required"
|
||||
I18nErrorURLRequired = "actions.http_url_required"
|
||||
I18nErrorURLInvalid = "actions.http_url_invalid"
|
||||
I18nErrorHTTPPartNameRequired = "actions.http_part_name_required"
|
||||
|
@ -304,6 +304,9 @@ const (
|
|||
I18nErrorEvSyncUnsupportedFs = "rules.sync_unsupported_fs_event"
|
||||
I18nErrorRuleFailureActionsOnly = "rules.only_failure_actions"
|
||||
I18nErrorRuleSyncActionRequired = "rules.sync_action_required"
|
||||
I18nErrorInvalidPNG = "branding.invalid_png"
|
||||
I18nErrorInvalidPNGSize = "branding.invalid_png_size"
|
||||
I18nErrorInvalidDisclaimerURL = "branding.invalid_disclaimer_url"
|
||||
)
|
||||
|
||||
// NewI18nError returns a I18nError wrappring the provided error
|
||||
|
|
|
@ -50,7 +50,7 @@ import (
|
|||
"unsafe"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lithammer/shortuuid/v3"
|
||||
"github.com/lithammer/shortuuid/v4"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/logger"
|
||||
|
@ -130,16 +130,6 @@ var bytesSizeTable = map[string]uint64{
|
|||
"e": eByte,
|
||||
}
|
||||
|
||||
// Contains reports whether v is present in elems.
|
||||
func Contains[T comparable](elems []T, v T) bool {
|
||||
for _, s := range elems {
|
||||
if v == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsStringPrefixInSlice searches a string prefix in a slice and returns true
|
||||
// if a matching prefix is found
|
||||
func IsStringPrefixInSlice(obj string, list []string) bool {
|
||||
|
@ -578,7 +568,7 @@ func GenerateOpaqueString() string {
|
|||
return hex.EncodeToString(randomBytes[:])
|
||||
}
|
||||
|
||||
// GenerateUniqueID retuens an unique ID
|
||||
// GenerateUniqueID returns an unique ID
|
||||
func GenerateUniqueID() string {
|
||||
u, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
|
@ -589,7 +579,10 @@ func GenerateUniqueID() string {
|
|||
|
||||
// HTTPListenAndServe is a wrapper for ListenAndServe that support both tcp
|
||||
// and Unix-domain sockets
|
||||
func HTTPListenAndServe(srv *http.Server, address string, port int, isTLS bool, logSender string) error {
|
||||
func HTTPListenAndServe(srv *http.Server, address string, port int, isTLS bool,
|
||||
listenerWrapper func(net.Listener) (net.Listener, error),
|
||||
logSender string,
|
||||
) error {
|
||||
var listener net.Listener
|
||||
var err error
|
||||
|
||||
|
@ -617,7 +610,12 @@ func HTTPListenAndServe(srv *http.Server, address string, port int, isTLS bool,
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if listenerWrapper != nil {
|
||||
listener, err = listenerWrapper(listener)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
logger.Info(logSender, "", "server listener registered, address: %s TLS enabled: %t", listener.Addr().String(), isTLS)
|
||||
|
||||
defer listener.Close()
|
||||
|
@ -638,6 +636,11 @@ func GetTLSCiphersFromNames(cipherNames []string) []uint16 {
|
|||
ciphers = append(ciphers, c.ID)
|
||||
}
|
||||
}
|
||||
for _, c := range tls.InsecureCipherSuites() {
|
||||
if c.Name == strings.TrimSpace(name) {
|
||||
ciphers = append(ciphers, c.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(ciphers) == 0 {
|
||||
|
@ -799,7 +802,9 @@ func GetRedactedURL(rawurl string) string {
|
|||
return u.Redacted()
|
||||
}
|
||||
|
||||
// GetTLSVersion returns the TLS version for integer:
|
||||
// GetTLSVersion returns the TLS version from an integer value:
|
||||
// - 10 means TLS 1.0
|
||||
// - 11 means TLS 1.1
|
||||
// - 12 means TLS 1.2
|
||||
// - 13 means TLS 1.3
|
||||
// default is TLS 1.2
|
||||
|
@ -807,6 +812,10 @@ func GetTLSVersion(val int) uint16 {
|
|||
switch val {
|
||||
case 13:
|
||||
return tls.VersionTLS13
|
||||
case 11:
|
||||
return tls.VersionTLS11
|
||||
case 10:
|
||||
return tls.VersionTLS10
|
||||
default:
|
||||
return tls.VersionTLS12
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ package version
|
|||
import "strings"
|
||||
|
||||
const (
|
||||
version = "2.6.4"
|
||||
version = "2.6.99-dev"
|
||||
appName = "SFTPGo"
|
||||
)
|
||||
|
||||
|
|
|
@ -42,7 +42,6 @@ import (
|
|||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
|
||||
"github.com/eikenb/pipeat"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/sftp"
|
||||
|
||||
|
@ -186,7 +185,11 @@ func (fs *AzureBlobFs) Stat(name string) (os.FileInfo, error) {
|
|||
if val := getAzureLastModified(attrs.Metadata); val > 0 {
|
||||
lastModified = util.GetTimeFromMsecSinceEpoch(val)
|
||||
}
|
||||
return NewFileInfo(name, isDir, util.GetIntFromPointer(attrs.ContentLength), lastModified, false), nil
|
||||
info := NewFileInfo(name, isDir, util.GetIntFromPointer(attrs.ContentLength), lastModified, false)
|
||||
if !isDir {
|
||||
info.setMetadataFromPointerVal(attrs.Metadata)
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
if !fs.IsNotExist(err) {
|
||||
return nil, err
|
||||
|
@ -209,7 +212,7 @@ func (fs *AzureBlobFs) Lstat(name string) (os.FileInfo, error) {
|
|||
|
||||
// Open opens the named file for reading
|
||||
func (fs *AzureBlobFs) Open(name string, offset int64) (File, PipeReader, func(), error) {
|
||||
r, w, err := pipeat.PipeInDir(fs.localTempDir)
|
||||
r, w, err := createPipeFn(fs.localTempDir, fs.config.DownloadPartSize*int64(fs.config.DownloadConcurrency)+1)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
@ -237,7 +240,7 @@ func (fs *AzureBlobFs) Create(name string, flag, checks int) (File, PipeWriter,
|
|||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
r, w, err := pipeat.PipeInDir(fs.localTempDir)
|
||||
r, w, err := createPipeFn(fs.localTempDir, fs.config.UploadPartSize+1024*1024)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
@ -302,19 +305,21 @@ func (fs *AzureBlobFs) Create(name string, flag, checks int) (File, PipeWriter,
|
|||
}
|
||||
|
||||
// Rename renames (moves) source to target.
|
||||
func (fs *AzureBlobFs) Rename(source, target string) (int, int64, error) {
|
||||
func (fs *AzureBlobFs) Rename(source, target string, checks int) (int, int64, error) {
|
||||
if source == target {
|
||||
return -1, -1, nil
|
||||
}
|
||||
_, err := fs.Stat(path.Dir(target))
|
||||
if err != nil {
|
||||
return -1, -1, err
|
||||
if checks&CheckParentDir != 0 {
|
||||
_, err := fs.Stat(path.Dir(target))
|
||||
if err != nil {
|
||||
return -1, -1, err
|
||||
}
|
||||
}
|
||||
fi, err := fs.Stat(source)
|
||||
if err != nil {
|
||||
return -1, -1, err
|
||||
}
|
||||
return fs.renameInternal(source, target, fi, 0)
|
||||
return fs.renameInternal(source, target, fi, 0, checks&CheckUpdateModTime != 0)
|
||||
}
|
||||
|
||||
// Remove removes the named file or (empty) directory.
|
||||
|
@ -396,7 +401,7 @@ func (fs *AzureBlobFs) Chtimes(name string, _, mtime time.Time, isUploading bool
|
|||
}
|
||||
found := false
|
||||
for k := range metadata {
|
||||
if strings.ToLower(k) == lastModifiedField {
|
||||
if strings.EqualFold(k, lastModifiedField) {
|
||||
metadata[k] = to.Ptr(strconv.FormatInt(mtime.UnixMilli(), 10))
|
||||
found = true
|
||||
break
|
||||
|
@ -661,9 +666,9 @@ func (fs *AzureBlobFs) ResolvePath(virtualPath string) (string, error) {
|
|||
}
|
||||
|
||||
// CopyFile implements the FsFileCopier interface
|
||||
func (fs *AzureBlobFs) CopyFile(source, target string, srcSize int64) (int, int64, error) {
|
||||
func (fs *AzureBlobFs) CopyFile(source, target string, srcInfo os.FileInfo) (int, int64, error) {
|
||||
numFiles := 1
|
||||
sizeDiff := srcSize
|
||||
sizeDiff := srcInfo.Size()
|
||||
attrs, err := fs.headObject(target)
|
||||
if err == nil {
|
||||
sizeDiff -= util.GetIntFromPointer(attrs.ContentLength)
|
||||
|
@ -673,7 +678,7 @@ func (fs *AzureBlobFs) CopyFile(source, target string, srcSize int64) (int, int6
|
|||
return 0, 0, err
|
||||
}
|
||||
}
|
||||
if err := fs.copyFileInternal(source, target); err != nil {
|
||||
if err := fs.copyFileInternal(source, target, srcInfo, true); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return numFiles, sizeDiff, nil
|
||||
|
@ -756,13 +761,13 @@ func (fs *AzureBlobFs) setConfigDefaults() {
|
|||
}
|
||||
}
|
||||
|
||||
func (fs *AzureBlobFs) copyFileInternal(source, target string) error {
|
||||
func (fs *AzureBlobFs) copyFileInternal(source, target string, srcInfo os.FileInfo, updateModTime bool) error {
|
||||
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxLongTimeout))
|
||||
defer cancelFn()
|
||||
|
||||
srcBlob := fs.containerClient.NewBlockBlobClient(source)
|
||||
dstBlob := fs.containerClient.NewBlockBlobClient(target)
|
||||
resp, err := dstBlob.StartCopyFromURL(ctx, srcBlob.URL(), fs.getCopyOptions())
|
||||
resp, err := dstBlob.StartCopyFromURL(ctx, srcBlob.URL(), fs.getCopyOptions(srcInfo, updateModTime))
|
||||
if err != nil {
|
||||
metric.AZCopyObjectCompleted(err)
|
||||
return err
|
||||
|
@ -795,11 +800,13 @@ func (fs *AzureBlobFs) copyFileInternal(source, target string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (fs *AzureBlobFs) renameInternal(source, target string, fi os.FileInfo, recursion int) (int, int64, error) {
|
||||
func (fs *AzureBlobFs) renameInternal(source, target string, srcInfo os.FileInfo, recursion int,
|
||||
updateModTime bool,
|
||||
) (int, int64, error) {
|
||||
var numFiles int
|
||||
var filesSize int64
|
||||
|
||||
if fi.IsDir() {
|
||||
if srcInfo.IsDir() {
|
||||
if renameMode == 0 {
|
||||
hasContents, err := fs.hasContents(source)
|
||||
if err != nil {
|
||||
|
@ -813,7 +820,7 @@ func (fs *AzureBlobFs) renameInternal(source, target string, fi os.FileInfo, rec
|
|||
return numFiles, filesSize, err
|
||||
}
|
||||
if renameMode == 1 {
|
||||
files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion)
|
||||
files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion, updateModTime)
|
||||
numFiles += files
|
||||
filesSize += size
|
||||
if err != nil {
|
||||
|
@ -821,13 +828,13 @@ func (fs *AzureBlobFs) renameInternal(source, target string, fi os.FileInfo, rec
|
|||
}
|
||||
}
|
||||
} else {
|
||||
if err := fs.copyFileInternal(source, target); err != nil {
|
||||
if err := fs.copyFileInternal(source, target, srcInfo, updateModTime); err != nil {
|
||||
return numFiles, filesSize, err
|
||||
}
|
||||
numFiles++
|
||||
filesSize += fi.Size()
|
||||
filesSize += srcInfo.Size()
|
||||
}
|
||||
err := fs.skipNotExistErr(fs.Remove(source, fi.IsDir()))
|
||||
err := fs.skipNotExistErr(fs.Remove(source, srcInfo.IsDir()))
|
||||
return numFiles, filesSize, err
|
||||
}
|
||||
|
||||
|
@ -1108,11 +1115,27 @@ func (*AzureBlobFs) readFill(r io.Reader, buf []byte) (n int, err error) {
|
|||
return n, err
|
||||
}
|
||||
|
||||
func (fs *AzureBlobFs) getCopyOptions() *blob.StartCopyFromURLOptions {
|
||||
func (fs *AzureBlobFs) getCopyOptions(srcInfo os.FileInfo, updateModTime bool) *blob.StartCopyFromURLOptions {
|
||||
copyOptions := &blob.StartCopyFromURLOptions{}
|
||||
if fs.config.AccessTier != "" {
|
||||
copyOptions.Tier = (*blob.AccessTier)(&fs.config.AccessTier)
|
||||
}
|
||||
if updateModTime {
|
||||
metadata := make(map[string]*string)
|
||||
for k, v := range getMetadata(srcInfo) {
|
||||
if v != "" {
|
||||
if strings.EqualFold(k, lastModifiedField) {
|
||||
metadata[k] = to.Ptr("0")
|
||||
} else {
|
||||
metadata[k] = to.Ptr(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(metadata) > 0 {
|
||||
copyOptions.Metadata = metadata
|
||||
}
|
||||
}
|
||||
|
||||
return copyOptions
|
||||
}
|
||||
|
||||
|
@ -1135,8 +1158,8 @@ func checkDirectoryMarkers(contentType string, metadata map[string]*string) bool
|
|||
return true
|
||||
}
|
||||
for k, v := range metadata {
|
||||
if strings.ToLower(k) == azFolderKey {
|
||||
return strings.ToLower(util.GetStringFromPointer(v)) == "true"
|
||||
if strings.EqualFold(k, azFolderKey) {
|
||||
return strings.EqualFold(util.GetStringFromPointer(v), "true")
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
@ -1264,6 +1287,7 @@ func (l *azureBlobDirLister) Next(limit int) ([]os.FileInfo, error) {
|
|||
name = strings.TrimPrefix(name, l.prefix)
|
||||
size := int64(0)
|
||||
isDir := false
|
||||
var metadata map[string]*string
|
||||
modTime := time.Unix(0, 0)
|
||||
if blobItem.Properties != nil {
|
||||
size = util.GetIntFromPointer(blobItem.Properties.ContentLength)
|
||||
|
@ -1276,12 +1300,16 @@ func (l *azureBlobDirLister) Next(limit int) ([]os.FileInfo, error) {
|
|||
continue
|
||||
}
|
||||
l.prefixes[name] = true
|
||||
} else {
|
||||
metadata = blobItem.Metadata
|
||||
}
|
||||
if val := getAzureLastModified(blobItem.Metadata); val > 0 {
|
||||
modTime = util.GetTimeFromMsecSinceEpoch(val)
|
||||
}
|
||||
}
|
||||
l.cache = append(l.cache, NewFileInfo(name, isDir, size, modTime, false))
|
||||
info := NewFileInfo(name, isDir, size, modTime, false)
|
||||
info.setMetadataFromPointerVal(metadata)
|
||||
l.cache = append(l.cache, info)
|
||||
}
|
||||
|
||||
return l.returnFromCache(limit), nil
|
||||
|
|
|
@ -24,7 +24,6 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/eikenb/pipeat"
|
||||
"github.com/minio/sio"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
|
||||
|
@ -89,7 +88,7 @@ func (fs *CryptFs) Open(name string, offset int64) (File, PipeReader, func(), er
|
|||
f.Close()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
r, w, err := pipeat.PipeInDir(fs.localTempDir)
|
||||
r, w, err := createPipeFn(fs.localTempDir, 0)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, nil, nil, err
|
||||
|
@ -175,7 +174,7 @@ func (fs *CryptFs) Create(name string, _, _ int) (File, PipeWriter, func(), erro
|
|||
f.Close()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
r, w, err := pipeat.PipeInDir(fs.localTempDir)
|
||||
r, w, err := createPipeFn(fs.localTempDir, 0)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, nil, nil, err
|
||||
|
|
|
@ -18,6 +18,8 @@ import (
|
|||
"os"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
)
|
||||
|
||||
// FileInfo implements os.FileInfo for a Cloud Storage file.
|
||||
|
@ -26,6 +28,7 @@ type FileInfo struct {
|
|||
sizeInBytes int64
|
||||
modTime time.Time
|
||||
mode os.FileMode
|
||||
metadata map[string]string
|
||||
}
|
||||
|
||||
// NewFileInfo creates file info.
|
||||
|
@ -79,5 +82,36 @@ func (fi *FileInfo) SetMode(mode os.FileMode) {
|
|||
|
||||
// Sys provides the underlying data source (can return nil)
|
||||
func (fi *FileInfo) Sys() any {
|
||||
return fi.metadata
|
||||
}
|
||||
|
||||
func (fi *FileInfo) setMetadata(value map[string]string) {
|
||||
fi.metadata = value
|
||||
}
|
||||
|
||||
func (fi *FileInfo) setMetadataFromPointerVal(value map[string]*string) {
|
||||
if len(value) == 0 {
|
||||
fi.metadata = nil
|
||||
return
|
||||
}
|
||||
|
||||
fi.metadata = map[string]string{}
|
||||
for k, v := range value {
|
||||
val := util.GetStringFromPointer(v)
|
||||
if val != "" {
|
||||
fi.metadata[k] = val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getMetadata(fi os.FileInfo) map[string]string {
|
||||
if fi.Sys() == nil {
|
||||
return nil
|
||||
}
|
||||
if val, ok := fi.Sys().(map[string]string); ok {
|
||||
if len(val) > 0 {
|
||||
return val
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"github.com/sftpgo/sdk"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/kms"
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
)
|
||||
|
||||
// Filesystem defines filesystem details
|
||||
|
@ -38,6 +39,7 @@ type Filesystem struct {
|
|||
// SetEmptySecrets sets the secrets to empty
|
||||
func (f *Filesystem) SetEmptySecrets() {
|
||||
f.S3Config.AccessSecret = kms.NewEmptySecret()
|
||||
f.S3Config.SSECustomerKey = kms.NewEmptySecret()
|
||||
f.GCSConfig.Credentials = kms.NewEmptySecret()
|
||||
f.AzBlobConfig.AccountKey = kms.NewEmptySecret()
|
||||
f.AzBlobConfig.SASURL = kms.NewEmptySecret()
|
||||
|
@ -54,6 +56,9 @@ func (f *Filesystem) SetEmptySecretsIfNil() {
|
|||
if f.S3Config.AccessSecret == nil {
|
||||
f.S3Config.AccessSecret = kms.NewEmptySecret()
|
||||
}
|
||||
if f.S3Config.SSECustomerKey == nil {
|
||||
f.S3Config.SSECustomerKey = kms.NewEmptySecret()
|
||||
}
|
||||
if f.GCSConfig.Credentials == nil {
|
||||
f.GCSConfig.Credentials = kms.NewEmptySecret()
|
||||
}
|
||||
|
@ -90,6 +95,9 @@ func (f *Filesystem) SetNilSecretsIfEmpty() {
|
|||
if f.S3Config.AccessSecret != nil && f.S3Config.AccessSecret.IsEmpty() {
|
||||
f.S3Config.AccessSecret = nil
|
||||
}
|
||||
if f.S3Config.SSECustomerKey != nil && f.S3Config.SSECustomerKey.IsEmpty() {
|
||||
f.S3Config.SSECustomerKey = nil
|
||||
}
|
||||
if f.GCSConfig.Credentials != nil && f.GCSConfig.Credentials.IsEmpty() {
|
||||
f.GCSConfig.Credentials = nil
|
||||
}
|
||||
|
@ -232,8 +240,7 @@ func (f *Filesystem) Validate(additionalData string) error {
|
|||
f.CryptConfig = CryptFsConfig{}
|
||||
f.SFTPConfig = SFTPFsConfig{}
|
||||
return nil
|
||||
default:
|
||||
f.Provider = sdk.LocalFilesystemProvider
|
||||
case sdk.LocalFilesystemProvider:
|
||||
f.S3Config = S3FsConfig{}
|
||||
f.GCSConfig = GCSFsConfig{}
|
||||
f.AzBlobConfig = AzBlobFsConfig{}
|
||||
|
@ -241,6 +248,11 @@ func (f *Filesystem) Validate(additionalData string) error {
|
|||
f.SFTPConfig = SFTPFsConfig{}
|
||||
f.HTTPConfig = HTTPFsConfig{}
|
||||
return validateOSFsConfig(&f.OSConfig)
|
||||
default:
|
||||
return util.NewI18nError(
|
||||
util.NewValidationError("invalid filesystem provider"),
|
||||
util.I18nErrorFsValidation,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -249,6 +261,9 @@ func (f *Filesystem) HasRedactedSecret() bool {
|
|||
// TODO move vfs specific code into each *FsConfig struct
|
||||
switch f.Provider {
|
||||
case sdk.S3FilesystemProvider:
|
||||
if f.S3Config.SSECustomerKey.IsRedacted() {
|
||||
return true
|
||||
}
|
||||
return f.S3Config.AccessSecret.IsRedacted()
|
||||
case sdk.GCSFilesystemProvider:
|
||||
return f.GCSConfig.Credentials.IsRedacted()
|
||||
|
@ -323,7 +338,8 @@ func (f *Filesystem) GetACopy() Filesystem {
|
|||
ForcePathStyle: f.S3Config.ForcePathStyle,
|
||||
SkipTLSVerify: f.S3Config.SkipTLSVerify,
|
||||
},
|
||||
AccessSecret: f.S3Config.AccessSecret.Clone(),
|
||||
AccessSecret: f.S3Config.AccessSecret.Clone(),
|
||||
SSECustomerKey: f.S3Config.SSECustomerKey.Clone(),
|
||||
},
|
||||
GCSConfig: GCSFsConfig{
|
||||
BaseGCSFsConfig: sdk.BaseGCSFsConfig{
|
||||
|
|
|
@ -32,7 +32,6 @@ import (
|
|||
"time"
|
||||
|
||||
"cloud.google.com/go/storage"
|
||||
"github.com/eikenb/pipeat"
|
||||
"github.com/pkg/sftp"
|
||||
"github.com/rs/xid"
|
||||
"google.golang.org/api/googleapi"
|
||||
|
@ -89,13 +88,20 @@ func NewGCSFs(connectionID, localTempDir, mountPath string, config GCSFsConfig)
|
|||
}
|
||||
ctx := context.Background()
|
||||
if fs.config.AutomaticCredentials > 0 {
|
||||
fs.svc, err = storage.NewClient(ctx)
|
||||
fs.svc, err = storage.NewClient(ctx,
|
||||
storage.WithJSONReads(),
|
||||
option.WithUserAgent(version.GetVersionHash()),
|
||||
)
|
||||
} else {
|
||||
err = fs.config.Credentials.TryDecrypt()
|
||||
if err != nil {
|
||||
return fs, err
|
||||
}
|
||||
fs.svc, err = storage.NewClient(ctx, option.WithCredentialsJSON([]byte(fs.config.Credentials.GetPayload())))
|
||||
fs.svc, err = storage.NewClient(ctx,
|
||||
storage.WithJSONReads(),
|
||||
option.WithUserAgent(version.GetVersionHash()),
|
||||
option.WithCredentialsJSON([]byte(fs.config.Credentials.GetPayload())),
|
||||
)
|
||||
}
|
||||
return fs, err
|
||||
}
|
||||
|
@ -128,7 +134,7 @@ func (fs *GCSFs) Lstat(name string) (os.FileInfo, error) {
|
|||
|
||||
// Open opens the named file for reading
|
||||
func (fs *GCSFs) Open(name string, offset int64) (File, PipeReader, func(), error) {
|
||||
r, w, err := pipeat.PipeInDir(fs.localTempDir)
|
||||
r, w, err := createPipeFn(fs.localTempDir, 0)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
@ -176,7 +182,11 @@ func (fs *GCSFs) Create(name string, flag, checks int) (File, PipeWriter, func()
|
|||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
r, w, err := pipeat.PipeInDir(fs.localTempDir)
|
||||
chunkSize := googleapi.DefaultUploadChunkSize
|
||||
if fs.config.UploadPartSize > 0 {
|
||||
chunkSize = int(fs.config.UploadPartSize) * 1024 * 1024
|
||||
}
|
||||
r, w, err := createPipeFn(fs.localTempDir, int64(chunkSize+1024*1024))
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
@ -220,9 +230,7 @@ func (fs *GCSFs) Create(name string, flag, checks int) (File, PipeWriter, func()
|
|||
objectWriter = obj.NewWriter(ctx)
|
||||
}
|
||||
|
||||
if fs.config.UploadPartSize > 0 {
|
||||
objectWriter.ChunkSize = int(fs.config.UploadPartSize) * 1024 * 1024
|
||||
}
|
||||
objectWriter.ChunkSize = chunkSize
|
||||
if fs.config.UploadPartMaxTime > 0 {
|
||||
objectWriter.ChunkRetryDeadline = time.Duration(fs.config.UploadPartMaxTime) * time.Second
|
||||
}
|
||||
|
@ -255,19 +263,21 @@ func (fs *GCSFs) Create(name string, flag, checks int) (File, PipeWriter, func()
|
|||
}
|
||||
|
||||
// Rename renames (moves) source to target.
|
||||
func (fs *GCSFs) Rename(source, target string) (int, int64, error) {
|
||||
func (fs *GCSFs) Rename(source, target string, checks int) (int, int64, error) {
|
||||
if source == target {
|
||||
return -1, -1, nil
|
||||
}
|
||||
_, err := fs.Stat(path.Dir(target))
|
||||
if err != nil {
|
||||
return -1, -1, err
|
||||
if checks&CheckParentDir != 0 {
|
||||
_, err := fs.Stat(path.Dir(target))
|
||||
if err != nil {
|
||||
return -1, -1, err
|
||||
}
|
||||
}
|
||||
fi, err := fs.getObjectStat(source)
|
||||
if err != nil {
|
||||
return -1, -1, err
|
||||
}
|
||||
return fs.renameInternal(source, target, fi, 0)
|
||||
return fs.renameInternal(source, target, fi, 0, checks&CheckUpdateModTime != 0)
|
||||
}
|
||||
|
||||
// Remove removes the named file or (empty) directory.
|
||||
|
@ -636,9 +646,9 @@ func (fs *GCSFs) ResolvePath(virtualPath string) (string, error) {
|
|||
}
|
||||
|
||||
// CopyFile implements the FsFileCopier interface
|
||||
func (fs *GCSFs) CopyFile(source, target string, srcSize int64) (int, int64, error) {
|
||||
func (fs *GCSFs) CopyFile(source, target string, srcInfo os.FileInfo) (int, int64, error) {
|
||||
numFiles := 1
|
||||
sizeDiff := srcSize
|
||||
sizeDiff := srcInfo.Size()
|
||||
var conditions *storage.Conditions
|
||||
attrs, err := fs.headObject(target)
|
||||
if err == nil {
|
||||
|
@ -651,7 +661,7 @@ func (fs *GCSFs) CopyFile(source, target string, srcSize int64) (int, int64, err
|
|||
}
|
||||
conditions = &storage.Conditions{DoesNotExist: true}
|
||||
}
|
||||
if err := fs.copyFileInternal(source, target, conditions); err != nil {
|
||||
if err := fs.copyFileInternal(source, target, conditions, srcInfo, true); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return numFiles, sizeDiff, nil
|
||||
|
@ -679,7 +689,11 @@ func (fs *GCSFs) getObjectStat(name string) (os.FileInfo, error) {
|
|||
objectModTime = util.GetTimeFromMsecSinceEpoch(val)
|
||||
}
|
||||
isDir := attrs.ContentType == dirMimeType || strings.HasSuffix(attrs.Name, "/")
|
||||
return NewFileInfo(name, isDir, objSize, objectModTime, false), nil
|
||||
info := NewFileInfo(name, isDir, objSize, objectModTime, false)
|
||||
if !isDir {
|
||||
info.setMetadata(attrs.Metadata)
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
if !fs.IsNotExist(err) {
|
||||
return nil, err
|
||||
|
@ -749,7 +763,9 @@ func (fs *GCSFs) composeObjects(ctx context.Context, dst, partialObject *storage
|
|||
return err
|
||||
}
|
||||
|
||||
func (fs *GCSFs) copyFileInternal(source, target string, conditions *storage.Conditions) error {
|
||||
func (fs *GCSFs) copyFileInternal(source, target string, conditions *storage.Conditions,
|
||||
srcInfo os.FileInfo, updateModTime bool,
|
||||
) error {
|
||||
src := fs.svc.Bucket(fs.config.Bucket).Object(source)
|
||||
dst := fs.svc.Bucket(fs.config.Bucket).Object(target)
|
||||
if conditions != nil {
|
||||
|
@ -780,16 +796,25 @@ func (fs *GCSFs) copyFileInternal(source, target string, conditions *storage.Con
|
|||
if contentType != "" {
|
||||
copier.ContentType = contentType
|
||||
}
|
||||
metadata := getMetadata(srcInfo)
|
||||
if updateModTime && len(metadata) > 0 {
|
||||
delete(metadata, lastModifiedField)
|
||||
}
|
||||
if len(metadata) > 0 {
|
||||
copier.Metadata = metadata
|
||||
}
|
||||
_, err := copier.Run(ctx)
|
||||
metric.GCSCopyObjectCompleted(err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (fs *GCSFs) renameInternal(source, target string, fi os.FileInfo, recursion int) (int, int64, error) {
|
||||
func (fs *GCSFs) renameInternal(source, target string, srcInfo os.FileInfo, recursion int,
|
||||
updateModTime bool,
|
||||
) (int, int64, error) {
|
||||
var numFiles int
|
||||
var filesSize int64
|
||||
|
||||
if fi.IsDir() {
|
||||
if srcInfo.IsDir() {
|
||||
if renameMode == 0 {
|
||||
hasContents, err := fs.hasContents(source)
|
||||
if err != nil {
|
||||
|
@ -803,7 +828,7 @@ func (fs *GCSFs) renameInternal(source, target string, fi os.FileInfo, recursion
|
|||
return numFiles, filesSize, err
|
||||
}
|
||||
if renameMode == 1 {
|
||||
files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion)
|
||||
files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion, updateModTime)
|
||||
numFiles += files
|
||||
filesSize += size
|
||||
if err != nil {
|
||||
|
@ -811,13 +836,13 @@ func (fs *GCSFs) renameInternal(source, target string, fi os.FileInfo, recursion
|
|||
}
|
||||
}
|
||||
} else {
|
||||
if err := fs.copyFileInternal(source, target, nil); err != nil {
|
||||
if err := fs.copyFileInternal(source, target, nil, srcInfo, updateModTime); err != nil {
|
||||
return numFiles, filesSize, err
|
||||
}
|
||||
numFiles++
|
||||
filesSize += fi.Size()
|
||||
filesSize += srcInfo.Size()
|
||||
}
|
||||
err := fs.Remove(source, fi.IsDir())
|
||||
err := fs.Remove(source, srcInfo.IsDir())
|
||||
if fs.IsNotExist(err) {
|
||||
err = nil
|
||||
}
|
||||
|
@ -1002,7 +1027,9 @@ func (l *gcsDirLister) Next(limit int) ([]os.FileInfo, error) {
|
|||
if val := getLastModified(attrs.Metadata); val > 0 {
|
||||
modTime = util.GetTimeFromMsecSinceEpoch(val)
|
||||
}
|
||||
l.cache = append(l.cache, NewFileInfo(name, isDir, attrs.Size, modTime, false))
|
||||
info := NewFileInfo(name, isDir, attrs.Size, modTime, false)
|
||||
info.setMetadata(attrs.Metadata)
|
||||
l.cache = append(l.cache, info)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -32,7 +32,6 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/eikenb/pipeat"
|
||||
"github.com/pkg/sftp"
|
||||
"github.com/sftpgo/sdk"
|
||||
|
||||
|
@ -317,7 +316,7 @@ func (fs *HTTPFs) Lstat(name string) (os.FileInfo, error) {
|
|||
|
||||
// Open opens the named file for reading
|
||||
func (fs *HTTPFs) Open(name string, offset int64) (File, PipeReader, func(), error) {
|
||||
r, w, err := pipeat.PipeInDir(fs.localTempDir)
|
||||
r, w, err := createPipeFn(fs.localTempDir, 0)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
@ -351,7 +350,7 @@ func (fs *HTTPFs) Open(name string, offset int64) (File, PipeReader, func(), err
|
|||
|
||||
// Create creates or opens the named file for writing
|
||||
func (fs *HTTPFs) Create(name string, flag, checks int) (File, PipeWriter, func(), error) {
|
||||
r, w, err := pipeat.PipeInDir(fs.localTempDir)
|
||||
r, w, err := createPipeFn(fs.localTempDir, 0)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
@ -384,7 +383,7 @@ func (fs *HTTPFs) Create(name string, flag, checks int) (File, PipeWriter, func(
|
|||
}
|
||||
|
||||
// Rename renames (moves) source to target.
|
||||
func (fs *HTTPFs) Rename(source, target string) (int, int64, error) {
|
||||
func (fs *HTTPFs) Rename(source, target string, checks int) (int, int64, error) {
|
||||
if source == target {
|
||||
return -1, -1, nil
|
||||
}
|
||||
|
@ -397,6 +396,9 @@ func (fs *HTTPFs) Rename(source, target string) (int, int64, error) {
|
|||
return -1, -1, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if checks&CheckUpdateModTime != 0 {
|
||||
fs.Chtimes(target, time.Now(), time.Now(), false) //nolint:errcheck
|
||||
}
|
||||
return -1, -1, nil
|
||||
}
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue