398 lines
17 KiB
Go
398 lines
17 KiB
Go
package repo
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/ente-io/museum/pkg/repo/passkey"
|
|
storageBonusRepo "github.com/ente-io/museum/pkg/repo/storagebonus"
|
|
"github.com/ente-io/stacktrace"
|
|
"github.com/lib/pq"
|
|
|
|
"github.com/ente-io/museum/ente"
|
|
"github.com/ente-io/museum/pkg/utils/crypto"
|
|
"github.com/ente-io/museum/pkg/utils/time"
|
|
)
|
|
|
|
const (
|
|
// Format for updated email_hash once the account is deleted
|
|
DELETED_EMAIL_HASH_FORMAT = "deleted+%d@ente.io"
|
|
)
|
|
|
|
// UserRepository defines the methods for inserting, updating and retrieving
|
|
// user entities from the underlying repository
|
|
type UserRepository struct {
|
|
DB *sql.DB
|
|
SecretEncryptionKey []byte
|
|
HashingKey []byte
|
|
StorageBonusRepo *storageBonusRepo.Repository
|
|
PasskeysRepository *passkey.Repository
|
|
}
|
|
|
|
// Get returns a user indicated by the userID
|
|
func (repo *UserRepository) Get(userID int64) (ente.User, error) {
|
|
var user ente.User
|
|
var encryptedEmail, nonce []byte
|
|
row := repo.DB.QueryRow(`SELECT user_id, encrypted_email, email_decryption_nonce, email_hash, family_admin_id, creation_time, is_two_factor_enabled, email_mfa FROM users WHERE user_id = $1`, userID)
|
|
err := row.Scan(&user.ID, &encryptedEmail, &nonce, &user.Hash, &user.FamilyAdminID, &user.CreationTime, &user.IsTwoFactorEnabled, &user.IsEmailMFAEnabled)
|
|
if err != nil {
|
|
return ente.User{}, stacktrace.Propagate(err, "")
|
|
}
|
|
// We should not be calling Get user for a deleted account. The one valid
|
|
// use case is for internal/Admin APIs, where please we should instead be
|
|
// using GetUserByIDInternal.
|
|
if strings.EqualFold(user.Hash, fmt.Sprintf(DELETED_EMAIL_HASH_FORMAT, userID)) {
|
|
return user, stacktrace.Propagate(ente.ErrUserDeleted, fmt.Sprintf("user account is deleted %d", userID))
|
|
}
|
|
email, err := crypto.Decrypt(encryptedEmail, repo.SecretEncryptionKey, nonce)
|
|
if err != nil {
|
|
return ente.User{}, stacktrace.Propagate(err, "")
|
|
}
|
|
user.Email = email
|
|
return user, nil
|
|
}
|
|
|
|
// GetUserByIDInternal returns a user indicated by the id. Strickly use this method for internal APIs only.
|
|
func (repo *UserRepository) GetUserByIDInternal(id int64) (ente.User, error) {
|
|
var user ente.User
|
|
var encryptedEmail, nonce []byte
|
|
row := repo.DB.QueryRow(`SELECT user_id, encrypted_email, email_decryption_nonce, email_hash, family_admin_id, creation_time FROM users WHERE user_id = $1 AND encrypted_email IS NOT NULL`, id)
|
|
err := row.Scan(&user.ID, &encryptedEmail, &nonce, &user.Hash, &user.FamilyAdminID, &user.CreationTime)
|
|
if err != nil {
|
|
return ente.User{}, stacktrace.Propagate(err, "")
|
|
}
|
|
email, err := crypto.Decrypt(encryptedEmail, repo.SecretEncryptionKey, nonce)
|
|
if err != nil {
|
|
return ente.User{}, stacktrace.Propagate(err, "")
|
|
}
|
|
user.Email = email
|
|
return user, nil
|
|
}
|
|
|
|
// Delete removes the email_hash and encrypted email information for the user. It replaces email_hash with placeholder value
|
|
// based on DELETED_EMAIL_HASH_FORMAT
|
|
func (repo *UserRepository) Delete(userID int64) error {
|
|
emailHash := fmt.Sprintf(DELETED_EMAIL_HASH_FORMAT, userID)
|
|
_, err := repo.DB.Exec(`UPDATE users SET encrypted_email = null, email_decryption_nonce = null, email_hash = $1 WHERE user_id = $2`, emailHash, userID)
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
|
|
// GetFamilyAdminID returns the *familyAdminID for the given userID
|
|
func (repo *UserRepository) GetFamilyAdminID(userID int64) (*int64, error) {
|
|
row := repo.DB.QueryRow(`SELECT family_admin_id FROM users WHERE user_id = $1`, userID)
|
|
var familyAdminID *int64
|
|
err := row.Scan(&familyAdminID)
|
|
if err != nil {
|
|
return nil, stacktrace.Propagate(err, "")
|
|
}
|
|
return familyAdminID, nil
|
|
}
|
|
|
|
// GetUserByEmailHash returns a user indicated by the emailHash
|
|
func (repo *UserRepository) GetUserByEmailHash(emailHash string) (ente.User, error) {
|
|
var user ente.User
|
|
row := repo.DB.QueryRow(`SELECT user_id, email_hash, creation_time FROM users WHERE email_hash = $1`, emailHash)
|
|
err := row.Scan(&user.ID, &user.Hash, &user.CreationTime)
|
|
if err != nil {
|
|
return ente.User{}, stacktrace.Propagate(err, "")
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
// GetAll returns all users between sinceTime and tillTime (exclusive).
|
|
func (repo *UserRepository) GetAll(sinceTime int64, tillTime int64) ([]ente.User, error) {
|
|
rows, err := repo.DB.Query(`SELECT user_id, encrypted_email, email_decryption_nonce, email_hash, creation_time FROM users WHERE creation_time > $1 AND creation_time < $2 AND encrypted_email IS NOT NULL ORDER BY creation_time`, sinceTime, tillTime)
|
|
if err != nil {
|
|
return nil, stacktrace.Propagate(err, "")
|
|
}
|
|
defer rows.Close()
|
|
users := make([]ente.User, 0)
|
|
for rows.Next() {
|
|
var user ente.User
|
|
var encryptedEmail, nonce []byte
|
|
err := rows.Scan(&user.ID, &encryptedEmail, &nonce, &user.Hash, &user.CreationTime)
|
|
|
|
if err != nil {
|
|
return users, stacktrace.Propagate(err, "")
|
|
}
|
|
email, err := crypto.Decrypt(encryptedEmail, repo.SecretEncryptionKey, nonce)
|
|
if err != nil {
|
|
return nil, stacktrace.Propagate(err, "")
|
|
}
|
|
user.Email = email
|
|
users = append(users, user)
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
// GetUserUsageWithSubData will return current storage usage & basic information about subscription for given list
|
|
// of users. It's primarily used for fetching storage utilisation for a family/group of users
|
|
func (repo *UserRepository) GetUserUsageWithSubData(ctx context.Context, userIds []int64) ([]ente.UserUsageWithSubData, error) {
|
|
rows, err := repo.DB.QueryContext(ctx, `select encrypted_email, email_decryption_nonce, u.user_id, coalesce(storage_consumed , 0) as storage_used, storage, expiry_time
|
|
from users as u
|
|
left join (select storage_consumed, user_id from usage where user_id = ANY($1)) as us
|
|
on us.user_id=u.user_id
|
|
left join (select user_id,expiry_time, storage from subscriptions where user_id = ANY($1)) as s
|
|
on s.user_id = u.user_id
|
|
where u.user_id = ANY($1)`, pq.Array(userIds))
|
|
if err != nil {
|
|
return nil, stacktrace.Propagate(err, "")
|
|
}
|
|
defer rows.Close()
|
|
result := make([]ente.UserUsageWithSubData, 0)
|
|
for rows.Next() {
|
|
var (
|
|
usageData ente.UserUsageWithSubData
|
|
encryptedEmail, nonce []byte
|
|
)
|
|
err = rows.Scan(&encryptedEmail, &nonce, &usageData.UserID, &usageData.StorageConsumed, &usageData.Storage, &usageData.ExpiryTime)
|
|
if err != nil {
|
|
return result, stacktrace.Propagate(err, "")
|
|
}
|
|
email, err := crypto.Decrypt(encryptedEmail, repo.SecretEncryptionKey, nonce)
|
|
if err != nil {
|
|
return nil, stacktrace.Propagate(err, "failed to decrypt email")
|
|
}
|
|
usageData.Email = &email
|
|
result = append(result, usageData)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// Create creates a user with a given email address and returns the generated
|
|
// userID
|
|
func (repo *UserRepository) Create(encryptedEmail ente.EncryptionResult, emailHash string, source *string) (int64, error) {
|
|
var userID int64
|
|
err := repo.DB.QueryRow(`INSERT INTO users(encrypted_email, email_decryption_nonce, email_hash, creation_time, source) VALUES($1, $2, $3, $4, $5) RETURNING user_id`,
|
|
encryptedEmail.Cipher, encryptedEmail.Nonce, emailHash, time.Microseconds(), source).Scan(&userID)
|
|
if err != nil {
|
|
return -1, stacktrace.Propagate(err, "")
|
|
}
|
|
return userID, nil
|
|
}
|
|
|
|
// UpdateDeleteFeedback for a given user in the delete_feedback column of type jsonb
|
|
func (repo *UserRepository) UpdateDeleteFeedback(userID int64, feedback map[string]string) error {
|
|
// Convert the feedback map into JSON
|
|
feedbackJSON, err := json.Marshal(feedback)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "Failed to marshal feedback into JSON")
|
|
}
|
|
// Execute the update query with the JSON
|
|
_, err = repo.DB.Exec(`UPDATE users SET delete_feedback = $1 WHERE user_id = $2`, feedbackJSON, userID)
|
|
return stacktrace.Propagate(err, "Failed to update delete feedback")
|
|
}
|
|
|
|
// UpdateEmail updates the email address of a user
|
|
func (repo *UserRepository) UpdateEmail(userID int64, encryptedEmail ente.EncryptionResult, emailHash string) error {
|
|
_, err := repo.DB.Exec(`UPDATE users SET encrypted_email = $1, email_decryption_nonce = $2, email_hash = $3 WHERE user_id = $4`, encryptedEmail.Cipher, encryptedEmail.Nonce, emailHash, userID)
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
|
|
// GetUserIDWithEmail returns the userID associated with a provided email
|
|
func (repo *UserRepository) GetUserIDWithEmail(email string) (int64, error) {
|
|
sanitizedEmail := strings.ToLower(strings.TrimSpace(email))
|
|
emailHash, err := crypto.GetHash(sanitizedEmail, repo.HashingKey)
|
|
if err != nil {
|
|
return -1, stacktrace.Propagate(err, "")
|
|
}
|
|
row := repo.DB.QueryRow(`SELECT user_id FROM users WHERE email_hash = $1`, emailHash)
|
|
var userID int64
|
|
err = row.Scan(&userID)
|
|
if err != nil {
|
|
return -1, stacktrace.Propagate(err, "")
|
|
}
|
|
return userID, nil
|
|
}
|
|
|
|
// GetKeyAttributes gets the key attributes for a given user
|
|
func (repo *UserRepository) GetKeyAttributes(userID int64) (ente.KeyAttributes, error) {
|
|
row := repo.DB.QueryRow(`SELECT kek_salt, kek_hash_bytes, encrypted_key, key_decryption_nonce, public_key, encrypted_secret_key, secret_key_decryption_nonce, mem_limit, ops_limit, master_key_encrypted_with_recovery_key, master_key_decryption_nonce, recovery_key_encrypted_with_master_key, recovery_key_decryption_nonce FROM key_attributes WHERE user_id = $1`, userID)
|
|
var (
|
|
keyAttributes ente.KeyAttributes
|
|
kekHashBytes []byte
|
|
masterKeyEncryptedWithRecoveryKey sql.NullString
|
|
masterKeyDecryptionNonce sql.NullString
|
|
recoveryKeyEncryptedWithMasterKey sql.NullString
|
|
recoveryKeyDecryptionNonce sql.NullString
|
|
)
|
|
err := row.Scan(&keyAttributes.KEKSalt,
|
|
&kekHashBytes,
|
|
&keyAttributes.EncryptedKey,
|
|
&keyAttributes.KeyDecryptionNonce,
|
|
&keyAttributes.PublicKey,
|
|
&keyAttributes.EncryptedSecretKey,
|
|
&keyAttributes.SecretKeyDecryptionNonce,
|
|
&keyAttributes.MemLimit,
|
|
&keyAttributes.OpsLimit,
|
|
&masterKeyEncryptedWithRecoveryKey,
|
|
&masterKeyDecryptionNonce,
|
|
&recoveryKeyEncryptedWithMasterKey,
|
|
&recoveryKeyDecryptionNonce,
|
|
)
|
|
if err != nil {
|
|
return ente.KeyAttributes{}, stacktrace.Propagate(err, "")
|
|
}
|
|
keyAttributes.KEKHash = string(kekHashBytes)
|
|
if masterKeyEncryptedWithRecoveryKey.Valid {
|
|
keyAttributes.MasterKeyEncryptedWithRecoveryKey = masterKeyEncryptedWithRecoveryKey.String
|
|
}
|
|
if masterKeyDecryptionNonce.Valid {
|
|
keyAttributes.MasterKeyDecryptionNonce = masterKeyDecryptionNonce.String
|
|
}
|
|
if recoveryKeyEncryptedWithMasterKey.Valid {
|
|
keyAttributes.RecoveryKeyEncryptedWithMasterKey = recoveryKeyEncryptedWithMasterKey.String
|
|
}
|
|
if recoveryKeyDecryptionNonce.Valid {
|
|
keyAttributes.RecoveryKeyDecryptionNonce = recoveryKeyDecryptionNonce.String
|
|
}
|
|
|
|
return keyAttributes, nil
|
|
}
|
|
|
|
// SetKeyAttributes sets the key attributes for a given user
|
|
func (repo *UserRepository) SetKeyAttributes(userID int64, keyAttributes ente.KeyAttributes) error {
|
|
_, err := repo.DB.Exec(`INSERT INTO key_attributes(user_id, kek_salt, kek_hash_bytes, encrypted_key, key_decryption_nonce, public_key, encrypted_secret_key, secret_key_decryption_nonce, mem_limit, ops_limit, master_key_encrypted_with_recovery_key, master_key_decryption_nonce, recovery_key_encrypted_with_master_key, recovery_key_decryption_nonce) VALUES($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)`,
|
|
userID, keyAttributes.KEKSalt, []byte(keyAttributes.KEKHash),
|
|
keyAttributes.EncryptedKey, keyAttributes.KeyDecryptionNonce,
|
|
keyAttributes.PublicKey, keyAttributes.EncryptedSecretKey,
|
|
keyAttributes.SecretKeyDecryptionNonce, keyAttributes.MemLimit, keyAttributes.OpsLimit,
|
|
keyAttributes.MasterKeyEncryptedWithRecoveryKey, keyAttributes.MasterKeyDecryptionNonce,
|
|
keyAttributes.RecoveryKeyEncryptedWithMasterKey, keyAttributes.RecoveryKeyDecryptionNonce)
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
|
|
// UpdateKeys sets the keys of a user
|
|
func (repo *UserRepository) UpdateKeys(userID int64, keys ente.UpdateKeysRequest) error {
|
|
_, err := repo.DB.Exec(`UPDATE key_attributes SET kek_salt = $1, encrypted_key = $2, key_decryption_nonce = $3, mem_limit = $4, ops_limit = $5 WHERE user_id = $6`,
|
|
keys.KEKSalt, keys.EncryptedKey, keys.KeyDecryptionNonce, keys.MemLimit, keys.OpsLimit, userID)
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
|
|
// SetRecoveryKeyAttributes sets the recovery key and related attributes for a user
|
|
func (repo *UserRepository) SetRecoveryKeyAttributes(userID int64, keys ente.SetRecoveryKeyRequest) error {
|
|
_, err := repo.DB.Exec(`UPDATE key_attributes SET master_key_encrypted_with_recovery_key = $1, master_key_decryption_nonce = $2, recovery_key_encrypted_with_master_key = $3, recovery_key_decryption_nonce = $4 WHERE user_id = $5`,
|
|
keys.MasterKeyEncryptedWithRecoveryKey, keys.MasterKeyDecryptionNonce, keys.RecoveryKeyEncryptedWithMasterKey, keys.RecoveryKeyDecryptionNonce, userID)
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
|
|
// GetPublicKey returns the public key of a user
|
|
func (repo *UserRepository) GetPublicKey(userID int64) (string, error) {
|
|
row := repo.DB.QueryRow(`SELECT public_key FROM key_attributes WHERE user_id = $1`, userID)
|
|
var publicKey string
|
|
err := row.Scan(&publicKey)
|
|
return publicKey, stacktrace.Propagate(err, "")
|
|
}
|
|
|
|
// GetUsersWithIndividualPlanWhoHaveExceededStorageQuota returns list of users who have consumed their storage quota
|
|
// and they are not part of any family plan
|
|
func (repo *UserRepository) GetUsersWithIndividualPlanWhoHaveExceededStorageQuota() ([]ente.User, error) {
|
|
rows, err := repo.DB.Query(`
|
|
SELECT users.user_id, users.encrypted_email, users.email_decryption_nonce, users.email_hash, usage.storage_consumed, subscriptions.storage
|
|
FROM users
|
|
INNER JOIN usage
|
|
ON users.user_id = usage.user_id
|
|
INNER JOIN subscriptions
|
|
ON users.user_id = subscriptions.user_id AND usage.storage_consumed > subscriptions.storage AND users.encrypted_email IS NOT NULL AND users.family_admin_id IS NULL;
|
|
`)
|
|
if err != nil {
|
|
return nil, stacktrace.Propagate(err, "")
|
|
}
|
|
refBonus, addOnBonus, bonusErr := repo.StorageBonusRepo.GetAllUsersSurplusBonus(context.Background())
|
|
if bonusErr != nil {
|
|
return nil, stacktrace.Propagate(bonusErr, "failed to fetch bonusInfo")
|
|
}
|
|
defer rows.Close()
|
|
users := make([]ente.User, 0)
|
|
for rows.Next() {
|
|
var user ente.User
|
|
var encryptedEmail, nonce []byte
|
|
var storageConsumed, subStorage int64
|
|
err := rows.Scan(&user.ID, &encryptedEmail, &nonce, &user.Hash, &storageConsumed, &subStorage)
|
|
if err != nil {
|
|
return users, stacktrace.Propagate(err, "")
|
|
}
|
|
// ignore deleted users
|
|
if strings.EqualFold(user.Hash, fmt.Sprintf(DELETED_EMAIL_HASH_FORMAT, &user.ID)) || len(encryptedEmail) == 0 {
|
|
continue
|
|
}
|
|
if refBonusStorage, ok := refBonus[user.ID]; ok {
|
|
addOnBonusStorage := addOnBonus[user.ID]
|
|
// cap usable ref bonus to the subscription storage + addOnBonus
|
|
if refBonusStorage > (subStorage + addOnBonusStorage) {
|
|
refBonusStorage = subStorage + addOnBonusStorage
|
|
}
|
|
if (storageConsumed) <= (subStorage + refBonusStorage + addOnBonusStorage) {
|
|
continue
|
|
}
|
|
}
|
|
email, err := crypto.Decrypt(encryptedEmail, repo.SecretEncryptionKey, nonce)
|
|
if err != nil {
|
|
return users, stacktrace.Propagate(err, "")
|
|
}
|
|
user.Email = email
|
|
users = append(users, user)
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
// SetTwoFactorSecret sets the two factor secret for a user
|
|
func (repo *UserRepository) SetTwoFactorSecret(userID int64, secret ente.EncryptionResult, secretHash string, recoveryEncryptedTwoFactorSecret string, recoveryTwoFactorSecretDecryptionNonce string) error {
|
|
_, err := repo.DB.Exec(`INSERT INTO two_factor(user_id,encrypted_two_factor_secret,two_factor_secret_decryption_nonce,two_factor_secret_hash,recovery_encrypted_two_factor_secret,recovery_two_factor_secret_decryption_nonce)
|
|
VALUES($1, $2, $3, $4, $5, $6)
|
|
ON CONFLICT (user_id) DO UPDATE
|
|
SET encrypted_two_factor_secret = $2,
|
|
two_factor_secret_decryption_nonce = $3,
|
|
two_factor_secret_hash = $4,
|
|
recovery_encrypted_two_factor_secret = $5,
|
|
recovery_two_factor_secret_decryption_nonce = $6
|
|
`,
|
|
userID, secret.Cipher, secret.Nonce, secretHash, recoveryEncryptedTwoFactorSecret, recoveryTwoFactorSecretDecryptionNonce)
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
|
|
// IsTwoFactorEnabled checks if a user's two factor is enabled or not
|
|
func (repo *UserRepository) IsTwoFactorEnabled(userID int64) (bool, error) {
|
|
var twoFAStatus bool
|
|
row := repo.DB.QueryRow(`SELECT is_two_factor_enabled FROM users WHERE user_id = $1`, userID)
|
|
err := row.Scan(&twoFAStatus)
|
|
if err != nil {
|
|
return false, stacktrace.Propagate(err, "")
|
|
}
|
|
return twoFAStatus, nil
|
|
}
|
|
|
|
func (repo *UserRepository) HasPasskeys(userID int64) (hasPasskeys bool, err error) {
|
|
passkeys, err := repo.PasskeysRepository.GetUserPasskeys(userID)
|
|
hasPasskeys = len(passkeys) > 0
|
|
return
|
|
}
|
|
|
|
func (repo *UserRepository) GetEmailsFromHashes(hashes []string) ([]string, error) {
|
|
rows, err := repo.DB.Query(`
|
|
SELECT users.encrypted_email, users.email_decryption_nonce
|
|
FROM users
|
|
WHERE users.email_hash = ANY($1);
|
|
`, pq.Array(hashes))
|
|
if err != nil {
|
|
return nil, stacktrace.Propagate(err, "")
|
|
}
|
|
defer rows.Close()
|
|
emails := make([]string, 0)
|
|
for rows.Next() {
|
|
var encryptedEmail, nonce []byte
|
|
err := rows.Scan(&encryptedEmail, &nonce)
|
|
if err != nil {
|
|
return emails, stacktrace.Propagate(err, "")
|
|
}
|
|
email, err := crypto.Decrypt(encryptedEmail, repo.SecretEncryptionKey, nonce)
|
|
if err != nil {
|
|
return emails, stacktrace.Propagate(err, "")
|
|
}
|
|
emails = append(emails, email)
|
|
}
|
|
return emails, nil
|
|
}
|