controller.go 13 KB


  1. package embedding
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "github.com/ente-io/museum/pkg/utils/array"
  9. "strconv"
  10. "sync"
  11. gTime "time"
  12. "github.com/aws/aws-sdk-go/aws"
  13. "github.com/aws/aws-sdk-go/service/s3"
  14. "github.com/aws/aws-sdk-go/service/s3/s3manager"
  15. "github.com/ente-io/museum/ente"
  16. "github.com/ente-io/museum/pkg/controller"
  17. "github.com/ente-io/museum/pkg/controller/access"
  18. "github.com/ente-io/museum/pkg/repo"
  19. "github.com/ente-io/museum/pkg/repo/embedding"
  20. "github.com/ente-io/museum/pkg/utils/auth"
  21. "github.com/ente-io/museum/pkg/utils/network"
  22. "github.com/ente-io/museum/pkg/utils/s3config"
  23. "github.com/ente-io/museum/pkg/utils/time"
  24. "github.com/ente-io/stacktrace"
  25. "github.com/gin-gonic/gin"
  26. log "github.com/sirupsen/logrus"
  27. )
  28. const (
  29. // maxEmbeddingDataSize is the min size of an embedding object in bytes
  30. minEmbeddingDataSize = 2048
  31. embeddingFetchTimeout = 15 * gTime.Second
  32. )
  33. type Controller struct {
  34. Repo *embedding.Repository
  35. AccessCtrl access.Controller
  36. ObjectCleanupController *controller.ObjectCleanupController
  37. S3Config *s3config.S3Config
  38. QueueRepo *repo.QueueRepository
  39. TaskLockingRepo *repo.TaskLockRepository
  40. FileRepo *repo.FileRepository
  41. CollectionRepo *repo.CollectionRepository
  42. HostName string
  43. cleanupCronRunning bool
  44. }
  45. func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) (*ente.Embedding, error) {
  46. userID := auth.GetUserID(ctx.Request.Header)
  47. err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
  48. ActorUserId: userID,
  49. FileIDs: []int64{req.FileID},
  50. })
  51. if err != nil {
  52. return nil, stacktrace.Propagate(err, "User does not own file")
  53. }
  54. count, err := c.CollectionRepo.GetCollectionCount(req.FileID)
  55. if err != nil {
  56. return nil, stacktrace.Propagate(err, "")
  57. }
  58. if count < 1 {
  59. return nil, stacktrace.Propagate(ente.ErrNotFound, "")
  60. }
  61. version := 1
  62. if req.Version != nil {
  63. version = *req.Version
  64. }
  65. obj := ente.EmbeddingObject{
  66. Version: version,
  67. EncryptedEmbedding: req.EncryptedEmbedding,
  68. DecryptionHeader: req.DecryptionHeader,
  69. Client: network.GetPrettyUA(ctx.GetHeader("User-Agent")) + "/" + ctx.GetHeader("X-Client-Version"),
  70. }
  71. size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model))
  72. if uploadErr != nil {
  73. log.Error(uploadErr)
  74. return nil, stacktrace.Propagate(uploadErr, "")
  75. }
  76. embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version)
  77. embedding.Version = &version
  78. if err != nil {
  79. return nil, stacktrace.Propagate(err, "")
  80. }
  81. return &embedding, nil
  82. }
  83. func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest) ([]ente.Embedding, error) {
  84. userID := auth.GetUserID(ctx.Request.Header)
  85. if req.Model == "" {
  86. req.Model = ente.GgmlClip
  87. }
  88. embeddings, err := c.Repo.GetDiff(ctx, userID, req.Model, *req.SinceTime, req.Limit)
  89. if err != nil {
  90. return nil, stacktrace.Propagate(err, "")
  91. }
  92. // Collect object keys for embeddings with missing data
  93. var objectKeys []string
  94. for i := range embeddings {
  95. if embeddings[i].EncryptedEmbedding == "" {
  96. objectKey := c.getObjectKey(userID, embeddings[i].FileID, embeddings[i].Model)
  97. objectKeys = append(objectKeys, objectKey)
  98. }
  99. }
  100. // Fetch missing embeddings in parallel
  101. if len(objectKeys) > 0 {
  102. embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys)
  103. if err != nil {
  104. return nil, stacktrace.Propagate(err, "")
  105. }
  106. // Populate missing data in embeddings from fetched objects
  107. for i, obj := range embeddingObjects {
  108. for j := range embeddings {
  109. if embeddings[j].EncryptedEmbedding == "" && c.getObjectKey(userID, embeddings[j].FileID, embeddings[j].Model) == objectKeys[i] {
  110. embeddings[j].EncryptedEmbedding = obj.EncryptedEmbedding
  111. embeddings[j].DecryptionHeader = obj.DecryptionHeader
  112. }
  113. }
  114. }
  115. }
  116. return embeddings, nil
  117. }
  118. func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbeddingRequest) (*ente.GetFilesEmbeddingResponse, error) {
  119. userID := auth.GetUserID(ctx.Request.Header)
  120. if err := c._validateGetFileEmbeddingsRequest(ctx, userID, req); err != nil {
  121. return nil, stacktrace.Propagate(err, "")
  122. }
  123. userFileEmbeddings, err := c.Repo.GetFilesEmbedding(ctx, userID, req.Model, req.FileIDs)
  124. if err != nil {
  125. return nil, stacktrace.Propagate(err, "")
  126. }
  127. embeddingsWithData := make([]ente.Embedding, 0)
  128. noEmbeddingFileIds := make([]int64, 0)
  129. dbFileIds := make([]int64, 0)
  130. // fileIDs that were indexed but they don't contain any embedding information
  131. for i := range userFileEmbeddings {
  132. dbFileIds = append(dbFileIds, userFileEmbeddings[i].FileID)
  133. if userFileEmbeddings[i].Size != nil && *userFileEmbeddings[i].Size < minEmbeddingDataSize {
  134. noEmbeddingFileIds = append(noEmbeddingFileIds, userFileEmbeddings[i].FileID)
  135. } else {
  136. embeddingsWithData = append(embeddingsWithData, userFileEmbeddings[i])
  137. }
  138. }
  139. pendingIndexFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds)
  140. errFileIds := make([]int64, 0)
  141. // Fetch missing userFileEmbeddings in parallel
  142. embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData)
  143. if err != nil {
  144. return nil, stacktrace.Propagate(err, "")
  145. }
  146. fetchedEmbeddings := make([]ente.Embedding, 0)
  147. // Populate missing data in userFileEmbeddings from fetched objects
  148. for _, obj := range embeddingObjects {
  149. if obj.err != nil {
  150. errFileIds = append(errFileIds, obj.dbEmbeddingRow.FileID)
  151. } else {
  152. fetchedEmbeddings = append(fetchedEmbeddings, ente.Embedding{
  153. FileID: obj.dbEmbeddingRow.FileID,
  154. Model: obj.dbEmbeddingRow.Model,
  155. EncryptedEmbedding: obj.embeddingObject.EncryptedEmbedding,
  156. DecryptionHeader: obj.embeddingObject.DecryptionHeader,
  157. UpdatedAt: obj.dbEmbeddingRow.UpdatedAt,
  158. Version: obj.dbEmbeddingRow.Version,
  159. })
  160. }
  161. }
  162. return &ente.GetFilesEmbeddingResponse{
  163. Embeddings: fetchedEmbeddings,
  164. PendingIndexFileIDs: pendingIndexFileIds,
  165. ErrFileIDs: errFileIds,
  166. NoEmbeddingFileIDs: noEmbeddingFileIds,
  167. }, nil
  168. }
  169. func (c *Controller) DeleteAll(ctx *gin.Context) error {
  170. userID := auth.GetUserID(ctx.Request.Header)
  171. err := c.Repo.DeleteAll(ctx, userID)
  172. if err != nil {
  173. return stacktrace.Propagate(err, "")
  174. }
  175. return nil
  176. }
  177. // CleanupDeletedEmbeddings clears all embeddings for deleted files from the object store
  178. func (c *Controller) CleanupDeletedEmbeddings() {
  179. log.Info("Cleaning up deleted embeddings")
  180. if c.cleanupCronRunning {
  181. log.Info("Skipping CleanupDeletedEmbeddings cron run as another instance is still running")
  182. return
  183. }
  184. c.cleanupCronRunning = true
  185. defer func() {
  186. c.cleanupCronRunning = false
  187. }()
  188. items, err := c.QueueRepo.GetItemsReadyForDeletion(repo.DeleteEmbeddingsQueue, 200)
  189. if err != nil {
  190. log.WithError(err).Error("Failed to fetch items from queue")
  191. return
  192. }
  193. for _, i := range items {
  194. c.deleteEmbedding(i)
  195. }
  196. }
  197. func (c *Controller) deleteEmbedding(qItem repo.QueueItem) {
  198. lockName := fmt.Sprintf("Embedding:%s", qItem.Item)
  199. lockStatus, err := c.TaskLockingRepo.AcquireLock(lockName, time.MicrosecondsAfterHours(1), c.HostName)
  200. ctxLogger := log.WithField("item", qItem.Item).WithField("queue_id", qItem.Id)
  201. if err != nil || !lockStatus {
  202. ctxLogger.Warn("unable to acquire lock")
  203. return
  204. }
  205. defer func() {
  206. err = c.TaskLockingRepo.ReleaseLock(lockName)
  207. if err != nil {
  208. ctxLogger.Errorf("Error while releasing lock %s", err)
  209. }
  210. }()
  211. ctxLogger.Info("Deleting all embeddings")
  212. fileID, _ := strconv.ParseInt(qItem.Item, 10, 64)
  213. ownerID, err := c.FileRepo.GetOwnerID(fileID)
  214. if err != nil {
  215. ctxLogger.WithError(err).Error("Failed to fetch ownerID")
  216. return
  217. }
  218. prefix := c.getEmbeddingObjectPrefix(ownerID, fileID)
  219. err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetHotDataCenter())
  220. if err != nil {
  221. ctxLogger.WithError(err).Error("Failed to delete all objects")
  222. return
  223. }
  224. err = c.Repo.Delete(fileID)
  225. if err != nil {
  226. ctxLogger.WithError(err).Error("Failed to remove from db")
  227. return
  228. }
  229. err = c.QueueRepo.DeleteItem(repo.DeleteEmbeddingsQueue, qItem.Item)
  230. if err != nil {
  231. ctxLogger.WithError(err).Error("Failed to remove item from the queue")
  232. return
  233. }
  234. ctxLogger.Info("Successfully deleted all embeddings")
  235. }
  236. func (c *Controller) getObjectKey(userID int64, fileID int64, model string) string {
  237. return c.getEmbeddingObjectPrefix(userID, fileID) + model + ".json"
  238. }
  239. func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string {
  240. return strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/"
  241. }
  242. // uploadObject uploads the embedding object to the object store and returns the object size
  243. func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, error) {
  244. embeddingObj, _ := json.Marshal(obj)
  245. uploader := s3manager.NewUploaderWithClient(c.S3Config.GetHotS3Client())
  246. up := s3manager.UploadInput{
  247. Bucket: c.S3Config.GetHotBucket(),
  248. Key: &key,
  249. Body: bytes.NewReader(embeddingObj),
  250. }
  251. result, err := uploader.Upload(&up)
  252. if err != nil {
  253. log.Error(err)
  254. return -1, stacktrace.Propagate(err, "")
  255. }
  256. log.Infof("Uploaded to bucket %s", result.Location)
  257. return len(embeddingObj), nil
  258. }
  259. var globalDiffFetchSemaphore = make(chan struct{}, 300)
  260. var globalFileFetchSemaphore = make(chan struct{}, 400)
  261. func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.EmbeddingObject, error) {
  262. var wg sync.WaitGroup
  263. var errs []error
  264. embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys))
  265. downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
  266. for i, objectKey := range objectKeys {
  267. wg.Add(1)
  268. globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore
  269. go func(i int, objectKey string) {
  270. defer wg.Done()
  271. defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
  272. obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
  273. if err != nil {
  274. errs = append(errs, err)
  275. log.Error("error fetching embedding object: "+objectKey, err)
  276. } else {
  277. embeddingObjects[i] = obj
  278. }
  279. }(i, objectKey)
  280. }
  281. wg.Wait()
  282. if len(errs) > 0 {
  283. return nil, stacktrace.Propagate(errors.New("failed to fetch some objects"), "")
  284. }
  285. return embeddingObjects, nil
  286. }
  287. type embeddingObjectResult struct {
  288. embeddingObject ente.EmbeddingObject
  289. dbEmbeddingRow ente.Embedding
  290. err error
  291. }
  292. func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) {
  293. var wg sync.WaitGroup
  294. embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows))
  295. downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
  296. for i, dbEmbeddingRow := range dbEmbeddingRows {
  297. wg.Add(1)
  298. globalFileFetchSemaphore <- struct{}{} // Acquire from global semaphore
  299. go func(i int, dbEmbeddingRow ente.Embedding) {
  300. defer wg.Done()
  301. defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
  302. objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
  303. ctx, cancel := context.WithTimeout(context.Background(), embeddingFetchTimeout)
  304. defer cancel()
  305. obj, err := c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 0)
  306. if err != nil {
  307. log.Error("error fetching embedding object: "+objectKey, err)
  308. embeddingObjects[i] = embeddingObjectResult{
  309. err: err,
  310. dbEmbeddingRow: dbEmbeddingRow,
  311. }
  312. } else {
  313. embeddingObjects[i] = embeddingObjectResult{
  314. embeddingObject: obj,
  315. dbEmbeddingRow: dbEmbeddingRow,
  316. }
  317. }
  318. }(i, dbEmbeddingRow)
  319. }
  320. wg.Wait()
  321. return embeddingObjects, nil
  322. }
  323. func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
  324. return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 3)
  325. }
  326. func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKey string, downloader *s3manager.Downloader, retryCount int) (ente.EmbeddingObject, error) {
  327. var obj ente.EmbeddingObject
  328. buff := &aws.WriteAtBuffer{}
  329. _, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
  330. Bucket: c.S3Config.GetHotBucket(),
  331. Key: &objectKey,
  332. })
  333. if err != nil {
  334. log.Error(err)
  335. if retryCount > 0 {
  336. return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, retryCount-1)
  337. }
  338. return obj, stacktrace.Propagate(err, "")
  339. }
  340. err = json.Unmarshal(buff.Bytes(), &obj)
  341. if err != nil {
  342. log.Error(err)
  343. return obj, stacktrace.Propagate(err, "")
  344. }
  345. return obj, nil
  346. }
  347. func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error {
  348. if req.Model == "" {
  349. return ente.NewBadRequestWithMessage("model is required")
  350. }
  351. if len(req.FileIDs) == 0 {
  352. return ente.NewBadRequestWithMessage("fileIDs are required")
  353. }
  354. if len(req.FileIDs) > 200 {
  355. return ente.NewBadRequestWithMessage("fileIDs should be less than or equal to 200")
  356. }
  357. if err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
  358. ActorUserId: userID,
  359. FileIDs: req.FileIDs,
  360. }); err != nil {
  361. return stacktrace.Propagate(err, "User does not own some file(s)")
  362. }
  363. return nil
  364. }