Browse Source

WebClient: redirect to the requested URL after login

This feature is only useful and enabled for file manager urls

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 2 years ago
parent
commit
9d60972743
7 changed files with 217 additions and 53 deletions
  1. 6 6
      go.mod
  2. 12 11
      go.sum
  3. 143 0
      internal/httpd/httpd_test.go
  4. 4 0
      internal/httpd/middleware.go
  5. 41 29
      internal/httpd/server.go
  6. 10 6
      internal/httpd/webclient.go
  7. 1 1
      pkgs/build.sh

+ 6 - 6
go.mod

@@ -4,7 +4,7 @@ go 1.20
 
 
 require (
 require (
 	cloud.google.com/go/storage v1.30.1
 	cloud.google.com/go/storage v1.30.1
-	github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0
+	github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.1
 	github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0
 	github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0
 	github.com/GehirnInc/crypt v0.0.0-20230320061759-8cc1b52080c5
 	github.com/GehirnInc/crypt v0.0.0-20230320061759-8cc1b52080c5
 	github.com/alexedwards/argon2id v0.0.0-20230305115115-4b3c3280a736
 	github.com/alexedwards/argon2id v0.0.0-20230305115115-4b3c3280a736
@@ -25,8 +25,8 @@ require (
 	github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001
 	github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001
 	github.com/fclairamb/ftpserverlib v0.21.0
 	github.com/fclairamb/ftpserverlib v0.21.0
 	github.com/fclairamb/go-log v0.4.1
 	github.com/fclairamb/go-log v0.4.1
-	github.com/go-acme/lego/v4 v4.12.0
-	github.com/go-chi/chi/v5 v5.0.8
+	github.com/go-acme/lego/v4 v4.12.1
+	github.com/go-chi/chi/v5 v5.0.9-0.20230502103705-7f280968675b
 	github.com/go-chi/jwtauth/v5 v5.1.0
 	github.com/go-chi/jwtauth/v5 v5.1.0
 	github.com/go-chi/render v1.0.2
 	github.com/go-chi/render v1.0.2
 	github.com/go-sql-driver/mysql v1.7.1
 	github.com/go-sql-driver/mysql v1.7.1
@@ -35,7 +35,7 @@ require (
 	github.com/google/uuid v1.3.0
 	github.com/google/uuid v1.3.0
 	github.com/hashicorp/go-hclog v1.5.0
 	github.com/hashicorp/go-hclog v1.5.0
 	github.com/hashicorp/go-plugin v1.4.10
 	github.com/hashicorp/go-plugin v1.4.10
-	github.com/hashicorp/go-retryablehttp v0.7.2
+	github.com/hashicorp/go-retryablehttp v0.7.4
 	github.com/jackc/pgx/v5 v5.3.2-0.20230603125928-d9560c78b8e6
 	github.com/jackc/pgx/v5 v5.3.2-0.20230603125928-d9560c78b8e6
 	github.com/jlaffaye/ftp v0.0.0-20201112195030-9aae4d151126
 	github.com/jlaffaye/ftp v0.0.0-20201112195030-9aae4d151126
 	github.com/klauspost/compress v1.16.5
 	github.com/klauspost/compress v1.16.5
@@ -114,7 +114,7 @@ require (
 	github.com/golang/protobuf v1.5.3 // indirect
 	github.com/golang/protobuf v1.5.3 // indirect
 	github.com/google/go-cmp v0.5.9 // indirect
 	github.com/google/go-cmp v0.5.9 // indirect
 	github.com/google/s2a-go v0.1.4 // indirect
 	github.com/google/s2a-go v0.1.4 // indirect
-	github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect
+	github.com/googleapis/enterprise-certificate-proxy v0.2.4 // indirect
 	github.com/googleapis/gax-go/v2 v2.10.0 // indirect
 	github.com/googleapis/gax-go/v2 v2.10.0 // indirect
 	github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
 	github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
 	github.com/hashicorp/hcl v1.0.0 // indirect
 	github.com/hashicorp/hcl v1.0.0 // indirect
@@ -172,5 +172,5 @@ require (
 replace (
 replace (
 	github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9
 	github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9
 	github.com/robfig/cron/v3 => github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0
 	github.com/robfig/cron/v3 => github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0
-	golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20230512104844-219592fc3028
+	golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20230608154636-e9d673c2a1a8
 )
 )

+ 12 - 11
go.sum

@@ -430,8 +430,8 @@ github.com/Azure/azure-sdk-for-go/sdk/azcore v1.1.1/go.mod h1:uGG2W01BaETf0Ozp+Q
 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.1.2/go.mod h1:uGG2W01BaETf0Ozp+QxxKJdMBNRWPdstHG0Fmdwn1/U=
 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.1.2/go.mod h1:uGG2W01BaETf0Ozp+QxxKJdMBNRWPdstHG0Fmdwn1/U=
 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.0/go.mod h1:tZoQYdDZNOiIjdSn0dVWVfl0NEPGOJqVLzSrcFk4Is0=
 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.0/go.mod h1:tZoQYdDZNOiIjdSn0dVWVfl0NEPGOJqVLzSrcFk4Is0=
 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1/go.mod h1:DffdKW9RFqa5VgmsjUOsS7UE7eiA5iAvYUs63bhKQ0M=
 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1/go.mod h1:DffdKW9RFqa5VgmsjUOsS7UE7eiA5iAvYUs63bhKQ0M=
-github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 h1:8kDqDngH+DmVBiCtIjCFTGa7MBnsIOkF9IccInFEbjk=
-github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q=
+github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.1 h1:SEy2xmstIphdPwNBUi7uhvjyjhVKISfwjfOJmuy7kg4=
+github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q=
 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.0.0/go.mod h1:+6sju8gk8FRmSajX3Oz4G5Gm7P+mbqE9FVaXXFYTkCM=
 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.0.0/go.mod h1:+6sju8gk8FRmSajX3Oz4G5Gm7P+mbqE9FVaXXFYTkCM=
 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.1.0/go.mod h1:bhXu1AjYL+wutSL/kpSq6s7733q2Rb0yuot9Zgfqa/0=
 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.1.0/go.mod h1:bhXu1AjYL+wutSL/kpSq6s7733q2Rb0yuot9Zgfqa/0=
 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1 h1:T8quHYlUGyb/oqtSTwqlCr1ilJHrDv+ZtpSfo+hm1BU=
 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1 h1:T8quHYlUGyb/oqtSTwqlCr1ilJHrDv+ZtpSfo+hm1BU=
@@ -876,8 +876,8 @@ github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZ
 github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE=
 github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE=
 github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0 h1:EW9gIJRmt9lzk66Fhh4S8VEtURA6QHZqGeSRE9Nb2/U=
 github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0 h1:EW9gIJRmt9lzk66Fhh4S8VEtURA6QHZqGeSRE9Nb2/U=
 github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
 github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
-github.com/drakkan/crypto v0.0.0-20230512104844-219592fc3028 h1:qUrs/afB0gubJUY5kOmxLx1euFlXn9yUMUhli7Njob8=
-github.com/drakkan/crypto v0.0.0-20230512104844-219592fc3028/go.mod h1:FPowDKc1rEQhN3Xf48AhpBr8eSNzpEYaAQczEYcuAVU=
+github.com/drakkan/crypto v0.0.0-20230608154636-e9d673c2a1a8 h1:0BDuAXFFCOqYrcOkArbc4MRE6jRvN0oiPgYAewp9FyI=
+github.com/drakkan/crypto v0.0.0-20230608154636-e9d673c2a1a8/go.mod h1:FPowDKc1rEQhN3Xf48AhpBr8eSNzpEYaAQczEYcuAVU=
 github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9 h1:LPH1dEblAOO/LoG7yHPMtBLXhQmjaga91/DDjWk9jWA=
 github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9 h1:LPH1dEblAOO/LoG7yHPMtBLXhQmjaga91/DDjWk9jWA=
 github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9/go.mod h1:2lmrmq866uF2tnje75wQHzmPXhmSWUt7Gyx2vgK1RCU=
 github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9/go.mod h1:2lmrmq866uF2tnje75wQHzmPXhmSWUt7Gyx2vgK1RCU=
 github.com/drakkan/webdav v0.0.0-20230227175313-32996838bcd8 h1:tdkLkSKtYd3WSDsZXGJDKsakiNstLQJPN5HjnqCkf2c=
 github.com/drakkan/webdav v0.0.0-20230227175313-32996838bcd8 h1:tdkLkSKtYd3WSDsZXGJDKsakiNstLQJPN5HjnqCkf2c=
@@ -951,10 +951,10 @@ github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeME
 github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
 github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
 github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M=
 github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M=
 github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U=
 github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U=
-github.com/go-acme/lego/v4 v4.12.0 h1:jox3II6YRjt1EXvrymSQuSNgEUOcbUkF2je0kyuv6YM=
-github.com/go-acme/lego/v4 v4.12.0/go.mod h1:UZoOlhVmUYP/N0z4tEbfUjoCNHRZNObzqWZtT76DIsc=
-github.com/go-chi/chi/v5 v5.0.8 h1:lD+NLqFcAi1ovnVZpsnObHGW4xb4J8lNmoYVfECH1Y0=
-github.com/go-chi/chi/v5 v5.0.8/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
+github.com/go-acme/lego/v4 v4.12.1 h1:Cy3FS7wADLNBqCLpz2wdfdNrThW9rZy8RCAfnUrL2uE=
+github.com/go-acme/lego/v4 v4.12.1/go.mod h1:UZoOlhVmUYP/N0z4tEbfUjoCNHRZNObzqWZtT76DIsc=
+github.com/go-chi/chi/v5 v5.0.9-0.20230502103705-7f280968675b h1:fOhf/SzZ2dPT7wFY5MPJAR/4HUusHgf8xT8XWqVeDtY=
+github.com/go-chi/chi/v5 v5.0.9-0.20230502103705-7f280968675b/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
 github.com/go-chi/jwtauth/v5 v5.1.0 h1:wJyf2YZ/ohPvNJBwPOzZaQbyzwgMZZceE1m8FOzXLeA=
 github.com/go-chi/jwtauth/v5 v5.1.0 h1:wJyf2YZ/ohPvNJBwPOzZaQbyzwgMZZceE1m8FOzXLeA=
 github.com/go-chi/jwtauth/v5 v5.1.0/go.mod h1:MA93hc1au3tAQwCKry+fI4LqJ5MIVN4XSsglOo+lSc8=
 github.com/go-chi/jwtauth/v5 v5.1.0/go.mod h1:MA93hc1au3tAQwCKry+fI4LqJ5MIVN4XSsglOo+lSc8=
 github.com/go-chi/render v1.0.2 h1:4ER/udB0+fMWB2Jlf15RV3F4A2FDuYi/9f+lFttR/Lg=
 github.com/go-chi/render v1.0.2 h1:4ER/udB0+fMWB2Jlf15RV3F4A2FDuYi/9f+lFttR/Lg=
@@ -1215,8 +1215,9 @@ github.com/googleapis/enterprise-certificate-proxy v0.0.0-20220520183353-fd19c99
 github.com/googleapis/enterprise-certificate-proxy v0.1.0/go.mod h1:17drOmN3MwGY7t0e+Ei9b45FFGA3fBs3x36SsCg1hq8=
 github.com/googleapis/enterprise-certificate-proxy v0.1.0/go.mod h1:17drOmN3MwGY7t0e+Ei9b45FFGA3fBs3x36SsCg1hq8=
 github.com/googleapis/enterprise-certificate-proxy v0.2.0/go.mod h1:8C0jb7/mgJe/9KK8Lm7X9ctZC2t60YyIpYEI16jx0Qg=
 github.com/googleapis/enterprise-certificate-proxy v0.2.0/go.mod h1:8C0jb7/mgJe/9KK8Lm7X9ctZC2t60YyIpYEI16jx0Qg=
 github.com/googleapis/enterprise-certificate-proxy v0.2.1/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k=
 github.com/googleapis/enterprise-certificate-proxy v0.2.1/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k=
-github.com/googleapis/enterprise-certificate-proxy v0.2.3 h1:yk9/cqRKtT9wXZSsRH9aurXEpJX+U6FLtpYTdC3R06k=
 github.com/googleapis/enterprise-certificate-proxy v0.2.3/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k=
 github.com/googleapis/enterprise-certificate-proxy v0.2.3/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k=
+github.com/googleapis/enterprise-certificate-proxy v0.2.4 h1:uGy6JWR/uMIILU8wbf+OkstIrNiMjGpEIyhx8f6W7s4=
+github.com/googleapis/enterprise-certificate-proxy v0.2.4/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k=
 github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
 github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
 github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
 github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
 github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pfu6SK+H1/DsU0=
 github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pfu6SK+H1/DsU0=
@@ -1298,8 +1299,8 @@ github.com/hashicorp/go-plugin v1.4.10 h1:xUbmA4jC6Dq163/fWcp8P3JuHilrHHMLNRxzGQ
 github.com/hashicorp/go-plugin v1.4.10/go.mod h1:6/1TEzT0eQznvI/gV2CM29DLSkAK/e58mUWKVsPaph0=
 github.com/hashicorp/go-plugin v1.4.10/go.mod h1:6/1TEzT0eQznvI/gV2CM29DLSkAK/e58mUWKVsPaph0=
 github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs=
 github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs=
 github.com/hashicorp/go-retryablehttp v0.7.1/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY=
 github.com/hashicorp/go-retryablehttp v0.7.1/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY=
-github.com/hashicorp/go-retryablehttp v0.7.2 h1:AcYqCvkpalPnPF2pn0KamgwamS42TqUDDYFRKq/RAd0=
-github.com/hashicorp/go-retryablehttp v0.7.2/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5wjtH1ewM9u8iYVjtX8=
+github.com/hashicorp/go-retryablehttp v0.7.4 h1:ZQgVdpTdAL7WpMIwLzCfbalOcSUdkDZnpUv3/+BxzFA=
+github.com/hashicorp/go-retryablehttp v0.7.4/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5wjtH1ewM9u8iYVjtX8=
 github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU=
 github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU=
 github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8=
 github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8=
 github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU=
 github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU=

+ 143 - 0
internal/httpd/httpd_test.go

@@ -3196,6 +3196,62 @@ func TestUpdateUserPassword(t *testing.T) {
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 }
 }
 
 
+func TestLoginRedirectNext(t *testing.T) {
+	user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
+	assert.NoError(t, err)
+
+	uri := webClientFilesPath + "?path=%2F"
+	req, err := http.NewRequest(http.MethodGet, uri, nil)
+	assert.NoError(t, err)
+	req.RequestURI = uri
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusFound, rr)
+	redirectURI := rr.Header().Get("Location")
+	assert.Equal(t, webClientLoginPath+"?next="+url.QueryEscape(uri), redirectURI)
+	// render the login page
+	req, err = http.NewRequest(http.MethodGet, redirectURI, nil)
+	assert.NoError(t, err)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr)
+	assert.Contains(t, rr.Body.String(), fmt.Sprintf("action=%q", redirectURI))
+	// now login the user and check the redirect
+	csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr)
+	assert.NoError(t, err)
+	form := getLoginForm(defaultUsername, defaultPassword, csrfToken)
+	req, err = http.NewRequest(http.MethodPost, redirectURI, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	req.RemoteAddr = defaultRemoteAddr
+	req.RequestURI = redirectURI
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusFound, rr)
+	assert.Equal(t, uri, rr.Header().Get("Location"))
+	// unsafe URI
+	unsafeURI := webClientLoginPath + "?next=" + url.QueryEscape("http://example.net")
+	req, err = http.NewRequest(http.MethodPost, unsafeURI, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	req.RemoteAddr = defaultRemoteAddr
+	req.RequestURI = unsafeURI
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusFound, rr)
+	assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
+	unsupportedURI := webClientLoginPath + "?next=" + url.QueryEscape(webClientProfilePath)
+	req, err = http.NewRequest(http.MethodPost, unsupportedURI, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	req.RemoteAddr = defaultRemoteAddr
+	req.RequestURI = unsupportedURI
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusFound, rr)
+	assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
+
+	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	assert.NoError(t, err)
+	err = os.RemoveAll(user.GetHomeDir())
+	assert.NoError(t, err)
+}
+
 func TestMustChangePasswordRequirement(t *testing.T) {
 func TestMustChangePasswordRequirement(t *testing.T) {
 	u := getTestUser()
 	u := getTestUser()
 	u.Filters.RequirePasswordChange = true
 	u.Filters.RequirePasswordChange = true
@@ -9929,6 +9985,93 @@ func TestWebUserTwoFactorLogin(t *testing.T) {
 	checkResponseCode(t, http.StatusInternalServerError, rr)
 	checkResponseCode(t, http.StatusInternalServerError, rr)
 }
 }
 
 
+func TestWebUserTwoFactoryLoginRedirect(t *testing.T) {
+	user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
+	assert.NoError(t, err)
+	configName, _, secret, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username)
+	assert.NoError(t, err)
+
+	token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword)
+	assert.NoError(t, err)
+	userTOTPConfig := dataprovider.UserTOTPConfig{
+		Enabled:    true,
+		ConfigName: configName,
+		Secret:     kms.NewPlainSecret(secret),
+		Protocols:  []string{common.ProtocolHTTP},
+	}
+	asJSON, err := json.Marshal(userTOTPConfig)
+	assert.NoError(t, err)
+	req, err := http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON))
+	assert.NoError(t, err)
+	setBearerForReq(req, token)
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr)
+
+	csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath)
+	assert.NoError(t, err)
+	form := getLoginForm(defaultUsername, defaultPassword, csrfToken)
+	uri := webClientFilesPath + "?path=%2F"
+	loginURI := webClientLoginPath + "?next=" + url.QueryEscape(uri)
+	expectedURI := webClientTwoFactorPath + "?next=" + url.QueryEscape(uri)
+	req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	req.RemoteAddr = defaultRemoteAddr
+	req.RequestURI = loginURI
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	rr = executeRequest(req)
+	assert.Equal(t, http.StatusFound, rr.Code)
+	assert.Equal(t, expectedURI, rr.Header().Get("Location"))
+	cookie, err := getCookieFromResponse(rr)
+	assert.NoError(t, err)
+	// test unsafe redirects
+	externalURI := webClientLoginPath + "?next=" + url.QueryEscape("https://example.com")
+	req, err = http.NewRequest(http.MethodPost, externalURI, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	req.RemoteAddr = defaultRemoteAddr
+	req.RequestURI = externalURI
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	rr = executeRequest(req)
+	assert.Equal(t, http.StatusFound, rr.Code)
+	assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location"))
+	internalURI := webClientLoginPath + "?next=" + url.QueryEscape(webClientMFAPath)
+	req, err = http.NewRequest(http.MethodPost, internalURI, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	req.RemoteAddr = defaultRemoteAddr
+	req.RequestURI = internalURI
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	rr = executeRequest(req)
+	assert.Equal(t, http.StatusFound, rr.Code)
+	assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location"))
+	// render two factor page
+	req, err = http.NewRequest(http.MethodGet, expectedURI, nil)
+	assert.NoError(t, err)
+	req.RequestURI = expectedURI
+	setJWTCookieForReq(req, cookie)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr)
+	assert.Contains(t, rr.Body.String(), fmt.Sprintf("action=%q", expectedURI))
+	// login with the passcode
+	passcode, err := generateTOTPPasscode(secret)
+	assert.NoError(t, err)
+	form = make(url.Values)
+	form.Set("passcode", passcode)
+	form.Set(csrfFormToken, csrfToken)
+	req, err = http.NewRequest(http.MethodPost, expectedURI, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	setJWTCookieForReq(req, cookie)
+	req.RemoteAddr = defaultRemoteAddr
+	req.RequestURI = expectedURI
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	rr = executeRequest(req)
+	assert.Equal(t, http.StatusFound, rr.Code)
+	assert.Equal(t, uri, rr.Header().Get("Location"))
+
+	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	assert.NoError(t, err)
+	err = os.RemoveAll(user.GetHomeDir())
+	assert.NoError(t, err)
+}
+
 func TestSearchEvents(t *testing.T) {
 func TestSearchEvents(t *testing.T) {
 	token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass)
 	token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass)
 	assert.NoError(t, err)
 	assert.NoError(t, err)

+ 4 - 0
internal/httpd/middleware.go

@@ -18,6 +18,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"net/http"
 	"net/http"
+	"net/url"
 	"strings"
 	"strings"
 
 
 	"github.com/go-chi/jwtauth/v5"
 	"github.com/go-chi/jwtauth/v5"
@@ -52,6 +53,9 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi
 		redirectPath = webAdminLoginPath
 		redirectPath = webAdminLoginPath
 	} else {
 	} else {
 		redirectPath = webClientLoginPath
 		redirectPath = webClientLoginPath
+		if uri := r.RequestURI; strings.HasPrefix(uri, webClientFilesPath) {
+			redirectPath += "?next=" + url.QueryEscape(uri)
+		}
 	}
 	}
 
 
 	isAPIToken := (audience == tokenAudienceAPI || audience == tokenAudienceAPIUser)
 	isAPIToken := (audience == tokenAudienceAPI || audience == tokenAudienceAPIUser)

+ 41 - 29
internal/httpd/server.go

@@ -23,6 +23,7 @@ import (
 	"log"
 	"log"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
+	"net/url"
 	"path/filepath"
 	"path/filepath"
 	"strings"
 	"strings"
 	"time"
 	"time"
@@ -160,7 +161,7 @@ func (s *httpdServer) refreshCookie(next http.Handler) http.Handler {
 	})
 	})
 }
 }
 
 
-func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, error, ip string) {
+func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, r *http.Request, error, ip string) {
 	data := loginPage{
 	data := loginPage{
 		CurrentURL:   webClientLoginPath,
 		CurrentURL:   webClientLoginPath,
 		Version:      version.Get().Version,
 		Version:      version.Get().Version,
@@ -170,6 +171,9 @@ func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, error, ip str
 		Branding:     s.binding.Branding.WebClient,
 		Branding:     s.binding.Branding.WebClient,
 		FormDisabled: s.binding.isWebClientLoginFormDisabled(),
 		FormDisabled: s.binding.isWebClientLoginFormDisabled(),
 	}
 	}
+	if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) {
+		data.CurrentURL += "?next=" + url.QueryEscape(next)
+	}
 	if s.binding.showAdminLoginURL() {
 	if s.binding.showAdminLoginURL() {
 		data.AltLoginURL = webAdminLoginPath
 		data.AltLoginURL = webAdminLoginPath
 		data.AltLoginName = s.binding.Branding.WebAdmin.ShortName
 		data.AltLoginName = s.binding.Branding.WebAdmin.ShortName
@@ -217,7 +221,7 @@ func (s *httpdServer) handleClientWebLogin(w http.ResponseWriter, r *http.Reques
 		http.Redirect(w, r, webAdminSetupPath, http.StatusFound)
 		http.Redirect(w, r, webAdminSetupPath, http.StatusFound)
 		return
 		return
 	}
 	}
-	s.renderClientLoginPage(w, getFlashMessage(w, r), util.GetIPFromRemoteAddress(r.RemoteAddr))
+	s.renderClientLoginPage(w, r, getFlashMessage(w, r), util.GetIPFromRemoteAddress(r.RemoteAddr))
 }
 }
 
 
 func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) {
 func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) {
@@ -225,7 +229,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
 
 
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 	if err := r.ParseForm(); err != nil {
 	if err := r.ParseForm(); err != nil {
-		s.renderClientLoginPage(w, err.Error(), ipAddr)
+		s.renderClientLoginPage(w, r, err.Error(), ipAddr)
 		return
 		return
 	}
 	}
 	protocol := common.ProtocolHTTP
 	protocol := common.ProtocolHTTP
@@ -234,33 +238,33 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
 	if username == "" || password == "" {
 	if username == "" || password == "" {
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 			dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials)
 			dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials)
-		s.renderClientLoginPage(w, "Invalid credentials", ipAddr)
+		s.renderClientLoginPage(w, r, "Invalid credentials", ipAddr)
 		return
 		return
 	}
 	}
 	if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
 	if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 			dataprovider.LoginMethodPassword, ipAddr, err)
 			dataprovider.LoginMethodPassword, ipAddr, err)
-		s.renderClientLoginPage(w, err.Error(), ipAddr)
+		s.renderClientLoginPage(w, r, err.Error(), ipAddr)
 		return
 		return
 	}
 	}
 
 
 	if err := common.Config.ExecutePostConnectHook(ipAddr, protocol); err != nil {
 	if err := common.Config.ExecutePostConnectHook(ipAddr, protocol); err != nil {
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 			dataprovider.LoginMethodPassword, ipAddr, err)
 			dataprovider.LoginMethodPassword, ipAddr, err)
-		s.renderClientLoginPage(w, fmt.Sprintf("access denied: %v", err), ipAddr)
+		s.renderClientLoginPage(w, r, fmt.Sprintf("access denied: %v", err), ipAddr)
 		return
 		return
 	}
 	}
 
 
 	user, err := dataprovider.CheckUserAndPass(username, password, ipAddr, protocol)
 	user, err := dataprovider.CheckUserAndPass(username, password, ipAddr, protocol)
 	if err != nil {
 	if err != nil {
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
-		s.renderClientLoginPage(w, dataprovider.ErrInvalidCredentials.Error(), ipAddr)
+		s.renderClientLoginPage(w, r, dataprovider.ErrInvalidCredentials.Error(), ipAddr)
 		return
 		return
 	}
 	}
 	connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
 	connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
 	if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
 	if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
-		s.renderClientLoginPage(w, err.Error(), ipAddr)
+		s.renderClientLoginPage(w, r, err.Error(), ipAddr)
 		return
 		return
 	}
 	}
 
 
@@ -269,7 +273,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
 	if err != nil {
 	if err != nil {
 		logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
 		logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
-		s.renderClientLoginPage(w, err.Error(), ipAddr)
+		s.renderClientLoginPage(w, r, err.Error(), ipAddr)
 		return
 		return
 	}
 	}
 	s.loginUser(w, r, &user, connectionID, ipAddr, false, s.renderClientLoginPage)
 	s.loginUser(w, r, &user, connectionID, ipAddr, false, s.renderClientLoginPage)
@@ -281,7 +285,7 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 	err := r.ParseForm()
 	err := r.ParseForm()
 	if err != nil {
 	if err != nil {
-		s.renderClientResetPwdPage(w, err.Error(), ipAddr)
+		s.renderClientResetPwdPage(w, r, err.Error(), ipAddr)
 		return
 		return
 	}
 	}
 	if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
 	if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
@@ -291,12 +295,12 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
 	_, user, err := handleResetPassword(r, strings.TrimSpace(r.Form.Get("code")),
 	_, user, err := handleResetPassword(r, strings.TrimSpace(r.Form.Get("code")),
 		strings.TrimSpace(r.Form.Get("password")), false)
 		strings.TrimSpace(r.Form.Get("password")), false)
 	if err != nil {
 	if err != nil {
-		s.renderClientResetPwdPage(w, err.Error(), ipAddr)
+		s.renderClientResetPwdPage(w, r, err.Error(), ipAddr)
 		return
 		return
 	}
 	}
 	connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String())
 	connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String())
 	if err := checkHTTPClientUser(user, r, connectionID, true); err != nil {
 	if err := checkHTTPClientUser(user, r, connectionID, true); err != nil {
-		s.renderClientResetPwdPage(w, fmt.Sprintf("Password reset successfully but unable to login: %v", err.Error()), ipAddr)
+		s.renderClientResetPwdPage(w, r, fmt.Sprintf("Password reset successfully but unable to login: %v", err.Error()), ipAddr)
 		return
 		return
 	}
 	}
 
 
