Prechádzať zdrojové kódy

[server] Increase embedding fetch limit (#1300)

## Description

Also use different semaphore than existing diff API

## Tests
Neeraj Gupta 1 rok pred
rodič
commit
2fe703df92

+ 9 - 7
server/pkg/controller/embedding/controller.go

@@ -275,7 +275,9 @@ func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, er
 	return len(embeddingObj), nil
 }
 
-var globalFetchSemaphore = make(chan struct{}, 300)
+var globalDiffFetchSemaphore = make(chan struct{}, 300)
+
+var globalFileFetchSemaphore = make(chan struct{}, 400)
 
 func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.EmbeddingObject, error) {
 	var wg sync.WaitGroup
@@ -285,10 +287,10 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
 
 	for i, objectKey := range objectKeys {
 		wg.Add(1)
-		globalFetchSemaphore <- struct{}{} // Acquire from global semaphore
+		globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore
 		go func(i int, objectKey string) {
 			defer wg.Done()
-			defer func() { <-globalFetchSemaphore }() // Release back to global semaphore
+			defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
 
 			obj, err := c.getEmbeddingObject(objectKey, downloader)
 			if err != nil {
@@ -322,10 +324,10 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
 
 	for i, dbEmbeddingRow := range dbEmbeddingRows {
 		wg.Add(1)
-		globalFetchSemaphore <- struct{}{} // Acquire from global semaphore
+		globalFileFetchSemaphore <- struct{}{} // Acquire from global semaphore
 		go func(i int, dbEmbeddingRow ente.Embedding) {
 			defer wg.Done()
-			defer func() { <-globalFetchSemaphore }() // Release back to global semaphore
+			defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
 			objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
 			obj, err := c.getEmbeddingObject(objectKey, downloader)
 			if err != nil {
@@ -373,8 +375,8 @@ func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID
 	if len(req.FileIDs) == 0 {
 		return ente.NewBadRequestWithMessage("fileIDs are required")
 	}
-	if len(req.FileIDs) > 100 {
-		return ente.NewBadRequestWithMessage("fileIDs should be less than or equal to 100")
+	if len(req.FileIDs) > 200 {
+		return ente.NewBadRequestWithMessage("fileIDs should be less than or equal to 200")
 	}
 	if err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
 		ActorUserId: userID,