controller.go 17 KB


  1. package embedding
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "github.com/aws/aws-sdk-go/aws/awserr"
  9. "github.com/ente-io/museum/pkg/utils/array"
  10. "strconv"
  11. "sync"
  12. gTime "time"
  13. "github.com/aws/aws-sdk-go/aws"
  14. "github.com/aws/aws-sdk-go/service/s3"
  15. "github.com/aws/aws-sdk-go/service/s3/s3manager"
  16. "github.com/ente-io/museum/ente"
  17. "github.com/ente-io/museum/pkg/controller"
  18. "github.com/ente-io/museum/pkg/controller/access"
  19. "github.com/ente-io/museum/pkg/repo"
  20. "github.com/ente-io/museum/pkg/repo/embedding"
  21. "github.com/ente-io/museum/pkg/utils/auth"
  22. "github.com/ente-io/museum/pkg/utils/network"
  23. "github.com/ente-io/museum/pkg/utils/s3config"
  24. "github.com/ente-io/museum/pkg/utils/time"
  25. "github.com/ente-io/stacktrace"
  26. "github.com/gin-gonic/gin"
  27. log "github.com/sirupsen/logrus"
  28. )
  29. const (
  30. // maxEmbeddingDataSize is the min size of an embedding object in bytes
  31. minEmbeddingDataSize = 2048
  32. embeddingFetchTimeout = 15 * gTime.Second
  33. )
  34. type Controller struct {
  35. Repo *embedding.Repository
  36. AccessCtrl access.Controller
  37. ObjectCleanupController *controller.ObjectCleanupController
  38. S3Config *s3config.S3Config
  39. QueueRepo *repo.QueueRepository
  40. TaskLockingRepo *repo.TaskLockRepository
  41. FileRepo *repo.FileRepository
  42. CollectionRepo *repo.CollectionRepository
  43. HostName string
  44. cleanupCronRunning bool
  45. embeddingS3Client *s3.S3
  46. embeddingBucket *string
  47. areEmbeddingAndHotBucketSame bool
  48. }
  49. func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller {
  50. return &Controller{
  51. Repo: repo,
  52. AccessCtrl: accessCtrl,
  53. ObjectCleanupController: objectCleanupController,
  54. S3Config: s3Config,
  55. QueueRepo: queueRepo,
  56. TaskLockingRepo: taskLockingRepo,
  57. FileRepo: fileRepo,
  58. CollectionRepo: collectionRepo,
  59. HostName: hostName,
  60. embeddingS3Client: s3Config.GetEmbeddingsS3Client(),
  61. embeddingBucket: s3Config.GetEmbeddingsBucket(),
  62. areEmbeddingAndHotBucketSame: s3Config.GetEmbeddingsBucket() == s3Config.GetHotBucket(),
  63. }
  64. }
  65. func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) (*ente.Embedding, error) {
  66. userID := auth.GetUserID(ctx.Request.Header)
  67. err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
  68. ActorUserId: userID,
  69. FileIDs: []int64{req.FileID},
  70. })
  71. if err != nil {
  72. return nil, stacktrace.Propagate(err, "User does not own file")
  73. }
  74. count, err := c.CollectionRepo.GetCollectionCount(req.FileID)
  75. if err != nil {
  76. return nil, stacktrace.Propagate(err, "")
  77. }
  78. if count < 1 {
  79. return nil, stacktrace.Propagate(ente.ErrNotFound, "")
  80. }
  81. version := 1
  82. if req.Version != nil {
  83. version = *req.Version
  84. }
  85. obj := ente.EmbeddingObject{
  86. Version: version,
  87. EncryptedEmbedding: req.EncryptedEmbedding,
  88. DecryptionHeader: req.DecryptionHeader,
  89. Client: network.GetPrettyUA(ctx.GetHeader("User-Agent")) + "/" + ctx.GetHeader("X-Client-Version"),
  90. }
  91. size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model))
  92. if uploadErr != nil {
  93. log.Error(uploadErr)
  94. return nil, stacktrace.Propagate(uploadErr, "")
  95. }
  96. embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version)
  97. embedding.Version = &version
  98. if err != nil {
  99. return nil, stacktrace.Propagate(err, "")
  100. }
  101. return &embedding, nil
  102. }
  103. func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest) ([]ente.Embedding, error) {
  104. userID := auth.GetUserID(ctx.Request.Header)
  105. if req.Model == "" {
  106. req.Model = ente.GgmlClip
  107. }
  108. embeddings, err := c.Repo.GetDiff(ctx, userID, req.Model, *req.SinceTime, req.Limit)
  109. if err != nil {
  110. return nil, stacktrace.Propagate(err, "")
  111. }
  112. // Collect object keys for embeddings with missing data
  113. var objectKeys []string
  114. for i := range embeddings {
  115. if embeddings[i].EncryptedEmbedding == "" {
  116. objectKey := c.getObjectKey(userID, embeddings[i].FileID, embeddings[i].Model)
  117. objectKeys = append(objectKeys, objectKey)
  118. }
  119. }
  120. // Fetch missing embeddings in parallel
  121. if len(objectKeys) > 0 {
  122. embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys)
  123. if err != nil {
  124. return nil, stacktrace.Propagate(err, "")
  125. }
  126. // Populate missing data in embeddings from fetched objects
  127. for i, obj := range embeddingObjects {
  128. for j := range embeddings {
  129. if embeddings[j].EncryptedEmbedding == "" && c.getObjectKey(userID, embeddings[j].FileID, embeddings[j].Model) == objectKeys[i] {
  130. embeddings[j].EncryptedEmbedding = obj.EncryptedEmbedding
  131. embeddings[j].DecryptionHeader = obj.DecryptionHeader
  132. }
  133. }
  134. }
  135. }
  136. return embeddings, nil
  137. }
  138. func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbeddingRequest) (*ente.GetFilesEmbeddingResponse, error) {
  139. userID := auth.GetUserID(ctx.Request.Header)
  140. if err := c._validateGetFileEmbeddingsRequest(ctx, userID, req); err != nil {
  141. return nil, stacktrace.Propagate(err, "")
  142. }
  143. userFileEmbeddings, err := c.Repo.GetFilesEmbedding(ctx, userID, req.Model, req.FileIDs)
  144. if err != nil {
  145. return nil, stacktrace.Propagate(err, "")
  146. }
  147. embeddingsWithData := make([]ente.Embedding, 0)
  148. noEmbeddingFileIds := make([]int64, 0)
  149. dbFileIds := make([]int64, 0)
  150. // fileIDs that were indexed, but they don't contain any embedding information
  151. for i := range userFileEmbeddings {
  152. dbFileIds = append(dbFileIds, userFileEmbeddings[i].FileID)
  153. if userFileEmbeddings[i].Size != nil && *userFileEmbeddings[i].Size < minEmbeddingDataSize {
  154. noEmbeddingFileIds = append(noEmbeddingFileIds, userFileEmbeddings[i].FileID)
  155. } else {
  156. embeddingsWithData = append(embeddingsWithData, userFileEmbeddings[i])
  157. }
  158. }
  159. pendingIndexFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds)
  160. errFileIds := make([]int64, 0)
  161. // Fetch missing userFileEmbeddings in parallel
  162. embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData)
  163. if err != nil {
  164. return nil, stacktrace.Propagate(err, "")
  165. }
  166. fetchedEmbeddings := make([]ente.Embedding, 0)
  167. // Populate missing data in userFileEmbeddings from fetched objects
  168. for _, obj := range embeddingObjects {
  169. if obj.err != nil {
  170. errFileIds = append(errFileIds, obj.dbEmbeddingRow.FileID)
  171. } else {
  172. fetchedEmbeddings = append(fetchedEmbeddings, ente.Embedding{
  173. FileID: obj.dbEmbeddingRow.FileID,
  174. Model: obj.dbEmbeddingRow.Model,
  175. EncryptedEmbedding: obj.embeddingObject.EncryptedEmbedding,
  176. DecryptionHeader: obj.embeddingObject.DecryptionHeader,
  177. UpdatedAt: obj.dbEmbeddingRow.UpdatedAt,
  178. Version: obj.dbEmbeddingRow.Version,
  179. })
  180. }
  181. }
  182. return &ente.GetFilesEmbeddingResponse{
  183. Embeddings: fetchedEmbeddings,
  184. PendingIndexFileIDs: pendingIndexFileIds,
  185. ErrFileIDs: errFileIds,
  186. NoEmbeddingFileIDs: noEmbeddingFileIds,
  187. }, nil
  188. }
  189. func (c *Controller) DeleteAll(ctx *gin.Context) error {
  190. userID := auth.GetUserID(ctx.Request.Header)
  191. err := c.Repo.DeleteAll(ctx, userID)
  192. if err != nil {
  193. return stacktrace.Propagate(err, "")
  194. }
  195. return nil
  196. }
  197. // CleanupDeletedEmbeddings clears all embeddings for deleted files from the object store
  198. func (c *Controller) CleanupDeletedEmbeddings() {
  199. log.Info("Cleaning up deleted embeddings")
  200. if c.cleanupCronRunning {
  201. log.Info("Skipping CleanupDeletedEmbeddings cron run as another instance is still running")
  202. return
  203. }
  204. c.cleanupCronRunning = true
  205. defer func() {
  206. c.cleanupCronRunning = false
  207. }()
  208. items, err := c.QueueRepo.GetItemsReadyForDeletion(repo.DeleteEmbeddingsQueue, 200)
  209. if err != nil {
  210. log.WithError(err).Error("Failed to fetch items from queue")
  211. return
  212. }
  213. for _, i := range items {
  214. c.deleteEmbedding(i)
  215. }
  216. }
  217. func (c *Controller) deleteEmbedding(qItem repo.QueueItem) {
  218. lockName := fmt.Sprintf("Embedding:%s", qItem.Item)
  219. lockStatus, err := c.TaskLockingRepo.AcquireLock(lockName, time.MicrosecondsAfterHours(1), c.HostName)
  220. ctxLogger := log.WithField("item", qItem.Item).WithField("queue_id", qItem.Id)
  221. if err != nil || !lockStatus {
  222. ctxLogger.Warn("unable to acquire lock")
  223. return
  224. }
  225. defer func() {
  226. err = c.TaskLockingRepo.ReleaseLock(lockName)
  227. if err != nil {
  228. ctxLogger.Errorf("Error while releasing lock %s", err)
  229. }
  230. }()
  231. ctxLogger.Info("Deleting all embeddings")
  232. fileID, _ := strconv.ParseInt(qItem.Item, 10, 64)
  233. ownerID, err := c.FileRepo.GetOwnerID(fileID)
  234. if err != nil {
  235. ctxLogger.WithError(err).Error("Failed to fetch ownerID")
  236. return
  237. }
  238. prefix := c.getEmbeddingObjectPrefix(ownerID, fileID)
  239. err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetEmbeddingsDataCenter())
  240. if err != nil {
  241. ctxLogger.WithError(err).Error("Failed to delete all objects")
  242. return
  243. }
  244. // if Embeddings DC is different from hot DC, delete from hot DC as well
  245. if !c.areEmbeddingAndHotBucketSame {
  246. err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetHotDataCenter())
  247. if err != nil {
  248. ctxLogger.WithError(err).Error("Failed to delete all objects from hot DC")
  249. return
  250. }
  251. }
  252. err = c.Repo.Delete(fileID)
  253. if err != nil {
  254. ctxLogger.WithError(err).Error("Failed to remove from db")
  255. return
  256. }
  257. err = c.QueueRepo.DeleteItem(repo.DeleteEmbeddingsQueue, qItem.Item)
  258. if err != nil {
  259. ctxLogger.WithError(err).Error("Failed to remove item from the queue")
  260. return
  261. }
  262. ctxLogger.Info("Successfully deleted all embeddings")
  263. }
  264. func (c *Controller) getObjectKey(userID int64, fileID int64, model string) string {
  265. return c.getEmbeddingObjectPrefix(userID, fileID) + model + ".json"
  266. }
  267. func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string {
  268. return strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/"
  269. }
  270. // uploadObject uploads the embedding object to the object store and returns the object size
  271. func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, error) {
  272. embeddingObj, _ := json.Marshal(obj)
  273. uploader := s3manager.NewUploaderWithClient(c.embeddingS3Client)
  274. up := s3manager.UploadInput{
  275. Bucket: c.embeddingBucket,
  276. Key: &key,
  277. Body: bytes.NewReader(embeddingObj),
  278. }
  279. result, err := uploader.Upload(&up)
  280. if err != nil {
  281. log.Error(err)
  282. return -1, stacktrace.Propagate(err, "")
  283. }
  284. log.Infof("Uploaded to bucket %s", result.Location)
  285. return len(embeddingObj), nil
  286. }
  287. var globalDiffFetchSemaphore = make(chan struct{}, 300)
  288. var globalFileFetchSemaphore = make(chan struct{}, 400)
  289. func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.EmbeddingObject, error) {
  290. var wg sync.WaitGroup
  291. var errs []error
  292. embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys))
  293. downloader := s3manager.NewDownloaderWithClient(c.embeddingS3Client)
  294. for i, objectKey := range objectKeys {
  295. wg.Add(1)
  296. globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore
  297. go func(i int, objectKey string) {
  298. defer wg.Done()
  299. defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
  300. obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
  301. if err != nil {
  302. errs = append(errs, err)
  303. log.Error("error fetching embedding object: "+objectKey, err)
  304. } else {
  305. embeddingObjects[i] = obj
  306. }
  307. }(i, objectKey)
  308. }
  309. wg.Wait()
  310. if len(errs) > 0 {
  311. return nil, stacktrace.Propagate(errors.New("failed to fetch some objects"), "")
  312. }
  313. return embeddingObjects, nil
  314. }
  315. type embeddingObjectResult struct {
  316. embeddingObject ente.EmbeddingObject
  317. dbEmbeddingRow ente.Embedding
  318. err error
  319. }
  320. func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) {
  321. var wg sync.WaitGroup
  322. embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows))
  323. downloader := s3manager.NewDownloaderWithClient(c.embeddingS3Client)
  324. for i, dbEmbeddingRow := range dbEmbeddingRows {
  325. wg.Add(1)
  326. globalFileFetchSemaphore <- struct{}{} // Acquire from global semaphore
  327. go func(i int, dbEmbeddingRow ente.Embedding) {
  328. defer wg.Done()
  329. defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
  330. objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
  331. obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
  332. if err != nil {
  333. log.Error("error fetching embedding object: "+objectKey, err)
  334. embeddingObjects[i] = embeddingObjectResult{
  335. err: err,
  336. dbEmbeddingRow: dbEmbeddingRow,
  337. }
  338. } else {
  339. embeddingObjects[i] = embeddingObjectResult{
  340. embeddingObject: obj,
  341. dbEmbeddingRow: dbEmbeddingRow,
  342. }
  343. }
  344. }(i, dbEmbeddingRow)
  345. }
  346. wg.Wait()
  347. return embeddingObjects, nil
  348. }
  349. type getOptions struct {
  350. RetryCount int
  351. FetchTimeOut gTime.Duration
  352. }
  353. func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, opt *getOptions) (ente.EmbeddingObject, error) {
  354. if opt == nil {
  355. opt = &getOptions{
  356. RetryCount: 3,
  357. FetchTimeOut: embeddingFetchTimeout,
  358. }
  359. }
  360. ctxLogger := log.WithField("objectKey", objectKey)
  361. totalAttempts := opt.RetryCount + 1
  362. for i := 0; i < totalAttempts; i++ {
  363. // Create a new context with a timeout for each fetch
  364. fetchCtx, cancel := context.WithTimeout(ctx, opt.FetchTimeOut)
  365. select {
  366. case <-ctx.Done():
  367. cancel()
  368. return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "")
  369. default:
  370. obj, err := c.downloadObject(fetchCtx, objectKey, downloader, c.embeddingBucket)
  371. cancel() // Ensure cancel is called to release resources
  372. if err == nil {
  373. if i > 0 {
  374. ctxLogger.Infof("Fetched object after %d attempts", i)
  375. }
  376. return obj, nil
  377. }
  378. // Check if the error is due to context timeout or cancellation
  379. if err == nil && fetchCtx.Err() != nil {
  380. ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err())
  381. } else {
  382. // check if the error is due to object not found
  383. if s3Err, ok := err.(awserr.RequestFailure); ok {
  384. if s3Err.Code() == s3.ErrCodeNoSuchKey {
  385. if c.areEmbeddingAndHotBucketSame {
  386. ctxLogger.Error("Object not found: ", s3Err)
  387. } else {
  388. // If embedding and hot bucket are different, try to copy from hot bucket
  389. copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey)
  390. if err == nil {
  391. ctxLogger.Info("Got the object from hot bucket object")
  392. return *copyEmbeddingObject, nil
  393. } else {
  394. ctxLogger.WithError(err).Error("Failed to copy from hot bucket object")
  395. }
  396. return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
  397. }
  398. }
  399. }
  400. ctxLogger.Error("Failed to fetch object: ", err)
  401. }
  402. }
  403. }
  404. return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "")
  405. }
  406. func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, bucket *string) (ente.EmbeddingObject, error) {
  407. var obj ente.EmbeddingObject
  408. buff := &aws.WriteAtBuffer{}
  409. _, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
  410. Bucket: bucket,
  411. Key: &objectKey,
  412. })
  413. if err != nil {
  414. return obj, err
  415. }
  416. err = json.Unmarshal(buff.Bytes(), &obj)
  417. if err != nil {
  418. return obj, stacktrace.Propagate(err, "unmarshal failed")
  419. }
  420. return obj, nil
  421. }
  422. // download the embedding object from hot bucket and upload to embeddings bucket
  423. func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) {
  424. if c.embeddingBucket == c.S3Config.GetHotBucket() {
  425. return nil, stacktrace.Propagate(errors.New("embedding bucket and hot bucket are same"), "")
  426. }
  427. downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
  428. obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotBucket())
  429. if err != nil {
  430. return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
  431. }
  432. go func() {
  433. _, err = c.uploadObject(obj, objectKey)
  434. if err != nil {
  435. log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", err)
  436. }
  437. }()
  438. return &obj, nil
  439. }
  440. func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error {
  441. if req.Model == "" {
  442. return ente.NewBadRequestWithMessage("model is required")
  443. }
  444. if len(req.FileIDs) == 0 {
  445. return ente.NewBadRequestWithMessage("fileIDs are required")
  446. }
  447. if len(req.FileIDs) > 200 {
  448. return ente.NewBadRequestWithMessage("fileIDs should be less than or equal to 200")
  449. }
  450. if err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
  451. ActorUserId: userID,
  452. FileIDs: req.FileIDs,
  453. }); err != nil {
  454. return stacktrace.Propagate(err, "User does not own some file(s)")
  455. }
  456. return nil
  457. }