controller.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. package embedding
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "errors"
  7. "github.com/aws/aws-sdk-go/aws/awserr"
  8. "github.com/ente-io/museum/pkg/utils/array"
  9. "strconv"
  10. "sync"
  11. gTime "time"
  12. "github.com/aws/aws-sdk-go/aws"
  13. "github.com/aws/aws-sdk-go/service/s3"
  14. "github.com/aws/aws-sdk-go/service/s3/s3manager"
  15. "github.com/ente-io/museum/ente"
  16. "github.com/ente-io/museum/pkg/controller"
  17. "github.com/ente-io/museum/pkg/controller/access"
  18. "github.com/ente-io/museum/pkg/repo"
  19. "github.com/ente-io/museum/pkg/repo/embedding"
  20. "github.com/ente-io/museum/pkg/utils/auth"
  21. "github.com/ente-io/museum/pkg/utils/network"
  22. "github.com/ente-io/museum/pkg/utils/s3config"
  23. "github.com/ente-io/stacktrace"
  24. "github.com/gin-gonic/gin"
  25. log "github.com/sirupsen/logrus"
  26. )
  27. const (
  28. // maxEmbeddingDataSize is the min size of an embedding object in bytes
  29. minEmbeddingDataSize = 2048
  30. embeddingFetchTimeout = 15 * gTime.Second
  31. )
  32. type Controller struct {
  33. Repo *embedding.Repository
  34. AccessCtrl access.Controller
  35. ObjectCleanupController *controller.ObjectCleanupController
  36. S3Config *s3config.S3Config
  37. QueueRepo *repo.QueueRepository
  38. TaskLockingRepo *repo.TaskLockRepository
  39. FileRepo *repo.FileRepository
  40. CollectionRepo *repo.CollectionRepository
  41. HostName string
  42. cleanupCronRunning bool
  43. derivedStorageDataCenter string
  44. areDerivedAndHotBucketSame bool
  45. }
  46. func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller {
  47. return &Controller{
  48. Repo: repo,
  49. AccessCtrl: accessCtrl,
  50. ObjectCleanupController: objectCleanupController,
  51. S3Config: s3Config,
  52. QueueRepo: queueRepo,
  53. TaskLockingRepo: taskLockingRepo,
  54. FileRepo: fileRepo,
  55. CollectionRepo: collectionRepo,
  56. HostName: hostName,
  57. derivedStorageDataCenter: s3Config.GetDerivedStorageDataCenter(),
  58. areDerivedAndHotBucketSame: s3Config.GetDerivedStorageDataCenter() == s3Config.GetDerivedStorageDataCenter(),
  59. }
  60. }
  61. func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) (*ente.Embedding, error) {
  62. userID := auth.GetUserID(ctx.Request.Header)
  63. err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
  64. ActorUserId: userID,
  65. FileIDs: []int64{req.FileID},
  66. })
  67. if err != nil {
  68. return nil, stacktrace.Propagate(err, "User does not own file")
  69. }
  70. count, err := c.CollectionRepo.GetCollectionCount(req.FileID)
  71. if err != nil {
  72. return nil, stacktrace.Propagate(err, "")
  73. }
  74. if count < 1 {
  75. return nil, stacktrace.Propagate(ente.ErrNotFound, "")
  76. }
  77. version := 1
  78. if req.Version != nil {
  79. version = *req.Version
  80. }
  81. obj := ente.EmbeddingObject{
  82. Version: version,
  83. EncryptedEmbedding: req.EncryptedEmbedding,
  84. DecryptionHeader: req.DecryptionHeader,
  85. Client: network.GetPrettyUA(ctx.GetHeader("User-Agent")) + "/" + ctx.GetHeader("X-Client-Version"),
  86. }
  87. size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model), c.derivedStorageDataCenter)
  88. if uploadErr != nil {
  89. log.Error(uploadErr)
  90. return nil, stacktrace.Propagate(uploadErr, "")
  91. }
  92. embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version, c.derivedStorageDataCenter)
  93. embedding.Version = &version
  94. if err != nil {
  95. return nil, stacktrace.Propagate(err, "")
  96. }
  97. return &embedding, nil
  98. }
  99. func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest) ([]ente.Embedding, error) {
  100. userID := auth.GetUserID(ctx.Request.Header)
  101. if req.Model == "" {
  102. req.Model = ente.GgmlClip
  103. }
  104. embeddings, err := c.Repo.GetDiff(ctx, userID, req.Model, *req.SinceTime, req.Limit)
  105. if err != nil {
  106. return nil, stacktrace.Propagate(err, "")
  107. }
  108. // Collect object keys for embeddings with missing data
  109. var objectKeys []string
  110. for i := range embeddings {
  111. if embeddings[i].EncryptedEmbedding == "" {
  112. objectKey := c.getObjectKey(userID, embeddings[i].FileID, embeddings[i].Model)
  113. objectKeys = append(objectKeys, objectKey)
  114. }
  115. }
  116. // Fetch missing embeddings in parallel
  117. if len(objectKeys) > 0 {
  118. embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys)
  119. if err != nil {
  120. return nil, stacktrace.Propagate(err, "")
  121. }
  122. // Populate missing data in embeddings from fetched objects
  123. for i, obj := range embeddingObjects {
  124. for j := range embeddings {
  125. if embeddings[j].EncryptedEmbedding == "" && c.getObjectKey(userID, embeddings[j].FileID, embeddings[j].Model) == objectKeys[i] {
  126. embeddings[j].EncryptedEmbedding = obj.EncryptedEmbedding
  127. embeddings[j].DecryptionHeader = obj.DecryptionHeader
  128. }
  129. }
  130. }
  131. }
  132. return embeddings, nil
  133. }
  134. func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbeddingRequest) (*ente.GetFilesEmbeddingResponse, error) {
  135. userID := auth.GetUserID(ctx.Request.Header)
  136. if err := c._validateGetFileEmbeddingsRequest(ctx, userID, req); err != nil {
  137. return nil, stacktrace.Propagate(err, "")
  138. }
  139. userFileEmbeddings, err := c.Repo.GetFilesEmbedding(ctx, userID, req.Model, req.FileIDs)
  140. if err != nil {
  141. return nil, stacktrace.Propagate(err, "")
  142. }
  143. embeddingsWithData := make([]ente.Embedding, 0)
  144. noEmbeddingFileIds := make([]int64, 0)
  145. dbFileIds := make([]int64, 0)
  146. // fileIDs that were indexed, but they don't contain any embedding information
  147. for i := range userFileEmbeddings {
  148. dbFileIds = append(dbFileIds, userFileEmbeddings[i].FileID)
  149. if userFileEmbeddings[i].Size != nil && *userFileEmbeddings[i].Size < minEmbeddingDataSize {
  150. noEmbeddingFileIds = append(noEmbeddingFileIds, userFileEmbeddings[i].FileID)
  151. } else {
  152. embeddingsWithData = append(embeddingsWithData, userFileEmbeddings[i])
  153. }
  154. }
  155. pendingIndexFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds)
  156. errFileIds := make([]int64, 0)
  157. // Fetch missing userFileEmbeddings in parallel
  158. embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData)
  159. if err != nil {
  160. return nil, stacktrace.Propagate(err, "")
  161. }
  162. fetchedEmbeddings := make([]ente.Embedding, 0)
  163. // Populate missing data in userFileEmbeddings from fetched objects
  164. for _, obj := range embeddingObjects {
  165. if obj.err != nil {
  166. errFileIds = append(errFileIds, obj.dbEmbeddingRow.FileID)
  167. } else {
  168. fetchedEmbeddings = append(fetchedEmbeddings, ente.Embedding{
  169. FileID: obj.dbEmbeddingRow.FileID,
  170. Model: obj.dbEmbeddingRow.Model,
  171. EncryptedEmbedding: obj.embeddingObject.EncryptedEmbedding,
  172. DecryptionHeader: obj.embeddingObject.DecryptionHeader,
  173. UpdatedAt: obj.dbEmbeddingRow.UpdatedAt,
  174. Version: obj.dbEmbeddingRow.Version,
  175. })
  176. }
  177. }
  178. return &ente.GetFilesEmbeddingResponse{
  179. Embeddings: fetchedEmbeddings,
  180. PendingIndexFileIDs: pendingIndexFileIds,
  181. ErrFileIDs: errFileIds,
  182. NoEmbeddingFileIDs: noEmbeddingFileIds,
  183. }, nil
  184. }
  185. func (c *Controller) getObjectKey(userID int64, fileID int64, model string) string {
  186. return c.getEmbeddingObjectPrefix(userID, fileID) + model + ".json"
  187. }
  188. func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string {
  189. return strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/"
  190. }
  191. // uploadObject uploads the embedding object to the object store and returns the object size
  192. func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) {
  193. embeddingObj, _ := json.Marshal(obj)
  194. s3Client := c.S3Config.GetS3Client(dc)
  195. s3Bucket := c.S3Config.GetBucket(dc)
  196. uploader := s3manager.NewUploaderWithClient(&s3Client)
  197. up := s3manager.UploadInput{
  198. Bucket: s3Bucket,
  199. Key: &key,
  200. Body: bytes.NewReader(embeddingObj),
  201. }
  202. result, err := uploader.Upload(&up)
  203. if err != nil {
  204. log.Error(err)
  205. return -1, stacktrace.Propagate(err, "")
  206. }
  207. log.Infof("Uploaded to bucket %s", result.Location)
  208. return len(embeddingObj), nil
  209. }
  210. var globalDiffFetchSemaphore = make(chan struct{}, 300)
  211. var globalFileFetchSemaphore = make(chan struct{}, 400)
  212. func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.EmbeddingObject, error) {
  213. var wg sync.WaitGroup
  214. var errs []error
  215. embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys))
  216. s3Client := c.S3Config.GetS3Client(c.derivedStorageDataCenter)
  217. downloader := s3manager.NewDownloaderWithClient(&s3Client)
  218. for i, objectKey := range objectKeys {
  219. wg.Add(1)
  220. globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore
  221. go func(i int, objectKey string) {
  222. defer wg.Done()
  223. defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
  224. obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
  225. if err != nil {
  226. errs = append(errs, err)
  227. log.Error("error fetching embedding object: "+objectKey, err)
  228. } else {
  229. embeddingObjects[i] = obj
  230. }
  231. }(i, objectKey)
  232. }
  233. wg.Wait()
  234. if len(errs) > 0 {
  235. return nil, stacktrace.Propagate(errors.New("failed to fetch some objects"), "")
  236. }
  237. return embeddingObjects, nil
  238. }
  239. type embeddingObjectResult struct {
  240. embeddingObject ente.EmbeddingObject
  241. dbEmbeddingRow ente.Embedding
  242. err error
  243. }
  244. func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) {
  245. var wg sync.WaitGroup
  246. embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows))
  247. s3Client := c.S3Config.GetS3Client(c.derivedStorageDataCenter)
  248. downloader := s3manager.NewDownloaderWithClient(&s3Client)
  249. for i, dbEmbeddingRow := range dbEmbeddingRows {
  250. wg.Add(1)
  251. globalFileFetchSemaphore <- struct{}{} // Acquire from global semaphore
  252. go func(i int, dbEmbeddingRow ente.Embedding) {
  253. defer wg.Done()
  254. defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
  255. objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
  256. obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
  257. if err != nil {
  258. log.Error("error fetching embedding object: "+objectKey, err)
  259. embeddingObjects[i] = embeddingObjectResult{
  260. err: err,
  261. dbEmbeddingRow: dbEmbeddingRow,
  262. }
  263. } else {
  264. embeddingObjects[i] = embeddingObjectResult{
  265. embeddingObject: obj,
  266. dbEmbeddingRow: dbEmbeddingRow,
  267. }
  268. }
  269. }(i, dbEmbeddingRow)
  270. }
  271. wg.Wait()
  272. return embeddingObjects, nil
  273. }
  274. type getOptions struct {
  275. RetryCount int
  276. FetchTimeOut gTime.Duration
  277. }
  278. func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, opt *getOptions) (ente.EmbeddingObject, error) {
  279. if opt == nil {
  280. opt = &getOptions{
  281. RetryCount: 3,
  282. FetchTimeOut: embeddingFetchTimeout,
  283. }
  284. }
  285. ctxLogger := log.WithField("objectKey", objectKey)
  286. totalAttempts := opt.RetryCount + 1
  287. for i := 0; i < totalAttempts; i++ {
  288. // Create a new context with a timeout for each fetch
  289. fetchCtx, cancel := context.WithTimeout(ctx, opt.FetchTimeOut)
  290. select {
  291. case <-ctx.Done():
  292. cancel()
  293. return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "")
  294. default:
  295. obj, err := c.downloadObject(fetchCtx, objectKey, downloader, c.derivedStorageDataCenter)
  296. cancel() // Ensure cancel is called to release resources
  297. if err == nil {
  298. if i > 0 {
  299. ctxLogger.Infof("Fetched object after %d attempts", i)
  300. }
  301. return obj, nil
  302. }
  303. // Check if the error is due to context timeout or cancellation
  304. if err == nil && fetchCtx.Err() != nil {
  305. ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err())
  306. } else {
  307. // check if the error is due to object not found
  308. if s3Err, ok := err.(awserr.RequestFailure); ok {
  309. if s3Err.Code() == s3.ErrCodeNoSuchKey {
  310. if c.derivedStorageDataCenter == c.S3Config.GetHotDataCenter() {
  311. ctxLogger.Error("Object not found: ", s3Err)
  312. } else {
  313. // If derived and hot bucket are different, try to copy from hot bucket
  314. copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey)
  315. if err == nil {
  316. ctxLogger.Info("Got the object from hot bucket object")
  317. return *copyEmbeddingObject, nil
  318. } else {
  319. ctxLogger.WithError(err).Error("Failed to copy from hot bucket object")
  320. }
  321. return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
  322. }
  323. }
  324. }
  325. ctxLogger.Error("Failed to fetch object: ", err)
  326. }
  327. }
  328. }
  329. return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "")
  330. }
  331. func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, dc string) (ente.EmbeddingObject, error) {
  332. var obj ente.EmbeddingObject
  333. buff := &aws.WriteAtBuffer{}
  334. bucket := c.S3Config.GetBucket(dc)
  335. _, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
  336. Bucket: bucket,
  337. Key: &objectKey,
  338. })
  339. if err != nil {
  340. return obj, err
  341. }
  342. err = json.Unmarshal(buff.Bytes(), &obj)
  343. if err != nil {
  344. return obj, stacktrace.Propagate(err, "unmarshal failed")
  345. }
  346. return obj, nil
  347. }
  348. // download the embedding object from hot bucket and upload to embeddings bucket
  349. func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) {
  350. if c.derivedStorageDataCenter == c.S3Config.GetHotDataCenter() {
  351. return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "")
  352. }
  353. downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
  354. obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotDataCenter())
  355. if err != nil {
  356. return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
  357. }
  358. go func() {
  359. _, err = c.uploadObject(obj, objectKey, c.derivedStorageDataCenter)
  360. if err != nil {
  361. log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", err)
  362. }
  363. }()
  364. return &obj, nil
  365. }
  366. func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error {
  367. if req.Model == "" {
  368. return ente.NewBadRequestWithMessage("model is required")
  369. }
  370. if len(req.FileIDs) == 0 {
  371. return ente.NewBadRequestWithMessage("fileIDs are required")
  372. }
  373. if len(req.FileIDs) > 200 {
  374. return ente.NewBadRequestWithMessage("fileIDs should be less than or equal to 200")
  375. }
  376. if err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
  377. ActorUserId: userID,
  378. FileIDs: req.FileIDs,
  379. }); err != nil {
  380. return stacktrace.Propagate(err, "User does not own some file(s)")
  381. }
  382. return nil
  383. }