controller.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. package embedding
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "errors"
  7. "github.com/aws/aws-sdk-go/aws/awserr"
  8. "github.com/ente-io/museum/pkg/utils/array"
  9. "strconv"
  10. "sync"
  11. gTime "time"
  12. "github.com/aws/aws-sdk-go/aws"
  13. "github.com/aws/aws-sdk-go/service/s3"
  14. "github.com/aws/aws-sdk-go/service/s3/s3manager"
  15. "github.com/ente-io/museum/ente"
  16. "github.com/ente-io/museum/pkg/controller"
  17. "github.com/ente-io/museum/pkg/controller/access"
  18. "github.com/ente-io/museum/pkg/repo"
  19. "github.com/ente-io/museum/pkg/repo/embedding"
  20. "github.com/ente-io/museum/pkg/utils/auth"
  21. "github.com/ente-io/museum/pkg/utils/network"
  22. "github.com/ente-io/museum/pkg/utils/s3config"
  23. "github.com/ente-io/stacktrace"
  24. "github.com/gin-gonic/gin"
  25. log "github.com/sirupsen/logrus"
  26. )
  27. const (
  28. // maxEmbeddingDataSize is the min size of an embedding object in bytes
  29. minEmbeddingDataSize = 2048
  30. embeddingFetchTimeout = 15 * gTime.Second
  31. )
  32. // _fetchConfig is the configuration for the fetching objects from S3
  33. type _fetchConfig struct {
  34. RetryCount int
  35. FetchTimeOut gTime.Duration
  36. }
  37. var _defaultFetchConfig = _fetchConfig{RetryCount: 3, FetchTimeOut: 15 * gTime.Second}
  38. type Controller struct {
  39. Repo *embedding.Repository
  40. AccessCtrl access.Controller
  41. ObjectCleanupController *controller.ObjectCleanupController
  42. S3Config *s3config.S3Config
  43. QueueRepo *repo.QueueRepository
  44. TaskLockingRepo *repo.TaskLockRepository
  45. FileRepo *repo.FileRepository
  46. CollectionRepo *repo.CollectionRepository
  47. HostName string
  48. cleanupCronRunning bool
  49. derivedStorageDataCenter string
  50. }
  51. func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller {
  52. return &Controller{
  53. Repo: repo,
  54. AccessCtrl: accessCtrl,
  55. ObjectCleanupController: objectCleanupController,
  56. S3Config: s3Config,
  57. QueueRepo: queueRepo,
  58. TaskLockingRepo: taskLockingRepo,
  59. FileRepo: fileRepo,
  60. CollectionRepo: collectionRepo,
  61. HostName: hostName,
  62. derivedStorageDataCenter: s3Config.GetDerivedStorageDataCenter(),
  63. }
  64. }
  65. func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) (*ente.Embedding, error) {
  66. userID := auth.GetUserID(ctx.Request.Header)
  67. err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
  68. ActorUserId: userID,
  69. FileIDs: []int64{req.FileID},
  70. })
  71. if err != nil {
  72. return nil, stacktrace.Propagate(err, "User does not own file")
  73. }
  74. count, err := c.CollectionRepo.GetCollectionCount(req.FileID)
  75. if err != nil {
  76. return nil, stacktrace.Propagate(err, "")
  77. }
  78. if count < 1 {
  79. return nil, stacktrace.Propagate(ente.ErrNotFound, "")
  80. }
  81. version := 1
  82. if req.Version != nil {
  83. version = *req.Version
  84. }
  85. obj := ente.EmbeddingObject{
  86. Version: version,
  87. EncryptedEmbedding: req.EncryptedEmbedding,
  88. DecryptionHeader: req.DecryptionHeader,
  89. Client: network.GetPrettyUA(ctx.GetHeader("User-Agent")) + "/" + ctx.GetHeader("X-Client-Version"),
  90. }
  91. size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model), c.derivedStorageDataCenter)
  92. if uploadErr != nil {
  93. log.Error(uploadErr)
  94. return nil, stacktrace.Propagate(uploadErr, "")
  95. }
  96. embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version, c.derivedStorageDataCenter)
  97. embedding.Version = &version
  98. if err != nil {
  99. return nil, stacktrace.Propagate(err, "")
  100. }
  101. return &embedding, nil
  102. }
  103. func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest) ([]ente.Embedding, error) {
  104. userID := auth.GetUserID(ctx.Request.Header)
  105. if req.Model == "" {
  106. req.Model = ente.GgmlClip
  107. }
  108. embeddings, err := c.Repo.GetDiff(ctx, userID, req.Model, *req.SinceTime, req.Limit)
  109. if err != nil {
  110. return nil, stacktrace.Propagate(err, "")
  111. }
  112. // Collect object keys for embeddings with missing data
  113. var objectKeys []string
  114. for i := range embeddings {
  115. if embeddings[i].EncryptedEmbedding == "" {
  116. objectKey := c.getObjectKey(userID, embeddings[i].FileID, embeddings[i].Model)
  117. objectKeys = append(objectKeys, objectKey)
  118. }
  119. }
  120. // Fetch missing embeddings in parallel
  121. if len(objectKeys) > 0 {
  122. embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys)
  123. if err != nil {
  124. return nil, stacktrace.Propagate(err, "")
  125. }
  126. // Populate missing data in embeddings from fetched objects
  127. for i, obj := range embeddingObjects {
  128. for j := range embeddings {
  129. if embeddings[j].EncryptedEmbedding == "" && c.getObjectKey(userID, embeddings[j].FileID, embeddings[j].Model) == objectKeys[i] {
  130. embeddings[j].EncryptedEmbedding = obj.EncryptedEmbedding
  131. embeddings[j].DecryptionHeader = obj.DecryptionHeader
  132. }
  133. }
  134. }
  135. }
  136. return embeddings, nil
  137. }
  138. func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbeddingRequest) (*ente.GetFilesEmbeddingResponse, error) {
  139. userID := auth.GetUserID(ctx.Request.Header)
  140. if err := c._validateGetFileEmbeddingsRequest(ctx, userID, req); err != nil {
  141. return nil, stacktrace.Propagate(err, "")
  142. }
  143. userFileEmbeddings, err := c.Repo.GetFilesEmbedding(ctx, userID, req.Model, req.FileIDs)
  144. if err != nil {
  145. return nil, stacktrace.Propagate(err, "")
  146. }
  147. embeddingsWithData := make([]ente.Embedding, 0)
  148. noEmbeddingFileIds := make([]int64, 0)
  149. dbFileIds := make([]int64, 0)
  150. // fileIDs that were indexed, but they don't contain any embedding information
  151. for i := range userFileEmbeddings {
  152. dbFileIds = append(dbFileIds, userFileEmbeddings[i].FileID)
  153. if userFileEmbeddings[i].Size != nil && *userFileEmbeddings[i].Size < minEmbeddingDataSize {
  154. noEmbeddingFileIds = append(noEmbeddingFileIds, userFileEmbeddings[i].FileID)
  155. } else {
  156. embeddingsWithData = append(embeddingsWithData, userFileEmbeddings[i])
  157. }
  158. }
  159. pendingIndexFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds)
  160. errFileIds := make([]int64, 0)
  161. // Fetch missing userFileEmbeddings in parallel
  162. embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData)
  163. if err != nil {
  164. return nil, stacktrace.Propagate(err, "")
  165. }
  166. fetchedEmbeddings := make([]ente.Embedding, 0)
  167. // Populate missing data in userFileEmbeddings from fetched objects
  168. for _, obj := range embeddingObjects {
  169. if obj.err != nil {
  170. errFileIds = append(errFileIds, obj.dbEmbeddingRow.FileID)
  171. } else {
  172. fetchedEmbeddings = append(fetchedEmbeddings, ente.Embedding{
  173. FileID: obj.dbEmbeddingRow.FileID,
  174. Model: obj.dbEmbeddingRow.Model,
  175. EncryptedEmbedding: obj.embeddingObject.EncryptedEmbedding,
  176. DecryptionHeader: obj.embeddingObject.DecryptionHeader,
  177. UpdatedAt: obj.dbEmbeddingRow.UpdatedAt,
  178. Version: obj.dbEmbeddingRow.Version,
  179. })
  180. }
  181. }
  182. return &ente.GetFilesEmbeddingResponse{
  183. Embeddings: fetchedEmbeddings,
  184. PendingIndexFileIDs: pendingIndexFileIds,
  185. ErrFileIDs: errFileIds,
  186. NoEmbeddingFileIDs: noEmbeddingFileIds,
  187. }, nil
  188. }
  189. func (c *Controller) getObjectKey(userID int64, fileID int64, model string) string {
  190. return c.getEmbeddingObjectPrefix(userID, fileID) + model + ".json"
  191. }
  192. func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string {
  193. return strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/"
  194. }
  195. // uploadObject uploads the embedding object to the object store and returns the object size
  196. func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) {
  197. embeddingObj, _ := json.Marshal(obj)
  198. s3Client := c.S3Config.GetS3Client(dc)
  199. s3Bucket := c.S3Config.GetBucket(dc)
  200. uploader := s3manager.NewUploaderWithClient(&s3Client)
  201. up := s3manager.UploadInput{
  202. Bucket: s3Bucket,
  203. Key: &key,
  204. Body: bytes.NewReader(embeddingObj),
  205. }
  206. result, err := uploader.Upload(&up)
  207. if err != nil {
  208. log.Error(err)
  209. return -1, stacktrace.Propagate(err, "")
  210. }
  211. log.Infof("Uploaded to bucket %s", result.Location)
  212. return len(embeddingObj), nil
  213. }
  214. var globalDiffFetchSemaphore = make(chan struct{}, 300)
  215. var globalFileFetchSemaphore = make(chan struct{}, 400)
  216. func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.EmbeddingObject, error) {
  217. var wg sync.WaitGroup
  218. var errs []error
  219. embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys))
  220. s3Client := c.S3Config.GetS3Client(c.derivedStorageDataCenter)
  221. downloader := s3manager.NewDownloaderWithClient(&s3Client)
  222. for i, objectKey := range objectKeys {
  223. wg.Add(1)
  224. globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore
  225. go func(i int, objectKey string) {
  226. defer wg.Done()
  227. defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
  228. obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
  229. if err != nil {
  230. errs = append(errs, err)
  231. log.Error("error fetching embedding object: "+objectKey, err)
  232. } else {
  233. embeddingObjects[i] = obj
  234. }
  235. }(i, objectKey)
  236. }
  237. wg.Wait()
  238. if len(errs) > 0 {
  239. return nil, stacktrace.Propagate(errors.New("failed to fetch some objects"), "")
  240. }
  241. return embeddingObjects, nil
  242. }
  243. type embeddingObjectResult struct {
  244. embeddingObject ente.EmbeddingObject
  245. dbEmbeddingRow ente.Embedding
  246. err error
  247. }
  248. func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) {
  249. var wg sync.WaitGroup
  250. embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows))
  251. s3Client := c.S3Config.GetS3Client(c.derivedStorageDataCenter)
  252. downloader := s3manager.NewDownloaderWithClient(&s3Client)
  253. for i, dbEmbeddingRow := range dbEmbeddingRows {
  254. wg.Add(1)
  255. globalFileFetchSemaphore <- struct{}{} // Acquire from global semaphore
  256. go func(i int, dbEmbeddingRow ente.Embedding) {
  257. defer wg.Done()
  258. defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
  259. objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
  260. obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
  261. if err != nil {
  262. log.Error("error fetching embedding object: "+objectKey, err)
  263. embeddingObjects[i] = embeddingObjectResult{
  264. err: err,
  265. dbEmbeddingRow: dbEmbeddingRow,
  266. }
  267. } else {
  268. embeddingObjects[i] = embeddingObjectResult{
  269. embeddingObject: obj,
  270. dbEmbeddingRow: dbEmbeddingRow,
  271. }
  272. }
  273. }(i, dbEmbeddingRow)
  274. }
  275. wg.Wait()
  276. return embeddingObjects, nil
  277. }
  278. func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
  279. opt := _defaultFetchConfig
  280. ctxLogger := log.WithField("objectKey", objectKey)
  281. totalAttempts := opt.RetryCount + 1
  282. for i := 0; i < totalAttempts; i++ {
  283. // Create a new context with a timeout for each fetch
  284. fetchCtx, cancel := context.WithTimeout(ctx, opt.FetchTimeOut)
  285. select {
  286. case <-ctx.Done():
  287. cancel()
  288. return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "")
  289. default:
  290. obj, err := c.downloadObject(fetchCtx, objectKey, downloader, c.derivedStorageDataCenter)
  291. cancel() // Ensure cancel is called to release resources
  292. if err == nil {
  293. if i > 0 {
  294. ctxLogger.Infof("Fetched object after %d attempts", i)
  295. }
  296. return obj, nil
  297. }
  298. // Check if the error is due to context timeout or cancellation
  299. if err == nil && fetchCtx.Err() != nil {
  300. ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err())
  301. } else {
  302. // check if the error is due to object not found
  303. if s3Err, ok := err.(awserr.RequestFailure); ok {
  304. if s3Err.Code() == s3.ErrCodeNoSuchKey {
  305. if c.derivedStorageDataCenter == c.S3Config.GetHotBackblazeDC() {
  306. ctxLogger.Error("Object not found: ", s3Err)
  307. return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
  308. } else {
  309. // If derived and hot bucket are different, try to copy from hot bucket
  310. copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey)
  311. if err == nil {
  312. ctxLogger.Info("Got the object from hot bucket object")
  313. return *copyEmbeddingObject, nil
  314. } else {
  315. ctxLogger.WithError(err).Error("Failed to copy from hot bucket object")
  316. }
  317. return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
  318. }
  319. }
  320. }
  321. ctxLogger.Error("Failed to fetch object: ", err)
  322. }
  323. }
  324. }
  325. return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "")
  326. }
  327. func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, dc string) (ente.EmbeddingObject, error) {
  328. var obj ente.EmbeddingObject
  329. buff := &aws.WriteAtBuffer{}
  330. bucket := c.S3Config.GetBucket(dc)
  331. _, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
  332. Bucket: bucket,
  333. Key: &objectKey,
  334. })
  335. if err != nil {
  336. return obj, err
  337. }
  338. err = json.Unmarshal(buff.Bytes(), &obj)
  339. if err != nil {
  340. return obj, stacktrace.Propagate(err, "unmarshal failed")
  341. }
  342. return obj, nil
  343. }
  344. // download the embedding object from hot bucket and upload to embeddings bucket
  345. func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) {
  346. if c.derivedStorageDataCenter == c.S3Config.GetHotBackblazeDC() {
  347. return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "")
  348. }
  349. downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
  350. obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotBackblazeDC())
  351. if err != nil {
  352. return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
  353. }
  354. go func() {
  355. _, err = c.uploadObject(obj, objectKey, c.derivedStorageDataCenter)
  356. if err != nil {
  357. log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", err)
  358. }
  359. }()
  360. return &obj, nil
  361. }
  362. func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error {
  363. if req.Model == "" {
  364. return ente.NewBadRequestWithMessage("model is required")
  365. }
  366. if len(req.FileIDs) == 0 {
  367. return ente.NewBadRequestWithMessage("fileIDs are required")
  368. }
  369. if len(req.FileIDs) > 200 {
  370. return ente.NewBadRequestWithMessage("fileIDs should be less than or equal to 200")
  371. }
  372. if err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
  373. ActorUserId: userID,
  374. FileIDs: req.FileIDs,
  375. }); err != nil {
  376. return stacktrace.Propagate(err, "User does not own some file(s)")
  377. }
  378. return nil
  379. }