[Server] Fix unique ott constraint for multiple apps (#1386)

## Description

## Tests
  Wrong attempt tracking is working fine
 Same ott can be issued for different app types
 For same app type, unique ott is issued
This commit is contained in:
Neeraj Gupta 2024-04-09 11:25:29 +05:30 committed by GitHub
commit b8968d2904
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 33 additions and 14 deletions

View file

@ -0,0 +1,9 @@
BEGIN;
ALTER TABLE
otts DROP CONSTRAINT IF EXISTS unique_otts_emailhash_app_ott;
ALTER TABLE
otts
ADD
CONSTRAINT unique_otts_emailhash_ott UNIQUE (ott, email_hash);
COMMIT;

View file

@ -0,0 +1,9 @@
BEGIN;
ALTER TABLE
otts DROP CONSTRAINT IF EXISTS unique_otts_emailhash_ott;
ALTER TABLE
otts
ADD
CONSTRAINT unique_otts_emailhash_app_ott UNIQUE (ott,app, email_hash);
COMMIT;

View file

@ -136,23 +136,24 @@ func (c *UserController) SendEmailOTT(context *gin.Context, email string, client
// verifyEmailOtt should be deprecated in favor of verifyEmailOttWithSession once clients are updated.
func (c *UserController) verifyEmailOtt(context *gin.Context, email string, ott string) error {
ott = strings.TrimSpace(ott)
app := auth.GetApp(context)
emailHash, err := crypto.GetHash(email, c.HashingKey)
if err != nil {
return stacktrace.Propagate(err, "")
}
wrongAttempt, err := c.UserAuthRepo.GetMaxWrongAttempts(emailHash)
wrongAttempt, err := c.UserAuthRepo.GetMaxWrongAttempts(emailHash, app)
if err != nil {
return stacktrace.Propagate(err, "")
}
if wrongAttempt >= OTTWrongAttemptLimit {
msg := "Too many wrong attempts for ott verification"
msg := fmt.Sprintf("Too many wrong ott verification attemp for app %s", app)
go c.DiscordController.NotifyPotentialAbuse(msg)
return stacktrace.Propagate(ente.ErrTooManyBadRequest, "User needs to wait before active ott are expired")
}
otts, err := c.UserAuthRepo.GetValidOTTs(emailHash, auth.GetApp(context))
log.Info("Valid otts for " + emailHash + " are " + strings.Join(otts, ","))
otts, err := c.UserAuthRepo.GetValidOTTs(emailHash, app)
log.Infof("Valid ott (app: %s) for %s are %s", app, emailHash, strings.Join(otts, ","))
if err != nil {
return stacktrace.Propagate(err, "")
}
@ -166,12 +167,12 @@ func (c *UserController) verifyEmailOtt(context *gin.Context, email string, ott
}
}
if !isValidOTT {
if err = c.UserAuthRepo.RecordWrongAttemptForActiveOtt(emailHash); err != nil {
if err = c.UserAuthRepo.RecordWrongAttemptForActiveOtt(emailHash, app); err != nil {
log.WithError(err).Warn("Failed to track wrong attempt")
}
return stacktrace.Propagate(ente.ErrIncorrectOTT, "")
}
err = c.UserAuthRepo.RemoveOTT(emailHash, ott)
err = c.UserAuthRepo.RemoveOTT(emailHash, ott, app)
if err != nil {
return stacktrace.Propagate(err, "")
}

View file

@ -20,14 +20,14 @@ type UserAuthRepository struct {
func (repo *UserAuthRepository) AddOTT(emailHash string, app ente.App, ott string, expirationTime int64) error {
_, err := repo.DB.Exec(`INSERT INTO otts(email_hash, ott, creation_time, expiration_time, app)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT ON CONSTRAINT unique_otts_emailhash_ott DO UPDATE SET creation_time = $3, expiration_time = $4`,
ON CONFLICT ON CONSTRAINT unique_otts_emailhash_app_ott DO UPDATE SET creation_time = $3, expiration_time = $4`,
emailHash, ott, time.Microseconds(), expirationTime, app)
return stacktrace.Propagate(err, "")
}
// RemoveOTT removes the specified OTT (to be used when an OTT has been consumed)
func (repo *UserAuthRepository) RemoveOTT(emailHash string, ott string) error {
_, err := repo.DB.Exec(`DELETE FROM otts WHERE email_hash = $1 AND ott = $2`, emailHash, ott)
func (repo *UserAuthRepository) RemoveOTT(emailHash string, ott string, app ente.App) error {
_, err := repo.DB.Exec(`DELETE FROM otts WHERE email_hash = $1 AND ott = $2 AND app = $3`, emailHash, ott, app)
return stacktrace.Propagate(err, "")
}
@ -69,9 +69,9 @@ func (repo *UserAuthRepository) GetValidOTTs(emailHash string, app ente.App) ([]
return otts, nil
}
func (repo *UserAuthRepository) GetMaxWrongAttempts(emailHash string) (int, error) {
row := repo.DB.QueryRow(`SELECT COALESCE(MAX(wrong_attempt),0) FROM otts WHERE email_hash = $1 AND expiration_time > $2`,
emailHash, time.Microseconds())
func (repo *UserAuthRepository) GetMaxWrongAttempts(emailHash string, app ente.App) (int, error) {
row := repo.DB.QueryRow(`SELECT COALESCE(MAX(wrong_attempt),0) FROM otts WHERE email_hash = $1 AND expiration_time > $2 AND app = $3`,
emailHash, time.Microseconds(), app)
var wrongAttempt int
if err := row.Scan(&wrongAttempt); err != nil {
return 0, stacktrace.Propagate(err, "Failed to scan row")
@ -81,9 +81,9 @@ func (repo *UserAuthRepository) GetMaxWrongAttempts(emailHash string) (int, erro
// RecordWrongAttemptForActiveOtt increases the wrong_attempt count for given emailHash and active ott.
// Assuming tha we keep deleting expired OTT, max(wrong_attempt) can be used to track brute-force attack
func (repo *UserAuthRepository) RecordWrongAttemptForActiveOtt(emailHash string) error {
func (repo *UserAuthRepository) RecordWrongAttemptForActiveOtt(emailHash string, app ente.App) error {
_, err := repo.DB.Exec(`UPDATE otts SET wrong_attempt = otts.wrong_attempt + 1
WHERE email_hash = $1 AND expiration_time > $2`, emailHash, time.Microseconds())
WHERE email_hash = $1 AND expiration_time > $2 AND app=$3`, emailHash, time.Microseconds(), app)
if err != nil {
return stacktrace.Propagate(err, "Failed to update wrong attempt count")
}