Pārlūkot izejas kodu

Show info about public key during login

This will show the key fingerprint and the associated comment, or
"password" when password was used, during login.

Eg.:

```
message":"User id: 1, logged in with: \"public_key:SHA256:FV3+wlAKGzYy7+J02786fh8N8c06+jga/mdiSOSPT7g:jo@desktop\",
```

or

```
message":"User id: 1, logged in with: \"password\",
...`

Signed-off-by: Jo Vandeginste <Jo.Vandeginste@kuleuven.be>
Jo Vandeginste 6 gadi atpakaļ
vecāks
revīzija
25260297aa

+ 3 - 3
dataprovider/bolt.go

@@ -71,15 +71,15 @@ func (p BoltProvider) validateUserAndPass(username string, password string) (Use
 	return checkUserAndPass(user, password)
 	return checkUserAndPass(user, password)
 }
 }
 
 
-func (p BoltProvider) validateUserAndPubKey(username string, pubKey string) (User, error) {
+func (p BoltProvider) validateUserAndPubKey(username string, pubKey string) (User, string, error) {
 	var user User
 	var user User
 	if len(pubKey) == 0 {
 	if len(pubKey) == 0 {
-		return user, errors.New("Credentials cannot be null or empty")
+		return user, "", errors.New("Credentials cannot be null or empty")
 	}
 	}
 	user, err := p.userExists(username)
 	user, err := p.userExists(username)
 	if err != nil {
 	if err != nil {
 		logger.Warn(logSender, "", "error authenticating user: %v, error: %v", username, err)
 		logger.Warn(logSender, "", "error authenticating user: %v, error: %v", username, err)
-		return user, err
+		return user, "", err
 	}
 	}
 	return checkUserAndPubKey(user, pubKey)
 	return checkUserAndPubKey(user, pubKey)
 }
 }

+ 9 - 8
dataprovider/dataprovider.go

@@ -132,7 +132,7 @@ func GetProvider() Provider {
 // Provider interface that data providers must implement.
 // Provider interface that data providers must implement.
 type Provider interface {
 type Provider interface {
 	validateUserAndPass(username string, password string) (User, error)
 	validateUserAndPass(username string, password string) (User, error)
-	validateUserAndPubKey(username string, pubKey string) (User, error)
+	validateUserAndPubKey(username string, pubKey string) (User, string, error)
 	updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error
 	updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error
 	getUsedQuota(username string) (int, int64, error)
 	getUsedQuota(username string) (int, int64, error)
 	userExists(username string) (User, error)
 	userExists(username string) (User, error)
@@ -166,7 +166,7 @@ func CheckUserAndPass(p Provider, username string, password string) (User, error
 }
 }
 
 
 // CheckUserAndPubKey retrieves the SFTP user with the given username and public key if a match is found or an error
 // CheckUserAndPubKey retrieves the SFTP user with the given username and public key if a match is found or an error
-func CheckUserAndPubKey(p Provider, username string, pubKey string) (User, error) {
+func CheckUserAndPubKey(p Provider, username string, pubKey string) (User, string, error) {
 	return p.validateUserAndPubKey(username, pubKey)
 	return p.validateUserAndPubKey(username, pubKey)
 }
 }
 
 
@@ -297,21 +297,22 @@ func checkUserAndPass(user User, password string) (User, error) {
 	return user, err
 	return user, err
 }
 }
 
 
-func checkUserAndPubKey(user User, pubKey string) (User, error) {
+func checkUserAndPubKey(user User, pubKey string) (User, string, error) {
 	if len(user.PublicKeys) == 0 {
 	if len(user.PublicKeys) == 0 {
-		return user, errors.New("Invalid credentials")
+		return user, "", errors.New("Invalid credentials")
 	}
 	}
 	for i, k := range user.PublicKeys {
 	for i, k := range user.PublicKeys {
-		storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k))
+		storedPubKey, comment, _, _, err := ssh.ParseAuthorizedKey([]byte(k))
 		if err != nil {
 		if err != nil {
 			logger.Warn(logSender, "", "error parsing stored public key %d for user %v: %v", i, user.Username, err)
 			logger.Warn(logSender, "", "error parsing stored public key %d for user %v: %v", i, user.Username, err)
-			return user, err
+			return user, "", err
 		}
 		}
 		if string(storedPubKey.Marshal()) == pubKey {
 		if string(storedPubKey.Marshal()) == pubKey {
-			return user, nil
+			fp := ssh.FingerprintSHA256(storedPubKey)
+			return user, fp + ":" + comment, nil
 		}
 		}
 	}
 	}
-	return user, errors.New("Invalid credentials")
+	return user, "", errors.New("Invalid credentials")
 }
 }
 
 
 func comparePbkdf2PasswordAndHash(password, hashedPassword string) (bool, error) {
 func comparePbkdf2PasswordAndHash(password, hashedPassword string) (bool, error) {

+ 1 - 1
dataprovider/mysql.go

@@ -41,7 +41,7 @@ func (p MySQLProvider) validateUserAndPass(username string, password string) (Us
 	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 }
 }
 
 
-func (p MySQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) {
+func (p MySQLProvider) validateUserAndPubKey(username string, publicKey string) (User, string, error) {
 	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 }
 }
 
 

+ 1 - 1
dataprovider/pgsql.go

@@ -39,7 +39,7 @@ func (p PGSQLProvider) validateUserAndPass(username string, password string) (Us
 	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 }
 }
 
 
-func (p PGSQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) {
+func (p PGSQLProvider) validateUserAndPubKey(username string, publicKey string) (User, string, error) {
 	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 }
 }
 
 

+ 3 - 3
dataprovider/sqlcommon.go

@@ -37,15 +37,15 @@ func sqlCommonValidateUserAndPass(username string, password string, dbHandle *sq
 	return checkUserAndPass(user, password)
 	return checkUserAndPass(user, password)
 }
 }
 
 
-func sqlCommonValidateUserAndPubKey(username string, pubKey string, dbHandle *sql.DB) (User, error) {
+func sqlCommonValidateUserAndPubKey(username string, pubKey string, dbHandle *sql.DB) (User, string, error) {
 	var user User
 	var user User
 	if len(pubKey) == 0 {
 	if len(pubKey) == 0 {
-		return user, errors.New("Credentials cannot be null or empty")
+		return user, "", errors.New("Credentials cannot be null or empty")
 	}
 	}
 	user, err := getUserByUsername(username, dbHandle)
 	user, err := getUserByUsername(username, dbHandle)
 	if err != nil {
 	if err != nil {
 		logger.Warn(logSender, "", "error authenticating user: %v, error: %v", username, err)
 		logger.Warn(logSender, "", "error authenticating user: %v, error: %v", username, err)
-		return user, err
+		return user, "", err
 	}
 	}
 	return checkUserAndPubKey(user, pubKey)
 	return checkUserAndPubKey(user, pubKey)
 }
 }

+ 1 - 1
dataprovider/sqlite.go

@@ -52,7 +52,7 @@ func (p SQLiteProvider) validateUserAndPass(username string, password string) (U
 	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 }
 }
 
 
-func (p SQLiteProvider) validateUserAndPubKey(username string, publicKey string) (User, error) {
+func (p SQLiteProvider) validateUserAndPubKey(username string, publicKey string) (User, string, error) {
 	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 }
 }
 
 

+ 1 - 1
sftpd/internal_test.go

