diff --git a/go.mod b/go.mod index 1f246d05..3944612f 100644 --- a/go.mod +++ b/go.mod @@ -185,6 +185,7 @@ require ( replace ( github.com/fclairamb/ftpserverlib => github.com/drakkan/ftpserverlib v0.0.0-20240603150004-6a8f643fbf2e github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f + github.com/pires/go-proxyproto => github.com/drakkan/go-proxyproto v0.0.0-20240811060125-2e92d08b5373 github.com/robfig/cron/v3 => github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0 golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20240726170110-f4e4a4627441 ) diff --git a/go.sum b/go.sum index b4d3310d..a5fcfe1d 100644 --- a/go.sum +++ b/go.sum @@ -119,6 +119,8 @@ github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f h1:S9JUlrOzjK58UKoLqqb github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f/go.mod h1:4p8lUl4vQ80L598CygL+3IFtm+3nggvvW/palOlViwE= github.com/drakkan/ftpserverlib v0.0.0-20240603150004-6a8f643fbf2e h1:VBpqQeChkGXSV1FXCtvd3BJTyB+DcMgiu7SfkpsGuKw= github.com/drakkan/ftpserverlib v0.0.0-20240603150004-6a8f643fbf2e/go.mod h1:aAwyOAC6IIe+IZeeGD1QjuE3GGDzqW/c5Xtn+Dp0JUM= +github.com/drakkan/go-proxyproto v0.0.0-20240811060125-2e92d08b5373 h1:0ltrbDRr7KT2aSgj2IXOzRraH2xdR+CWZjm5uC4ChXU= +github.com/drakkan/go-proxyproto v0.0.0-20240811060125-2e92d08b5373/go.mod h1:iknsfgnH8EkjrMeMyvfKByp9TiBZCKZM0jx2xmKqnVY= github.com/drakkan/webdav v0.0.0-20240503091431-218ec83910bb h1:067/Uo8cfeY7QC0yzWCr/RImuNcM0rLWAsBUyMks59o= github.com/drakkan/webdav v0.0.0-20240503091431-218ec83910bb/go.mod h1:zOVb1QDhwwqWn2L2qZ0U3swMSO4GTSNyIwXCGO/UGWE= github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001 h1:/ZshrfQzayqRSBDodmp3rhNCHJCff+utvgBuWRbiqu4= @@ -300,8 +302,6 @@ github.com/otiai10/mint v1.5.1 h1:XaPLeE+9vGbuyEHem1JNk3bYc7KKqyI/na0/mLd/Kks= github.com/otiai10/mint v1.5.1/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= -github.com/pires/go-proxyproto v0.7.0 h1:IukmRewDQFWC7kfnb66CSomk2q/seBuilHBYFwyq0Hs= -github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP40O9LbuiFR4= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/internal/common/common.go b/internal/common/common.go index 39bb2ce5..510c642a 100644 --- a/internal/common/common.go +++ b/internal/common/common.go @@ -631,7 +631,7 @@ func (c *Configuration) GetProxyListener(listener net.Listener) (*proxyproto.Lis return &proxyproto.Listener{ Listener: listener, - Policy: getProxyPolicy(c.proxyAllowed, c.proxySkipped, defaultPolicy), + ConnPolicy: getProxyPolicy(c.proxyAllowed, c.proxySkipped, defaultPolicy), ReadHeaderTimeout: 10 * time.Second, }, nil } @@ -806,13 +806,13 @@ func (c *Configuration) ExecutePostConnectHook(ipAddr, protocol string) error { return nil } -func getProxyPolicy(allowed, skipped []func(net.IP) bool, def proxyproto.Policy) proxyproto.PolicyFunc { - return func(upstream net.Addr) (proxyproto.Policy, error) { - upstreamIP, err := util.GetIPFromNetAddr(upstream) +func getProxyPolicy(allowed, skipped []func(net.IP) bool, def proxyproto.Policy) proxyproto.ConnPolicyFunc { + return func(connPolicyOptions proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) { + upstreamIP, err := util.GetIPFromNetAddr(connPolicyOptions.Upstream) if err != nil { // Something is wrong with the source IP, better reject the - // connection if a proxy header is found. - return proxyproto.REJECT, err + // connection. + return proxyproto.REJECT, proxyproto.ErrInvalidUpstream } for _, skippedFrom := range skipped { @@ -831,7 +831,7 @@ func getProxyPolicy(allowed, skipped []func(net.IP) bool, def proxyproto.Policy) } if def == proxyproto.REQUIRE { - return proxyproto.REJECT, nil + return proxyproto.REJECT, proxyproto.ErrInvalidUpstream } return def, nil } diff --git a/internal/common/common_test.go b/internal/common/common_test.go index adcac1db..a3e04a92 100644 --- a/internal/common/common_test.go +++ b/internal/common/common_test.go @@ -1041,9 +1041,13 @@ func TestQuotaScansRole(t *testing.T) { func TestProxyPolicy(t *testing.T) { addr := net.TCPAddr{} + downstream := net.TCPAddr{IP: net.ParseIP("1.1.1.1")} p := getProxyPolicy(nil, nil, proxyproto.IGNORE) - policy, err := p(&addr) - assert.Error(t, err) + policy, err := p(proxyproto.ConnPolicyOptions{ + Upstream: &addr, + Downstream: &downstream, + }) + assert.ErrorIs(t, err, proxyproto.ErrInvalidUpstream) assert.Equal(t, proxyproto.REJECT, policy) ip1 := net.ParseIP("10.8.1.1") ip2 := net.ParseIP("10.8.1.2") @@ -1053,30 +1057,54 @@ func TestProxyPolicy(t *testing.T) { skipped, err := util.ParseAllowedIPAndRanges([]string{ip2.String(), ip3.String()}) assert.NoError(t, err) p = getProxyPolicy(allowed, skipped, proxyproto.IGNORE) - policy, err = p(&net.TCPAddr{IP: ip1}) + policy, err = p(proxyproto.ConnPolicyOptions{ + Upstream: &net.TCPAddr{IP: ip1}, + Downstream: &downstream, + }) assert.NoError(t, err) assert.Equal(t, proxyproto.USE, policy) - policy, err = p(&net.TCPAddr{IP: ip2}) + policy, err = p(proxyproto.ConnPolicyOptions{ + Upstream: &net.TCPAddr{IP: ip2}, + Downstream: &downstream, + }) assert.NoError(t, err) assert.Equal(t, proxyproto.SKIP, policy) - policy, err = p(&net.TCPAddr{IP: ip3}) + policy, err = p(proxyproto.ConnPolicyOptions{ + Upstream: &net.TCPAddr{IP: ip3}, + Downstream: &downstream, + }) assert.NoError(t, err) assert.Equal(t, proxyproto.SKIP, policy) - policy, err = p(&net.TCPAddr{IP: net.ParseIP("10.8.1.4")}) + policy, err = p(proxyproto.ConnPolicyOptions{ + Upstream: &net.TCPAddr{IP: net.ParseIP("10.8.1.4")}, + Downstream: &downstream, + }) assert.NoError(t, err) assert.Equal(t, proxyproto.IGNORE, policy) p = getProxyPolicy(allowed, skipped, proxyproto.REQUIRE) - policy, err = p(&net.TCPAddr{IP: ip1}) + policy, err = p(proxyproto.ConnPolicyOptions{ + Upstream: &net.TCPAddr{IP: ip1}, + Downstream: &downstream, + }) assert.NoError(t, err) assert.Equal(t, proxyproto.REQUIRE, policy) - policy, err = p(&net.TCPAddr{IP: ip2}) + policy, err = p(proxyproto.ConnPolicyOptions{ + Upstream: &net.TCPAddr{IP: ip2}, + Downstream: &downstream, + }) assert.NoError(t, err) assert.Equal(t, proxyproto.SKIP, policy) - policy, err = p(&net.TCPAddr{IP: ip3}) + policy, err = p(proxyproto.ConnPolicyOptions{ + Upstream: &net.TCPAddr{IP: ip3}, + Downstream: &downstream, + }) assert.NoError(t, err) assert.Equal(t, proxyproto.SKIP, policy) - policy, err = p(&net.TCPAddr{IP: net.ParseIP("10.8.1.5")}) - assert.NoError(t, err) + policy, err = p(proxyproto.ConnPolicyOptions{ + Upstream: &net.TCPAddr{IP: net.ParseIP("10.8.1.5")}, + Downstream: &downstream, + }) + assert.ErrorIs(t, err, proxyproto.ErrInvalidUpstream) assert.Equal(t, proxyproto.REJECT, policy) } @@ -1091,12 +1119,12 @@ func TestProxyProtocolVersion(t *testing.T) { c.ProxyProtocol = 1 proxyListener, err := c.GetProxyListener(nil) assert.NoError(t, err) - assert.NotNil(t, proxyListener.Policy) + assert.NotNil(t, proxyListener.ConnPolicy) c.ProxyProtocol = 2 proxyListener, err = c.GetProxyListener(nil) assert.NoError(t, err) - assert.NotNil(t, proxyListener.Policy) + assert.NotNil(t, proxyListener.ConnPolicy) } func TestStartupHook(t *testing.T) { diff --git a/internal/common/protocol_test.go b/internal/common/protocol_test.go index f80ffcac..66e6d4d4 100644 --- a/internal/common/protocol_test.go +++ b/internal/common/protocol_test.go @@ -9031,9 +9031,8 @@ func TestHTTPFs(t *testing.T) { func TestProxyProtocol(t *testing.T) { resp, err := httpclient.Get(fmt.Sprintf("http://%v", httpProxyAddr)) - if assert.NoError(t, err) { - defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) + if !assert.Error(t, err) { + resp.Body.Close() } } diff --git a/internal/sftpd/sftpd_test.go b/internal/sftpd/sftpd_test.go index 9279b4a7..869d0972 100644 --- a/internal/sftpd/sftpd_test.go +++ b/internal/sftpd/sftpd_test.go @@ -1196,11 +1196,8 @@ func TestProxyProtocol(t *testing.T) { defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } - conn, client, err = getSftpClientWithAddr(user, usePubKey, "127.0.0.1:2224") - if assert.NoError(t, err) { - defer client.Close() - defer conn.Close() - } + _, _, err = getSftpClientWithAddr(user, usePubKey, "127.0.0.1:2224") + assert.Error(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) diff --git a/internal/vfs/azblobfs.go b/internal/vfs/azblobfs.go index 067d612c..3d4e92c1 100644 --- a/internal/vfs/azblobfs.go +++ b/internal/vfs/azblobfs.go @@ -394,7 +394,17 @@ func (fs *AzureBlobFs) Chtimes(name string, _, mtime time.Time, isUploading bool if metadata == nil { metadata = make(map[string]*string) } - metadata[lastModifiedField] = to.Ptr(strconv.FormatInt(mtime.UnixMilli(), 10)) + found := false + for k := range metadata { + if strings.ToLower(k) == lastModifiedField { + metadata[k] = to.Ptr(strconv.FormatInt(mtime.UnixMilli(), 10)) + found = true + break + } + } + if !found { + metadata[lastModifiedField] = to.Ptr(strconv.FormatInt(mtime.UnixMilli(), 10)) + } ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn()