diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index fc2300d93..84c34189d 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -839,7 +839,7 @@ func setupAndStartCrons(userAuthRepo *repo.UserAuthRepository, publicCollectionR schedule(c, "@every 24h", func() { _ = userAuthRepo.RemoveDeletedTokens(timeUtil.MicrosecondBeforeDays(30)) - _ = castDb.DeleteOldCodes(context.Background(), timeUtil.MicrosecondBeforeDays(1)) + _ = castDb.DeleteOldSessions(context.Background(), timeUtil.MicrosecondBeforeDays(7)) _ = publicCollectionRepo.CleanupAccessHistory(context.Background()) }) @@ -897,6 +897,8 @@ func setupAndStartCrons(userAuthRepo *repo.UserAuthRepository, publicCollectionR }) schedule(c, "@every 30m", func() { + // delete unclaimed codes older than 60 minutes + _ = castDb.DeleteUnclaimedCodes(context.Background(), timeUtil.MicrosecondsBeforeMinutes(60)) dataCleanupCtrl.DeleteDataCron() }) diff --git a/server/ente/errors.go b/server/ente/errors.go index 49aed7151..96e7bd4a1 100644 --- a/server/ente/errors.go +++ b/server/ente/errors.go @@ -149,6 +149,12 @@ var ErrCastPermissionDenied = ApiError{ HttpStatusCode: http.StatusForbidden, } +var ErrCastIPMismatch = ApiError{ + Code: "CAST_IP_MISMATCH", + Message: "IP mismatch", + HttpStatusCode: http.StatusForbidden, +} + type ErrorCode string const ( diff --git a/server/migrations/84_add_cast_column.down.sql b/server/migrations/84_add_cast_column.down.sql new file mode 100644 index 000000000..c08fed94e --- /dev/null +++ b/server/migrations/84_add_cast_column.down.sql @@ -0,0 +1 @@ +ALTER TABLE casting DROP COLUMN IF EXISTS ip; \ No newline at end of file diff --git a/server/migrations/84_add_cast_column.up.sql b/server/migrations/84_add_cast_column.up.sql new file mode 100644 index 000000000..828c2e57c --- /dev/null +++ b/server/migrations/84_add_cast_column.up.sql @@ -0,0 +1,5 @@ +--- Delete all rows from casting table and add a non-nullable column called ip +BEGIN; +DELETE FROM casting; +ALTER TABLE casting ADD COLUMN ip text NOT NULL; +COMMIT; diff --git a/server/pkg/controller/cast/controller.go b/server/pkg/controller/cast/controller.go index 3b76420cc..4432e149f 100644 --- a/server/pkg/controller/cast/controller.go +++ b/server/pkg/controller/cast/controller.go @@ -2,12 +2,15 @@ package cast import ( "context" + "github.com/ente-io/museum/ente" "github.com/ente-io/museum/ente/cast" "github.com/ente-io/museum/pkg/controller/access" castRepo "github.com/ente-io/museum/pkg/repo/cast" "github.com/ente-io/museum/pkg/utils/auth" + "github.com/ente-io/museum/pkg/utils/network" "github.com/ente-io/stacktrace" "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" ) type Controller struct { @@ -24,12 +27,24 @@ func NewController(castRepo *castRepo.Repository, } } -func (c *Controller) RegisterDevice(ctx context.Context, request *cast.RegisterDeviceRequest) (string, error) { - return c.CastRepo.AddCode(ctx, request.DeviceCode, request.PublicKey) +func (c *Controller) RegisterDevice(ctx *gin.Context, request *cast.RegisterDeviceRequest) (string, error) { + return c.CastRepo.AddCode(ctx, request.DeviceCode, request.PublicKey, network.GetClientIP(ctx)) } -func (c *Controller) GetPublicKey(ctx context.Context, deviceCode string) (string, error) { - return c.CastRepo.GetPubKey(ctx, deviceCode) +func (c *Controller) GetPublicKey(ctx *gin.Context, deviceCode string) (string, error) { + pubKey, ip, err := c.CastRepo.GetPubKeyAndIp(ctx, deviceCode) + if err != nil { + return "", stacktrace.Propagate(err, "") + } + if ip != network.GetClientIP(ctx) { + logrus.WithFields(logrus.Fields{ + "deviceCode": deviceCode, + "ip": ip, + "clientIP": network.GetClientIP(ctx), + }).Warn("GetPublicKey: IP mismatch") + return "", &ente.ErrCastIPMismatch + } + return pubKey, nil } func (c *Controller) GetEncCastData(ctx context.Context, deviceCode string) (*string, error) { diff --git a/server/pkg/middleware/rate_limit.go b/server/pkg/middleware/rate_limit.go index 08e0f00b6..076c050c9 100644 --- a/server/pkg/middleware/rate_limit.go +++ b/server/pkg/middleware/rate_limit.go @@ -150,6 +150,7 @@ func (r *RateLimitMiddleware) getLimiter(reqPath string, reqMethod string) *limi reqPath == "/public-collection/verify-password" || reqPath == "/family/accept-invite" || reqPath == "/users/srp/attributes" || + (reqPath == "/cast/device-info/" && reqMethod == "POST") || reqPath == "/users/srp/verify-session" || reqPath == "/family/invite-info/:token" || reqPath == "/family/add-member" || diff --git a/server/pkg/repo/cast/repo.go b/server/pkg/repo/cast/repo.go index 306c1d481..89ebc4083 100644 --- a/server/pkg/repo/cast/repo.go +++ b/server/pkg/repo/cast/repo.go @@ -7,6 +7,7 @@ import ( "github.com/ente-io/museum/pkg/utils/random" "github.com/ente-io/stacktrace" "github.com/google/uuid" + log "github.com/sirupsen/logrus" "strings" ) @@ -14,7 +15,7 @@ type Repository struct { DB *sql.DB } -func (r *Repository) AddCode(ctx context.Context, code *string, pubKey string) (string, error) { +func (r *Repository) AddCode(ctx context.Context, code *string, pubKey string, ip string) (string, error) { var codeValue string var err error if code == nil || *code == "" { @@ -25,7 +26,7 @@ func (r *Repository) AddCode(ctx context.Context, code *string, pubKey string) ( } else { codeValue = strings.TrimSpace(*code) } - _, err = r.DB.ExecContext(ctx, "INSERT INTO casting (code, public_key, id) VALUES ($1, $2, $3)", codeValue, pubKey, uuid.New()) + _, err = r.DB.ExecContext(ctx, "INSERT INTO casting (code, public_key, id, ip) VALUES ($1, $2, $3, $4)", codeValue, pubKey, uuid.New(), ip) if err != nil { return "", err } @@ -38,17 +39,17 @@ func (r *Repository) InsertCastData(ctx context.Context, castUserID int64, code return err } -func (r *Repository) GetPubKey(ctx context.Context, code string) (string, error) { - var pubKey string - row := r.DB.QueryRowContext(ctx, "SELECT public_key FROM casting WHERE code = $1 and is_deleted=false", code) - err := row.Scan(&pubKey) +func (r *Repository) GetPubKeyAndIp(ctx context.Context, code string) (string, string, error) { + var pubKey, ip string + row := r.DB.QueryRowContext(ctx, "SELECT public_key, ip FROM casting WHERE code = $1 and is_deleted=false", code) + err := row.Scan(&pubKey, &ip) if err != nil { if err == sql.ErrNoRows { - return "", ente.ErrNotFoundError.NewErr("code not found") + return "", "", ente.ErrNotFoundError.NewErr("code not found") } - return "", err + return "", "", err } - return pubKey, nil + return pubKey, ip, nil } func (r *Repository) GetEncCastData(ctx context.Context, code string) (*string, error) { @@ -89,12 +90,27 @@ func (r *Repository) UpdateLastUsedAtForToken(ctx context.Context, token string) return nil } -// DeleteOldCodes that are not associated with a collection and are older than the given time -func (r *Repository) DeleteOldCodes(ctx context.Context, expirtyTime int64) error { - _, err := r.DB.ExecContext(ctx, "DELETE FROM casting WHERE last_used_at < $1 and is_deleted=false and collection_id is null", expirtyTime) +// DeleteUnclaimedCodes that are not associated with a collection and are older than the given time +func (r *Repository) DeleteUnclaimedCodes(ctx context.Context, expiryTime int64) error { + result, err := r.DB.ExecContext(ctx, "DELETE FROM casting WHERE last_used_at < $1 and is_deleted=false and collection_id is null", expiryTime) if err != nil { return err } + if rows, rErr := result.RowsAffected(); rErr == nil && rows > 0 { + log.Infof("Deleted %d unclaimed codes", rows) + } + return nil +} + +// DeleteOldSessions where last used at is older than the given time +func (r *Repository) DeleteOldSessions(ctx context.Context, expiryTime int64) error { + result, err := r.DB.ExecContext(ctx, "DELETE FROM casting WHERE last_used_at < $1", expiryTime) + if err != nil { + return err + } + if rows, rErr := result.RowsAffected(); rErr == nil && rows > 0 { + log.Infof("Deleted %d old sessions", rows) + } return nil }