Bladeren bron

[server] Add validation logic for file copy

Neeraj Gupta 1 jaar geleden
bovenliggende
commit
aabb884828
2 gewijzigde bestanden met toevoegingen van 59 en 0 verwijderingen
  1. 35 0
      server/pkg/controller/collection.go
  2. 24 0
      server/pkg/repo/collection.go

+ 35 - 0
server/pkg/controller/collection.go

@@ -464,6 +464,41 @@ func (c *CollectionController) isRemoveAllowed(ctx *gin.Context, actorUserID int
 	return nil
 }
 
+func (c *CollectionController) isCopyAllowed(ctx *gin.Context, actorUserID int64, req ente.CopyFileSyncRequest) error {
+	// verify that srcCollectionID is accessible by actorUserID
+	if _, err := c.AccessCtrl.GetCollection(ctx, &access.GetCollectionParams{
+		CollectionID: req.SrcCollectionID,
+		ActorUserID:  actorUserID,
+	}); err != nil {
+		return stacktrace.Propagate(err, "failed to verify srcCollection access")
+	}
+	// verify that dstCollectionID is owned by actorUserID
+	if _, err := c.AccessCtrl.GetCollection(ctx, &access.GetCollectionParams{
+		CollectionID: req.DstCollection,
+		ActorUserID:  actorUserID,
+		VerifyOwner:  true,
+	}); err != nil {
+		return stacktrace.Propagate(err, "failed to ownership of the dstCollection access")
+	}
+	// verify that all FileIDs exists in the srcCollection
+	fileIDs := make([]int64, len(req.Files))
+	for idx, file := range req.Files {
+		fileIDs[idx] = file.ID
+	}
+	if err := c.CollectionRepo.VerifyAllFileIDsExistsInCollection(ctx, req.SrcCollectionID, fileIDs); err != nil {
+		return stacktrace.Propagate(err, "failed to verify fileIDs in srcCollection")
+	}
+	dsMap, err := c.FileRepo.GetOwnerToFileIDsMap(ctx, fileIDs)
+	if err != nil {
+		return err
+	}
+	// verify that none of the file belongs to actorUserID
+	if _, ok := dsMap[actorUserID]; ok {
+		return ente.NewBadRequestWithMessage("can not copy files owned by actor")
+	}
+	return nil
+}
+
 // GetDiffV2 returns the changes in user's collections since a timestamp, along with hasMore bool flag.
 func (c *CollectionController) GetDiffV2(ctx *gin.Context, cID int64, userID int64, sinceTime int64) ([]ente.File, bool, error) {
 	reqContextLogger := log.WithFields(log.Fields{

+ 24 - 0
server/pkg/repo/collection.go

@@ -374,6 +374,30 @@ func (repo *CollectionRepository) DoesFileExistInCollections(fileID int64, cIDs
 	return exists, stacktrace.Propagate(err, "")
 }
 
+// VerifyAllFileIDsExistsInCollection returns error if the fileIDs don't exist in the collection
+func (repo *CollectionRepository) VerifyAllFileIDsExistsInCollection(ctx context.Context, cID int64, fileIDs []int64) error {
+	fileIdMap := make(map[int64]bool)
+	rows, err := repo.DB.QueryContext(ctx, `SELECT file_id FROM collection_files WHERE collection_id = $1 AND is_deleted = $2 AND file_id = ALL ($3)`,
+		cID, false, pq.Array(fileIDs))
+	if err != nil {
+		return stacktrace.Propagate(err, "")
+	}
+	for rows.Next() {
+		var fileID int64
+		if err := rows.Scan(&fileID); err != nil {
+			return stacktrace.Propagate(err, "")
+		}
+		fileIdMap[fileID] = true
+	}
+	// find fileIds that are not present in the collection
+	for _, fileID := range fileIDs {
+		if _, ok := fileIdMap[fileID]; !ok {
+			return stacktrace.Propagate(fmt.Errorf("fileID %d not found in collection %d", fileID, cID), "")
+		}
+	}
+	return nil
+}
+
 // GetCollectionShareeRole returns true if the collection is shared with the user
 func (repo *CollectionRepository) GetCollectionShareeRole(cID int64, userID int64) (*ente.CollectionParticipantRole, error) {
 	var role *ente.CollectionParticipantRole