Browse Source

WebClient: redirect to the requested URL after login

This feature is only useful and enabled for file manager urls

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 2 years ago
parent
commit
9d60972743
7 changed files with 217 additions and 53 deletions
  1. 6 6
      go.mod
  2. 12 11
      go.sum
  3. 143 0
      internal/httpd/httpd_test.go
  4. 4 0
      internal/httpd/middleware.go
  5. 41 29
      internal/httpd/server.go
  6. 10 6
      internal/httpd/webclient.go
  7. 1 1
      pkgs/build.sh

+ 6 - 6
go.mod

@@ -4,7 +4,7 @@ go 1.20
 
 require (
 	cloud.google.com/go/storage v1.30.1
-	github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0
+	github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.1
 	github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0
 	github.com/GehirnInc/crypt v0.0.0-20230320061759-8cc1b52080c5
 	github.com/alexedwards/argon2id v0.0.0-20230305115115-4b3c3280a736
@@ -25,8 +25,8 @@ require (
 	github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001
 	github.com/fclairamb/ftpserverlib v0.21.0
 	github.com/fclairamb/go-log v0.4.1
-	github.com/go-acme/lego/v4 v4.12.0
-	github.com/go-chi/chi/v5 v5.0.8
+	github.com/go-acme/lego/v4 v4.12.1
+	github.com/go-chi/chi/v5 v5.0.9-0.20230502103705-7f280968675b
 	github.com/go-chi/jwtauth/v5 v5.1.0
 	github.com/go-chi/render v1.0.2
 	github.com/go-sql-driver/mysql v1.7.1
@@ -35,7 +35,7 @@ require (
 	github.com/google/uuid v1.3.0
 	github.com/hashicorp/go-hclog v1.5.0
 	github.com/hashicorp/go-plugin v1.4.10
-	github.com/hashicorp/go-retryablehttp v0.7.2
+	github.com/hashicorp/go-retryablehttp v0.7.4
 	github.com/jackc/pgx/v5 v5.3.2-0.20230603125928-d9560c78b8e6
 	github.com/jlaffaye/ftp v0.0.0-20201112195030-9aae4d151126
 	github.com/klauspost/compress v1.16.5
@@ -114,7 +114,7 @@ require (
 	github.com/golang/protobuf v1.5.3 // indirect
 	github.com/google/go-cmp v0.5.9 // indirect
 	github.com/google/s2a-go v0.1.4 // indirect
-	github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect
+	github.com/googleapis/enterprise-certificate-proxy v0.2.4 // indirect
 	github.com/googleapis/gax-go/v2 v2.10.0 // indirect
 	github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
 	github.com/hashicorp/hcl v1.0.0 // indirect
@@ -172,5 +172,5 @@ require (
 replace (
 	github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9
 	github.com/robfig/cron/v3 => github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0
-	golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20230512104844-219592fc3028
+	golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20230608154636-e9d673c2a1a8
 )

+ 12 - 11
go.sum

@@ -430,8 +430,8 @@ github.com/Azure/azure-sdk-for-go/sdk/azcore v1.1.1/go.mod h1:uGG2W01BaETf0Ozp+Q
 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.1.2/go.mod h1:uGG2W01BaETf0Ozp+QxxKJdMBNRWPdstHG0Fmdwn1/U=
 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.0/go.mod h1:tZoQYdDZNOiIjdSn0dVWVfl0NEPGOJqVLzSrcFk4Is0=
 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1/go.mod h1:DffdKW9RFqa5VgmsjUOsS7UE7eiA5iAvYUs63bhKQ0M=
-github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 h1:8kDqDngH+DmVBiCtIjCFTGa7MBnsIOkF9IccInFEbjk=
-github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q=
+github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.1 h1:SEy2xmstIphdPwNBUi7uhvjyjhVKISfwjfOJmuy7kg4=
+github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q=
 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.0.0/go.mod h1:+6sju8gk8FRmSajX3Oz4G5Gm7P+mbqE9FVaXXFYTkCM=
 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.1.0/go.mod h1:bhXu1AjYL+wutSL/kpSq6s7733q2Rb0yuot9Zgfqa/0=
 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1 h1:T8quHYlUGyb/oqtSTwqlCr1ilJHrDv+ZtpSfo+hm1BU=
@@ -876,8 +876,8 @@ github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZ
 github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE=
 github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0 h1:EW9gIJRmt9lzk66Fhh4S8VEtURA6QHZqGeSRE9Nb2/U=
 github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
-github.com/drakkan/crypto v0.0.0-20230512104844-219592fc3028 h1:qUrs/afB0gubJUY5kOmxLx1euFlXn9yUMUhli7Njob8=
-github.com/drakkan/crypto v0.0.0-20230512104844-219592fc3028/go.mod h1:FPowDKc1rEQhN3Xf48AhpBr8eSNzpEYaAQczEYcuAVU=
+github.com/drakkan/crypto v0.0.0-20230608154636-e9d673c2a1a8 h1:0BDuAXFFCOqYrcOkArbc4MRE6jRvN0oiPgYAewp9FyI=
+github.com/drakkan/crypto v0.0.0-20230608154636-e9d673c2a1a8/go.mod h1:FPowDKc1rEQhN3Xf48AhpBr8eSNzpEYaAQczEYcuAVU=
 github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9 h1:LPH1dEblAOO/LoG7yHPMtBLXhQmjaga91/DDjWk9jWA=
 github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9/go.mod h1:2lmrmq866uF2tnje75wQHzmPXhmSWUt7Gyx2vgK1RCU=
 github.com/drakkan/webdav v0.0.0-20230227175313-32996838bcd8 h1:tdkLkSKtYd3WSDsZXGJDKsakiNstLQJPN5HjnqCkf2c=
@@ -951,10 +951,10 @@ github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeME
 github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
 github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M=
 github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U=
-github.com/go-acme/lego/v4 v4.12.0 h1:jox3II6YRjt1EXvrymSQuSNgEUOcbUkF2je0kyuv6YM=
-github.com/go-acme/lego/v4 v4.12.0/go.mod h1:UZoOlhVmUYP/N0z4tEbfUjoCNHRZNObzqWZtT76DIsc=
-github.com/go-chi/chi/v5 v5.0.8 h1:lD+NLqFcAi1ovnVZpsnObHGW4xb4J8lNmoYVfECH1Y0=
-github.com/go-chi/chi/v5 v5.0.8/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
+github.com/go-acme/lego/v4 v4.12.1 h1:Cy3FS7wADLNBqCLpz2wdfdNrThW9rZy8RCAfnUrL2uE=
+github.com/go-acme/lego/v4 v4.12.1/go.mod h1:UZoOlhVmUYP/N0z4tEbfUjoCNHRZNObzqWZtT76DIsc=
+github.com/go-chi/chi/v5 v5.0.9-0.20230502103705-7f280968675b h1:fOhf/SzZ2dPT7wFY5MPJAR/4HUusHgf8xT8XWqVeDtY=
+github.com/go-chi/chi/v5 v5.0.9-0.20230502103705-7f280968675b/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
 github.com/go-chi/jwtauth/v5 v5.1.0 h1:wJyf2YZ/ohPvNJBwPOzZaQbyzwgMZZceE1m8FOzXLeA=
 github.com/go-chi/jwtauth/v5 v5.1.0/go.mod h1:MA93hc1au3tAQwCKry+fI4LqJ5MIVN4XSsglOo+lSc8=
 github.com/go-chi/render v1.0.2 h1:4ER/udB0+fMWB2Jlf15RV3F4A2FDuYi/9f+lFttR/Lg=
@@ -1215,8 +1215,9 @@ github.com/googleapis/enterprise-certificate-proxy v0.0.0-20220520183353-fd19c99
 github.com/googleapis/enterprise-certificate-proxy v0.1.0/go.mod h1:17drOmN3MwGY7t0e+Ei9b45FFGA3fBs3x36SsCg1hq8=
 github.com/googleapis/enterprise-certificate-proxy v0.2.0/go.mod h1:8C0jb7/mgJe/9KK8Lm7X9ctZC2t60YyIpYEI16jx0Qg=
 github.com/googleapis/enterprise-certificate-proxy v0.2.1/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k=
-github.com/googleapis/enterprise-certificate-proxy v0.2.3 h1:yk9/cqRKtT9wXZSsRH9aurXEpJX+U6FLtpYTdC3R06k=
 github.com/googleapis/enterprise-certificate-proxy v0.2.3/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k=
+github.com/googleapis/enterprise-certificate-proxy v0.2.4 h1:uGy6JWR/uMIILU8wbf+OkstIrNiMjGpEIyhx8f6W7s4=
+github.com/googleapis/enterprise-certificate-proxy v0.2.4/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k=
 github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
 github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
 github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pfu6SK+H1/DsU0=
@@ -1298,8 +1299,8 @@ github.com/hashicorp/go-plugin v1.4.10 h1:xUbmA4jC6Dq163/fWcp8P3JuHilrHHMLNRxzGQ
 github.com/hashicorp/go-plugin v1.4.10/go.mod h1:6/1TEzT0eQznvI/gV2CM29DLSkAK/e58mUWKVsPaph0=
 github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs=
 github.com/hashicorp/go-retryablehttp v0.7.1/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY=
-github.com/hashicorp/go-retryablehttp v0.7.2 h1:AcYqCvkpalPnPF2pn0KamgwamS42TqUDDYFRKq/RAd0=
-github.com/hashicorp/go-retryablehttp v0.7.2/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5wjtH1ewM9u8iYVjtX8=
+github.com/hashicorp/go-retryablehttp v0.7.4 h1:ZQgVdpTdAL7WpMIwLzCfbalOcSUdkDZnpUv3/+BxzFA=
+github.com/hashicorp/go-retryablehttp v0.7.4/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5wjtH1ewM9u8iYVjtX8=
 github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU=
 github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8=
 github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU=

+ 143 - 0
internal/httpd/httpd_test.go

@@ -3196,6 +3196,62 @@ func TestUpdateUserPassword(t *testing.T) {
 	assert.NoError(t, err)
 }
 
+func TestLoginRedirectNext(t *testing.T) {
+	user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
+	assert.NoError(t, err)
+
+	uri := webClientFilesPath + "?path=%2F"
+	req, err := http.NewRequest(http.MethodGet, uri, nil)
+	assert.NoError(t, err)
+	req.RequestURI = uri
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusFound, rr)
+	redirectURI := rr.Header().Get("Location")
+	assert.Equal(t, webClientLoginPath+"?next="+url.QueryEscape(uri), redirectURI)
+	// render the login page
+	req, err = http.NewRequest(http.MethodGet, redirectURI, nil)
+	assert.NoError(t, err)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr)
+	assert.Contains(t, rr.Body.String(), fmt.Sprintf("action=%q", redirectURI))
+	// now login the user and check the redirect
+	csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr)
+	assert.NoError(t, err)
+	form := getLoginForm(defaultUsername, defaultPassword, csrfToken)
+	req, err = http.NewRequest(http.MethodPost, redirectURI, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	req.RemoteAddr = defaultRemoteAddr
+	req.RequestURI = redirectURI
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusFound, rr)
+	assert.Equal(t, uri, rr.Header().Get("Location"))
+	// unsafe URI
+	unsafeURI := webClientLoginPath + "?next=" + url.QueryEscape("http://example.net")
+	req, err = http.NewRequest(http.MethodPost, unsafeURI, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	req.RemoteAddr = defaultRemoteAddr
+	req.RequestURI = unsafeURI
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusFound, rr)
+	assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
+	unsupportedURI := webClientLoginPath + "?next=" + url.QueryEscape(webClientProfilePath)
+	req, err = http.NewRequest(http.MethodPost, unsupportedURI, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	req.RemoteAddr = defaultRemoteAddr
+	req.RequestURI = unsupportedURI
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusFound, rr)
+	assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
+
+	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	assert.NoError(t, err)
+	err = os.RemoveAll(user.GetHomeDir())
+	assert.NoError(t, err)
+}
+
 func TestMustChangePasswordRequirement(t *testing.T) {
 	u := getTestUser()
 	u.Filters.RequirePasswordChange = true
@@ -9929,6 +9985,93 @@ func TestWebUserTwoFactorLogin(t *testing.T) {
 	checkResponseCode(t, http.StatusInternalServerError, rr)
 }
 
+func TestWebUserTwoFactoryLoginRedirect(t *testing.T) {
+	user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
+	assert.NoError(t, err)
+	configName, _, secret, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username)
+	assert.NoError(t, err)
+
+	token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword)
+	assert.NoError(t, err)
+	userTOTPConfig := dataprovider.UserTOTPConfig{
+		Enabled:    true,
+		ConfigName: configName,
+		Secret:     kms.NewPlainSecret(secret),
+		Protocols:  []string{common.ProtocolHTTP},
+	}
+	asJSON, err := json.Marshal(userTOTPConfig)
+	assert.NoError(t, err)
+	req, err := http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON))
+	assert.NoError(t, err)
+	setBearerForReq(req, token)
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr)
+
+	csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath)
+	assert.NoError(t, err)
+	form := getLoginForm(defaultUsername, defaultPassword, csrfToken)
+	uri := webClientFilesPath + "?path=%2F"
+	loginURI := webClientLoginPath + "?next=" + url.QueryEscape(uri)
+	expectedURI := webClientTwoFactorPath + "?next=" + url.QueryEscape(uri)
+	req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	req.RemoteAddr = defaultRemoteAddr
+	req.RequestURI = loginURI
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	rr = executeRequest(req)
+	assert.Equal(t, http.StatusFound, rr.Code)
+	assert.Equal(t, expectedURI, rr.Header().Get("Location"))
+	cookie, err := getCookieFromResponse(rr)
+	assert.NoError(t, err)
+	// test unsafe redirects
+	externalURI := webClientLoginPath + "?next=" + url.QueryEscape("https://example.com")
+	req, err = http.NewRequest(http.MethodPost, externalURI, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	req.RemoteAddr = defaultRemoteAddr
+	req.RequestURI = externalURI
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	rr = executeRequest(req)
+	assert.Equal(t, http.StatusFound, rr.Code)
+	assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location"))
+	internalURI := webClientLoginPath + "?next=" + url.QueryEscape(webClientMFAPath)
+	req, err = http.NewRequest(http.MethodPost, internalURI, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	req.RemoteAddr = defaultRemoteAddr
+	req.RequestURI = internalURI
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	rr = executeRequest(req)
+	assert.Equal(t, http.StatusFound, rr.Code)
+	assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location"))
+	// render two factor page
+	req, err = http.NewRequest(http.MethodGet, expectedURI, nil)
+	assert.NoError(t, err)
+	req.RequestURI = expectedURI
+	setJWTCookieForReq(req, cookie)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr)
+	assert.Contains(t, rr.Body.String(), fmt.Sprintf("action=%q", expectedURI))
+	// login with the passcode
+	passcode, err := generateTOTPPasscode(secret)
+	assert.NoError(t, err)
+	form = make(url.Values)
+	form.Set("passcode", passcode)
+	form.Set(csrfFormToken, csrfToken)
+	req, err = http.NewRequest(http.MethodPost, expectedURI, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	setJWTCookieForReq(req, cookie)
+	req.RemoteAddr = defaultRemoteAddr
+	req.RequestURI = expectedURI
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	rr = executeRequest(req)
+	assert.Equal(t, http.StatusFound, rr.Code)
+	assert.Equal(t, uri, rr.Header().Get("Location"))
+
+	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	assert.NoError(t, err)
+	err = os.RemoveAll(user.GetHomeDir())
+	assert.NoError(t, err)
+}
+
 func TestSearchEvents(t *testing.T) {
 	token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass)
 	assert.NoError(t, err)

+ 4 - 0
internal/httpd/middleware.go

@@ -18,6 +18,7 @@ import (
 	"errors"
 	"fmt"
 	"net/http"
+	"net/url"
 	"strings"
 
 	"github.com/go-chi/jwtauth/v5"
@@ -52,6 +53,9 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi
 		redirectPath = webAdminLoginPath
 	} else {
 		redirectPath = webClientLoginPath
+		if uri := r.RequestURI; strings.HasPrefix(uri, webClientFilesPath) {
+			redirectPath += "?next=" + url.QueryEscape(uri)
+		}
 	}
 
 	isAPIToken := (audience == tokenAudienceAPI || audience == tokenAudienceAPIUser)

+ 41 - 29
internal/httpd/server.go

@@ -23,6 +23,7 @@ import (
 	"log"
 	"net"
 	"net/http"
+	"net/url"
 	"path/filepath"
 	"strings"
 	"time"
@@ -160,7 +161,7 @@ func (s *httpdServer) refreshCookie(next http.Handler) http.Handler {
 	})
 }
 
