Browse Source

OIDC: execute pre-login hook after IDP authentication

so the SFTPGo users can be auto-created using the hook

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 3 years ago
parent
commit
c6b8644828

+ 2 - 1
common/common.go

@@ -83,6 +83,7 @@ const (
 	ProtocolHTTP          = "HTTP"
 	ProtocolHTTP          = "HTTP"
 	ProtocolHTTPShare     = "HTTPShare"
 	ProtocolHTTPShare     = "HTTPShare"
 	ProtocolDataRetention = "DataRetention"
 	ProtocolDataRetention = "DataRetention"
+	ProtocolOIDC          = "OIDC"
 )
 )
 
 
 // Upload modes
 // Upload modes
@@ -128,7 +129,7 @@ var (
 	periodicTimeoutTicker     *time.Ticker
 	periodicTimeoutTicker     *time.Ticker
 	periodicTimeoutTickerDone chan bool
 	periodicTimeoutTickerDone chan bool
 	supportedProtocols        = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
 	supportedProtocols        = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
-		ProtocolHTTP, ProtocolHTTPShare}
+		ProtocolHTTP, ProtocolHTTPShare, ProtocolOIDC}
 	disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP}
 	disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP}
 	// the map key is the protocol, for each protocol we can have multiple rate limiters
 	// the map key is the protocol, for each protocol we can have multiple rate limiters
 	rateLimiters map[string][]*rateLimiter
 	rateLimiters map[string][]*rateLimiter

+ 3 - 3
common/connection.go