@@ -304,7 +308,7 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
 	err = user.CheckFsRoot(connectionID)
 	err = user.CheckFsRoot(connectionID)
 	if err != nil {
 	if err != nil {
 		logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
 		logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
-		s.renderClientResetPwdPage(w, fmt.Sprintf("Password reset successfully but unable to login: %s", err.Error()), ipAddr)
+		s.renderClientResetPwdPage(w, r, fmt.Sprintf("Password reset successfully but unable to login: %s", err.Error()), ipAddr)
 		return
 		return
 	}
 	}
 	s.loginUser(w, r, user, connectionID, ipAddr, false, s.renderClientResetPwdPage)
 	s.loginUser(w, r, user, connectionID, ipAddr, false, s.renderClientResetPwdPage)
@@ -319,17 +323,17 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
 	}
 	}
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 	if err := r.ParseForm(); err != nil {
 	if err := r.ParseForm(); err != nil {
-		s.renderClientTwoFactorRecoveryPage(w, err.Error(), ipAddr)
+		s.renderClientTwoFactorRecoveryPage(w, r, err.Error(), ipAddr)
 		return
 		return
 	}
 	}
 	username := claims.Username
 	username := claims.Username
 	recoveryCode := strings.TrimSpace(r.Form.Get("recovery_code"))
 	recoveryCode := strings.TrimSpace(r.Form.Get("recovery_code"))
 	if username == "" || recoveryCode == "" {
 	if username == "" || recoveryCode == "" {
-		s.renderClientTwoFactorRecoveryPage(w, "Invalid credentials", ipAddr)
+		s.renderClientTwoFactorRecoveryPage(w, r, "Invalid credentials", ipAddr)
 		return
 		return
 	}
 	}
 	if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
 	if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