-func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, error, ip string) {
+func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, r *http.Request, error, ip string) {
 	data := loginPage{
 		CurrentURL:   webClientLoginPath,
 		Version:      version.Get().Version,
@@ -170,6 +171,9 @@ func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, error, ip str
 		Branding:     s.binding.Branding.WebClient,
 		FormDisabled: s.binding.isWebClientLoginFormDisabled(),
 	}
+	if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) {
+		data.CurrentURL += "?next=" + url.QueryEscape(next)
+	}
 	if s.binding.showAdminLoginURL() {
 		data.AltLoginURL = webAdminLoginPath
 		data.AltLoginName = s.binding.Branding.WebAdmin.ShortName
@@ -217,7 +221,7 @@ func (s *httpdServer) handleClientWebLogin(w http.ResponseWriter, r *http.Reques
 		http.Redirect(w, r, webAdminSetupPath, http.StatusFound)
 		return
 	}
-	s.renderClientLoginPage(w, getFlashMessage(w, r), util.GetIPFromRemoteAddress(r.RemoteAddr))
+	s.renderClientLoginPage(w, r, getFlashMessage(w, r), util.GetIPFromRemoteAddress(r.RemoteAddr))
 }
 
 func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) {
@@ -225,7 +229,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
 
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 	if err := r.ParseForm(); err != nil {
-		s.renderClientLoginPage(w, err.Error(), ipAddr)
+		s.renderClientLoginPage(w, r, err.Error(), ipAddr)
 		return
 	}
 	protocol := common.ProtocolHTTP
@@ -234,33 +238,33 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
 	if username == "" || password == "" {
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 			dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials)
-		s.renderClientLoginPage(w, "Invalid credentials", ipAddr)
+		s.renderClientLoginPage(w, r, "Invalid credentials", ipAddr)
 		return
 	}
 	if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 			dataprovider.LoginMethodPassword, ipAddr, err)