@@ -1250,7 +1250,7 @@ func (c *BaseConnection) IsNotExistError(err error) bool {
 	switch c.protocol {
 	switch c.protocol {
 	case ProtocolSFTP:
 	case ProtocolSFTP:
 		return errors.Is(err, sftp.ErrSSHFxNoSuchFile)
 		return errors.Is(err, sftp.ErrSSHFxNoSuchFile)
-	case ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolHTTPShare, ProtocolDataRetention:
+	case ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolOIDC, ProtocolHTTPShare, ProtocolDataRetention:
 		return errors.Is(err, os.ErrNotExist)
 		return errors.Is(err, os.ErrNotExist)
 	default:
 	default:
 		return errors.Is(err, ErrNotExist)
 		return errors.Is(err, ErrNotExist)
@@ -1272,7 +1272,7 @@ func (c *BaseConnection) GetPermissionDeniedError() error {
 	switch c.protocol {
 	switch c.protocol {
 	case ProtocolSFTP:
 	case ProtocolSFTP:
 		return sftp.ErrSSHFxPermissionDenied
 		return sftp.ErrSSHFxPermissionDenied
-	case ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolHTTPShare, ProtocolDataRetention:
+	case ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolOIDC, ProtocolHTTPShare, ProtocolDataRetention:
 		return os.ErrPermission
 		return os.ErrPermission
 	default:
 	default:
 		return ErrPermissionDenied
 		return ErrPermissionDenied
@@ -1284,7 +1284,7 @@ func (c *BaseConnection) GetNotExistError() error {
 	switch c.protocol {
 	switch c.protocol {
 	case ProtocolSFTP:
 	case ProtocolSFTP:
 		return sftp.ErrSSHFxNoSuchFile
 		return sftp.ErrSSHFxNoSuchFile
-	case ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolHTTPShare, ProtocolDataRetention:
+	case ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolOIDC, ProtocolHTTPShare, ProtocolDataRetention:
 		return os.ErrNotExist
 		return os.ErrNotExist
 	default:
 	default:
 		return ErrNotExist
 		return ErrNotExist

+ 4 - 2
common/connection_test.go

@@ -15,6 +15,7 @@ import (
 
 
 	"github.com/drakkan/sftpgo/v2/dataprovider"
 	"github.com/drakkan/sftpgo/v2/dataprovider"
 	"github.com/drakkan/sftpgo/v2/kms"
 	"github.com/drakkan/sftpgo/v2/kms"
+	"github.com/drakkan/sftpgo/v2/util"
 	"github.com/drakkan/sftpgo/v2/vfs"
 	"github.com/drakkan/sftpgo/v2/vfs"
 )
 )
 
 
@@ -262,13 +263,14 @@ func TestUpdateQuotaAfterRename(t *testing.T) {
 func TestErrorsMapping(t *testing.T) {
 func TestErrorsMapping(t *testing.T) {
 	fs := vfs.NewOsFs("", os.TempDir(), "")
 	fs := vfs.NewOsFs("", os.TempDir(), "")
 	conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{BaseUser: sdk.BaseUser{HomeDir: os.TempDir()}})
 	conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{BaseUser: sdk.BaseUser{HomeDir: os.TempDir()}})
+	osErrorsProtocols := []string{ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolHTTPShare,
+		ProtocolDataRetention, ProtocolOIDC}
 	for _, protocol := range supportedProtocols {
 	for _, protocol := range supportedProtocols {
 		conn.SetProtocol(protocol)
 		conn.SetProtocol(protocol)
 		err := conn.GetFsError(fs, os.ErrNotExist)
 		err := conn.GetFsError(fs, os.ErrNotExist)
 		if protocol == ProtocolSFTP {
 		if protocol == ProtocolSFTP {
 			assert.ErrorIs(t, err, sftp.ErrSSHFxNoSuchFile)
 			assert.ErrorIs(t, err, sftp.ErrSSHFxNoSuchFile)
-		} else if protocol == ProtocolWebDAV || protocol == ProtocolFTP || protocol == ProtocolHTTP ||
-			protocol == ProtocolHTTPShare || protocol == ProtocolDataRetention {
+		} else if util.IsStringInSlice(protocol, osErrorsProtocols) {
 			assert.EqualError(t, err, os.ErrNotExist.Error())
 			assert.EqualError(t, err, os.ErrNotExist.Error())
 		} else {
 		} else {
 			assert.EqualError(t, err, ErrNotExist.Error())
 			assert.EqualError(t, err, ErrNotExist.Error())

+ 11 - 0
dataprovider/dataprovider.go

@@ -1050,6 +1050,17 @@ func CheckKeyboardInteractiveAuth(username, authHook string, client ssh.Keyboard
 	return doKeyboardInteractiveAuth(&user, authHook, client, ip, protocol)
 	return doKeyboardInteractiveAuth(&user, authHook, client, ip, protocol)
 }
 }
 
 
+// GetUserAfterIDPAuth returns the SFTPGo user with the specified username
+// after a successful authentication with an external identity provider.
+// If a pre-login hook is defined it will be executed so the SFTPGo user
+// can be created if it does not exist
+func GetUserAfterIDPAuth(username, ip, protocol string) (User, error) {
+	if config.PreLoginHook != "" {
+		return executePreLoginHook(username, LoginMethodIDP, ip, protocol)
+	}
+	return UserExists(username)
+}
+
 // GetDefenderHosts returns hosts that are banned or for which some violations have been detected
 // GetDefenderHosts returns hosts that are banned or for which some violations have been detected
 func GetDefenderHosts(from int64, limit int) ([]DefenderEntry, error) {
 func GetDefenderHosts(from int64, limit int) ([]DefenderEntry, error) {
 	return provider.getDefenderHosts(from, limit)
 	return provider.getDefenderHosts(from, limit)

+ 1 - 0
dataprovider/user.go

@@ -70,6 +70,7 @@ const (
 	SSHLoginMethodKeyAndKeyboardInt   = "publickey+keyboard-interactive"
 	SSHLoginMethodKeyAndKeyboardInt   = "publickey+keyboard-interactive"
 	LoginMethodTLSCertificate         = "TLSCertificate"
 	LoginMethodTLSCertificate         = "TLSCertificate"
 	LoginMethodTLSCertificateAndPwd   = "TLSCertificate+password"
 	LoginMethodTLSCertificateAndPwd   = "TLSCertificate+password"
+	LoginMethodIDP                    = "IDP"
 )
 )
 
 
 var (
 var (

+ 4 - 2
docs/dynamic-user-mod.md

@@ -6,9 +6,9 @@ To enable dynamic user modification, you must set the absolute path of your prog
 The external program can read the following environment variables to get info about the user trying to login:
 The external program can read the following environment variables to get info about the user trying to login:
 
 
 - `SFTPGO_LOGIND_USER`, it contains the user trying to login serialized as JSON. A JSON serialized user id equal to zero means the user does not exist inside SFTPGo
 - `SFTPGO_LOGIND_USER`, it contains the user trying to login serialized as JSON. A JSON serialized user id equal to zero means the user does not exist inside SFTPGo
-- `SFTPGO_LOGIND_METHOD`, possible values are: `password`, `publickey`, `keyboard-interactive`, `TLSCertificate`
+- `SFTPGO_LOGIND_METHOD`, possible values are: `password`, `publickey`, `keyboard-interactive`, `TLSCertificate`, `IDP` (external identity provider)
 - `SFTPGO_LOGIND_IP`, ip address of the user trying to login
 - `SFTPGO_LOGIND_IP`, ip address of the user trying to login
-- `SFTPGO_LOGIND_PROTOCOL`, possible values are `SSH`, `FTP`, `DAV`, `HTTP`
+- `SFTPGO_LOGIND_PROTOCOL`, possible values are `SSH`, `FTP`, `DAV`, `HTTP`, `OIDC` (OpenID Connect)
 
 
 The program must write, on its standard output:
 The program must write, on its standard output:
 
 
@@ -35,6 +35,8 @@ If an error happens while executing the hook then login will be denied.
 "Dynamic user creation or modification" and "External Authentication" are mutually exclusive, they are quite similar, the difference is that "External Authentication" returns an already authenticated user while using "Dynamic users modification" you simply create or update a user. The authentication will be checked inside SFTPGo.
 "Dynamic user creation or modification" and "External Authentication" are mutually exclusive, they are quite similar, the difference is that "External Authentication" returns an already authenticated user while using "Dynamic users modification" you simply create or update a user. The authentication will be checked inside SFTPGo.
 In other words while using "External Authentication" the external program receives the credentials of the user trying to login (for example the cleartext password) and it needs to validate them. While using "Dynamic users modification" the pre-login program receives the user stored inside the dataprovider (it includes the hashed password if any) and it can modify it, after the modification SFTPGo will check the credentials of the user trying to login.
 In other words while using "External Authentication" the external program receives the credentials of the user trying to login (for example the cleartext password) and it needs to validate them. While using "Dynamic users modification" the pre-login program receives the user stored inside the dataprovider (it includes the hashed password if any) and it can modify it, after the modification SFTPGo will check the credentials of the user trying to login.
 
 
+For SFTPGo users (not admins) authenticating using an external identity provider such as OpenID Connect, the pre-login hook will be executed after a successful authentication against the external IDP so that you can create/update the SFTPGo user matching the one authenticated against the identity provider. This is the only case where the pre-login hook is executed even if an external authentication hook is defined.
+
 You can disable the hook on a per-user basis.
 You can disable the hook on a per-user basis.
 
 
 Let's see a very basic example. Our sample program will grant access to the existing user `test_user` only in the time range 10:00-18:00. Other users will not be modified since the program will terminate with no output.
 Let's see a very basic example. Our sample program will grant access to the existing user `test_user` only in the time range 10:00-18:00. Other users will not be modified since the program will terminate with no output.

+ 1 - 1
docs/full-configuration.md

@@ -252,7 +252,7 @@ The configuration file contains the following sections:
       - `sts_seconds`, integer. Defines the max-age of the `Strict-Transport-Security` header. This header will be included for `https` responses or for HTTP request if the request includes a defined HTTPS proxy header. Default: `0`, which would NOT include the header.
       - `sts_seconds`, integer. Defines the max-age of the `Strict-Transport-Security` header. This header will be included for `https` responses or for HTTP request if the request includes a defined HTTPS proxy header. Default: `0`, which would NOT include the header.
       - `sts_include_subdomains`, boolean. Set to `true`, the `includeSubdomains` will be appended to the `Strict-Transport-Security` header. Default: `false`.
       - `sts_include_subdomains`, boolean. Set to `true`, the `includeSubdomains` will be appended to the `Strict-Transport-Security` header. Default: `false`.
       - `sts_preload`, boolean. Set to true, the `preload` flag will be appended to the `Strict-Transport-Security` header. Default: `false`.
       - `sts_preload`, boolean. Set to true, the `preload` flag will be appended to the `Strict-Transport-Security` header. Default: `false`.
-      - `content_type_nosniff`, boolean. Set to `true` to add the `X-Content-Type-Options` header with the value `nosniff` Default: `false`.
+      - `content_type_nosniff`, boolean. Set to `true` to add the `X-Content-Type-Options` header with the value `nosniff`. Default: `false`.
       - `content_security_policy`, string. Allows to set the `Content-Security-Policy` header value. Default: blank.
       - `content_security_policy`, string. Allows to set the `Content-Security-Policy` header value. Default: blank.
       - `permissions_policy`, string. Allows to set the `Permissions-Policy` header value. Default: blank.
       - `permissions_policy`, string. Allows to set the `Permissions-Policy` header value. Default: blank.
       - `cross_origin_opener_policy`, string. Allows to set the `Cross-Origin-Opener-Policy` header value. Default: blank.
       - `cross_origin_opener_policy`, string. Allows to set the `Cross-Origin-Opener-Policy` header value. Default: blank.

+ 2 - 0
docs/oidc.md

@@ -97,3 +97,5 @@ And the following is an example ID token which allows the SFTPGo user `user1` to
     "preferred_username": "user1"
     "preferred_username": "user1"
 }
 }
 ```
 ```
+
+SFTPGo users (not admins) can be created/updated after successful OpenID authentication by defining a [pre-login hook](./dynamic-user-mod.md).

+ 2 - 2
docs/post-connect-hook.md

@@ -9,7 +9,7 @@ The `post_connect_hook` can be defined as the absolute path of your program or a
 If the hook defines an external program it can read the following environment variables:
 If the hook defines an external program it can read the following environment variables:
 
 
 - `SFTPGO_CONNECTION_IP`
 - `SFTPGO_CONNECTION_IP`
-- `SFTPGO_CONNECTION_PROTOCOL`
+- `SFTPGO_CONNECTION_PROTOCOL`, possible values are `SSH`, `FTP`, `DAV`, `HTTP`, `OIDC` (OpenID Connect)
 
 
 If the external command completes with a zero exit status the connection will be accepted otherwise rejected.
 If the external command completes with a zero exit status the connection will be accepted otherwise rejected.
 
 
@@ -19,7 +19,7 @@ The program must finish within 20 seconds.
 If the hook defines an HTTP URL then this URL will be invoked as HTTP GET with the following query parameters:
 If the hook defines an HTTP URL then this URL will be invoked as HTTP GET with the following query parameters:
 
 
 - `ip`
 - `ip`
-- `protocol`
+- `protocol`, possible values are `SSH`, `FTP`, `DAV`, `HTTP`, `OIDC` (OpenID Connect)
 
 
 The connection is accepted if the HTTP response code is `200` otherwise rejected.
 The connection is accepted if the HTTP response code is `200` otherwise rejected.
 
 

+ 2 - 2
docs/post-disconnect-hook.md

@@ -9,7 +9,7 @@ The `post_disconnect_hook` can be defined as the absolute path of your program o
 If the hook defines an external program it can read the following environment variables:
 If the hook defines an external program it can read the following environment variables:
 
 
 - `SFTPGO_CONNECTION_IP`
 - `SFTPGO_CONNECTION_IP`
-- `SFTPGO_CONNECTION_PROTOCOL`
+- `SFTPGO_CONNECTION_PROTOCOL`, possible values are `SSH`, `FTP`, `DAV`, `HTTP`, `OIDC` (OpenID Connect)
 - `SFTPGO_CONNECTION_USERNAME`, can be empty if the channel is closed before user authentication
 - `SFTPGO_CONNECTION_USERNAME`, can be empty if the channel is closed before user authentication
 - `SFTPGO_CONNECTION_DURATION`, connection duration in milliseconds
 - `SFTPGO_CONNECTION_DURATION`, connection duration in milliseconds
 
 
@@ -19,7 +19,7 @@ The program must finish within 20 seconds.
 If the hook defines an HTTP URL then this URL will be invoked as HTTP GET with the following query parameters:
 If the hook defines an HTTP URL then this URL will be invoked as HTTP GET with the following query parameters:
 
 
 - `ip`
 - `ip`
-- `protocol`
+- `protocol`, possible values are `SSH`, `FTP`, `DAV`, `HTTP`, `OIDC` (OpenID Connect)
 - `username`, can be empty if the channel is closed before user authentication
 - `username`, can be empty if the channel is closed before user authentication
 - `connection_duration`, connection duration in milliseconds
 - `connection_duration`, connection duration in milliseconds
 
 

+ 2 - 2
docs/post-login-hook.md

@@ -10,9 +10,9 @@ If the hook defines an external program it can reads the following environment v
 
 
 - `SFTPGO_LOGIND_USER`, it contains the user serialized as JSON. The username is empty if the connection is closed for authentication timeout
 - `SFTPGO_LOGIND_USER`, it contains the user serialized as JSON. The username is empty if the connection is closed for authentication timeout
 - `SFTPGO_LOGIND_IP`
 - `SFTPGO_LOGIND_IP`
-- `SFTPGO_LOGIND_METHOD`, possible values are `publickey`, `password`, `keyboard-interactive`, `publickey+password`, `publickey+keyboard-interactive`, `TLSCertificate`, `TLSCertificate+password` or `no_auth_tryed`
+- `SFTPGO_LOGIND_METHOD`, possible values are `publickey`, `password`, `keyboard-interactive`, `publickey+password`, `publickey+keyboard-interactive`, `TLSCertificate`, `TLSCertificate+password` or `no_auth_tryed`, `IDP` (external identity provider)
 - `SFTPGO_LOGIND_STATUS`, 1 means login OK, 0 login KO
 - `SFTPGO_LOGIND_STATUS`, 1 means login OK, 0 login KO
-- `SFTPGO_LOGIND_PROTOCOL`, possible values are `SSH`, `FTP`, `DAV`, `HTTP`
+- `SFTPGO_LOGIND_PROTOCOL`, possible values are `SSH`, `FTP`, `DAV`, `HTTP`, `OIDC` (OpenID Connect)
 
 
 Previous global environment variables aren't cleared when the script is called.
 Previous global environment variables aren't cleared when the script is called.
 The program must finish within 20 seconds.
 The program must finish within 20 seconds.

+ 4 - 5
go.mod

@@ -3,14 +3,14 @@ module github.com/drakkan/sftpgo/v2
 go 1.17
 go 1.17
 
 
 require (
 require (
-	cloud.google.com/go/storage v1.20.0
+	cloud.google.com/go/storage v1.21.0
 	github.com/Azure/azure-storage-blob-go v0.14.0
 	github.com/Azure/azure-storage-blob-go v0.14.0
 	github.com/GehirnInc/crypt v0.0.0-20200316065508-bb7000b8a962
 	github.com/GehirnInc/crypt v0.0.0-20200316065508-bb7000b8a962
 	github.com/alexedwards/argon2id v0.0.0-20211130144151-3585854a6387
 	github.com/alexedwards/argon2id v0.0.0-20211130144151-3585854a6387
-	github.com/aws/aws-sdk-go v1.43.0
+	github.com/aws/aws-sdk-go v1.43.2
 	github.com/cockroachdb/cockroach-go/v2 v2.2.8
 	github.com/cockroachdb/cockroach-go/v2 v2.2.8
 	github.com/coreos/go-oidc/v3 v3.1.0
 	github.com/coreos/go-oidc/v3 v3.1.0
-	github.com/eikenb/pipeat v0.0.0-20210603033007-44fc3ffce52b
+	github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001
 	github.com/fclairamb/ftpserverlib v0.17.1-0.20220212161409-5157f18d716f
 	github.com/fclairamb/ftpserverlib v0.17.1-0.20220212161409-5157f18d716f
 	github.com/fclairamb/go-log v0.2.0
 	github.com/fclairamb/go-log v0.2.0
 	github.com/go-chi/chi/v5 v5.0.8-0.20220103230436-7dbe9a0bd10f
 	github.com/go-chi/chi/v5 v5.0.8-0.20220103230436-7dbe9a0bd10f
@@ -128,7 +128,7 @@ require (
 	golang.org/x/tools v0.1.9 // indirect
 	golang.org/x/tools v0.1.9 // indirect
 	golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
 	golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
 	google.golang.org/appengine v1.6.7 // indirect
 	google.golang.org/appengine v1.6.7 // indirect
-	google.golang.org/genproto v0.0.0-20220217155828-d576998c0009 // indirect
+	google.golang.org/genproto v0.0.0-20220218161850-94dd64e39d7c // indirect
 	google.golang.org/grpc v1.44.0 // indirect
 	google.golang.org/grpc v1.44.0 // indirect
 	google.golang.org/protobuf v1.27.1 // indirect
 	google.golang.org/protobuf v1.27.1 // indirect
 	gopkg.in/ini.v1 v1.66.4 // indirect
 	gopkg.in/ini.v1 v1.66.4 // indirect
@@ -138,7 +138,6 @@ require (
 )
 )
 
 
 replace (
 replace (
-	github.com/eikenb/pipeat => github.com/drakkan/pipeat v0.0.0-20210805162858-70e57fa8a639
 	github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9
 	github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9
 	golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20220215181150-74469fa99b22
 	golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20220215181150-74469fa99b22
 	golang.org/x/net => github.com/drakkan/net v0.0.0-20220130095023-bd85f1236c34
 	golang.org/x/net => github.com/drakkan/net v0.0.0-20220130095023-bd85f1236c34

+ 9 - 8
go.sum

@@ -72,8 +72,8 @@ cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RX
 cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0=
 cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0=
 cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo=
 cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo=
 cloud.google.com/go/storage v1.16.1/go.mod h1:LaNorbty3ehnU3rEjXSNV/NRgQA0O8Y+uh6bPe5UOk4=
 cloud.google.com/go/storage v1.16.1/go.mod h1:LaNorbty3ehnU3rEjXSNV/NRgQA0O8Y+uh6bPe5UOk4=
-cloud.google.com/go/storage v1.20.0 h1:kv3rQ3clEQdxqokkCCgQo+bxPqcuXiROjxvnKb8Oqdk=
-cloud.google.com/go/storage v1.20.0/go.mod h1:TiC1o6FxNCG8y5gB7rqCsFZCIYPMPZCO81ppOoEPLGI=
+cloud.google.com/go/storage v1.21.0 h1:HwnT2u2D309SFDHQII6m18HlrCi3jAXhUMTLOWXYH14=
+cloud.google.com/go/storage v1.21.0/go.mod h1:XmRlxkgPjlBONznT2dDUU/5XlpU2OjMnKuqnZI01LAA=
 cloud.google.com/go/trace v0.1.0/go.mod h1:wxEwsoeRVPbeSkt7ZC9nWCgmoKQRAoySN7XHW2AmI7g=
 cloud.google.com/go/trace v0.1.0/go.mod h1:wxEwsoeRVPbeSkt7ZC9nWCgmoKQRAoySN7XHW2AmI7g=
 contrib.go.opencensus.io/exporter/aws v0.0.0-20200617204711-c478e41e60e9/go.mod h1:uu1P0UCM/6RbsMrgPa98ll8ZcHM858i/AD06a9aLRCA=
 contrib.go.opencensus.io/exporter/aws v0.0.0-20200617204711-c478e41e60e9/go.mod h1:uu1P0UCM/6RbsMrgPa98ll8ZcHM858i/AD06a9aLRCA=
 contrib.go.opencensus.io/exporter/stackdriver v0.13.8/go.mod h1:huNtlWx75MwO7qMs0KrMxPZXzNNWebav1Sq/pm02JdQ=
 contrib.go.opencensus.io/exporter/stackdriver v0.13.8/go.mod h1:huNtlWx75MwO7qMs0KrMxPZXzNNWebav1Sq/pm02JdQ=
@@ -143,8 +143,8 @@ github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgI
 github.com/aws/aws-sdk-go v1.15.27/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0=
 github.com/aws/aws-sdk-go v1.15.27/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0=
 github.com/aws/aws-sdk-go v1.37.0/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro=
 github.com/aws/aws-sdk-go v1.37.0/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro=
 github.com/aws/aws-sdk-go v1.40.34/go.mod h1:585smgzpB/KqRA+K3y/NL/oYRqQvpNJYvLm+LY1U59Q=
 github.com/aws/aws-sdk-go v1.40.34/go.mod h1:585smgzpB/KqRA+K3y/NL/oYRqQvpNJYvLm+LY1U59Q=
-github.com/aws/aws-sdk-go v1.43.0 h1:y4UrPbxU/mIL08qksVPE/nwH9IXuC1udjOaNyhEe+pI=
-github.com/aws/aws-sdk-go v1.43.0/go.mod h1:OGr6lGMAKGlG9CVrYnWYDKIyb829c6EVBRjxqjmPepc=
+github.com/aws/aws-sdk-go v1.43.2 h1:T6LuKCNu8CYXXDn3xJoldh8FbdvuVH7C9aSuLNrlht0=
+github.com/aws/aws-sdk-go v1.43.2/go.mod h1:OGr6lGMAKGlG9CVrYnWYDKIyb829c6EVBRjxqjmPepc=
 github.com/aws/aws-sdk-go-v2 v1.9.0/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4=
 github.com/aws/aws-sdk-go-v2 v1.9.0/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4=
 github.com/aws/aws-sdk-go-v2/config v1.7.0/go.mod h1:w9+nMZ7soXCe5nT46Ri354SNhXDQ6v+V5wqDjnZE+GY=
 github.com/aws/aws-sdk-go-v2/config v1.7.0/go.mod h1:w9+nMZ7soXCe5nT46Ri354SNhXDQ6v+V5wqDjnZE+GY=
 github.com/aws/aws-sdk-go-v2/credentials v1.4.0/go.mod h1:dgGR+Qq7Wjcd4AOAW5Rf5Tnv3+x7ed6kETXyS9WCuAY=
 github.com/aws/aws-sdk-go-v2/credentials v1.4.0/go.mod h1:dgGR+Qq7Wjcd4AOAW5Rf5Tnv3+x7ed6kETXyS9WCuAY=
@@ -224,8 +224,8 @@ github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9 h1:LPH1dEblAOO/LoG7yHP
 github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9/go.mod h1:2lmrmq866uF2tnje75wQHzmPXhmSWUt7Gyx2vgK1RCU=
 github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9/go.mod h1:2lmrmq866uF2tnje75wQHzmPXhmSWUt7Gyx2vgK1RCU=
 github.com/drakkan/net v0.0.0-20220130095023-bd85f1236c34 h1:DRayAKtBRaVU3jg58b/HCbkRleByBD5q6NkN1wcJ2RU=
 github.com/drakkan/net v0.0.0-20220130095023-bd85f1236c34 h1:DRayAKtBRaVU3jg58b/HCbkRleByBD5q6NkN1wcJ2RU=
 github.com/drakkan/net v0.0.0-20220130095023-bd85f1236c34/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
 github.com/drakkan/net v0.0.0-20220130095023-bd85f1236c34/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
-github.com/drakkan/pipeat v0.0.0-20210805162858-70e57fa8a639 h1:8tfGdb4kg/YCvAbIrsMazgoNtnqdOqQVDKW12uUCuuU=
-github.com/drakkan/pipeat v0.0.0-20210805162858-70e57fa8a639/go.mod h1:kltMsfRMTHSFdMbK66XdS8mfMW77+FZA1fGY1xYMF84=
+github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001 h1:/ZshrfQzayqRSBDodmp3rhNCHJCff+utvgBuWRbiqu4=
+github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001/go.mod h1:kltMsfRMTHSFdMbK66XdS8mfMW77+FZA1fGY1xYMF84=
 github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
 github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
 github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
 github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
 github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
 github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
@@ -1177,8 +1177,9 @@ google.golang.org/genproto v0.0.0-20220126215142-9970aeb2e350/go.mod h1:5CzLGKJ6
 google.golang.org/genproto v0.0.0-20220201184016-50beb8ab5c44/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
 google.golang.org/genproto v0.0.0-20220201184016-50beb8ab5c44/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
 google.golang.org/genproto v0.0.0-20220207164111-0872dc986b00/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
 google.golang.org/genproto v0.0.0-20220207164111-0872dc986b00/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
 google.golang.org/genproto v0.0.0-20220211171837-173942840c17/go.mod h1:kGP+zUP2Ddo0ayMi4YuN7C3WZyJvGLZRh8Z5wnAqvEI=
 google.golang.org/genproto v0.0.0-20220211171837-173942840c17/go.mod h1:kGP+zUP2Ddo0ayMi4YuN7C3WZyJvGLZRh8Z5wnAqvEI=
-google.golang.org/genproto v0.0.0-20220217155828-d576998c0009 h1:8QEZX8dJDqdCxQVLRWzEKGOkOzuDx0AU4+bQX6LwmU4=
-google.golang.org/genproto v0.0.0-20220217155828-d576998c0009/go.mod h1:kGP+zUP2Ddo0ayMi4YuN7C3WZyJvGLZRh8Z5wnAqvEI=
+google.golang.org/genproto v0.0.0-20220216160803-4663080d8bc8/go.mod h1:kGP+zUP2Ddo0ayMi4YuN7C3WZyJvGLZRh8Z5wnAqvEI=
+google.golang.org/genproto v0.0.0-20220218161850-94dd64e39d7c h1:TU4rFa5APdKTq0s6B7WTsH6Xmx0Knj86s6Biz56mErE=
+google.golang.org/genproto v0.0.0-20220218161850-94dd64e39d7c/go.mod h1:kGP+zUP2Ddo0ayMi4YuN7C3WZyJvGLZRh8Z5wnAqvEI=
 google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
 google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
 google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
 google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
 google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
 google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=

+ 4 - 3
httpd/api_http_user.go

@@ -32,13 +32,14 @@ func getUserConnection(w http.ResponseWriter, r *http.Request) (*Connection, err
 		return nil, err
 		return nil, err
 	}
 	}
 	connID := xid.New().String()
 	connID := xid.New().String()
-	connectionID := fmt.Sprintf("%v_%v", common.ProtocolHTTP, connID)
+	protocol := getProtocolFromRequest(r)
+	connectionID := fmt.Sprintf("%v_%v", protocol, connID)
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
 		sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
 		sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
 		return nil, err
 		return nil, err
 	}
 	}
 	connection := &Connection{
 	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(connID, common.ProtocolHTTP, util.GetHTTPLocalAddress(r),
+		BaseConnection: common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r),
 			r.RemoteAddr, user),
 			r.RemoteAddr, user),
 		request: r,
 		request: r,
 	}
 	}
@@ -552,7 +553,7 @@ func doChangeUserPassword(r *http.Request, currentPassword, newPassword, confirm
 		return errors.New("invalid token claims")
 		return errors.New("invalid token claims")
 	}
 	}
 	user, err := dataprovider.CheckUserAndPass(claims.Username, currentPassword, util.GetIPFromRemoteAddress(r.RemoteAddr),
 	user, err := dataprovider.CheckUserAndPass(claims.Username, currentPassword, util.GetIPFromRemoteAddress(r.RemoteAddr),
-		common.ProtocolHTTP)
+		getProtocolFromRequest(r))
 	if err != nil {
 	if err != nil {
 		return util.NewValidationError("current password does not match")
 		return util.NewValidationError("current password does not match")
 	}
 	}

+ 20 - 6
httpd/api_utils.go

@@ -477,18 +477,25 @@ func parseRangeRequest(bytesRange string, size int64) (int64, int64, error) {
 	return start, size, err
 	return start, size, err
 }
 }
 
 
-func updateLoginMetrics(user *dataprovider.User, ip string, err error) {
-	metric.AddLoginAttempt(dataprovider.LoginMethodPassword)
+func updateLoginMetrics(user *dataprovider.User, loginMethod, ip string, err error) {
+	metric.AddLoginAttempt(loginMethod)
+	var protocol string
+	switch loginMethod {
+	case dataprovider.LoginMethodIDP:
+		protocol = common.ProtocolOIDC
+	default:
+		protocol = common.ProtocolHTTP
+	}
 	if err != nil && err != common.ErrInternalFailure && err != common.ErrNoCredentials {
 	if err != nil && err != common.ErrInternalFailure && err != common.ErrNoCredentials {
-		logger.ConnectionFailedLog(user.Username, ip, dataprovider.LoginMethodPassword, common.ProtocolHTTP, err.Error())
+		logger.ConnectionFailedLog(user.Username, ip, loginMethod, protocol, err.Error())
 		event := common.HostEventLoginFailed
 		event := common.HostEventLoginFailed
 		if _, ok := err.(*util.RecordNotFoundError); ok {
 		if _, ok := err.(*util.RecordNotFoundError); ok {
 			event = common.HostEventUserNotFound
 			event = common.HostEventUserNotFound
 		}
 		}
 		common.AddDefenderEvent(ip, event)
 		common.AddDefenderEvent(ip, event)
 	}
 	}
-	metric.AddLoginResult(dataprovider.LoginMethodPassword, err)
-	dataprovider.ExecutePostLoginHook(user, dataprovider.LoginMethodPassword, ip, common.ProtocolHTTP, err)
+	metric.AddLoginResult(loginMethod, err)
+	dataprovider.ExecutePostLoginHook(user, loginMethod, ip, protocol, err)
 }
 }
 
 
 func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string) error {
 func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string) error {
@@ -496,7 +503,7 @@ func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID
 		logger.Info(logSender, connectionID, "cannot login user %#v, protocol HTTP is not allowed", user.Username)
 		logger.Info(logSender, connectionID, "cannot login user %#v, protocol HTTP is not allowed", user.Username)
 		return fmt.Errorf("protocol HTTP is not allowed for user %#v", user.Username)
 		return fmt.Errorf("protocol HTTP is not allowed for user %#v", user.Username)
 	}
 	}
-	if !user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, nil) {
+	if !isLoggedInWithOIDC(r) && !user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, nil) {
 		logger.Info(logSender, connectionID, "cannot login user %#v, password login method is not allowed", user.Username)
 		logger.Info(logSender, connectionID, "cannot login user %#v, password login method is not allowed", user.Username)
 		return fmt.Errorf("login method password is not allowed for user %#v", user.Username)
 		return fmt.Errorf("login method password is not allowed for user %#v", user.Username)
 	}
 	}
@@ -635,3 +642,10 @@ func isUserAllowedToResetPassword(r *http.Request, user *dataprovider.User) bool
 	}
 	}
 	return true
 	return true
 }
 }
+
+func getProtocolFromRequest(r *http.Request) string {
+	if isLoggedInWithOIDC(r) {
+		return common.ProtocolOIDC
+	}
+	return common.ProtocolHTTP
+}

+ 13 - 10
httpd/middleware.go

@@ -376,31 +376,34 @@ func authenticateAdminWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTA
 
 
 func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAuth, r *http.Request) error {
 func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAuth, r *http.Request) error {
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
+	protocol := common.ProtocolHTTP
 	if username == "" {
 	if username == "" {
 		err := errors.New("the provided key is not associated with any user and no username was provided")
 		err := errors.New("the provided key is not associated with any user and no username was provided")
-		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, ipAddr, err)
+		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
+			dataprovider.LoginMethodPassword, ipAddr, err)
 		return err
 		return err
 	}
 	}
