controller.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. package embedding
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "errors"
  7. "github.com/aws/aws-sdk-go/aws/awserr"
  8. "github.com/ente-io/museum/pkg/utils/array"
  9. "strconv"
  10. "sync"
  11. gTime "time"
  12. "github.com/aws/aws-sdk-go/aws"
  13. "github.com/aws/aws-sdk-go/service/s3"
  14. "github.com/aws/aws-sdk-go/service/s3/s3manager"
  15. "github.com/ente-io/museum/ente"
  16. "github.com/ente-io/museum/pkg/controller"
  17. "github.com/ente-io/museum/pkg/controller/access"
  18. "github.com/ente-io/museum/pkg/repo"
  19. "github.com/ente-io/museum/pkg/repo/embedding"
  20. "github.com/ente-io/museum/pkg/utils/auth"
  21. "github.com/ente-io/museum/pkg/utils/network"
  22. "github.com/ente-io/museum/pkg/utils/s3config"
  23. "github.com/ente-io/stacktrace"
  24. "github.com/gin-gonic/gin"
  25. log "github.com/sirupsen/logrus"
  26. )
  27. const (
  28. // maxEmbeddingDataSize is the min size of an embedding object in bytes
  29. minEmbeddingDataSize = 2048
  30. embeddingFetchTimeout = 15 * gTime.Second
  31. )
  32. type Controller struct {
  33. Repo *embedding.Repository
  34. AccessCtrl access.Controller
  35. ObjectCleanupController *controller.ObjectCleanupController
  36. S3Config *s3config.S3Config
  37. QueueRepo *repo.QueueRepository
  38. TaskLockingRepo *repo.TaskLockRepository
  39. FileRepo *repo.FileRepository
  40. CollectionRepo *repo.CollectionRepository
  41. HostName string
  42. cleanupCronRunning bool
  43. derivedStorageS3Client *s3.S3
  44. derivedStorageDataCenter string
  45. areDerivedAndHotBucketSame bool
  46. }
  47. func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller {
  48. return &Controller{
  49. Repo: repo,
  50. AccessCtrl: accessCtrl,
  51. ObjectCleanupController: objectCleanupController,
  52. S3Config: s3Config,
  53. QueueRepo: queueRepo,
  54. TaskLockingRepo: taskLockingRepo,
  55. FileRepo: fileRepo,
  56. CollectionRepo: collectionRepo,
  57. HostName: hostName,
  58. derivedStorageS3Client: s3Config.GetDerivedStorageS3Client(),
  59. derivedStorageDataCenter: s3Config.GetDerivedStorageDataCenter(),
  60. areDerivedAndHotBucketSame: s3Config.GetDerivedStorageDataCenter() == s3Config.GetDerivedStorageDataCenter(),
  61. }
  62. }
  63. func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) (*ente.Embedding, error) {
  64. userID := auth.GetUserID(ctx.Request.Header)
  65. err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
  66. ActorUserId: userID,
  67. FileIDs: []int64{req.FileID},
  68. })
  69. if err != nil {
  70. return nil, stacktrace.Propagate(err, "User does not own file")
  71. }
  72. count, err := c.CollectionRepo.GetCollectionCount(req.FileID)
  73. if err != nil {
  74. return nil, stacktrace.Propagate(err, "")
  75. }
  76. if count < 1 {
  77. return nil, stacktrace.Propagate(ente.ErrNotFound, "")
  78. }
  79. version := 1
  80. if req.Version != nil {
  81. version = *req.Version
  82. }
  83. obj := ente.EmbeddingObject{
  84. Version: version,
  85. EncryptedEmbedding: req.EncryptedEmbedding,
  86. DecryptionHeader: req.DecryptionHeader,
  87. Client: network.GetPrettyUA(ctx.GetHeader("User-Agent")) + "/" + ctx.GetHeader("X-Client-Version"),
  88. }
  89. size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model), c.derivedStorageDataCenter)
  90. if uploadErr != nil {
  91. log.Error(uploadErr)
  92. return nil, stacktrace.Propagate(uploadErr, "")
  93. }
  94. embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version, c.derivedStorageDataCenter)
  95. embedding.Version = &version
  96. if err != nil {
  97. return nil, stacktrace.Propagate(err, "")
  98. }
  99. return &embedding, nil
  100. }
  101. func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest) ([]ente.Embedding, error) {
  102. userID := auth.GetUserID(ctx.Request.Header)
  103. if req.Model == "" {
  104. req.Model = ente.GgmlClip
  105. }
  106. embeddings, err := c.Repo.GetDiff(ctx, userID, req.Model, *req.SinceTime, req.Limit)
  107. if err != nil {
  108. return nil, stacktrace.Propagate(err, "")
  109. }
  110. // Collect object keys for embeddings with missing data
  111. var objectKeys []string
  112. for i := range embeddings {
  113. if embeddings[i].EncryptedEmbedding == "" {
  114. objectKey := c.getObjectKey(userID, embeddings[i].FileID, embeddings[i].Model)
  115. objectKeys = append(objectKeys, objectKey)
  116. }
  117. }
  118. // Fetch missing embeddings in parallel
  119. if len(objectKeys) > 0 {
  120. embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys)
  121. if err != nil {
  122. return nil, stacktrace.Propagate(err, "")
  123. }
  124. // Populate missing data in embeddings from fetched objects
  125. for i, obj := range embeddingObjects {
  126. for j := range embeddings {
  127. if embeddings[j].EncryptedEmbedding == "" && c.getObjectKey(userID, embeddings[j].FileID, embeddings[j].Model) == objectKeys[i] {
  128. embeddings[j].EncryptedEmbedding = obj.EncryptedEmbedding
  129. embeddings[j].DecryptionHeader = obj.DecryptionHeader
  130. }
  131. }
  132. }
  133. }
  134. return embeddings, nil
  135. }
  136. func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbeddingRequest) (*ente.GetFilesEmbeddingResponse, error) {
  137. userID := auth.GetUserID(ctx.Request.Header)
  138. if err := c._validateGetFileEmbeddingsRequest(ctx, userID, req); err != nil {
  139. return nil, stacktrace.Propagate(err, "")
  140. }
  141. userFileEmbeddings, err := c.Repo.GetFilesEmbedding(ctx, userID, req.Model, req.FileIDs)
  142. if err != nil {
  143. return nil, stacktrace.Propagate(err, "")
  144. }
  145. embeddingsWithData := make([]ente.Embedding, 0)
  146. noEmbeddingFileIds := make([]int64, 0)
  147. dbFileIds := make([]int64, 0)
  148. // fileIDs that were indexed, but they don't contain any embedding information
  149. for i := range userFileEmbeddings {
  150. dbFileIds = append(dbFileIds, userFileEmbeddings[i].FileID)
  151. if userFileEmbeddings[i].Size != nil && *userFileEmbeddings[i].Size < minEmbeddingDataSize {
  152. noEmbeddingFileIds = append(noEmbeddingFileIds, userFileEmbeddings[i].FileID)
  153. } else {
  154. embeddingsWithData = append(embeddingsWithData, userFileEmbeddings[i])
  155. }
  156. }
  157. pendingIndexFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds)
  158. errFileIds := make([]int64, 0)
  159. // Fetch missing userFileEmbeddings in parallel
  160. embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData)
  161. if err != nil {
  162. return nil, stacktrace.Propagate(err, "")
  163. }
  164. fetchedEmbeddings := make([]ente.Embedding, 0)
  165. // Populate missing data in userFileEmbeddings from fetched objects
  166. for _, obj := range embeddingObjects {
  167. if obj.err != nil {
  168. errFileIds = append(errFileIds, obj.dbEmbeddingRow.FileID)
  169. } else {
  170. fetchedEmbeddings = append(fetchedEmbeddings, ente.Embedding{
  171. FileID: obj.dbEmbeddingRow.FileID,
  172. Model: obj.dbEmbeddingRow.Model,
  173. EncryptedEmbedding: obj.embeddingObject.EncryptedEmbedding,
  174. DecryptionHeader: obj.embeddingObject.DecryptionHeader,
  175. UpdatedAt: obj.dbEmbeddingRow.UpdatedAt,
  176. Version: obj.dbEmbeddingRow.Version,
  177. })
  178. }
  179. }
  180. return &ente.GetFilesEmbeddingResponse{
  181. Embeddings: fetchedEmbeddings,
  182. PendingIndexFileIDs: pendingIndexFileIds,
  183. ErrFileIDs: errFileIds,
  184. NoEmbeddingFileIDs: noEmbeddingFileIds,
  185. }, nil
  186. }
  187. func (c *Controller) getObjectKey(userID int64, fileID int64, model string) string {
  188. return c.getEmbeddingObjectPrefix(userID, fileID) + model + ".json"
  189. }
  190. func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string {
  191. return strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/"
  192. }
  193. // uploadObject uploads the embedding object to the object store and returns the object size
  194. func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) {
  195. embeddingObj, _ := json.Marshal(obj)
  196. s3Client := c.S3Config.GetS3Client(dc)
  197. s3Bucket := c.S3Config.GetBucket(dc)
  198. uploader := s3manager.NewUploaderWithClient(&s3Client)
  199. up := s3manager.UploadInput{
  200. Bucket: s3Bucket,
  201. Key: &key,
  202. Body: bytes.NewReader(embeddingObj),
  203. }
  204. result, err := uploader.Upload(&up)
  205. if err != nil {
  206. log.Error(err)
  207. return -1, stacktrace.Propagate(err, "")
  208. }
  209. log.Infof("Uploaded to bucket %s", result.Location)
  210. return len(embeddingObj), nil
  211. }
  212. var globalDiffFetchSemaphore = make(chan struct{}, 300)
  213. var globalFileFetchSemaphore = make(chan struct{}, 400)
  214. func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.EmbeddingObject, error) {
  215. var wg sync.WaitGroup
  216. var errs []error
  217. embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys))
  218. downloader := s3manager.NewDownloaderWithClient(c.derivedStorageS3Client)
  219. for i, objectKey := range objectKeys {
  220. wg.Add(1)
  221. globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore
  222. go func(i int, objectKey string) {
  223. defer wg.Done()
  224. defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
  225. obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
  226. if err != nil {
  227. errs = append(errs, err)
  228. log.Error("error fetching embedding object: "+objectKey, err)
  229. } else {
  230. embeddingObjects[i] = obj
  231. }
  232. }(i, objectKey)
  233. }
  234. wg.Wait()
  235. if len(errs) > 0 {
  236. return nil, stacktrace.Propagate(errors.New("failed to fetch some objects"), "")
  237. }
  238. return embeddingObjects, nil
  239. }
  240. type embeddingObjectResult struct {
  241. embeddingObject ente.EmbeddingObject
  242. dbEmbeddingRow ente.Embedding
  243. err error
  244. }
  245. func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) {
  246. var wg sync.WaitGroup
  247. embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows))
  248. downloader := s3manager.NewDownloaderWithClient(c.derivedStorageS3Client)
  249. for i, dbEmbeddingRow := range dbEmbeddingRows {
  250. wg.Add(1)
  251. globalFileFetchSemaphore <- struct{}{} // Acquire from global semaphore
  252. go func(i int, dbEmbeddingRow ente.Embedding) {
  253. defer wg.Done()
  254. defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
  255. objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
  256. obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
  257. if err != nil {
  258. log.Error("error fetching embedding object: "+objectKey, err)
  259. embeddingObjects[i] = embeddingObjectResult{
  260. err: err,
  261. dbEmbeddingRow: dbEmbeddingRow,
  262. }
  263. } else {
  264. embeddingObjects[i] = embeddingObjectResult{
  265. embeddingObject: obj,
  266. dbEmbeddingRow: dbEmbeddingRow,
  267. }
  268. }
  269. }(i, dbEmbeddingRow)
  270. }
  271. wg.Wait()
  272. return embeddingObjects, nil
  273. }
  274. type getOptions struct {
  275. RetryCount int
  276. FetchTimeOut gTime.Duration
  277. }
  278. func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, opt *getOptions) (ente.EmbeddingObject, error) {
  279. if opt == nil {
  280. opt = &getOptions{
  281. RetryCount: 3,
  282. FetchTimeOut: embeddingFetchTimeout,
  283. }
  284. }
  285. ctxLogger := log.WithField("objectKey", objectKey)
  286. totalAttempts := opt.RetryCount + 1
  287. for i := 0; i < totalAttempts; i++ {
  288. // Create a new context with a timeout for each fetch
  289. fetchCtx, cancel := context.WithTimeout(ctx, opt.FetchTimeOut)
  290. select {
  291. case <-ctx.Done():
  292. cancel()
  293. return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "")
  294. default:
  295. obj, err := c.downloadObject(fetchCtx, objectKey, downloader, c.derivedStorageDataCenter)
  296. cancel() // Ensure cancel is called to release resources
  297. if err == nil {
  298. if i > 0 {
  299. ctxLogger.Infof("Fetched object after %d attempts", i)
  300. }
  301. return obj, nil
  302. }
  303. // Check if the error is due to context timeout or cancellation
  304. if err == nil && fetchCtx.Err() != nil {
  305. ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err())
  306. } else {
  307. // check if the error is due to object not found
  308. if s3Err, ok := err.(awserr.RequestFailure); ok {
  309. if s3Err.Code() == s3.ErrCodeNoSuchKey {
  310. if c.areDerivedAndHotBucketSame {
  311. ctxLogger.Error("Object not found: ", s3Err)
  312. } else {
  313. // If derived and hot bucket are different, try to copy from hot bucket
  314. copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey)
  315. if err == nil {
  316. ctxLogger.Info("Got the object from hot bucket object")
  317. return *copyEmbeddingObject, nil
  318. } else {
  319. ctxLogger.WithError(err).Error("Failed to copy from hot bucket object")
  320. }
  321. return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
  322. }
  323. }
  324. }
  325. ctxLogger.Error("Failed to fetch object: ", err)
  326. }
  327. }
  328. }
  329. return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "")
  330. }
  331. func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, dc string) (ente.EmbeddingObject, error) {
  332. var obj ente.EmbeddingObject
  333. buff := &aws.WriteAtBuffer{}
  334. bucket := c.S3Config.GetBucket(dc)
  335. _, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
  336. Bucket: bucket,
  337. Key: &objectKey,
  338. })
  339. if err != nil {
  340. return obj, err
  341. }
  342. err = json.Unmarshal(buff.Bytes(), &obj)
  343. if err != nil {
  344. return obj, stacktrace.Propagate(err, "unmarshal failed")
  345. }
  346. return obj, nil
  347. }
  348. // download the embedding object from hot bucket and upload to embeddings bucket
  349. func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) {
  350. if c.derivedStorageDataCenter == c.S3Config.GetHotDataCenter() {
  351. return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "")
  352. }
  353. downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
  354. obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotDataCenter())
  355. if err != nil {
  356. return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
  357. }
  358. go func() {
  359. _, err = c.uploadObject(obj, objectKey, c.derivedStorageDataCenter)
  360. if err != nil {
  361. log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", err)
  362. }
  363. }()
  364. return &obj, nil
  365. }
  366. func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error {
  367. if req.Model == "" {
  368. return ente.NewBadRequestWithMessage("model is required")
  369. }
  370. if len(req.FileIDs) == 0 {
  371. return ente.NewBadRequestWithMessage("fileIDs are required")
  372. }
  373. if len(req.FileIDs) > 200 {
  374. return ente.NewBadRequestWithMessage("fileIDs should be less than or equal to 200")
  375. }
  376. if err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
  377. ActorUserId: userID,
  378. FileIDs: req.FileIDs,
  379. }); err != nil {
  380. return stacktrace.Propagate(err, "User does not own some file(s)")
  381. }
  382. return nil
  383. }