-		s.renderClientLoginPage(w, err.Error(), ipAddr)
+		s.renderClientLoginPage(w, r, err.Error(), ipAddr)
 		return
 	}
 
 	if err := common.Config.ExecutePostConnectHook(ipAddr, protocol); err != nil {
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 			dataprovider.LoginMethodPassword, ipAddr, err)
-		s.renderClientLoginPage(w, fmt.Sprintf("access denied: %v", err), ipAddr)
+		s.renderClientLoginPage(w, r, fmt.Sprintf("access denied: %v", err), ipAddr)
 		return
 	}
 
 	user, err := dataprovider.CheckUserAndPass(username, password, ipAddr, protocol)
 	if err != nil {
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
-		s.renderClientLoginPage(w, dataprovider.ErrInvalidCredentials.Error(), ipAddr)
+		s.renderClientLoginPage(w, r, dataprovider.ErrInvalidCredentials.Error(), ipAddr)
 		return
 	}
 	connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
 	if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
-		s.renderClientLoginPage(w, err.Error(), ipAddr)
+		s.renderClientLoginPage(w, r, err.Error(), ipAddr)
 		return
 	}
 
@@ -269,7 +273,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
 	if err != nil {
 		logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
-		s.renderClientLoginPage(w, err.Error(), ipAddr)
+		s.renderClientLoginPage(w, r, err.Error(), ipAddr)
 		return
 	}
 	s.loginUser(w, r, &user, connectionID, ipAddr, false, s.renderClientLoginPage)
@@ -281,7 +285,7 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 	err := r.ParseForm()
 	if err != nil {
-		s.renderClientResetPwdPage(w, err.Error(), ipAddr)
+		s.renderClientResetPwdPage(w, r, err.Error(), ipAddr)
 		return
 	}
 	if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
@@ -291,12 +295,12 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
 	_, user, err := handleResetPassword(r, strings.TrimSpace(r.Form.Get("code")),
 		strings.TrimSpace(r.Form.Get("password")), false)
 	if err != nil {
-		s.renderClientResetPwdPage(w, err.Error(), ipAddr)
+		s.renderClientResetPwdPage(w, r, err.Error(), ipAddr)
 		return
 	}
 	connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String())
 	if err := checkHTTPClientUser(user, r, connectionID, true); err != nil {
-		s.renderClientResetPwdPage(w, fmt.Sprintf("Password reset successfully but unable to login: %v", err.Error()), ipAddr)
+		s.renderClientResetPwdPage(w, r, fmt.Sprintf("Password reset successfully but unable to login: %v", err.Error()), ipAddr)
 		return
 	}
 
@@ -304,7 +308,7 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
 	err = user.CheckFsRoot(connectionID)
 	if err != nil {
 		logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
-		s.renderClientResetPwdPage(w, fmt.Sprintf("Password reset successfully but unable to login: %s", err.Error()), ipAddr)
+		s.renderClientResetPwdPage(w, r, fmt.Sprintf("Password reset successfully but unable to login: %s", err.Error()), ipAddr)
 		return
 	}
 	s.loginUser(w, r, user, connectionID, ipAddr, false, s.renderClientResetPwdPage)
@@ -319,17 +323,17 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
 	}
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 	if err := r.ParseForm(); err != nil {
-		s.renderClientTwoFactorRecoveryPage(w, err.Error(), ipAddr)
+		s.renderClientTwoFactorRecoveryPage(w, r, err.Error(), ipAddr)
 		return
 	}
 	username := claims.Username
 	recoveryCode := strings.TrimSpace(r.Form.Get("recovery_code"))
 	if username == "" || recoveryCode == "" {
-		s.renderClientTwoFactorRecoveryPage(w, "Invalid credentials", ipAddr)
+		s.renderClientTwoFactorRecoveryPage(w, r, "Invalid credentials", ipAddr)
 		return
 	}
 	if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
-		s.renderClientTwoFactorRecoveryPage(w, err.Error(), ipAddr)
+		s.renderClientTwoFactorRecoveryPage(w, r, err.Error(), ipAddr)
 		return
 	}
 	user, userMerged, err := dataprovider.GetUserVariants(username, "")
@@ -337,11 +341,11 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
 		if errors.Is(err, util.ErrNotFound) {
 			handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck
 		}
-		s.renderClientTwoFactorRecoveryPage(w, "Invalid credentials", ipAddr)
+		s.renderClientTwoFactorRecoveryPage(w, r, "Invalid credentials", ipAddr)
 		return
 	}
 	if !userMerged.Filters.TOTPConfig.Enabled || !util.Contains(userMerged.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) {
-		s.renderClientTwoFactorPage(w, "Two factory authentication is not enabled", ipAddr)
+		s.renderClientTwoFactorPage(w, r, "Two factory authentication is not enabled", ipAddr)
 		return
 	}
 	for idx, code := range user.Filters.RecoveryCodes {
@@ -351,7 +355,7 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
 		}
 		if code.Secret.GetPayload() == recoveryCode {
 			if code.Used {
-				s.renderClientTwoFactorRecoveryPage(w, "This recovery code was already used", ipAddr)
+				s.renderClientTwoFactorRecoveryPage(w, r, "This recovery code was already used", ipAddr)
 				return
 			}
 			user.Filters.RecoveryCodes[idx].Used = true
@@ -368,7 +372,7 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
 		}
 	}
 	handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck
-	s.renderClientTwoFactorRecoveryPage(w, "Invalid recovery code", ipAddr)
+	s.renderClientTwoFactorRecoveryPage(w, r, "Invalid recovery code", ipAddr)
 }
 
 func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *http.Request) {
@@ -380,7 +384,7 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt
 	}
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 	if err := r.ParseForm(); err != nil {
-		s.renderClientTwoFactorPage(w, err.Error(), ipAddr)
+		s.renderClientTwoFactorPage(w, r, err.Error(), ipAddr)
 		return
 	}
 	username := claims.Username
@@ -388,25 +392,25 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt
 	if username == "" || passcode == "" {
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 			dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials)
-		s.renderClientTwoFactorPage(w, "Invalid credentials", ipAddr)
+		s.renderClientTwoFactorPage(w, r, "Invalid credentials", ipAddr)
 		return
 	}
 	if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 			dataprovider.LoginMethodPassword, ipAddr, err)
-		s.renderClientTwoFactorPage(w, err.Error(), ipAddr)
+		s.renderClientTwoFactorPage(w, r, err.Error(), ipAddr)
 		return
 	}
 	user, err := dataprovider.GetUserWithGroupSettings(username, "")
 	if err != nil {
 		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
 			dataprovider.LoginMethodPassword, ipAddr, err)
-		s.renderClientTwoFactorPage(w, "Invalid credentials", ipAddr)
+		s.renderClientTwoFactorPage(w, r, "Invalid credentials", ipAddr)
 		return
 	}
 	if !user.Filters.TOTPConfig.Enabled || !util.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) {
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
-		s.renderClientTwoFactorPage(w, "Two factory authentication is not enabled", ipAddr)
+		s.renderClientTwoFactorPage(w, r, "Two factory authentication is not enabled", ipAddr)
 		return
 	}
 	err = user.Filters.TOTPConfig.Secret.Decrypt()
@@ -419,7 +423,7 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt
 		user.Filters.TOTPConfig.Secret.GetPayload())
 	if !match || err != nil {
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, dataprovider.ErrInvalidCredentials)
-		s.renderClientTwoFactorPage(w, "Invalid authentication code", ipAddr)
+		s.renderClientTwoFactorPage(w, r, "Invalid authentication code", ipAddr)
 		return
 	}
 	connectionID := fmt.Sprintf("%s_%s", getProtocolFromRequest(r), xid.New().String())
@@ -703,7 +707,7 @@ func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Req
 
 func (s *httpdServer) loginUser(
 	w http.ResponseWriter, r *http.Request, user *dataprovider.User, connectionID, ipAddr string,
-	isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, error, ip string),
+	isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, error, ip string),
 ) {
 	c := jwtTokenClaims{
 		Username:                   user.Username,
@@ -725,18 +729,26 @@ func (s *httpdServer) loginUser(
 	if err != nil {
 		logger.Warn(logSender, connectionID, "unable to set user login cookie %v", err)
 		updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
-		errorFunc(w, err.Error(), ipAddr)
+		errorFunc(w, r, err.Error(), ipAddr)
 		return
 	}
 	if isSecondFactorAuth {
 		invalidateToken(r)
 	}
 	if audience == tokenAudienceWebClientPartial {
-		http.Redirect(w, r, webClientTwoFactorPath, http.StatusFound)
+		redirectPath := webClientTwoFactorPath
+		if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) {
+			redirectPath += "?next=" + url.QueryEscape(next)
+		}
+		http.Redirect(w, r, redirectPath, http.StatusFound)
 		return
 	}
 	updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, err)
 	dataprovider.UpdateLastLogin(user)
+	if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) {
+		http.Redirect(w, r, next, http.StatusFound)
+		return
+	}
 	http.Redirect(w, r, webClientFilesPath, http.StatusFound)
 }
 

+ 10 - 6
internal/httpd/webclient.go

@@ -402,7 +402,7 @@ func (s *httpdServer) renderClientForgotPwdPage(w http.ResponseWriter, error, ip
 	renderClientTemplate(w, templateForgotPassword, data)
 }
 
-func (s *httpdServer) renderClientResetPwdPage(w http.ResponseWriter, error, ip string) {
+func (s *httpdServer) renderClientResetPwdPage(w http.ResponseWriter, _ *http.Request, error, ip string) {
 	data := resetPwdPage{
 		CurrentURL: webClientResetPwdPath,
 		Error:      error,
@@ -467,7 +467,7 @@ func (s *httpdServer) renderClientNotFoundPage(w http.ResponseWriter, r *http.Re
 	s.renderClientMessagePage(w, r, page404Title, page404Body, http.StatusNotFound, err, "")
 }
 
-func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, error, ip string) {
+func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, r *http.Request, error, ip string) {
 	data := twoFactorPage{
 		CurrentURL:  webClientTwoFactorPath,
 		Version:     version.Get().Version,
@@ -477,10 +477,13 @@ func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, error, ip
 		RecoveryURL: webClientTwoFactorRecoveryPath,
 		Branding:    s.binding.Branding.WebClient,
 	}
+	if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) {
+		data.CurrentURL += "?next=" + url.QueryEscape(next)
+	}
 	renderClientTemplate(w, templateTwoFactor, data)
 }
 
-func (s *httpdServer) renderClientTwoFactorRecoveryPage(w http.ResponseWriter, error, ip string) {
+func (s *httpdServer) renderClientTwoFactorRecoveryPage(w http.ResponseWriter, _ *http.Request, error, ip string) {
 	data := twoFactorPage{
 		CurrentURL: webClientTwoFactorRecoveryPath,
 		Version:    version.Get().Version,
@@ -1284,12 +1287,12 @@ func (s *httpdServer) handleWebClientMFA(w http.ResponseWriter, r *http.Request)
 
 func (s *httpdServer) handleWebClientTwoFactor(w http.ResponseWriter, r *http.Request) {
 	r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
-	s.renderClientTwoFactorPage(w, "", util.GetIPFromRemoteAddress(r.RemoteAddr))
+	s.renderClientTwoFactorPage(w, r, "", util.GetIPFromRemoteAddress(r.RemoteAddr))
 }
 
 func (s *httpdServer) handleWebClientTwoFactorRecovery(w http.ResponseWriter, r *http.Request) {
 	r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
-	s.renderClientTwoFactorRecoveryPage(w, "", util.GetIPFromRemoteAddress(r.RemoteAddr))
+	s.renderClientTwoFactorRecoveryPage(w, r, "", util.GetIPFromRemoteAddress(r.RemoteAddr))
 }
 
 func getShareFromPostFields(r *http.Request) (*dataprovider.Share, error) {
@@ -1371,7 +1374,7 @@ func (s *httpdServer) handleWebClientPasswordReset(w http.ResponseWriter, r *htt
 		s.renderClientNotFoundPage(w, r, errors.New("this page does not exist"))
 		return
 	}
-	s.renderClientResetPwdPage(w, "", util.GetIPFromRemoteAddress(r.RemoteAddr))
+	s.renderClientResetPwdPage(w, r, "", util.GetIPFromRemoteAddress(r.RemoteAddr))
 }
 
 func (s *httpdServer) handleClientViewPDF(w http.ResponseWriter, r *http.Request) {
@@ -1500,6 +1503,7 @@ func (s *httpdServer) handleClientShareLoginPost(w http.ResponseWriter, r *http.
 	next := path.Clean(r.URL.Query().Get("next"))
 	if strings.HasPrefix(next, path.Join(webClientPubSharesPath, share.ShareID)) {
 		http.Redirect(w, r, next, http.StatusFound)
+		return
 	}
 	s.renderClientMessagePage(w, r, "Share Login OK", "Share login successful, you can now use your link",
 		http.StatusOK, nil, "")

+ 1 - 1
pkgs/build.sh

@@ -1,6 +1,6 @@
 #!/bin/bash
 
-NFPM_VERSION=2.29.0
+NFPM_VERSION=2.30.1
 NFPM_ARCH=${NFPM_ARCH:-amd64}
 if [ -z ${SFTPGO_VERSION} ]
 then