-	if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolHTTP); err != nil {
+	if err := common.Config.ExecutePostConnectHook(ipAddr, protocol); err != nil {
 		return err
 		return err
 	}
 	}
 	user, err := dataprovider.UserExists(username)
 	user, err := dataprovider.UserExists(username)
 	if err != nil {
 	if err != nil {
-		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, ipAddr, err)
+		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
+			dataprovider.LoginMethodPassword, ipAddr, err)
 		return err
 		return err
 	}
 	}
 	if !user.Filters.AllowAPIKeyAuth {
 	if !user.Filters.AllowAPIKeyAuth {
 		err := fmt.Errorf("API key authentication disabled for user %#v", user.Username)
 		err := fmt.Errorf("API key authentication disabled for user %#v", user.Username)
-		updateLoginMetrics(&user, ipAddr, err)
+		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
 		return err
 		return err
 	}
 	}
 	if err := user.CheckLoginConditions(); err != nil {
 	if err := user.CheckLoginConditions(); err != nil {
-		updateLoginMetrics(&user, ipAddr, err)
+		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
 		return err
 		return err
 	}
 	}
-	connectionID := fmt.Sprintf("%v_%v", common.ProtocolHTTP, xid.New().String())
+	connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
-		updateLoginMetrics(&user, ipAddr, err)
+		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
 		return err
 		return err
 	}
 	}
 	lastLogin := util.GetTimeFromMsecSinceEpoch(user.LastLogin)
 	lastLogin := util.GetTimeFromMsecSinceEpoch(user.LastLogin)
