123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235 |
- package user
- import (
- "context"
- "database/sql"
- "errors"
- "github.com/ente-io/museum/ente"
- "github.com/ente-io/museum/pkg/utils/auth"
- "github.com/ente-io/stacktrace"
- "github.com/gin-gonic/gin"
- "github.com/google/uuid"
- "github.com/kong/go-srp"
- "github.com/sirupsen/logrus"
- "net/http"
- )
- const Srp4096Params = 4096
- func (c *UserController) SetupSRP(context *gin.Context, userID int64, req ente.SetupSRPRequest) (*ente.SetupSRPResponse, error) {
- srpB, sessionID, err := c.createAndInsertSRPSession(context, req.SrpUserID, req.SRPVerifier, req.SRPA)
- if err != nil {
- return nil, stacktrace.Propagate(err, "")
- }
- setupID, err := c.UserAuthRepo.InsertTempSRPSetup(context, req, userID, sessionID)
- if err != nil {
- return nil, stacktrace.Propagate(err, "failed to add entry in setup table")
- }
- return &ente.SetupSRPResponse{
- SetupID: *setupID,
- SRPB: *srpB,
- }, nil
- }
- func (c *UserController) CompleteSRPSetup(context *gin.Context, req ente.CompleteSRPSetupRequest) (*ente.CompleteSRPSetupResponse, error) {
- userID := auth.GetUserID(context.Request.Header)
- setup, err := c.UserAuthRepo.GetTempSRPSetupEntity(context, req.SetupID)
- if err != nil {
- return nil, stacktrace.Propagate(err, "")
- }
- srpM2, err := c.verifySRPSession(context, setup.Verifier, setup.SessionID, req.SRPM1)
- if err != nil {
- return nil, err
- }
- err = c.UserAuthRepo.InsertSRPAuth(context, userID, setup.SRPUserID, setup.Verifier, setup.Salt)
- if err != nil {
- return nil, stacktrace.Propagate(err, "failed to add entry in srp auth")
- }
- return &ente.CompleteSRPSetupResponse{
- SetupID: req.SetupID,
- SRPM2: *srpM2,
- }, nil
- }
- // UpdateSrpAndKeyAttributes updates the SRP and keys attributes if the SRP setup is successfully done
- func (c *UserController) UpdateSrpAndKeyAttributes(context *gin.Context,
- userID int64,
- req ente.UpdateSRPAndKeysRequest,
- shouldClearTokens bool,
- ) (*ente.UpdateSRPSetupResponse, error) {
- setup, err := c.UserAuthRepo.GetTempSRPSetupEntity(context, req.SetupID)
- if err != nil {
- return nil, stacktrace.Propagate(err, "")
- }
- srpM2, err := c.verifySRPSession(context, setup.Verifier, setup.SessionID, req.SRPM1)
- if err != nil {
- return nil, err
- }
- err = c.UserAuthRepo.InsertOrUpdateSRPAuthAndKeyAttr(context, userID, req, setup)
- if err != nil {
- return nil, stacktrace.Propagate(err, "failed to add entry in srp auth")
- }
- if shouldClearTokens {
- token := auth.GetToken(context)
- err = c.UserAuthRepo.RemoveAllOtherTokens(userID, token)
- if err != nil {
- return nil, err
- }
- } else {
- logrus.WithField("user_id", userID).Info("not clearing tokens")
- }
- return &ente.UpdateSRPSetupResponse{
- SetupID: req.SetupID,
- SRPM2: *srpM2,
- }, nil
- }
- func (c *UserController) GetSRPAttributes(context *gin.Context, email string) (*ente.GetSRPAttributesResponse, error) {
- userID, err := c.UserRepo.GetUserIDWithEmail(email)
- if err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return nil, stacktrace.Propagate(ente.ErrNotFound, "user does not exist")
- } else {
- return nil, stacktrace.Propagate(err, "failed to get user")
- }
- }
- srpAttributes, err := c.UserAuthRepo.GetSRPAttributes(userID)
- if err != nil {
- return nil, stacktrace.Propagate(err, "")
- }
- return srpAttributes, nil
- }
- func (c *UserController) CreateSrpSession(context *gin.Context, req ente.CreateSRPSessionRequest) (*ente.CreateSRPSessionResponse, error) {
- srpAuthEntity, err := c.UserAuthRepo.GetSRPAuthEntityBySRPUserID(context, req.SRPUserID)
- if err != nil {
- return nil, err
- }
- isEmailMFAEnabled, err := c.UserAuthRepo.IsEmailMFAEnabled(context, srpAuthEntity.UserID)
- if err != nil {
- return nil, stacktrace.Propagate(err, "")
- }
- if *isEmailMFAEnabled {
- return nil, stacktrace.Propagate(&ente.ApiError{
- Code: "EMAIL_MFA_ENABLED",
- Message: "Email MFA is enabled",
- HttpStatusCode: http.StatusConflict,
- }, "email mfa is enabled")
- }
- srpBBase64, sessionID, err := c.createAndInsertSRPSession(context, req.SRPUserID, srpAuthEntity.Verifier, req.SRPA)
- if err != nil {
- return nil, stacktrace.Propagate(err, "")
- }
- return &ente.CreateSRPSessionResponse{
- SRPB: *srpBBase64,
- SessionID: *sessionID,
- }, nil
- }
- func (c *UserController) VerifySRPSession(context *gin.Context, req ente.VerifySRPSessionRequest) (*ente.EmailAuthorizationResponse, error) {
- srpAuthEntity, err := c.UserAuthRepo.GetSRPAuthEntityBySRPUserID(context, req.SRPUserID)
- if err != nil {
- return nil, stacktrace.Propagate(err, "")
- }
- srpM2, err := c.verifySRPSession(context, srpAuthEntity.Verifier, req.SessionID, req.SRPM1)
- if err != nil {
- return nil, stacktrace.Propagate(err, "")
- }
- user, err := c.UserRepo.Get(srpAuthEntity.UserID)
- if err != nil {
- return nil, err
- }
- verResponse, err := c.onVerificationSuccess(context, user.Email, nil)
- if err != nil {
- return nil, stacktrace.Propagate(err, "")
- }
- verResponse.SrpM2 = srpM2
- return &verResponse, nil
- }
- func (c *UserController) createAndInsertSRPSession(
- gContext *gin.Context,
- srpUserID uuid.UUID,
- srpVerifier string,
- srpA string,
- ) (*string, *uuid.UUID, error) {
- serverSecret := srp.GenKey()
- srpParams := srp.GetParams(Srp4096Params)
- srpServer := srp.NewServer(srpParams, convertStringToBytes(srpVerifier), serverSecret)
- if srpServer == nil {
- return nil, nil, stacktrace.NewError("server is nil")
- }
- srpServer.SetA(convertStringToBytes(srpA))
- srpB := srpServer.ComputeB()
- if srpB == nil {
- return nil, nil, stacktrace.NewError("srpB is nil")
- }
- sessionID, err := c.UserAuthRepo.AddSRPSession(srpUserID, convertBytesToString(serverSecret), srpA)
- if err != nil {
- return nil, nil, stacktrace.Propagate(err, "")
- }
- srpBBase64 := convertBytesToString(srpB)
- return &srpBBase64, &sessionID, nil
- }
- func (c *UserController) verifySRPSession(ctx context.Context,
- srpVerifier string,
- sessionID uuid.UUID,
- srpM1 string,
- ) (*string, error) {
- srpSession, err := c.UserAuthRepo.GetSrpSessionEntity(ctx, sessionID)
- if err != nil {
- return nil, stacktrace.Propagate(err, "")
- }
- if srpSession.IsVerified {
- return nil, stacktrace.Propagate(&ente.ApiError{
- Code: "SESSION_ALREADY_VERIFIED",
- HttpStatusCode: http.StatusGone,
- }, "")
- } else if srpSession.AttemptCount >= 5 {
- return nil, stacktrace.Propagate(&ente.ApiError{
- Code: "TOO_MANY_WRONG_ATTEMPTS",
- HttpStatusCode: http.StatusGone,
- }, "")
- }
- srpParams := srp.GetParams(Srp4096Params)
- srpServer := srp.NewServer(srpParams, convertStringToBytes(srpVerifier), convertStringToBytes(srpSession.ServerKey))
- if srpServer == nil {
- return nil, stacktrace.NewError("server is nil")
- }
- srpServer.SetA(convertStringToBytes(srpSession.SRP_A))
- srpM2Bytes, err := srpServer.CheckM1(convertStringToBytes(srpM1))
- if err != nil {
- err2 := c.UserAuthRepo.IncrementSrpSessionAttemptCount(ctx, sessionID)
- if err2 != nil {
- return nil, stacktrace.Propagate(err2, "")
- }
- return nil, stacktrace.Propagate(ente.ErrInvalidPassword, "failed to verify srp session")
- } else {
- err2 := c.UserAuthRepo.SetSrpSessionVerified(ctx, sessionID)
- if err2 != nil {
- return nil, stacktrace.Propagate(err2, "")
- }
- }
- srpM2 := convertBytesToString(srpM2Bytes)
- return &srpM2, nil
- }
|