diff --git a/server/ente/errors.go b/server/ente/errors.go index 49aed7151..96e7bd4a1 100644 --- a/server/ente/errors.go +++ b/server/ente/errors.go @@ -149,6 +149,12 @@ var ErrCastPermissionDenied = ApiError{ HttpStatusCode: http.StatusForbidden, } +var ErrCastIPMismatch = ApiError{ + Code: "CAST_IP_MISMATCH", + Message: "IP mismatch", + HttpStatusCode: http.StatusForbidden, +} + type ErrorCode string const ( diff --git a/server/pkg/controller/cast/controller.go b/server/pkg/controller/cast/controller.go index 3b76420cc..4432e149f 100644 --- a/server/pkg/controller/cast/controller.go +++ b/server/pkg/controller/cast/controller.go @@ -2,12 +2,15 @@ package cast import ( "context" + "github.com/ente-io/museum/ente" "github.com/ente-io/museum/ente/cast" "github.com/ente-io/museum/pkg/controller/access" castRepo "github.com/ente-io/museum/pkg/repo/cast" "github.com/ente-io/museum/pkg/utils/auth" + "github.com/ente-io/museum/pkg/utils/network" "github.com/ente-io/stacktrace" "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" ) type Controller struct { @@ -24,12 +27,24 @@ func NewController(castRepo *castRepo.Repository, } } -func (c *Controller) RegisterDevice(ctx context.Context, request *cast.RegisterDeviceRequest) (string, error) { - return c.CastRepo.AddCode(ctx, request.DeviceCode, request.PublicKey) +func (c *Controller) RegisterDevice(ctx *gin.Context, request *cast.RegisterDeviceRequest) (string, error) { + return c.CastRepo.AddCode(ctx, request.DeviceCode, request.PublicKey, network.GetClientIP(ctx)) } -func (c *Controller) GetPublicKey(ctx context.Context, deviceCode string) (string, error) { - return c.CastRepo.GetPubKey(ctx, deviceCode) +func (c *Controller) GetPublicKey(ctx *gin.Context, deviceCode string) (string, error) { + pubKey, ip, err := c.CastRepo.GetPubKeyAndIp(ctx, deviceCode) + if err != nil { + return "", stacktrace.Propagate(err, "") + } + if ip != network.GetClientIP(ctx) { + logrus.WithFields(logrus.Fields{ + "deviceCode": deviceCode, + "ip": ip, + "clientIP": network.GetClientIP(ctx), + }).Warn("GetPublicKey: IP mismatch") + return "", &ente.ErrCastIPMismatch + } + return pubKey, nil } func (c *Controller) GetEncCastData(ctx context.Context, deviceCode string) (*string, error) { diff --git a/server/pkg/repo/cast/repo.go b/server/pkg/repo/cast/repo.go index 451842bcc..a8e02338e 100644 --- a/server/pkg/repo/cast/repo.go +++ b/server/pkg/repo/cast/repo.go @@ -14,7 +14,7 @@ type Repository struct { DB *sql.DB } -func (r *Repository) AddCode(ctx context.Context, code *string, pubKey string) (string, error) { +func (r *Repository) AddCode(ctx context.Context, code *string, pubKey string, ip string) (string, error) { var codeValue string var err error if code == nil || *code == "" { @@ -25,7 +25,7 @@ func (r *Repository) AddCode(ctx context.Context, code *string, pubKey string) ( } else { codeValue = strings.TrimSpace(*code) } - _, err = r.DB.ExecContext(ctx, "INSERT INTO casting (code, public_key, id) VALUES ($1, $2, $3)", codeValue, pubKey, uuid.New()) + _, err = r.DB.ExecContext(ctx, "INSERT INTO casting (code, public_key, id, ip) VALUES ($1, $2, $3, $4)", codeValue, pubKey, uuid.New(), ip) if err != nil { return "", err } @@ -38,17 +38,17 @@ func (r *Repository) InsertCastData(ctx context.Context, castUserID int64, code return err } -func (r *Repository) GetPubKey(ctx context.Context, code string) (string, error) { - var pubKey string - row := r.DB.QueryRowContext(ctx, "SELECT public_key FROM casting WHERE code = $1 and is_deleted=false", code) - err := row.Scan(&pubKey) +func (r *Repository) GetPubKeyAndIp(ctx context.Context, code string) (string, string, error) { + var pubKey, ip string + row := r.DB.QueryRowContext(ctx, "SELECT public_key, ip FROM casting WHERE code = $1 and is_deleted=false", code) + err := row.Scan(&pubKey, &ip) if err != nil { if err == sql.ErrNoRows { - return "", ente.ErrNotFoundError.NewErr("code not found") + return "", "", ente.ErrNotFoundError.NewErr("code not found") } - return "", err + return "", "", err } - return pubKey, nil + return pubKey, ip, nil } func (r *Repository) GetEncCastData(ctx context.Context, code string) (*string, error) {