@@ -409,7 +412,7 @@ func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAu
 		defer user.CloseFs() //nolint:errcheck
 		defer user.CloseFs() //nolint:errcheck
 		err = user.CheckFsRoot(connectionID)
 		err = user.CheckFsRoot(connectionID)
 		if err != nil {
 		if err != nil {
-			updateLoginMetrics(&user, ipAddr, common.ErrInternalFailure)
+			updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
 			return common.ErrInternalFailure
 			return common.ErrInternalFailure
 		}
 		}
 	}
 	}
@@ -422,12 +425,12 @@ func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAu
 
 
 	resp, err := c.createTokenResponse(tokenAuth, tokenAudienceAPIUser)
 	resp, err := c.createTokenResponse(tokenAuth, tokenAudienceAPIUser)
 	if err != nil {
 	if err != nil {
-		updateLoginMetrics(&user, ipAddr, common.ErrInternalFailure)
+		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
 		return err
 		return err
 	}
 	}
 	r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", resp["access_token"]))
 	r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", resp["access_token"]))
 	dataprovider.UpdateLastLogin(&user)
 	dataprovider.UpdateLastLogin(&user)
-	updateLoginMetrics(&user, ipAddr, nil)
+	updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, nil)
 
 
 	return nil
 	return nil
 }
 }

+ 9 - 9
httpd/oidc.go

@@ -278,32 +278,32 @@ func (t *oidcToken) getUser(r *http.Request) error {
 		dataprovider.UpdateAdminLastLogin(&admin)
 		dataprovider.UpdateAdminLastLogin(&admin)
 		return nil
 		return nil
 	}
 	}
-	user, err := dataprovider.UserExists(t.Username)
+	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
+	user, err := dataprovider.GetUserAfterIDPAuth(t.Username, ipAddr, common.ProtocolOIDC)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
-	if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolHTTP); err != nil {
-		updateLoginMetrics(&user, ipAddr, err)
+	if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolOIDC); err != nil {
+		updateLoginMetrics(&user, dataprovider.LoginMethodIDP, ipAddr, err)
 		return fmt.Errorf("access denied by post connect hook: %w", err)
 		return fmt.Errorf("access denied by post connect hook: %w", err)
 	}
 	}
 	if err := user.CheckLoginConditions(); err != nil {
 	if err := user.CheckLoginConditions(); err != nil {
-		updateLoginMetrics(&user, ipAddr, err)
+		updateLoginMetrics(&user, dataprovider.LoginMethodIDP, ipAddr, err)
 		return err
 		return err
 	}
 	}
-	connectionID := fmt.Sprintf("%v_%v", common.ProtocolHTTP, xid.New().String())
+	connectionID := fmt.Sprintf("%v_%v", common.ProtocolOIDC, xid.New().String())
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
-		updateLoginMetrics(&user, ipAddr, err)
+		updateLoginMetrics(&user, dataprovider.LoginMethodIDP, ipAddr, err)
 		return err
 		return err
 	}
 	}
 	defer user.CloseFs() //nolint:errcheck
 	defer user.CloseFs() //nolint:errcheck
 	err = user.CheckFsRoot(connectionID)
 	err = user.CheckFsRoot(connectionID)
 	if err != nil {
 	if err != nil {
 		logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
 		logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
-		updateLoginMetrics(&user, ipAddr, common.ErrInternalFailure)
+		updateLoginMetrics(&user, dataprovider.LoginMethodIDP, ipAddr, common.ErrInternalFailure)
 		return err
 		return err
 	}
 	}
-	updateLoginMetrics(&user, ipAddr, nil)
+	updateLoginMetrics(&user, dataprovider.LoginMethodIDP, ipAddr, nil)
 	dataprovider.UpdateLastLogin(&user)
 	dataprovider.UpdateLastLogin(&user)
 	t.Permissions = user.Filters.WebClient
 	t.Permissions = user.Filters.WebClient
 	return nil
 	return nil

+ 141 - 17
httpd/oidc_test.go

@@ -2,12 +2,14 @@ package httpd
 
 
 import (
 import (
 	"context"
 	"context"
+	"encoding/json"
 	"fmt"
 	"fmt"
 	"net/http"
 	"net/http"
 	"net/http/httptest"
 	"net/http/httptest"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"reflect"
 	"reflect"
+	"runtime"
 	"testing"
 	"testing"
 	"time"
 	"time"
 	"unsafe"
 	"unsafe"
@@ -782,23 +784,6 @@ func TestOIDCToken(t *testing.T) {
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 }
 }
 
 
-func getTestOIDCServer() *httpdServer {
-	return &httpdServer{
-		binding: Binding{
-			OIDC: OIDC{
-				ClientID:        "sftpgo-client",
-				ClientSecret:    "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c",
-				ConfigURL:       fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr),
-				RedirectBaseURL: "http://127.0.0.1:8081/",
-				UsernameField:   "preferred_username",
-				RoleField:       "sftpgo_role",
-			},
-		},
-		enableWebAdmin:  true,
-		enableWebClient: true,
-	}
-}
-
 func TestOIDCManager(t *testing.T) {
 func TestOIDCManager(t *testing.T) {
 	require.Len(t, oidcMgr.pendingAuths, 0)
 	require.Len(t, oidcMgr.pendingAuths, 0)
 	authReq := newOIDCPendingAuth(tokenAudienceWebAdmin)
 	authReq := newOIDCPendingAuth(tokenAudienceWebAdmin)
@@ -881,3 +866,142 @@ func TestOIDCManager(t *testing.T) {
 	oidcMgr.removeToken(newToken.Cookie)
 	oidcMgr.removeToken(newToken.Cookie)
 	require.Len(t, oidcMgr.tokens, 0)
 	require.Len(t, oidcMgr.tokens, 0)
 }
 }
+
+func TestOIDCPreLoginHook(t *testing.T) {
+	if runtime.GOOS == osWindows {
+		t.Skip("this test is not available on Windows")
+	}
+	username := "test_oidc_user_prelogin"
+	u := dataprovider.User{
+		BaseUser: sdk.BaseUser{
+			Username: username,
+			Password: "unused",
+			HomeDir:  filepath.Join(os.TempDir(), username),
+			Status:   1,
+			Permissions: map[string][]string{
+				"/": {dataprovider.PermAny},
+			},
+		},
+	}
+	preLoginPath := filepath.Join(os.TempDir(), "prelogin.sh")
+	providerConf := dataprovider.GetProviderConfig()
+	err := dataprovider.Close()
+	assert.NoError(t, err)
+	err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm)
+	assert.NoError(t, err)
+	newProviderConf := providerConf
+	newProviderConf.PreLoginHook = preLoginPath
+	err = dataprovider.Initialize(newProviderConf, "..", true)
+	assert.NoError(t, err)
+	server := getTestOIDCServer()
+	err = server.binding.OIDC.initialize()
+	assert.NoError(t, err)
+	server.initializeRouter()
+
+	_, err = dataprovider.UserExists(username)
+	_, ok := err.(*util.RecordNotFoundError)
+	assert.True(t, ok)
+	// now login with OIDC
+	authReq := newOIDCPendingAuth(tokenAudienceWebClient)
+	oidcMgr.addPendingAuth(authReq)
+	token := &oauth2.Token{
+		AccessToken: "1234",
+		Expiry:      time.Now().Add(5 * time.Minute),
+	}
+	token = token.WithExtra(map[string]interface{}{
+		"id_token": "id_token_val",
+	})
+	server.binding.OIDC.oauth2Config = &mockOAuth2Config{
+		tokenSource: &mockTokenSource{},
+		authCodeURL: webOIDCRedirectPath,
+		token:       token,
+	}
+	idToken := &oidc.IDToken{
+		Nonce:  authReq.Nonce,
+		Expiry: time.Now().Add(5 * time.Minute),
+	}
+	setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+username+`"}`))
+	server.binding.OIDC.verifier = &mockOIDCVerifier{
+		err:   nil,
+		token: idToken,
+	}
+	rr := httptest.NewRecorder()
+	r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
+	assert.NoError(t, err)
+	server.router.ServeHTTP(rr, r)
+	assert.Equal(t, http.StatusFound, rr.Code)
+	assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
+	_, err = dataprovider.UserExists(username)
+	assert.NoError(t, err)
+
+	err = dataprovider.DeleteUser(username, "", "")
+	assert.NoError(t, err)
+
+	err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, true), os.ModePerm)
+	assert.NoError(t, err)
+
+	authReq = newOIDCPendingAuth(tokenAudienceWebClient)
+	oidcMgr.addPendingAuth(authReq)
+	idToken = &oidc.IDToken{
+		Nonce:  authReq.Nonce,
+		Expiry: time.Now().Add(5 * time.Minute),
+	}
+	setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+username+`"}`))
+	server.binding.OIDC.verifier = &mockOIDCVerifier{
+		err:   nil,
+		token: idToken,
+	}
+	rr = httptest.NewRecorder()
+	r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
+	assert.NoError(t, err)
+	server.router.ServeHTTP(rr, r)
+	assert.Equal(t, http.StatusFound, rr.Code)
+	assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
+	_, err = dataprovider.UserExists(username)
+	_, ok = err.(*util.RecordNotFoundError)
+	assert.True(t, ok)
+	if assert.Len(t, oidcMgr.tokens, 1) {
+		for k := range oidcMgr.tokens {
+			oidcMgr.removeToken(k)
+		}
+	}
+	require.Len(t, oidcMgr.pendingAuths, 0)
+	require.Len(t, oidcMgr.tokens, 0)
+
+	err = dataprovider.Close()
+	assert.NoError(t, err)
+	err = dataprovider.Initialize(providerConf, "..", true)
+	assert.NoError(t, err)
+	err = os.Remove(preLoginPath)
+	assert.NoError(t, err)
+}
+
+func getTestOIDCServer() *httpdServer {
+	return &httpdServer{
+		binding: Binding{
+			OIDC: OIDC{
+				ClientID:        "sftpgo-client",
+				ClientSecret:    "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c",
+				ConfigURL:       fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr),
+				RedirectBaseURL: "http://127.0.0.1:8081/",
+				UsernameField:   "preferred_username",
+				RoleField:       "sftpgo_role",
+			},
+		},
+		enableWebAdmin:  true,
+		enableWebClient: true,
+	}
+}
+
+func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []byte {
+	content := []byte("#!/bin/sh\n\n")
+	if nonJSONResponse {
+		content = append(content, []byte("echo 'text response'\n")...)
+		return content
+	}
+	if len(user.Username) > 0 {
+		u, _ := json.Marshal(user)
+		content = append(content, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...)
+	}
+	return content
+}

+ 32 - 26
httpd/server.go

@@ -199,33 +199,36 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
 		return
 		return
 	}
 	}
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
+	protocol := common.ProtocolHTTP
 	username := r.Form.Get("username")
 	username := r.Form.Get("username")
 	password := r.Form.Get("password")
 	password := r.Form.Get("password")
 	if username == "" || password == "" {
 	if username == "" || password == "" {
-		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, ipAddr, common.ErrNoCredentials)
+		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
+			dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials)
 		s.renderClientLoginPage(w, "Invalid credentials")
 		s.renderClientLoginPage(w, "Invalid credentials")
 		return
 		return
 	}
 	}
 	if err := verifyCSRFToken(r.Form.Get(csrfFormToken)); err != nil {
 	if err := verifyCSRFToken(r.Form.Get(csrfFormToken)); err != nil {
-		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, ipAddr, err)
+		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
+			dataprovider.LoginMethodPassword, ipAddr, err)
 		s.renderClientLoginPage(w, err.Error())
 		s.renderClientLoginPage(w, err.Error())
 		return
 		return
 	}
 	}
 
 
-	if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolHTTP); err != nil {
+	if err := common.Config.ExecutePostConnectHook(ipAddr, protocol); err != nil {
 		s.renderClientLoginPage(w, fmt.Sprintf("access denied by post connect hook: %v", err))
 		s.renderClientLoginPage(w, fmt.Sprintf("access denied by post connect hook: %v", err))
 		return
 		return
 	}
 	}
 
 
-	user, err := dataprovider.CheckUserAndPass(username, password, ipAddr, common.ProtocolHTTP)
+	user, err := dataprovider.CheckUserAndPass(username, password, ipAddr, protocol)
 	if err != nil {
 	if err != nil {
-		updateLoginMetrics(&user, ipAddr, err)
+		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
 		s.renderClientLoginPage(w, dataprovider.ErrInvalidCredentials.Error())
 		s.renderClientLoginPage(w, dataprovider.ErrInvalidCredentials.Error())
 		return
 		return
 	}
 	}
-	connectionID := fmt.Sprintf("%v_%v", common.ProtocolHTTP, xid.New().String())
+	connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
-		updateLoginMetrics(&user, ipAddr, err)
+		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
 		s.renderClientLoginPage(w, err.Error())
 		s.renderClientLoginPage(w, err.Error())
 		return
 		return
 	}
 	}
@@ -234,7 +237,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
 	err = user.CheckFsRoot(connectionID)
 	err = user.CheckFsRoot(connectionID)
 	if err != nil {
 	if err != nil {
 		logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
 		logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
-		updateLoginMetrics(&user, ipAddr, common.ErrInternalFailure)
+		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
 		s.renderClientLoginPage(w, err.Error())
 		s.renderClientLoginPage(w, err.Error())
 		return
 		return
 	}
 	}
@@ -261,7 +264,7 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
 		renderClientResetPwdPage(w, err.Error())
 		renderClientResetPwdPage(w, err.Error())
 		return
 		return
 	}
 	}
-	connectionID := fmt.Sprintf("%v_%v", common.ProtocolHTTP, xid.New().String())
+	connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String())
 	if err := checkHTTPClientUser(user, r, connectionID); err != nil {
 	if err := checkHTTPClientUser(user, r, connectionID); err != nil {
 		renderClientResetPwdPage(w, fmt.Sprintf("Password reset successfully but unable to login: %v", err.Error()))
 		renderClientResetPwdPage(w, fmt.Sprintf("Password reset successfully but unable to login: %v", err.Error()))
 		return
 		return
@@ -325,7 +328,7 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
 				renderClientInternalServerErrorPage(w, r, errors.New("unable to set the recovery code as used"))
 				renderClientInternalServerErrorPage(w, r, errors.New("unable to set the recovery code as used"))
 				return
 				return
 			}
 			}
-			connectionID := fmt.Sprintf("%v_%v", common.ProtocolHTTP, xid.New().String())
+			connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String())
 			s.loginUser(w, r, &user, connectionID, util.GetIPFromRemoteAddress(r.RemoteAddr), true,
 			s.loginUser(w, r, &user, connectionID, util.GetIPFromRemoteAddress(r.RemoteAddr), true,
 				renderClientTwoFactorRecoveryPage)
 				renderClientTwoFactorRecoveryPage)
 			return
 			return
