controller.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. package embedding
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "github.com/aws/aws-sdk-go/aws/awserr"
  9. "github.com/ente-io/museum/pkg/utils/array"
  10. "strconv"
  11. "sync"
  12. gTime "time"
  13. "github.com/aws/aws-sdk-go/aws"
  14. "github.com/aws/aws-sdk-go/service/s3"
  15. "github.com/aws/aws-sdk-go/service/s3/s3manager"
  16. "github.com/ente-io/museum/ente"
  17. "github.com/ente-io/museum/pkg/controller"
  18. "github.com/ente-io/museum/pkg/controller/access"
  19. "github.com/ente-io/museum/pkg/repo"
  20. "github.com/ente-io/museum/pkg/repo/embedding"
  21. "github.com/ente-io/museum/pkg/utils/auth"
  22. "github.com/ente-io/museum/pkg/utils/network"
  23. "github.com/ente-io/museum/pkg/utils/s3config"
  24. "github.com/ente-io/museum/pkg/utils/time"
  25. "github.com/ente-io/stacktrace"
  26. "github.com/gin-gonic/gin"
  27. log "github.com/sirupsen/logrus"
  28. )
  29. const (
  30. // maxEmbeddingDataSize is the min size of an embedding object in bytes
  31. minEmbeddingDataSize = 2048
  32. embeddingFetchTimeout = 15 * gTime.Second
  33. )
  34. type Controller struct {
  35. Repo *embedding.Repository
  36. AccessCtrl access.Controller
  37. ObjectCleanupController *controller.ObjectCleanupController
  38. S3Config *s3config.S3Config
  39. QueueRepo *repo.QueueRepository
  40. TaskLockingRepo *repo.TaskLockRepository
  41. FileRepo *repo.FileRepository
  42. CollectionRepo *repo.CollectionRepository
  43. HostName string
  44. cleanupCronRunning bool
  45. }
  46. func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) (*ente.Embedding, error) {
  47. userID := auth.GetUserID(ctx.Request.Header)
  48. err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
  49. ActorUserId: userID,
  50. FileIDs: []int64{req.FileID},
  51. })
  52. if err != nil {
  53. return nil, stacktrace.Propagate(err, "User does not own file")
  54. }
  55. count, err := c.CollectionRepo.GetCollectionCount(req.FileID)
  56. if err != nil {
  57. return nil, stacktrace.Propagate(err, "")
  58. }
  59. if count < 1 {
  60. return nil, stacktrace.Propagate(ente.ErrNotFound, "")
  61. }
  62. version := 1
  63. if req.Version != nil {
  64. version = *req.Version
  65. }
  66. obj := ente.EmbeddingObject{
  67. Version: version,
  68. EncryptedEmbedding: req.EncryptedEmbedding,
  69. DecryptionHeader: req.DecryptionHeader,
  70. Client: network.GetPrettyUA(ctx.GetHeader("User-Agent")) + "/" + ctx.GetHeader("X-Client-Version"),
  71. }
  72. size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model))
  73. if uploadErr != nil {
  74. log.Error(uploadErr)
  75. return nil, stacktrace.Propagate(uploadErr, "")
  76. }
  77. embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version)
  78. embedding.Version = &version
  79. if err != nil {
  80. return nil, stacktrace.Propagate(err, "")
  81. }
  82. return &embedding, nil
  83. }
  84. func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest) ([]ente.Embedding, error) {
  85. userID := auth.GetUserID(ctx.Request.Header)
  86. if req.Model == "" {
  87. req.Model = ente.GgmlClip
  88. }
  89. embeddings, err := c.Repo.GetDiff(ctx, userID, req.Model, *req.SinceTime, req.Limit)
  90. if err != nil {
  91. return nil, stacktrace.Propagate(err, "")
  92. }
  93. // Collect object keys for embeddings with missing data
  94. var objectKeys []string
  95. for i := range embeddings {
  96. if embeddings[i].EncryptedEmbedding == "" {
  97. objectKey := c.getObjectKey(userID, embeddings[i].FileID, embeddings[i].Model)
  98. objectKeys = append(objectKeys, objectKey)
  99. }
  100. }
  101. // Fetch missing embeddings in parallel
  102. if len(objectKeys) > 0 {
  103. embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys)
  104. if err != nil {
  105. return nil, stacktrace.Propagate(err, "")
  106. }
  107. // Populate missing data in embeddings from fetched objects
  108. for i, obj := range embeddingObjects {
  109. for j := range embeddings {
  110. if embeddings[j].EncryptedEmbedding == "" && c.getObjectKey(userID, embeddings[j].FileID, embeddings[j].Model) == objectKeys[i] {
  111. embeddings[j].EncryptedEmbedding = obj.EncryptedEmbedding
  112. embeddings[j].DecryptionHeader = obj.DecryptionHeader
  113. }
  114. }
  115. }
  116. }
  117. return embeddings, nil
  118. }
  119. func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbeddingRequest) (*ente.GetFilesEmbeddingResponse, error) {
  120. userID := auth.GetUserID(ctx.Request.Header)
  121. if err := c._validateGetFileEmbeddingsRequest(ctx, userID, req); err != nil {
  122. return nil, stacktrace.Propagate(err, "")
  123. }
  124. userFileEmbeddings, err := c.Repo.GetFilesEmbedding(ctx, userID, req.Model, req.FileIDs)
  125. if err != nil {
  126. return nil, stacktrace.Propagate(err, "")
  127. }
  128. embeddingsWithData := make([]ente.Embedding, 0)
  129. noEmbeddingFileIds := make([]int64, 0)
  130. dbFileIds := make([]int64, 0)
  131. // fileIDs that were indexed, but they don't contain any embedding information
  132. for i := range userFileEmbeddings {
  133. dbFileIds = append(dbFileIds, userFileEmbeddings[i].FileID)
  134. if userFileEmbeddings[i].Size != nil && *userFileEmbeddings[i].Size < minEmbeddingDataSize {
  135. noEmbeddingFileIds = append(noEmbeddingFileIds, userFileEmbeddings[i].FileID)
  136. } else {
  137. embeddingsWithData = append(embeddingsWithData, userFileEmbeddings[i])
  138. }
  139. }
  140. pendingIndexFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds)
  141. errFileIds := make([]int64, 0)
  142. // Fetch missing userFileEmbeddings in parallel
  143. embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData)
  144. if err != nil {
  145. return nil, stacktrace.Propagate(err, "")
  146. }
  147. fetchedEmbeddings := make([]ente.Embedding, 0)
  148. // Populate missing data in userFileEmbeddings from fetched objects
  149. for _, obj := range embeddingObjects {
  150. if obj.err != nil {
  151. errFileIds = append(errFileIds, obj.dbEmbeddingRow.FileID)
  152. } else {
  153. fetchedEmbeddings = append(fetchedEmbeddings, ente.Embedding{
  154. FileID: obj.dbEmbeddingRow.FileID,
  155. Model: obj.dbEmbeddingRow.Model,
  156. EncryptedEmbedding: obj.embeddingObject.EncryptedEmbedding,
  157. DecryptionHeader: obj.embeddingObject.DecryptionHeader,
  158. UpdatedAt: obj.dbEmbeddingRow.UpdatedAt,
  159. Version: obj.dbEmbeddingRow.Version,
  160. })
  161. }
  162. }
  163. return &ente.GetFilesEmbeddingResponse{
  164. Embeddings: fetchedEmbeddings,
  165. PendingIndexFileIDs: pendingIndexFileIds,
  166. ErrFileIDs: errFileIds,
  167. NoEmbeddingFileIDs: noEmbeddingFileIds,
  168. }, nil
  169. }
  170. func (c *Controller) DeleteAll(ctx *gin.Context) error {
  171. userID := auth.GetUserID(ctx.Request.Header)
  172. err := c.Repo.DeleteAll(ctx, userID)
  173. if err != nil {
  174. return stacktrace.Propagate(err, "")
  175. }
  176. return nil
  177. }
  178. // CleanupDeletedEmbeddings clears all embeddings for deleted files from the object store
  179. func (c *Controller) CleanupDeletedEmbeddings() {
  180. log.Info("Cleaning up deleted embeddings")
  181. if c.cleanupCronRunning {
  182. log.Info("Skipping CleanupDeletedEmbeddings cron run as another instance is still running")
  183. return
  184. }
  185. c.cleanupCronRunning = true
  186. defer func() {
  187. c.cleanupCronRunning = false
  188. }()
  189. items, err := c.QueueRepo.GetItemsReadyForDeletion(repo.DeleteEmbeddingsQueue, 200)
  190. if err != nil {
  191. log.WithError(err).Error("Failed to fetch items from queue")
  192. return
  193. }
  194. for _, i := range items {
  195. c.deleteEmbedding(i)
  196. }
  197. }
  198. func (c *Controller) deleteEmbedding(qItem repo.QueueItem) {
  199. lockName := fmt.Sprintf("Embedding:%s", qItem.Item)
  200. lockStatus, err := c.TaskLockingRepo.AcquireLock(lockName, time.MicrosecondsAfterHours(1), c.HostName)
  201. ctxLogger := log.WithField("item", qItem.Item).WithField("queue_id", qItem.Id)
  202. if err != nil || !lockStatus {
  203. ctxLogger.Warn("unable to acquire lock")
  204. return
  205. }
  206. defer func() {
  207. err = c.TaskLockingRepo.ReleaseLock(lockName)
  208. if err != nil {
  209. ctxLogger.Errorf("Error while releasing lock %s", err)
  210. }
  211. }()
  212. ctxLogger.Info("Deleting all embeddings")
  213. fileID, _ := strconv.ParseInt(qItem.Item, 10, 64)
  214. ownerID, err := c.FileRepo.GetOwnerID(fileID)
  215. if err != nil {
  216. ctxLogger.WithError(err).Error("Failed to fetch ownerID")
  217. return
  218. }
  219. prefix := c.getEmbeddingObjectPrefix(ownerID, fileID)
  220. err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetHotDataCenter())
  221. if err != nil {
  222. ctxLogger.WithError(err).Error("Failed to delete all objects")
  223. return
  224. }
  225. err = c.Repo.Delete(fileID)
  226. if err != nil {
  227. ctxLogger.WithError(err).Error("Failed to remove from db")
  228. return
  229. }
  230. err = c.QueueRepo.DeleteItem(repo.DeleteEmbeddingsQueue, qItem.Item)
  231. if err != nil {
  232. ctxLogger.WithError(err).Error("Failed to remove item from the queue")
  233. return
  234. }
  235. ctxLogger.Info("Successfully deleted all embeddings")
  236. }
  237. func (c *Controller) getObjectKey(userID int64, fileID int64, model string) string {
  238. return c.getEmbeddingObjectPrefix(userID, fileID) + model + ".json"
  239. }
  240. func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string {
  241. return strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/"
  242. }
  243. // uploadObject uploads the embedding object to the object store and returns the object size
  244. func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, error) {
  245. embeddingObj, _ := json.Marshal(obj)
  246. uploader := s3manager.NewUploaderWithClient(c.S3Config.GetHotS3Client())
  247. up := s3manager.UploadInput{
  248. Bucket: c.S3Config.GetHotBucket(),
  249. Key: &key,
  250. Body: bytes.NewReader(embeddingObj),
  251. }
  252. result, err := uploader.Upload(&up)
  253. if err != nil {
  254. log.Error(err)
  255. return -1, stacktrace.Propagate(err, "")
  256. }
  257. log.Infof("Uploaded to bucket %s", result.Location)
  258. return len(embeddingObj), nil
  259. }
  260. var globalDiffFetchSemaphore = make(chan struct{}, 300)
  261. var globalFileFetchSemaphore = make(chan struct{}, 400)
  262. func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.EmbeddingObject, error) {
  263. var wg sync.WaitGroup
  264. var errs []error
  265. embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys))
  266. downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
  267. for i, objectKey := range objectKeys {
  268. wg.Add(1)
  269. globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore
  270. go func(i int, objectKey string) {
  271. defer wg.Done()
  272. defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
  273. obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
  274. if err != nil {
  275. errs = append(errs, err)
  276. log.Error("error fetching embedding object: "+objectKey, err)
  277. } else {
  278. embeddingObjects[i] = obj
  279. }
  280. }(i, objectKey)
  281. }
  282. wg.Wait()
  283. if len(errs) > 0 {
  284. return nil, stacktrace.Propagate(errors.New("failed to fetch some objects"), "")
  285. }
  286. return embeddingObjects, nil
  287. }
  288. type embeddingObjectResult struct {
  289. embeddingObject ente.EmbeddingObject
  290. dbEmbeddingRow ente.Embedding
  291. err error
  292. }
  293. func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) {
  294. var wg sync.WaitGroup
  295. embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows))
  296. downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
  297. for i, dbEmbeddingRow := range dbEmbeddingRows {
  298. wg.Add(1)
  299. globalFileFetchSemaphore <- struct{}{} // Acquire from global semaphore
  300. go func(i int, dbEmbeddingRow ente.Embedding) {
  301. defer wg.Done()
  302. defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
  303. objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
  304. obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
  305. if err != nil {
  306. log.Error("error fetching embedding object: "+objectKey, err)
  307. embeddingObjects[i] = embeddingObjectResult{
  308. err: err,
  309. dbEmbeddingRow: dbEmbeddingRow,
  310. }
  311. } else {
  312. embeddingObjects[i] = embeddingObjectResult{
  313. embeddingObject: obj,
  314. dbEmbeddingRow: dbEmbeddingRow,
  315. }
  316. }
  317. }(i, dbEmbeddingRow)
  318. }
  319. wg.Wait()
  320. return embeddingObjects, nil
  321. }
  322. type getOptions struct {
  323. RetryCount int
  324. FetchTimeOut gTime.Duration
  325. }
  326. func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, opt *getOptions) (ente.EmbeddingObject, error) {
  327. if opt == nil {
  328. opt = &getOptions{
  329. RetryCount: 3,
  330. FetchTimeOut: embeddingFetchTimeout,
  331. }
  332. }
  333. ctxLogger := log.WithField("objectKey", objectKey)
  334. totalAttempts := opt.RetryCount + 1
  335. for i := 0; i < totalAttempts; i++ {
  336. // Create a new context with a timeout for each fetch
  337. fetchCtx, cancel := context.WithTimeout(ctx, opt.FetchTimeOut)
  338. select {
  339. case <-ctx.Done():
  340. cancel()
  341. return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "")
  342. default:
  343. obj, err := c.downloadObject(fetchCtx, objectKey, downloader)
  344. cancel() // Ensure cancel is called to release resources
  345. if err == nil {
  346. return obj, nil
  347. }
  348. // Check if the error is due to context timeout or cancellation
  349. if fetchCtx.Err() != nil {
  350. ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err())
  351. } else {
  352. // check if the error is due to object not found
  353. if s3Err, ok := err.(awserr.Error); ok {
  354. if s3Err.Code() == s3.ErrCodeNoSuchKey {
  355. ctxLogger.Warn("Object not found: ", s3Err)
  356. return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
  357. }
  358. }
  359. ctxLogger.Error("Failed to fetch object: ", err)
  360. }
  361. }
  362. }
  363. return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "")
  364. }
  365. func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
  366. var obj ente.EmbeddingObject
  367. buff := &aws.WriteAtBuffer{}
  368. _, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
  369. Bucket: c.S3Config.GetHotBucket(),
  370. Key: &objectKey,
  371. })
  372. if err != nil {
  373. return obj, stacktrace.Propagate(err, "downloadFailed")
  374. }
  375. err = json.Unmarshal(buff.Bytes(), &obj)
  376. if err != nil {
  377. return obj, stacktrace.Propagate(err, "unmarshal failed")
  378. }
  379. return obj, nil
  380. }
  381. func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error {
  382. if req.Model == "" {
  383. return ente.NewBadRequestWithMessage("model is required")
  384. }
  385. if len(req.FileIDs) == 0 {
  386. return ente.NewBadRequestWithMessage("fileIDs are required")
  387. }
  388. if len(req.FileIDs) > 200 {
  389. return ente.NewBadRequestWithMessage("fileIDs should be less than or equal to 200")
  390. }
  391. if err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
  392. ActorUserId: userID,
  393. FileIDs: req.FileIDs,
  394. }); err != nil {
  395. return stacktrace.Propagate(err, "User does not own some file(s)")
  396. }
  397. return nil
  398. }