package repo import ( "database/sql" "github.com/ente-io/museum/ente" "github.com/ente-io/museum/pkg/utils/network" "github.com/ente-io/museum/pkg/utils/time" "github.com/ente-io/stacktrace" ) // UserAuthRepository defines the methods for inserting, updating and retrieving // one time tokens (currently) used for email verification. type UserAuthRepository struct { DB *sql.DB } // AddOTT saves the provided one time token for the specified user func (repo *UserAuthRepository) AddOTT(emailHash string, app ente.App, ott string, expirationTime int64) error { _, err := repo.DB.Exec(`INSERT INTO otts(email_hash, ott, creation_time, expiration_time, app) VALUES($1, $2, $3, $4, $5) ON CONFLICT ON CONSTRAINT unique_otts_emailhash_app_ott DO UPDATE SET creation_time = $3, expiration_time = $4`, emailHash, ott, time.Microseconds(), expirationTime, app) return stacktrace.Propagate(err, "") } // RemoveOTT removes the specified OTT (to be used when an OTT has been consumed) func (repo *UserAuthRepository) RemoveOTT(emailHash string, ott string, app ente.App) error { _, err := repo.DB.Exec(`DELETE FROM otts WHERE email_hash = $1 AND ott = $2 AND app = $3`, emailHash, ott, app) return stacktrace.Propagate(err, "") } // RemoveExpiredOTTs removes all OTTs that have expired func (repo *UserAuthRepository) RemoveExpiredOTTs() error { _, err := repo.DB.Exec(`DELETE FROM otts WHERE expiration_time <= $1`, time.Microseconds()) return stacktrace.Propagate(err, "") } // GetTokenCreationTime return the creation_time for the given token func (repo *UserAuthRepository) GetTokenCreationTime(token string) (int64, error) { row := repo.DB.QueryRow(`SELECT creation_time from tokens where token = $1`, token) var result int64 if err := row.Scan(&result); err != nil { return 0, stacktrace.Propagate(err, "Failed to scan row") } return result, nil } // GetValidOTTs returns the list of OTTs that haven't expired for a given user func (repo *UserAuthRepository) GetValidOTTs(emailHash string, app ente.App) ([]string, error) { rows, err := repo.DB.Query(`SELECT ott FROM otts WHERE email_hash = $1 AND app = $2 AND expiration_time > $3`, emailHash, app, time.Microseconds()) if err != nil { return nil, stacktrace.Propagate(err, "") } defer rows.Close() otts := make([]string, 0) for rows.Next() { var ott string err := rows.Scan(&ott) if err != nil { return otts, stacktrace.Propagate(err, "") } otts = append(otts, ott) } return otts, nil } func (repo *UserAuthRepository) GetMaxWrongAttempts(emailHash string, app ente.App) (int, error) { row := repo.DB.QueryRow(`SELECT COALESCE(MAX(wrong_attempt),0) FROM otts WHERE email_hash = $1 AND expiration_time > $2 AND app = $3`, emailHash, time.Microseconds(), app) var wrongAttempt int if err := row.Scan(&wrongAttempt); err != nil { return 0, stacktrace.Propagate(err, "Failed to scan row") } return wrongAttempt, nil } // RecordWrongAttemptForActiveOtt increases the wrong_attempt count for given emailHash and active ott. // Assuming tha we keep deleting expired OTT, max(wrong_attempt) can be used to track brute-force attack func (repo *UserAuthRepository) RecordWrongAttemptForActiveOtt(emailHash string, app ente.App) error { _, err := repo.DB.Exec(`UPDATE otts SET wrong_attempt = otts.wrong_attempt + 1 WHERE email_hash = $1 AND expiration_time > $2 AND app=$3`, emailHash, time.Microseconds(), app) if err != nil { return stacktrace.Propagate(err, "Failed to update wrong attempt count") } return nil } // AddToken saves the provided long lived token for the specified user func (repo *UserAuthRepository) AddToken(userID int64, app ente.App, token string, ip string, userAgent string) error { _, err := repo.DB.Exec(`INSERT INTO tokens(user_id, app, token, creation_time, ip, user_agent) VALUES($1, $2, $3, $4, $5, $6)`, userID, app, token, time.Microseconds(), ip, userAgent) return stacktrace.Propagate(err, "") } // GetUserIDWithToken returns the userID associated with a given token func (repo *UserAuthRepository) GetUserIDWithToken(token string, app ente.App) (int64, error) { row := repo.DB.QueryRow(`SELECT user_id FROM tokens WHERE token = $1 AND app = $2 AND is_deleted = false`, token, app) var id int64 err := row.Scan(&id) if err != nil { return -1, stacktrace.Propagate(err, "") } return id, nil } // RemoveToken marks the specified token (to be used when a user logs out) as deleted func (repo *UserAuthRepository) RemoveToken(userID int64, token string) error { _, err := repo.DB.Exec(`UPDATE tokens SET is_deleted = true WHERE user_id = $1 AND token = $2`, userID, token) return stacktrace.Propagate(err, "") } // UpdateLastUsedAt updates the last used at timestamp for the particular token func (repo *UserAuthRepository) UpdateLastUsedAt(userID int64, token string, ip string, userAgent string) error { _, err := repo.DB.Exec(`UPDATE tokens SET ip = $1, user_agent = $2, last_used_at = $3 WHERE user_id = $4 AND token = $5`, ip, userAgent, time.Microseconds(), userID, token) return stacktrace.Propagate(err, "") } // RemoveAllOtherTokens marks the all tokens apart from the specified one for a user as deleted func (repo *UserAuthRepository) RemoveAllOtherTokens(userID int64, token string) error { _, err := repo.DB.Exec(`UPDATE tokens SET is_deleted = true WHERE user_id = $1 AND token <> $2`, userID, token) return stacktrace.Propagate(err, "") } func (repo *UserAuthRepository) RemoveDeletedTokens(expiryTime int64) error { _, err := repo.DB.Exec(`DELETE FROM tokens WHERE is_deleted = true AND last_used_at < $1`, expiryTime) return stacktrace.Propagate(err, "") } // RemoveAllTokens marks the all tokens for a user as deleted func (repo *UserAuthRepository) RemoveAllTokens(userID int64) error { _, err := repo.DB.Exec(`UPDATE tokens SET is_deleted = true WHERE user_id = $1`, userID) return stacktrace.Propagate(err, "") } // GetActiveSessions returns the list of tokens that are valid for a given user func (repo *UserAuthRepository) GetActiveSessions(userID int64, app ente.App) ([]ente.Session, error) { rows, err := repo.DB.Query(`SELECT token, creation_time, ip, user_agent, last_used_at FROM tokens WHERE user_id = $1 AND app = $2 AND is_deleted = false`, userID, app) if err != nil { return nil, stacktrace.Propagate(err, "") } defer rows.Close() sessions := make([]ente.Session, 0) for rows.Next() { var ip sql.NullString var userAgent sql.NullString var session ente.Session err := rows.Scan(&session.Token, &session.CreationTime, &ip, &userAgent, &session.LastUsedTime) if err != nil { return nil, stacktrace.Propagate(err, "") } if ip.Valid { session.IP = ip.String } else { session.IP = "Unknown IP" } if userAgent.Valid { session.UA = userAgent.String session.PrettyUA = network.GetPrettyUA(userAgent.String) } else { session.UA = "Unknown Device" session.PrettyUA = "Unknown Device" } sessions = append(sessions, session) } return sessions, nil }