@@ -375,7 +378,7 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt
 		renderClientTwoFactorPage(w, "Invalid authentication code")
 		renderClientTwoFactorPage(w, "Invalid authentication code")
 		return
 		return
 	}
 	}
-	connectionID := fmt.Sprintf("%v_%v", common.ProtocolHTTP, xid.New().String())
+	connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String())
 	s.loginUser(w, r, &user, connectionID, util.GetIPFromRemoteAddress(r.RemoteAddr), true, renderClientTwoFactorPage)
 	s.loginUser(w, r, &user, connectionID, util.GetIPFromRemoteAddress(r.RemoteAddr), true, renderClientTwoFactorPage)
 }
 }
 
 
@@ -646,7 +649,7 @@ func (s *httpdServer) loginUser(
 	err := c.createAndSetCookie(w, r, s.tokenAuth, audience)
 	err := c.createAndSetCookie(w, r, s.tokenAuth, audience)
 	if err != nil {
 	if err != nil {
 		logger.Warn(logSender, connectionID, "unable to set user login cookie %v", err)
 		logger.Warn(logSender, connectionID, "unable to set user login cookie %v", err)
-		updateLoginMetrics(user, ipAddr, common.ErrInternalFailure)
+		updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
 		errorFunc(w, err.Error())
 		errorFunc(w, err.Error())
 		return
 		return
 	}
 	}
@@ -657,7 +660,7 @@ func (s *httpdServer) loginUser(
 		http.Redirect(w, r, webClientTwoFactorPath, http.StatusFound)
 		http.Redirect(w, r, webClientTwoFactorPath, http.StatusFound)
 		return
 		return
 	}
 	}
-	updateLoginMetrics(user, ipAddr, err)
+	updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, err)
 	dataprovider.UpdateLastLogin(user)
 	dataprovider.UpdateLastLogin(user)
 	http.Redirect(w, r, webClientFilesPath, http.StatusFound)
 	http.Redirect(w, r, webClientFilesPath, http.StatusFound)
 }
 }
@@ -708,33 +711,36 @@ func (s *httpdServer) getUserToken(w http.ResponseWriter, r *http.Request) {
 	r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
 	r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 	username, password, ok := r.BasicAuth()
 	username, password, ok := r.BasicAuth()
+	protocol := common.ProtocolHTTP
 	if !ok {
 	if !ok {
-		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, ipAddr, common.ErrNoCredentials)
+		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
+			dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials)
 		w.Header().Set(common.HTTPAuthenticationHeader, basicRealm)
 		w.Header().Set(common.HTTPAuthenticationHeader, basicRealm)
 		sendAPIResponse(w, r, nil, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
 		sendAPIResponse(w, r, nil, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
 		return
 		return
 	}
 	}
 	if username == "" || password == "" {
 	if username == "" || password == "" {
-		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, ipAddr, common.ErrNoCredentials)
+		updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
+			dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials)
 		w.Header().Set(common.HTTPAuthenticationHeader, basicRealm)
 		w.Header().Set(common.HTTPAuthenticationHeader, basicRealm)
 		sendAPIResponse(w, r, nil, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
 		sendAPIResponse(w, r, nil, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
 		return
 		return
 	}
 	}
-	if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolHTTP); err != nil {
+	if err := common.Config.ExecutePostConnectHook(ipAddr, protocol); err != nil {
 		sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
 		sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
 		return
 		return
 	}
 	}
-	user, err := dataprovider.CheckUserAndPass(username, password, ipAddr, common.ProtocolHTTP)
+	user, err := dataprovider.CheckUserAndPass(username, password, ipAddr, protocol)
 	if err != nil {
 	if err != nil {
 		w.Header().Set(common.HTTPAuthenticationHeader, basicRealm)
 		w.Header().Set(common.HTTPAuthenticationHeader, basicRealm)
-		updateLoginMetrics(&user, ipAddr, err)
+		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
 		sendAPIResponse(w, r, dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized),
 		sendAPIResponse(w, r, dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized),
 			http.StatusUnauthorized)
 			http.StatusUnauthorized)
 		return
 		return
 	}
 	}
-	connectionID := fmt.Sprintf("%v_%v", common.ProtocolHTTP, xid.New().String())
+	connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
-		updateLoginMetrics(&user, ipAddr, err)
+		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
 		sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
 		sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
 		return
 		return
 	}
 	}
@@ -744,14 +750,14 @@ func (s *httpdServer) getUserToken(w http.ResponseWriter, r *http.Request) {
 		if passcode == "" {
 		if passcode == "" {
 			logger.Debug(logSender, "", "TOTP enabled for user %#v and not passcode provided, authentication refused", user.Username)
 			logger.Debug(logSender, "", "TOTP enabled for user %#v and not passcode provided, authentication refused", user.Username)
 			w.Header().Set(common.HTTPAuthenticationHeader, basicRealm)
 			w.Header().Set(common.HTTPAuthenticationHeader, basicRealm)
-			updateLoginMetrics(&user, ipAddr, dataprovider.ErrInvalidCredentials)
+			updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, dataprovider.ErrInvalidCredentials)
 			sendAPIResponse(w, r, dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized),
 			sendAPIResponse(w, r, dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized),
 				http.StatusUnauthorized)
 				http.StatusUnauthorized)
 			return
 			return
 		}
 		}
 		err = user.Filters.TOTPConfig.Secret.Decrypt()
 		err = user.Filters.TOTPConfig.Secret.Decrypt()
 		if err != nil {
 		if err != nil {
-			updateLoginMetrics(&user, ipAddr, common.ErrInternalFailure)
+			updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
 			sendAPIResponse(w, r, fmt.Errorf("unable to decrypt TOTP secret: %w", err), http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
 			sendAPIResponse(w, r, fmt.Errorf("unable to decrypt TOTP secret: %w", err), http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
 			return
 			return
 		}
 		}
@@ -760,7 +766,7 @@ func (s *httpdServer) getUserToken(w http.ResponseWriter, r *http.Request) {
 		if !match || err != nil {
 		if !match || err != nil {
 			logger.Debug(logSender, "invalid passcode for user %#v, match? %v, err: %v", user.Username, match, err)
 			logger.Debug(logSender, "invalid passcode for user %#v, match? %v, err: %v", user.Username, match, err)
 			w.Header().Set(common.HTTPAuthenticationHeader, basicRealm)
 			w.Header().Set(common.HTTPAuthenticationHeader, basicRealm)
-			updateLoginMetrics(&user, ipAddr, dataprovider.ErrInvalidCredentials)
+			updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, dataprovider.ErrInvalidCredentials)
 			sendAPIResponse(w, r, dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized),
 			sendAPIResponse(w, r, dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized),
 				http.StatusUnauthorized)
 				http.StatusUnauthorized)
 			return
 			return
@@ -771,7 +777,7 @@ func (s *httpdServer) getUserToken(w http.ResponseWriter, r *http.Request) {
 	err = user.CheckFsRoot(connectionID)
 	err = user.CheckFsRoot(connectionID)
 	if err != nil {
 	if err != nil {
 		logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
 		logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
-		updateLoginMetrics(&user, ipAddr, common.ErrInternalFailure)
+		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
 		sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
 		sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
 		return
 		return
 	}
 	}
@@ -789,11 +795,11 @@ func (s *httpdServer) generateAndSendUserToken(w http.ResponseWriter, r *http.Re
 	resp, err := c.createTokenResponse(s.tokenAuth, tokenAudienceAPIUser)
 	resp, err := c.createTokenResponse(s.tokenAuth, tokenAudienceAPIUser)
 
 
 	if err != nil {
 	if err != nil {
-		updateLoginMetrics(&user, ipAddr, common.ErrInternalFailure)
+		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
 		sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
 		sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
 		return
 		return
 	}
 	}
-	updateLoginMetrics(&user, ipAddr, err)
+	updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
 	dataprovider.UpdateLastLogin(&user)
 	dataprovider.UpdateLastLogin(&user)
 
 
 	render.JSON(w, r, resp)
 	render.JSON(w, r, resp)

+ 12 - 8
httpd/webclient.go

@@ -584,13 +584,14 @@ func handleWebClientDownloadZip(w http.ResponseWriter, r *http.Request) {
 	}
 	}
 
 
 	connID := xid.New().String()
 	connID := xid.New().String()
-	connectionID := fmt.Sprintf("%v_%v", common.ProtocolHTTP, connID)
+	protocol := getProtocolFromRequest(r)
+	connectionID := fmt.Sprintf("%v_%v", protocol, connID)
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
 		renderClientForbiddenPage(w, r, err.Error())
 		renderClientForbiddenPage(w, r, err.Error())
 		return
 		return
 	}
 	}
 	connection := &Connection{
 	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(connID, common.ProtocolHTTP, util.GetHTTPLocalAddress(r),
+		BaseConnection: common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r),
 			r.RemoteAddr, user),
 			r.RemoteAddr, user),
 		request: r,
 		request: r,
 	}
 	}