@@ -144,7 +144,7 @@ func TestUploadFiles(t *testing.T) {
 func TestWithInvalidHome(t *testing.T) {
 func TestWithInvalidHome(t *testing.T) {
 	u := dataprovider.User{}
 	u := dataprovider.User{}
 	u.HomeDir = "home_rel_path"
 	u.HomeDir = "home_rel_path"
-	_, err := loginUser(u)
+	_, err := loginUser(u, "password")
 	if err == nil {
 	if err == nil {
 		t.Errorf("login a user with an invalid home_dir must fail")
 		t.Errorf("login a user with an invalid home_dir must fail")
 	}
 	}

+ 10 - 6
sftpd/server.go

@@ -210,6 +210,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
 	logger.Debug(logSender, "", "accepted inbound connection, ip: %v", conn.RemoteAddr().String())
 	logger.Debug(logSender, "", "accepted inbound connection, ip: %v", conn.RemoteAddr().String())
 
 
 	var user dataprovider.User
 	var user dataprovider.User
+	var loginType string
 
 
 	err = json.Unmarshal([]byte(sconn.Permissions.Extensions["user"]), &user)
 	err = json.Unmarshal([]byte(sconn.Permissions.Extensions["user"]), &user)
 
 
@@ -217,6 +218,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
 		logger.Warn(logSender, "", "Unable to deserialize user info, cannot serve connection: %v", err)
 		logger.Warn(logSender, "", "Unable to deserialize user info, cannot serve connection: %v", err)
 		return
 		return
 	}
 	}
+	loginType = sconn.Permissions.Extensions["login_type"]
 
 
 	connectionID := hex.EncodeToString(sconn.SessionID())
 	connectionID := hex.EncodeToString(sconn.SessionID())
 
 
@@ -230,8 +232,8 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
 		lock:          new(sync.Mutex),
 		lock:          new(sync.Mutex),
 		sshConn:       sconn,
 		sshConn:       sconn,
 	}
 	}
-	logger.Info(logSender, connectionID, "User id: %d, name: %#v, home_dir: %#v",
-		user.ID, user.Username, user.HomeDir)
+	logger.Info(logSender, connectionID, "User id: %d, logged in with: %#v, name: %#v, home_dir: %#v",
+		user.ID, loginType, user.Username, user.HomeDir)
 
 
 	go ssh.DiscardRequests(reqs)
 	go ssh.DiscardRequests(reqs)
 
 
@@ -317,7 +319,7 @@ func (c Configuration) createHandler(connection Connection) sftp.Handlers {
 	}
 	}
 }
 }
 
 
-func loginUser(user dataprovider.User) (*ssh.Permissions, error) {
+func loginUser(user dataprovider.User, loginType string) (*ssh.Permissions, error) {
 	if !filepath.IsAbs(user.HomeDir) {
 	if !filepath.IsAbs(user.HomeDir) {
 		logger.Warn(logSender, "", "user %v has invalid home dir: %v. Home dir must be an absolute path, login not allowed",
 		logger.Warn(logSender, "", "user %v has invalid home dir: %v. Home dir must be an absolute path, login not allowed",
 			user.Username, user.HomeDir)
 			user.Username, user.HomeDir)
@@ -348,6 +350,7 @@ func loginUser(user dataprovider.User) (*ssh.Permissions, error) {
 	p := &ssh.Permissions{}
 	p := &ssh.Permissions{}
 	p.Extensions = make(map[string]string)
 	p.Extensions = make(map[string]string)
 	p.Extensions["user"] = string(json)
 	p.Extensions["user"] = string(json)
+	p.Extensions["login_type"] = loginType
 	return p, nil
 	return p, nil
 }
 }
 
 
@@ -370,9 +373,10 @@ func (c *Configuration) checkHostKeys(configDir string) error {
 func (c Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubKey string) (*ssh.Permissions, error) {
 func (c Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubKey string) (*ssh.Permissions, error) {
 	var err error
 	var err error
 	var user dataprovider.User
 	var user dataprovider.User
+	var keyID string
 
 
-	if user, err = dataprovider.CheckUserAndPubKey(dataProvider, conn.User(), pubKey); err == nil {
-		return loginUser(user)
+	if user, keyID, err = dataprovider.CheckUserAndPubKey(dataProvider, conn.User(), pubKey); err == nil {
+		return loginUser(user, "public_key:"+keyID)
 	}
 	}
 	return nil, err
 	return nil, err
 }
 }
@@ -382,7 +386,7 @@ func (c Configuration) validatePasswordCredentials(conn ssh.ConnMetadata, pass [
 	var user dataprovider.User
 	var user dataprovider.User
 
 
 	if user, err = dataprovider.CheckUserAndPass(dataProvider, conn.User(), string(pass)); err == nil {
 	if user, err = dataprovider.CheckUserAndPass(dataProvider, conn.User(), string(pass)); err == nil {
-		return loginUser(user)
+		return loginUser(user, "password")
 	}
 	}
 	return nil, err
 	return nil, err
 }
 }