-		s.renderClientTwoFactorRecoveryPage(w, err.Error(), ipAddr)
+		s.renderClientTwoFactorRecoveryPage(w, r, err.Error(), ipAddr)
 		return
 		return
 	}
 	}
 	user, userMerged, err := dataprovider.GetUserVariants(username, "")
 	user, userMerged, err := dataprovider.GetUserVariants(username, "")
@@ -337,11 +341,11 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
 		if errors.Is(err, util.ErrNotFound) {
 		if errors.Is(err, util.ErrNotFound) {
 			handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck
 			handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck
 		}
 		}
-		s.renderClientTwoFactorRecoveryPage(w, "Invalid credentials", ipAddr)
+		s.renderClientTwoFactorRecoveryPage(w, r, "Invalid credentials", ipAddr)
 		return
 		return
 	}
 	}
 	if !userMerged.Filters.TOTPConfig.Enabled || !util.Contains(userMerged.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) {
 	if !userMerged.Filters.TOTPConfig.Enabled || !util.Contains(userMerged.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) {
-		s.renderClientTwoFactorPage(w, "Two factory authentication is not enabled", ipAddr)
+		s.renderClientTwoFactorPage(w, r, "Two factory authentication is not enabled", ipAddr)
 		return
 		return
 	}
 	}
 	for idx, code := range user.Filters.RecoveryCodes {
 	for idx, code := range user.Filters.RecoveryCodes {
@@ -351,7 +355,7 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
 		}
 		}
 		if code.Secret.GetPayload() == recoveryCode {
 		if code.Secret.GetPayload() == recoveryCode {
 			if code.Used {
 			if code.Used {
-				s.renderClientTwoFactorRecoveryPage(w, "This recovery code was already used", ipAddr)
+				s.renderClientTwoFactorRecoveryPage(w, r, "This recovery code was already used", ipAddr)
 				return
 				return
 			}
 			}
 			user.Filters.RecoveryCodes[idx].Used = true
 			user.Filters.RecoveryCodes[idx].Used = true
@@ -368,7 +372,7 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
 		}
 		}
 	}
 	}
 	handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck
 	handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck
-	s.renderClientTwoFactorRecoveryPage(w, "Invalid recovery code", ipAddr)
+	s.renderClientTwoFactorRecoveryPage(w, r, "Invalid recovery code", ipAddr)
 }
 }
 
 
 func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *http.Request) {
 func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *http.Request) {
@@ -380,7 +384,7 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt
 	}
 	}
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 	if err := r.ParseForm(); err != nil {
 	if err := r.ParseForm(); err != nil {
-		s.renderClientTwoFactorPage(w, err.Error(), ipAddr)
+		s.renderClientTwoFactorPage(w, r, err.Error(), ipAddr)
 		return
 		return
 	}
 	}
 	username := claims.Username
 	username := claims.Username
@@ -388,25 +392,25 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt
 	if username == "" || passcode == "" {
 	if username == "" || passcode == "" {
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 			dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials)
 			dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials)
-		s.renderClientTwoFactorPage(w, "Invalid credentials", ipAddr)
+		s.renderClientTwoFactorPage(w, r, "Invalid credentials", ipAddr)
 		return
 		return
 	}
 	}
 	if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
 	if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 			dataprovider.LoginMethodPassword, ipAddr, err)
 			dataprovider.LoginMethodPassword, ipAddr, err)
-		s.renderClientTwoFactorPage(w, err.Error(), ipAddr)
+		s.renderClientTwoFactorPage(w, r, err.Error(), ipAddr)
 		return
 		return
 	}
 	}
 	user, err := dataprovider.GetUserWithGroupSettings(username, "")
 	user, err := dataprovider.GetUserWithGroupSettings(username, "")
 	if err != nil {
 	if err != nil {
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 			dataprovider.LoginMethodPassword, ipAddr, err)
 			dataprovider.LoginMethodPassword, ipAddr, err)
-		s.renderClientTwoFactorPage(w, "Invalid credentials", ipAddr)
+		s.renderClientTwoFactorPage(w, r, "Invalid credentials", ipAddr)
 		return
 		return
 	}
 	}
 	if !user.Filters.TOTPConfig.Enabled || !util.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) {
 	if !user.Filters.TOTPConfig.Enabled || !util.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) {
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
-		s.renderClientTwoFactorPage(w, "Two factory authentication is not enabled", ipAddr)
+		s.renderClientTwoFactorPage(w, r, "Two factory authentication is not enabled", ipAddr)
 		return
 		return
 	}
 	}
 	err = user.Filters.TOTPConfig.Secret.Decrypt()
 	err = user.Filters.TOTPConfig.Secret.Decrypt()
@@ -419,7 +423,7 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt
 		user.Filters.TOTPConfig.Secret.GetPayload())
 		user.Filters.TOTPConfig.Secret.GetPayload())
 	if !match || err != nil {
 	if !match || err != nil {
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, dataprovider.ErrInvalidCredentials)
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, dataprovider.ErrInvalidCredentials)
-		s.renderClientTwoFactorPage(w, "Invalid authentication code", ipAddr)
+		s.renderClientTwoFactorPage(w, r, "Invalid authentication code", ipAddr)
 		return
 		return
 	}
 	}
 	connectionID := fmt.Sprintf("%s_%s", getProtocolFromRequest(r), xid.New().String())
 	connectionID := fmt.Sprintf("%s_%s", getProtocolFromRequest(r), xid.New().String())
@@ -703,7 +707,7 @@ func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Req
 
 
 func (s *httpdServer) loginUser(
 func (s *httpdServer) loginUser(
 	w http.ResponseWriter, r *http.Request, user *dataprovider.User, connectionID, ipAddr string,
 	w http.ResponseWriter, r *http.Request, user *dataprovider.User, connectionID, ipAddr string,
-	isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, error, ip string),
+	isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, error, ip string),
 ) {
 ) {
 	c := jwtTokenClaims{
 	c := jwtTokenClaims{
 		Username:                   user.Username,
 		Username:                   user.Username,
@@ -725,18 +729,26 @@ func (s *httpdServer) loginUser(
 	if err != nil {
 	if err != nil {
 		logger.Warn(logSender, connectionID, "unable to set user login cookie %v", err)
 		logger.Warn(logSender, connectionID, "unable to set user login cookie %v", err)
 		updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
 		updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
-		errorFunc(w, err.Error(), ipAddr)
+		errorFunc(w, r, err.Error(), ipAddr)
 		return
 		return
 	}
 	}
 	if isSecondFactorAuth {
 	if isSecondFactorAuth {
 		invalidateToken(r)
 		invalidateToken(r)
 	}
 	}
 	if audience == tokenAudienceWebClientPartial {
 	if audience == tokenAudienceWebClientPartial {
-		http.Redirect(w, r, webClientTwoFactorPath, http.StatusFound)
+		redirectPath := webClientTwoFactorPath
+		if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) {
+			redirectPath += "?next=" + url.QueryEscape(next)
+		}
+		http.Redirect(w, r, redirectPath, http.StatusFound)
 		return
 		return
 	}
 	}
 	updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, err)
 	updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, err)
 	dataprovider.UpdateLastLogin(user)
 	dataprovider.UpdateLastLogin(user)
+	if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) {
+		http.Redirect(w, r, next, http.StatusFound)
+		return
+	}
 	http.Redirect(w, r, webClientFilesPath, http.StatusFound)
 	http.Redirect(w, r, webClientFilesPath, http.StatusFound)
 }
 }
 
 

+ 10 - 6
internal/httpd/webclient.go

@@ -402,7 +402,7 @@ func (s *httpdServer) renderClientForgotPwdPage(w http.ResponseWriter, error, ip
 	renderClientTemplate(w, templateForgotPassword, data)
 	renderClientTemplate(w, templateForgotPassword, data)
 }
 }
 
 
-func (s *httpdServer) renderClientResetPwdPage(w http.ResponseWriter, error, ip string) {
+func (s *httpdServer) renderClientResetPwdPage(w http.ResponseWriter, _ *http.Request, error, ip string) {
 	data := resetPwdPage{
 	data := resetPwdPage{
 		CurrentURL: webClientResetPwdPath,
 		CurrentURL: webClientResetPwdPath,
 		Error:      error,
 		Error:      error,
@@ -467,7 +467,7 @@ func (s *httpdServer) renderClientNotFoundPage(w http.ResponseWriter, r *http.Re
 	s.renderClientMessagePage(w, r, page404Title, page404Body, http.StatusNotFound, err, "")
 	s.renderClientMessagePage(w, r, page404Title, page404Body, http.StatusNotFound, err, "")
 }
 }
 
 
-func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, error, ip string) {
+func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, r *http.Request, error, ip string) {
 	data := twoFactorPage{
 	data := twoFactorPage{
 		CurrentURL:  webClientTwoFactorPath,
 		CurrentURL:  webClientTwoFactorPath,
 		Version:     version.Get().Version,
 		Version:     version.Get().Version,
@@ -477,10 +477,13 @@ func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, error, ip
 		RecoveryURL: webClientTwoFactorRecoveryPath,
 		RecoveryURL: webClientTwoFactorRecoveryPath,
 		Branding:    s.binding.Branding.WebClient,
 		Branding:    s.binding.Branding.WebClient,
 	}
 	}
+	if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) {
+		data.CurrentURL += "?next=" + url.QueryEscape(next)
+	}
 	renderClientTemplate(w, templateTwoFactor, data)
 	renderClientTemplate(w, templateTwoFactor, data)
 }
 }
 
 
-func (s *httpdServer) renderClientTwoFactorRecoveryPage(w http.ResponseWriter, error, ip string) {
+func (s *httpdServer) renderClientTwoFactorRecoveryPage(w http.ResponseWriter, _ *http.Request, error, ip string) {
 	data := twoFactorPage{
 	data := twoFactorPage{
 		CurrentURL: webClientTwoFactorRecoveryPath,
 		CurrentURL: webClientTwoFactorRecoveryPath,
 		Version:    version.Get().Version,
 		Version:    version.Get().Version,
@@ -1284,12 +1287,12 @@ func (s *httpdServer) handleWebClientMFA(w http.ResponseWriter, r *http.Request)
 
 
 func (s *httpdServer) handleWebClientTwoFactor(w http.ResponseWriter, r *http.Request) {
 func (s *httpdServer) handleWebClientTwoFactor(w http.ResponseWriter, r *http.Request) {
 	r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
 	r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
-	s.renderClientTwoFactorPage(w, "", util.GetIPFromRemoteAddress(r.RemoteAddr))
+	s.renderClientTwoFactorPage(w, r, "", util.GetIPFromRemoteAddress(r.RemoteAddr))
 }
 }
 
 
 func (s *httpdServer) handleWebClientTwoFactorRecovery(w http.ResponseWriter, r *http.Request) {
 func (s *httpdServer) handleWebClientTwoFactorRecovery(w http.ResponseWriter, r *http.Request) {
 	r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
 	r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
-	s.renderClientTwoFactorRecoveryPage(w, "", util.GetIPFromRemoteAddress(r.RemoteAddr))
+	s.renderClientTwoFactorRecoveryPage(w, r, "", util.GetIPFromRemoteAddress(r.RemoteAddr))
 }
 }
 
 
 func getShareFromPostFields(r *http.Request) (*dataprovider.Share, error) {
 func getShareFromPostFields(r *http.Request) (*dataprovider.Share, error) {
@@ -1371,7 +1374,7 @@ func (s *httpdServer) handleWebClientPasswordReset(w http.ResponseWriter, r *htt
 		s.renderClientNotFoundPage(w, r, errors.New("this page does not exist"))
 		s.renderClientNotFoundPage(w, r, errors.New("this page does not exist"))
 		return
 		return
 	}
 	}
-	s.renderClientResetPwdPage(w, "", util.GetIPFromRemoteAddress(r.RemoteAddr))
+	s.renderClientResetPwdPage(w, r, "", util.GetIPFromRemoteAddress(r.RemoteAddr))
 }
 }
 
 
 func (s *httpdServer) handleClientViewPDF(w http.ResponseWriter, r *http.Request) {
 func (s *httpdServer) handleClientViewPDF(w http.ResponseWriter, r *http.Request) {
@@ -1500,6 +1503,7 @@ func (s *httpdServer) handleClientShareLoginPost(w http.ResponseWriter, r *http.
 	next := path.Clean(r.URL.Query().Get("next"))
 	next := path.Clean(r.URL.Query().Get("next"))
 	if strings.HasPrefix(next, path.Join(webClientPubSharesPath, share.ShareID)) {
 	if strings.HasPrefix(next, path.Join(webClientPubSharesPath, share.ShareID)) {
 		http.Redirect(w, r, next, http.StatusFound)
 		http.Redirect(w, r, next, http.StatusFound)
+		return
 	}
 	}
 	s.renderClientMessagePage(w, r, "Share Login OK", "Share login successful, you can now use your link",
 	s.renderClientMessagePage(w, r, "Share Login OK", "Share login successful, you can now use your link",
 		http.StatusOK, nil, "")
 		http.StatusOK, nil, "")

+ 1 - 1
pkgs/build.sh

@@ -1,6 +1,6 @@
 #!/bin/bash
 #!/bin/bash
 
 
-NFPM_VERSION=2.29.0
+NFPM_VERSION=2.30.1
 NFPM_ARCH=${NFPM_ARCH:-amd64}
 NFPM_ARCH=${NFPM_ARCH:-amd64}
 if [ -z ${SFTPGO_VERSION} ]
 if [ -z ${SFTPGO_VERSION} ]
 then
 then