Quellcode durchsuchen

[server] Support for copying files owned by others (#1496)

## Description

## Tests
Tested as part of https://github.com/ente-io/ente/pull/1484
- Verified that the client was able to download and decrypt the metadata
- Verified that storage was going up.
Neeraj Gupta vor 1 Jahr
Ursprung
Commit
b7d002feef

+ 11 - 1
server/cmd/museum/main.go

@@ -5,6 +5,7 @@ import (
 	"database/sql"
 	b64 "encoding/base64"
 	"fmt"
+	"github.com/ente-io/museum/pkg/controller/file_copy"
 	"net/http"
 	"os"
 	"os/signal"
@@ -389,9 +390,17 @@ func main() {
 		timeout.WithHandler(healthCheckHandler.PingDBStats),
 		timeout.WithResponse(timeOutResponse),
 	))
+	fileCopyCtrl := &file_copy.FileCopyController{
+		FileController: fileController,
+		CollectionCtrl: collectionController,
+		S3Config:       s3Config,
+		ObjectRepo:     objectRepo,
+		FileRepo:       fileRepo,
+	}
 
 	fileHandler := &api.FileHandler{
-		Controller: fileController,
+		Controller:   fileController,
+		FileCopyCtrl: fileCopyCtrl,
 	}
 	privateAPI.GET("/files/upload-urls", fileHandler.GetUploadURLs)
 	privateAPI.GET("/files/multipart-upload-urls", fileHandler.GetMultipartUploadURLs)
@@ -400,6 +409,7 @@ func main() {
 	privateAPI.GET("/files/preview/:fileID", fileHandler.GetThumbnail)
 	privateAPI.GET("/files/preview/v2/:fileID", fileHandler.GetThumbnail)
 	privateAPI.POST("/files", fileHandler.CreateOrUpdate)
+	privateAPI.POST("/files/copy", fileHandler.CopyFiles)
 	privateAPI.PUT("/files/update", fileHandler.Update)
 	privateAPI.POST("/files/trash", fileHandler.Trash)
 	privateAPI.POST("/files/size", fileHandler.GetSize)

+ 11 - 0
server/ente/collection.go

@@ -103,6 +103,17 @@ type AddFilesRequest struct {
 	Files        []CollectionFileItem `json:"files" binding:"required"`
 }
 
+// CopyFileSyncRequest is request object for creating copy of CollectionFileItems, and those copy to the destination collection
+type CopyFileSyncRequest struct {
+	SrcCollectionID     int64                `json:"srcCollectionID" binding:"required"`
+	DstCollection       int64                `json:"dstCollectionID" binding:"required"`
+	CollectionFileItems []CollectionFileItem `json:"files" binding:"required"`
+}
+
+type CopyResponse struct {
+	OldToNewFileIDMap map[int64]int64 `json:"oldToNewFileIDMap"`
+}
+
 // RemoveFilesRequest represents a request to remove files from a collection
 type RemoveFilesRequest struct {
 	CollectionID int64 `json:"collectionID" binding:"required"`

+ 24 - 1
server/pkg/api/file.go

@@ -1,6 +1,8 @@
 package api
 
 import (
+	"fmt"
+	"github.com/ente-io/museum/pkg/controller/file_copy"
 	"net/http"
 	"os"
 	"strconv"
@@ -20,11 +22,13 @@ import (
 
 // FileHandler exposes request handlers for all encrypted file related requests
 type FileHandler struct {
-	Controller *controller.FileController
+	Controller   *controller.FileController
+	FileCopyCtrl *file_copy.FileCopyController
 }
 
 // DefaultMaxBatchSize is the default maximum API batch size unless specified otherwise
 const DefaultMaxBatchSize = 1000
+const DefaultCopyBatchSize = 100
 
 // CreateOrUpdate creates an entry for a file
 func (h *FileHandler) CreateOrUpdate(c *gin.Context) {
@@ -58,6 +62,25 @@ func (h *FileHandler) CreateOrUpdate(c *gin.Context) {
 	c.JSON(http.StatusOK, response)
 }
 
+// CopyFiles copies files that are owned by another user
+func (h *FileHandler) CopyFiles(c *gin.Context) {
+	var req ente.CopyFileSyncRequest
+	if err := c.ShouldBindJSON(&req); err != nil {
+		handler.Error(c, stacktrace.Propagate(err, ""))
+		return
+	}
+	if len(req.CollectionFileItems) > DefaultCopyBatchSize {
+		handler.Error(c, stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("more than %d items", DefaultCopyBatchSize)), ""))
+		return
+	}
+	response, err := h.FileCopyCtrl.CopyFiles(c, req)
+	if err != nil {
+		handler.Error(c, stacktrace.Propagate(err, ""))
+		return
+	}
+	c.JSON(http.StatusOK, response)
+}
+
 // Update updates already existing file
 func (h *FileHandler) Update(c *gin.Context) {
 	enteApp := auth.GetApp(c)

+ 35 - 0
server/pkg/controller/collection.go

@@ -464,6 +464,41 @@ func (c *CollectionController) isRemoveAllowed(ctx *gin.Context, actorUserID int
 	return nil
 }
 
+func (c *CollectionController) IsCopyAllowed(ctx *gin.Context, actorUserID int64, req ente.CopyFileSyncRequest) error {
+	// verify that srcCollectionID is accessible by actorUserID
+	if _, err := c.AccessCtrl.GetCollection(ctx, &access.GetCollectionParams{
+		CollectionID: req.SrcCollectionID,
+		ActorUserID:  actorUserID,
+	}); err != nil {
+		return stacktrace.Propagate(err, "failed to verify srcCollection access")
+	}
+	// verify that dstCollectionID is owned by actorUserID
+	if _, err := c.AccessCtrl.GetCollection(ctx, &access.GetCollectionParams{
+		CollectionID: req.DstCollection,
+		ActorUserID:  actorUserID,
+		VerifyOwner:  true,
+	}); err != nil {
+		return stacktrace.Propagate(err, "failed to ownership of the dstCollection access")
+	}
+	// verify that all FileIDs exists in the srcCollection
+	fileIDs := make([]int64, len(req.CollectionFileItems))
+	for idx, file := range req.CollectionFileItems {
+		fileIDs[idx] = file.ID
+	}
+	if err := c.CollectionRepo.VerifyAllFileIDsExistsInCollection(ctx, req.SrcCollectionID, fileIDs); err != nil {
+		return stacktrace.Propagate(err, "failed to verify fileIDs in srcCollection")
+	}
+	dsMap, err := c.FileRepo.GetOwnerToFileIDsMap(ctx, fileIDs)
+	if err != nil {
+		return err
+	}
+	// verify that none of the file belongs to actorUserID
+	if _, ok := dsMap[actorUserID]; ok {
+		return ente.NewBadRequestWithMessage("can not copy files owned by actor")
+	}
+	return nil
+}
+
 // GetDiffV2 returns the changes in user's collections since a timestamp, along with hasMore bool flag.
 func (c *CollectionController) GetDiffV2(ctx *gin.Context, cID int64, userID int64, sinceTime int64) ([]ente.File, bool, error) {
 	reqContextLogger := log.WithFields(log.Fields{

+ 206 - 0
server/pkg/controller/file_copy/file_copy.go

@@ -0,0 +1,206 @@
+package file_copy
+
+import (
+	"fmt"
+	"github.com/aws/aws-sdk-go/service/s3"
+	"github.com/ente-io/museum/ente"
+	"github.com/ente-io/museum/pkg/controller"
+	"github.com/ente-io/museum/pkg/repo"
+	"github.com/ente-io/museum/pkg/utils/auth"
+	"github.com/ente-io/museum/pkg/utils/s3config"
+	enteTime "github.com/ente-io/museum/pkg/utils/time"
+	"github.com/gin-contrib/requestid"
+	"github.com/gin-gonic/gin"
+	"github.com/sirupsen/logrus"
+	"golang.org/x/sync/errgroup"
+	"sync"
+	"time"
+)
+
+const ()
+
+type FileCopyController struct {
+	S3Config       *s3config.S3Config
+	FileController *controller.FileController
+	FileRepo       *repo.FileRepository
+	CollectionCtrl *controller.CollectionController
+	ObjectRepo     *repo.ObjectRepository
+}
+
+type copyS3ObjectReq struct {
+	SourceS3Object ente.S3ObjectKey
+	DestObjectKey  string
+}
+
+type fileCopyInternal struct {
+	SourceFile       ente.File
+	DestCollectionID int64
+	// The FileKey is encrypted with the destination collection's key
+	EncryptedFileKey      string
+	EncryptedFileKeyNonce string
+	FileCopyReq           *copyS3ObjectReq
+	ThumbCopyReq          *copyS3ObjectReq
+}
+
+func (fci fileCopyInternal) newFile(ownedID int64) ente.File {
+	newFileAttributes := fci.SourceFile.File
+	newFileAttributes.ObjectKey = fci.FileCopyReq.DestObjectKey
+	newThumbAttributes := fci.SourceFile.Thumbnail
+	newThumbAttributes.ObjectKey = fci.ThumbCopyReq.DestObjectKey
+	return ente.File{
+		OwnerID:            ownedID,
+		CollectionID:       fci.DestCollectionID,
+		EncryptedKey:       fci.EncryptedFileKey,
+		KeyDecryptionNonce: fci.EncryptedFileKeyNonce,
+		File:               newFileAttributes,
+		Thumbnail:          newThumbAttributes,
+		Metadata:           fci.SourceFile.Metadata,
+		UpdationTime:       enteTime.Microseconds(),
+		IsDeleted:          false,
+	}
+}
+
+func (fc *FileCopyController) CopyFiles(c *gin.Context, req ente.CopyFileSyncRequest) (*ente.CopyResponse, error) {
+	userID := auth.GetUserID(c.Request.Header)
+	app := auth.GetApp(c)
+	logger := logrus.WithFields(logrus.Fields{"req_id": requestid.Get(c), "user_id": userID})
+	err := fc.CollectionCtrl.IsCopyAllowed(c, userID, req)
+	if err != nil {
+		return nil, err
+	}
+	fileIDs := make([]int64, 0, len(req.CollectionFileItems))
+	fileToCollectionFileMap := make(map[int64]*ente.CollectionFileItem, len(req.CollectionFileItems))
+	for i := range req.CollectionFileItems {
+		item := &req.CollectionFileItems[i]
+		fileToCollectionFileMap[item.ID] = item
+		fileIDs = append(fileIDs, item.ID)
+	}
+	s3ObjectsToCopy, err := fc.ObjectRepo.GetObjectsForFileIDs(fileIDs)
+	if err != nil {
+		return nil, err
+	}
+	// note: this assumes that preview existingFilesToCopy for videos are not tracked inside the object_keys table
+	if len(s3ObjectsToCopy) != 2*len(fileIDs) {
+		return nil, ente.NewInternalError(fmt.Sprintf("expected %d objects, got %d", 2*len(fileIDs), len(s3ObjectsToCopy)))
+	}
+	// todo:(neeraj) if the total size is greater than 1GB, do an early check if the user can upload the existingFilesToCopy
+	var totalSize int64
+	for _, obj := range s3ObjectsToCopy {
+		totalSize += obj.FileSize
+	}
+	logger.WithField("totalSize", totalSize).Info("total size of existingFilesToCopy to copy")
+
+	// request the uploadUrls using existing method. This is to ensure that orphan objects are automatically cleaned up
+	// todo:(neeraj) optimize this method by removing the need for getting a signed url for each object
+	uploadUrls, err := fc.FileController.GetUploadURLs(c, userID, len(s3ObjectsToCopy), app)
+	if err != nil {
+		return nil, err
+	}
+	existingFilesToCopy, err := fc.FileRepo.GetFileAttributesForCopy(fileIDs)
+	if err != nil {
+		return nil, err
+	}
+	if len(existingFilesToCopy) != len(fileIDs) {
+		return nil, ente.NewInternalError(fmt.Sprintf("expected %d existingFilesToCopy, got %d", len(fileIDs), len(existingFilesToCopy)))
+	}
+	fileOGS3Object := make(map[int64]*copyS3ObjectReq)
+	fileThumbS3Object := make(map[int64]*copyS3ObjectReq)
+	for i, s3Obj := range s3ObjectsToCopy {
+		if s3Obj.Type == ente.FILE {
+			fileOGS3Object[s3Obj.FileID] = &copyS3ObjectReq{
+				SourceS3Object: s3Obj,
+				DestObjectKey:  uploadUrls[i].ObjectKey,
+			}
+		} else if s3Obj.Type == ente.THUMBNAIL {
+			fileThumbS3Object[s3Obj.FileID] = &copyS3ObjectReq{
+				SourceS3Object: s3Obj,
+				DestObjectKey:  uploadUrls[i].ObjectKey,
+			}
+		} else {
+			return nil, ente.NewInternalError(fmt.Sprintf("unexpected object type %s", s3Obj.Type))
+		}
+	}
+	fileCopyList := make([]fileCopyInternal, 0, len(existingFilesToCopy))
+	for i := range existingFilesToCopy {
+		file := existingFilesToCopy[i]
+		collectionItem := fileToCollectionFileMap[file.ID]
+		if collectionItem.ID != file.ID {
+			return nil, ente.NewInternalError(fmt.Sprintf("expected collectionItem.ID %d, got %d", file.ID, collectionItem.ID))
+		}
+		fileCopy := fileCopyInternal{
+			SourceFile:            file,
+			DestCollectionID:      req.DstCollection,
+			EncryptedFileKey:      fileToCollectionFileMap[file.ID].EncryptedKey,
+			EncryptedFileKeyNonce: fileToCollectionFileMap[file.ID].KeyDecryptionNonce,
+			FileCopyReq:           fileOGS3Object[file.ID],
+			ThumbCopyReq:          fileThumbS3Object[file.ID],
+		}
+		fileCopyList = append(fileCopyList, fileCopy)
+	}
+	oldToNewFileIDMap := make(map[int64]int64)
+	var wg sync.WaitGroup
+	errChan := make(chan error, len(fileCopyList))
+
+	for _, fileCopy := range fileCopyList {
+		wg.Add(1)
+		go func(fileCopy fileCopyInternal) {
+			defer wg.Done()
+			newFile, err := fc.createCopy(c, fileCopy, userID, app)
+			if err != nil {
+				errChan <- err
+				return
+			}
+			oldToNewFileIDMap[fileCopy.SourceFile.ID] = newFile.ID
+		}(fileCopy)
+	}
+
+	// Wait for all goroutines to finish
+	wg.Wait()
+
+	// Close the error channel and check if there were any errors
+	close(errChan)
+	if err, ok := <-errChan; ok {
+		return nil, err
+	}
+	return &ente.CopyResponse{OldToNewFileIDMap: oldToNewFileIDMap}, nil
+}
+
+func (fc *FileCopyController) createCopy(c *gin.Context, fcInternal fileCopyInternal, userID int64, app ente.App) (*ente.File, error) {
+	// using HotS3Client copy the File and Thumbnail
+	s3Client := fc.S3Config.GetHotS3Client()
+	hotBucket := fc.S3Config.GetHotBucket()
+	g := new(errgroup.Group)
+	g.Go(func() error {
+		return copyS3Object(s3Client, hotBucket, fcInternal.FileCopyReq)
+	})
+	g.Go(func() error {
+		return copyS3Object(s3Client, hotBucket, fcInternal.ThumbCopyReq)
+	})
+	if err := g.Wait(); err != nil {
+		return nil, err
+	}
+	file := fcInternal.newFile(userID)
+	newFile, err := fc.FileController.Create(c, userID, file, "", app)
+	if err != nil {
+		return nil, err
+	}
+	return &newFile, nil
+}
+
+// Helper function for S3 object copying.
+func copyS3Object(s3Client *s3.S3, bucket *string, req *copyS3ObjectReq) error {
+	copySource := fmt.Sprintf("%s/%s", *bucket, req.SourceS3Object.ObjectKey)
+	copyInput := &s3.CopyObjectInput{
+		Bucket:     bucket,
+		CopySource: &copySource,
+		Key:        &req.DestObjectKey,
+	}
+	start := time.Now()
+	_, err := s3Client.CopyObject(copyInput)
+	elapsed := time.Since(start)
+	if err != nil {
+		return fmt.Errorf("failed to copy (%s) from %s to %s: %w", req.SourceS3Object.Type, copySource, req.DestObjectKey, err)
+	}
+	logrus.WithField("duration", elapsed).WithField("size", req.SourceS3Object.FileSize).Infof("copied (%s) from %s to %s", req.SourceS3Object.Type, copySource, req.DestObjectKey)
+	return nil
+}

+ 24 - 0
server/pkg/repo/collection.go

@@ -374,6 +374,30 @@ func (repo *CollectionRepository) DoesFileExistInCollections(fileID int64, cIDs
 	return exists, stacktrace.Propagate(err, "")
 }
 
+// VerifyAllFileIDsExistsInCollection returns error if the fileIDs don't exist in the collection
+func (repo *CollectionRepository) VerifyAllFileIDsExistsInCollection(ctx context.Context, cID int64, fileIDs []int64) error {
+	fileIdMap := make(map[int64]bool)
+	rows, err := repo.DB.QueryContext(ctx, `SELECT file_id FROM collection_files WHERE collection_id = $1 AND is_deleted = $2 AND file_id = ANY ($3)`,
+		cID, false, pq.Array(fileIDs))
+	if err != nil {
+		return stacktrace.Propagate(err, "")
+	}
+	for rows.Next() {
+		var fileID int64
+		if err := rows.Scan(&fileID); err != nil {
+			return stacktrace.Propagate(err, "")
+		}
+		fileIdMap[fileID] = true
+	}
+	// find fileIds that are not present in the collection
+	for _, fileID := range fileIDs {
+		if _, ok := fileIdMap[fileID]; !ok {
+			return stacktrace.Propagate(fmt.Errorf("fileID %d not found in collection %d", fileID, cID), "")
+		}
+	}
+	return nil
+}
+
 // GetCollectionShareeRole returns true if the collection is shared with the user
 func (repo *CollectionRepository) GetCollectionShareeRole(cID int64, userID int64) (*ente.CollectionParticipantRole, error) {
 	var role *ente.CollectionParticipantRole

+ 18 - 0
server/pkg/repo/file.go

@@ -612,6 +612,24 @@ func (repo *FileRepository) GetFileAttributesFromObjectKey(objectKey string) (en
 	return file, nil
 }
 
+func (repo *FileRepository) GetFileAttributesForCopy(fileIDs []int64) ([]ente.File, error) {
+	result := make([]ente.File, 0)
+	rows, err := repo.DB.Query(`SELECT file_id, owner_id, file_decryption_header, thumbnail_decryption_header, metadata_decryption_header, encrypted_metadata, pub_magic_metadata FROM files WHERE file_id = ANY($1)`, pq.Array(fileIDs))
+	if err != nil {
+		return nil, stacktrace.Propagate(err, "")
+	}
+	defer rows.Close()
+	for rows.Next() {
+		var file ente.File
+		err := rows.Scan(&file.ID, &file.OwnerID, &file.File.DecryptionHeader, &file.Thumbnail.DecryptionHeader, &file.Metadata.DecryptionHeader, &file.Metadata.EncryptedData, &file.PubicMagicMetadata)
+		if err != nil {
+			return nil, stacktrace.Propagate(err, "")
+		}
+		result = append(result, file)
+	}
+	return result, nil
+}
+
 // GetUsage  gets the Storage usage of a user
 // Deprecated: GetUsage is deprecated, use UsageRepository.GetUsage
 func (repo *FileRepository) GetUsage(userID int64) (int64, error) {

+ 9 - 0
server/pkg/repo/object.go

@@ -44,6 +44,15 @@ func (repo *ObjectRepository) MarkObjectReplicated(objectKey string, datacenter
 	return result.RowsAffected()
 }
 
+func (repo *ObjectRepository) GetObjectsForFileIDs(fileIDs []int64) ([]ente.S3ObjectKey, error) {
+	rows, err := repo.DB.Query(`SELECT file_id, o_type, object_key, size FROM object_keys 
+		WHERE file_id = ANY($1) AND is_deleted=false`, pq.Array(fileIDs))
+	if err != nil {
+		return nil, stacktrace.Propagate(err, "")
+	}
+	return convertRowsToObjectKeys(rows)
+}
+
 // GetObject returns the ente.S3ObjectKey key for a file id and type
 func (repo *ObjectRepository) GetObject(fileID int64, objType ente.ObjectType) (ente.S3ObjectKey, error) {
 	// todo: handling of deleted objects