controller.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. package embedding
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "github.com/aws/aws-sdk-go/aws/awserr"
  9. "github.com/ente-io/museum/pkg/utils/array"
  10. "strconv"
  11. "strings"
  12. "sync"
  13. gTime "time"
  14. "github.com/aws/aws-sdk-go/aws"
  15. "github.com/aws/aws-sdk-go/service/s3"
  16. "github.com/aws/aws-sdk-go/service/s3/s3manager"
  17. "github.com/ente-io/museum/ente"
  18. "github.com/ente-io/museum/pkg/controller"
  19. "github.com/ente-io/museum/pkg/controller/access"
  20. "github.com/ente-io/museum/pkg/repo"
  21. "github.com/ente-io/museum/pkg/repo/embedding"
  22. "github.com/ente-io/museum/pkg/utils/auth"
  23. "github.com/ente-io/museum/pkg/utils/network"
  24. "github.com/ente-io/museum/pkg/utils/s3config"
  25. "github.com/ente-io/stacktrace"
  26. "github.com/gin-gonic/gin"
  27. log "github.com/sirupsen/logrus"
  28. )
  29. const (
  30. // maxEmbeddingDataSize is the min size of an embedding object in bytes
  31. minEmbeddingDataSize = 2048
  32. embeddingFetchTimeout = 10 * gTime.Second
  33. )
  34. // _fetchConfig is the configuration for the fetching objects from S3
  35. type _fetchConfig struct {
  36. RetryCount int
  37. InitialTimeout gTime.Duration
  38. MaxTimeout gTime.Duration
  39. }
  40. var _defaultFetchConfig = _fetchConfig{RetryCount: 3, InitialTimeout: 10 * gTime.Second, MaxTimeout: 30 * gTime.Second}
  41. var _b2FetchConfig = _fetchConfig{RetryCount: 3, InitialTimeout: 15 * gTime.Second, MaxTimeout: 30 * gTime.Second}
  42. type Controller struct {
  43. Repo *embedding.Repository
  44. AccessCtrl access.Controller
  45. ObjectCleanupController *controller.ObjectCleanupController
  46. S3Config *s3config.S3Config
  47. QueueRepo *repo.QueueRepository
  48. TaskLockingRepo *repo.TaskLockRepository
  49. FileRepo *repo.FileRepository
  50. CollectionRepo *repo.CollectionRepository
  51. HostName string
  52. cleanupCronRunning bool
  53. derivedStorageDataCenter string
  54. downloadManagerCache map[string]*s3manager.Downloader
  55. }
  56. func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller {
  57. embeddingDcs := []string{s3Config.GetHotBackblazeDC(), s3Config.GetHotWasabiDC(), s3Config.GetDerivedStorageDataCenter()}
  58. cache := make(map[string]*s3manager.Downloader, len(embeddingDcs))
  59. for i := range embeddingDcs {
  60. s3Client := s3Config.GetS3Client(embeddingDcs[i])
  61. cache[embeddingDcs[i]] = s3manager.NewDownloaderWithClient(&s3Client)
  62. }
  63. return &Controller{
  64. Repo: repo,
  65. AccessCtrl: accessCtrl,
  66. ObjectCleanupController: objectCleanupController,
  67. S3Config: s3Config,
  68. QueueRepo: queueRepo,
  69. TaskLockingRepo: taskLockingRepo,
  70. FileRepo: fileRepo,
  71. CollectionRepo: collectionRepo,
  72. HostName: hostName,
  73. derivedStorageDataCenter: s3Config.GetDerivedStorageDataCenter(),
  74. downloadManagerCache: cache,
  75. }
  76. }
  77. func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) (*ente.Embedding, error) {
  78. userID := auth.GetUserID(ctx.Request.Header)
  79. err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
  80. ActorUserId: userID,
  81. FileIDs: []int64{req.FileID},
  82. })
  83. if err != nil {
  84. return nil, stacktrace.Propagate(err, "User does not own file")
  85. }
  86. count, err := c.CollectionRepo.GetCollectionCount(req.FileID)
  87. if err != nil {
  88. return nil, stacktrace.Propagate(err, "")
  89. }
  90. if count < 1 {
  91. return nil, stacktrace.Propagate(ente.ErrNotFound, "")
  92. }
  93. version := 1
  94. if req.Version != nil {
  95. version = *req.Version
  96. }
  97. obj := ente.EmbeddingObject{
  98. Version: version,
  99. EncryptedEmbedding: req.EncryptedEmbedding,
  100. DecryptionHeader: req.DecryptionHeader,
  101. Client: network.GetPrettyUA(ctx.GetHeader("User-Agent")) + "/" + ctx.GetHeader("X-Client-Version"),
  102. }
  103. size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model), c.derivedStorageDataCenter)
  104. if uploadErr != nil {
  105. log.Error(uploadErr)
  106. return nil, stacktrace.Propagate(uploadErr, "")
  107. }
  108. embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version, c.derivedStorageDataCenter)
  109. embedding.Version = &version
  110. if err != nil {
  111. return nil, stacktrace.Propagate(err, "")
  112. }
  113. return &embedding, nil
  114. }
  115. func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest) ([]ente.Embedding, error) {
  116. userID := auth.GetUserID(ctx.Request.Header)
  117. if req.Model == "" {
  118. req.Model = ente.GgmlClip
  119. }
  120. embeddings, err := c.Repo.GetDiff(ctx, userID, req.Model, *req.SinceTime, req.Limit)
  121. if err != nil {
  122. return nil, stacktrace.Propagate(err, "")
  123. }
  124. // Collect object keys for embeddings with missing data
  125. var objectKeys []string
  126. for i := range embeddings {
  127. if embeddings[i].EncryptedEmbedding == "" {
  128. objectKey := c.getObjectKey(userID, embeddings[i].FileID, embeddings[i].Model)
  129. objectKeys = append(objectKeys, objectKey)
  130. }
  131. }
  132. // Fetch missing embeddings in parallel
  133. if len(objectKeys) > 0 {
  134. embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys, c.derivedStorageDataCenter)
  135. if err != nil {
  136. return nil, stacktrace.Propagate(err, "")
  137. }
  138. // Populate missing data in embeddings from fetched objects
  139. for i, obj := range embeddingObjects {
  140. for j := range embeddings {
  141. if embeddings[j].EncryptedEmbedding == "" && c.getObjectKey(userID, embeddings[j].FileID, embeddings[j].Model) == objectKeys[i] {
  142. embeddings[j].EncryptedEmbedding = obj.EncryptedEmbedding
  143. embeddings[j].DecryptionHeader = obj.DecryptionHeader
  144. }
  145. }
  146. }
  147. }
  148. return embeddings, nil
  149. }
  150. func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbeddingRequest) (*ente.GetFilesEmbeddingResponse, error) {
  151. userID := auth.GetUserID(ctx.Request.Header)
  152. if err := c._validateGetFileEmbeddingsRequest(ctx, userID, req); err != nil {
  153. return nil, stacktrace.Propagate(err, "")
  154. }
  155. userFileEmbeddings, err := c.Repo.GetFilesEmbedding(ctx, userID, req.Model, req.FileIDs)
  156. if err != nil {
  157. return nil, stacktrace.Propagate(err, "")
  158. }
  159. embeddingsWithData := make([]ente.Embedding, 0)
  160. noEmbeddingFileIds := make([]int64, 0)
  161. dbFileIds := make([]int64, 0)
  162. // fileIDs that were indexed, but they don't contain any embedding information
  163. for i := range userFileEmbeddings {
  164. dbFileIds = append(dbFileIds, userFileEmbeddings[i].FileID)
  165. if userFileEmbeddings[i].Size != nil && *userFileEmbeddings[i].Size < minEmbeddingDataSize {
  166. noEmbeddingFileIds = append(noEmbeddingFileIds, userFileEmbeddings[i].FileID)
  167. } else {
  168. embeddingsWithData = append(embeddingsWithData, userFileEmbeddings[i])
  169. }
  170. }
  171. pendingIndexFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds)
  172. errFileIds := make([]int64, 0)
  173. // Fetch missing userFileEmbeddings in parallel
  174. embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData, c.derivedStorageDataCenter)
  175. if err != nil {
  176. return nil, stacktrace.Propagate(err, "")
  177. }
  178. fetchedEmbeddings := make([]ente.Embedding, 0)
  179. // Populate missing data in userFileEmbeddings from fetched objects
  180. for _, obj := range embeddingObjects {
  181. if obj.err != nil {
  182. errFileIds = append(errFileIds, obj.dbEmbeddingRow.FileID)
  183. } else {
  184. fetchedEmbeddings = append(fetchedEmbeddings, ente.Embedding{
  185. FileID: obj.dbEmbeddingRow.FileID,
  186. Model: obj.dbEmbeddingRow.Model,
  187. EncryptedEmbedding: obj.embeddingObject.EncryptedEmbedding,
  188. DecryptionHeader: obj.embeddingObject.DecryptionHeader,
  189. UpdatedAt: obj.dbEmbeddingRow.UpdatedAt,
  190. Version: obj.dbEmbeddingRow.Version,
  191. })
  192. }
  193. }
  194. return &ente.GetFilesEmbeddingResponse{
  195. Embeddings: fetchedEmbeddings,
  196. PendingIndexFileIDs: pendingIndexFileIds,
  197. ErrFileIDs: errFileIds,
  198. NoEmbeddingFileIDs: noEmbeddingFileIds,
  199. }, nil
  200. }
  201. func (c *Controller) getObjectKey(userID int64, fileID int64, model string) string {
  202. return c.getEmbeddingObjectPrefix(userID, fileID) + model + ".json"
  203. }
  204. func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string {
  205. return strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/"
  206. }
  207. // Get userId, model and fileID from the object key
  208. func (c *Controller) getEmbeddingObjectDetails(objectKey string) (userID int64, model string, fileID int64) {
  209. split := strings.Split(objectKey, "/")
  210. userID, _ = strconv.ParseInt(split[0], 10, 64)
  211. fileID, _ = strconv.ParseInt(split[2], 10, 64)
  212. model = strings.Split(split[3], ".")[0]
  213. return userID, model, fileID
  214. }
  215. // uploadObject uploads the embedding object to the object store and returns the object size
  216. func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) {
  217. embeddingObj, _ := json.Marshal(obj)
  218. s3Client := c.S3Config.GetS3Client(dc)
  219. s3Bucket := c.S3Config.GetBucket(dc)
  220. uploader := s3manager.NewUploaderWithClient(&s3Client)
  221. up := s3manager.UploadInput{
  222. Bucket: s3Bucket,
  223. Key: &key,
  224. Body: bytes.NewReader(embeddingObj),
  225. }
  226. result, err := uploader.Upload(&up)
  227. if err != nil {
  228. log.Error(err)
  229. return -1, stacktrace.Propagate(err, "")
  230. }
  231. log.Infof("Uploaded to bucket %s", result.Location)
  232. return len(embeddingObj), nil
  233. }
  234. var globalDiffFetchSemaphore = make(chan struct{}, 300)
  235. var globalFileFetchSemaphore = make(chan struct{}, 400)
  236. func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string, dc string) ([]ente.EmbeddingObject, error) {
  237. var wg sync.WaitGroup
  238. var errs []error
  239. embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys))
  240. for i, objectKey := range objectKeys {
  241. wg.Add(1)
  242. globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore
  243. go func(i int, objectKey string) {
  244. defer wg.Done()
  245. defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
  246. obj, err := c.getEmbeddingObject(context.Background(), objectKey, dc)
  247. if err != nil {
  248. errs = append(errs, err)
  249. log.Error("error fetching embedding object: "+objectKey, err)
  250. } else {
  251. embeddingObjects[i] = obj
  252. }
  253. }(i, objectKey)
  254. }
  255. wg.Wait()
  256. if len(errs) > 0 {
  257. return nil, stacktrace.Propagate(errors.New("failed to fetch some objects"), "")
  258. }
  259. return embeddingObjects, nil
  260. }
  261. type embeddingObjectResult struct {
  262. embeddingObject ente.EmbeddingObject
  263. dbEmbeddingRow ente.Embedding
  264. err error
  265. }
  266. func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding, dc string) ([]embeddingObjectResult, error) {
  267. var wg sync.WaitGroup
  268. embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows))
  269. for i, dbEmbeddingRow := range dbEmbeddingRows {
  270. wg.Add(1)
  271. globalFileFetchSemaphore <- struct{}{} // Acquire from global semaphore
  272. go func(i int, dbEmbeddingRow ente.Embedding) {
  273. defer wg.Done()
  274. defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
  275. objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
  276. obj, err := c.getEmbeddingObject(context.Background(), objectKey, dc)
  277. if err != nil {
  278. log.Error("error fetching embedding object: "+objectKey, err)
  279. embeddingObjects[i] = embeddingObjectResult{
  280. err: err,
  281. dbEmbeddingRow: dbEmbeddingRow,
  282. }
  283. } else {
  284. embeddingObjects[i] = embeddingObjectResult{
  285. embeddingObject: obj,
  286. dbEmbeddingRow: dbEmbeddingRow,
  287. }
  288. }
  289. }(i, dbEmbeddingRow)
  290. }
  291. wg.Wait()
  292. return embeddingObjects, nil
  293. }
  294. func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) {
  295. opt := _defaultFetchConfig
  296. if dc == c.S3Config.GetHotBackblazeDC() {
  297. opt = _b2FetchConfig
  298. }
  299. ctxLogger := log.WithField("objectKey", objectKey).WithField("dc", dc)
  300. totalAttempts := opt.RetryCount + 1
  301. timeout := opt.InitialTimeout
  302. for i := 0; i < totalAttempts; i++ {
  303. if i > 0 {
  304. timeout = timeout * 2
  305. if timeout > opt.MaxTimeout {
  306. timeout = opt.MaxTimeout
  307. }
  308. }
  309. fetchCtx, cancel := context.WithTimeout(ctx, timeout)
  310. select {
  311. case <-ctx.Done():
  312. cancel()
  313. return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "")
  314. default:
  315. obj, err := c.downloadObject(fetchCtx, objectKey, dc)
  316. cancel() // Ensure cancel is called to release resources
  317. if err == nil {
  318. if i > 0 {
  319. ctxLogger.Infof("Fetched object after %d attempts", i)
  320. }
  321. return obj, nil
  322. }
  323. // Check if the error is due to context timeout or cancellation
  324. if err == nil && fetchCtx.Err() != nil {
  325. ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err())
  326. } else {
  327. // check if the error is due to object not found
  328. if s3Err, ok := err.(awserr.RequestFailure); ok {
  329. if s3Err.Code() == s3.ErrCodeNoSuchKey {
  330. if c.derivedStorageDataCenter != c.S3Config.GetHotBackblazeDC() {
  331. // todo:(neeraj) Refactor this later to get available the DC from the DB and use that to
  332. // copy the object to currently active DC for derived storage
  333. // If derived and hot bucket are different, try to copy from hot bucket
  334. copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey, c.S3Config.GetHotBackblazeDC(), c.derivedStorageDataCenter)
  335. if err == nil {
  336. ctxLogger.Info("Got the object from hot bucket object")
  337. return *copyEmbeddingObject, nil
  338. } else {
  339. ctxLogger.WithError(err).Error("Failed to copy from hot bucket object")
  340. }
  341. return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
  342. } else {
  343. ctxLogger.Error("Object not found: ", s3Err)
  344. return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
  345. }
  346. }
  347. }
  348. ctxLogger.Error("Failed to fetch object: ", err)
  349. }
  350. }
  351. }
  352. return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "")
  353. }
  354. func (c *Controller) downloadObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) {
  355. var obj ente.EmbeddingObject
  356. buff := &aws.WriteAtBuffer{}
  357. bucket := c.S3Config.GetBucket(dc)
  358. downloader := c.downloadManagerCache[dc]
  359. _, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
  360. Bucket: bucket,
  361. Key: &objectKey,
  362. })
  363. if err != nil {
  364. return obj, err
  365. }
  366. err = json.Unmarshal(buff.Bytes(), &obj)
  367. if err != nil {
  368. return obj, stacktrace.Propagate(err, "unmarshal failed")
  369. }
  370. return obj, nil
  371. }
  372. // download the embedding object from hot bucket and upload to embeddings bucket
  373. func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string, srcDC, destDC string) (*ente.EmbeddingObject, error) {
  374. if srcDC == destDC {
  375. return nil, stacktrace.Propagate(errors.New("src and dest dc can not be same"), "")
  376. }
  377. obj, err := c.downloadObject(ctx, objectKey, srcDC)
  378. if err != nil {
  379. return nil, stacktrace.Propagate(err, fmt.Sprintf("failed to download object from %s", srcDC))
  380. }
  381. go func() {
  382. userID, modelName, fileID := c.getEmbeddingObjectDetails(objectKey)
  383. size, uploadErr := c.uploadObject(obj, objectKey, c.derivedStorageDataCenter)
  384. if uploadErr != nil {
  385. log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", uploadErr)
  386. }
  387. updateDcErr := c.Repo.AddNewDC(context.Background(), fileID, ente.Model(modelName), userID, size, destDC)
  388. if updateDcErr != nil {
  389. log.WithField("object", objectKey).Error("Failed to update dc in db: ", updateDcErr)
  390. return
  391. }
  392. }()
  393. return &obj, nil
  394. }
  395. func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error {
  396. if req.Model == "" {
  397. return ente.NewBadRequestWithMessage("model is required")
  398. }
  399. if len(req.FileIDs) == 0 {
  400. return ente.NewBadRequestWithMessage("fileIDs are required")
  401. }
  402. if len(req.FileIDs) > 200 {
  403. return ente.NewBadRequestWithMessage("fileIDs should be less than or equal to 200")
  404. }
  405. if err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
  406. ActorUserId: userID,
  407. FileIDs: req.FileIDs,
  408. }); err != nil {
  409. return stacktrace.Propagate(err, "User does not own some file(s)")
  410. }
  411. return nil
  412. }