@@ -727,13 +728,14 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http.
 	}
 	}
 
 
 	connID := xid.New().String()
 	connID := xid.New().String()
-	connectionID := fmt.Sprintf("%v_%v", common.ProtocolHTTP, connID)
+	protocol := getProtocolFromRequest(r)
+	connectionID := fmt.Sprintf("%v_%v", protocol, connID)
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
 		sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
 		sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
 		return
 		return
 	}
 	}
 	connection := &Connection{
 	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(connID, common.ProtocolHTTP, util.GetHTTPLocalAddress(r),
+		BaseConnection: common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r),
 			r.RemoteAddr, user),
 			r.RemoteAddr, user),
 		request: r,
 		request: r,
 	}
 	}
@@ -804,13 +806,14 @@ func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Reques
 	}
 	}
 
 
 	connID := xid.New().String()
 	connID := xid.New().String()
-	connectionID := fmt.Sprintf("%v_%v", common.ProtocolHTTP, connID)
+	protocol := getProtocolFromRequest(r)
+	connectionID := fmt.Sprintf("%v_%v", protocol, connID)
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
 		renderClientForbiddenPage(w, r, err.Error())
 		renderClientForbiddenPage(w, r, err.Error())
 		return
 		return
 	}
 	}
 	connection := &Connection{
 	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(connID, common.ProtocolHTTP, util.GetHTTPLocalAddress(r),
+		BaseConnection: common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r),
 			r.RemoteAddr, user),
 			r.RemoteAddr, user),
 		request: r,
 		request: r,
 	}
 	}
@@ -863,13 +866,14 @@ func handleClientEditFile(w http.ResponseWriter, r *http.Request) {
 	}
 	}
 
 
 	connID := xid.New().String()
 	connID := xid.New().String()
-	connectionID := fmt.Sprintf("%v_%v", common.ProtocolHTTP, connID)
+	protocol := getProtocolFromRequest(r)
+	connectionID := fmt.Sprintf("%v_%v", protocol, connID)
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
 	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
 		renderClientForbiddenPage(w, r, err.Error())
 		renderClientForbiddenPage(w, r, err.Error())
 		return
 		return
 	}
 	}
 	connection := &Connection{
 	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(connID, common.ProtocolHTTP, util.GetHTTPLocalAddress(r),
+		BaseConnection: common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r),
 			r.RemoteAddr, user),
 			r.RemoteAddr, user),
 		request: r,
 		request: r,
 	}
 	}

+ 33 - 1
metric/metric.go

@@ -20,6 +20,7 @@ const (
 	loginMethodKeyAndKeyboardInt    = "publickey+keyboard-interactive"
 	loginMethodKeyAndKeyboardInt    = "publickey+keyboard-interactive"
 	loginMethodTLSCertificate       = "TLSCertificate"
 	loginMethodTLSCertificate       = "TLSCertificate"
 	loginMethodTLSCertificateAndPwd = "TLSCertificate+password"
 	loginMethodTLSCertificateAndPwd = "TLSCertificate+password"
+	loginMethodIDP                  = "IDP"
 )
 )
 
 
 func init() {
 func init() {
@@ -259,6 +260,27 @@ var (
 		Help: "The total number of failed logins using  public key + keyboard interactive",
 		Help: "The total number of failed logins using  public key + keyboard interactive",
 	})
 	})
 
 
+	// totalIDPLoginAttempts is the metric that reports the total number of
+	// login attempts using identity providers
+	totalIDPLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{
+		Name: "sftpgo_idp_login_attempts_total",
+		Help: "The total number of login attempts using Identity Providers",
+	})
+
+	// totalIDPLoginOK is the metric that reports the total number of
+	// successful logins using identity providers
+	totalIDPLoginOK = promauto.NewCounter(prometheus.CounterOpts{
+		Name: "sftpgo_idp_login_ok_total",
+		Help: "The total number of successful logins using Identity Providers",
+	})
+
+	// totalIDPLoginFailed is the metric that reports the total number of
+	// failed logins using identity providers
+	totalIDPLoginFailed = promauto.NewCounter(prometheus.CounterOpts{
+		Name: "sftpgo_idp_login_ko_total",
+		Help: "The total number of failed logins using  Identity Providers",
+	})
+
 	totalHTTPRequests = promauto.NewCounter(prometheus.CounterOpts{
 	totalHTTPRequests = promauto.NewCounter(prometheus.CounterOpts{
 		Name: "sftpgo_http_req_total",
 		Name: "sftpgo_http_req_total",
 		Help: "The total number of HTTP requests served",
 		Help: "The total number of HTTP requests served",
@@ -582,7 +604,6 @@ func TransferCompleted(bytesSent, bytesReceived int64, transferKind int, err err
 		} else {
 		} else {
 			totalUploadErrors.Inc()
 			totalUploadErrors.Inc()
 		}
 		}
-		totalUploadSize.Add(float64(bytesReceived))
 	} else {
 	} else {
 		// download
 		// download
 		if err == nil {
 		if err == nil {
@@ -590,6 +611,11 @@ func TransferCompleted(bytesSent, bytesReceived int64, transferKind int, err err
 		} else {
 		} else {
 			totalDownloadErrors.Inc()
 			totalDownloadErrors.Inc()
 		}
 		}
+	}
+	if bytesReceived > 0 {
+		totalUploadSize.Add(float64(bytesReceived))
+	}
+	if bytesSent > 0 {
 		totalDownloadSize.Add(float64(bytesSent))
 		totalDownloadSize.Add(float64(bytesSent))
 	}
 	}
 }
 }
@@ -826,6 +852,8 @@ func AddLoginAttempt(authMethod string) {
 		totalTLSCertLoginAttempts.Inc()
 		totalTLSCertLoginAttempts.Inc()
 	case loginMethodTLSCertificateAndPwd:
 	case loginMethodTLSCertificateAndPwd:
 		totalTLSCertAndPwdLoginAttempts.Inc()
 		totalTLSCertAndPwdLoginAttempts.Inc()
+	case loginMethodIDP:
+		totalIDPLoginAttempts.Inc()
 	default:
 	default:
 		totalPasswordLoginAttempts.Inc()
 		totalPasswordLoginAttempts.Inc()
 	}
 	}
@@ -846,6 +874,8 @@ func incLoginOK(authMethod string) {
 		totalTLSCertLoginOK.Inc()
 		totalTLSCertLoginOK.Inc()
 	case loginMethodTLSCertificateAndPwd:
 	case loginMethodTLSCertificateAndPwd:
 		totalTLSCertAndPwdLoginOK.Inc()
 		totalTLSCertAndPwdLoginOK.Inc()
+	case loginMethodIDP:
+		totalIDPLoginOK.Inc()
 	default:
 	default:
 		totalPasswordLoginOK.Inc()
 		totalPasswordLoginOK.Inc()
 	}
 	}
@@ -866,6 +896,8 @@ func incLoginFailed(authMethod string) {
 		totalTLSCertLoginFailed.Inc()
 		totalTLSCertLoginFailed.Inc()
 	case loginMethodTLSCertificateAndPwd:
 	case loginMethodTLSCertificateAndPwd:
 		totalTLSCertAndPwdLoginFailed.Inc()
 		totalTLSCertAndPwdLoginFailed.Inc()
+	case loginMethodIDP:
+		totalIDPLoginFailed.Inc()
 	default:
 	default:
 		totalPasswordLoginFailed.Inc()
 		totalPasswordLoginFailed.Inc()
 	}
 	}

+ 2 - 0
openapi/openapi.yaml

@@ -4356,6 +4356,7 @@ components:
         - DAV
         - DAV
         - HTTP
         - HTTP
         - DataRetention
         - DataRetention
+        - OIDC
       description: |
       description: |
         Protocols:
         Protocols:
           * `SSH` - SSH commands
           * `SSH` - SSH commands
@@ -4364,6 +4365,7 @@ components:
           * `DAV` - WebDAV
           * `DAV` - WebDAV
           * `HTTP` - WebClient/REST API
           * `HTTP` - WebClient/REST API
           * `DataRetention` - the event is generated by a data retention check
           * `DataRetention` - the event is generated by a data retention check
+          * `OIDC` - OpenID Connect
     WebClientOptions:
     WebClientOptions:
       type: string
       type: string
       enum:
       enum: