Minor refactor

This commit is contained in:
Neeraj Gupta 2024-05-16 12:23:25 +05:30
parent 3c7d86da8d
commit e0738db6ae

View file

@ -45,7 +45,7 @@ type Controller struct {
HostName string
cleanupCronRunning bool
derivedStorageS3Client *s3.S3
derivedStorageBucket *string
derivedStorageDataCenter string
areDerivedAndHotBucketSame bool
}
@ -61,8 +61,8 @@ func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanup
CollectionRepo: collectionRepo,
HostName: hostName,
derivedStorageS3Client: s3Config.GetDerivedStorageS3Client(),
derivedStorageBucket: s3Config.GetDerivedStorageBucket(),
areDerivedAndHotBucketSame: s3Config.GetDerivedStorageBucket() == s3Config.GetHotBucket(),
derivedStorageDataCenter: s3Config.GetDerivedStorageDataCenter(),
areDerivedAndHotBucketSame: s3Config.GetDerivedStorageDataCenter() == s3Config.GetDerivedStorageDataCenter(),
}
}
@ -96,12 +96,12 @@ func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmb
DecryptionHeader: req.DecryptionHeader,
Client: network.GetPrettyUA(ctx.GetHeader("User-Agent")) + "/" + ctx.GetHeader("X-Client-Version"),
}
size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model))
size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model), c.derivedStorageDataCenter)
if uploadErr != nil {
log.Error(uploadErr)
return nil, stacktrace.Propagate(uploadErr, "")
}
embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version, c.S3Config.GetDerivedStorageDataCenter())
embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version, c.derivedStorageDataCenter)
embedding.Version = &version
if err != nil {
return nil, stacktrace.Propagate(err, "")
@ -217,11 +217,13 @@ func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string
}
// uploadObject uploads the embedding object to the object store and returns the object size
func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, error) {
func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) {
embeddingObj, _ := json.Marshal(obj)
uploader := s3manager.NewUploaderWithClient(c.derivedStorageS3Client)
s3Client := c.S3Config.GetS3Client(dc)
s3Bucket := c.S3Config.GetBucket(dc)
uploader := s3manager.NewUploaderWithClient(&s3Client)
up := s3manager.UploadInput{
Bucket: c.derivedStorageBucket,
Bucket: s3Bucket,
Key: &key,
Body: bytes.NewReader(embeddingObj),
}
@ -331,7 +333,7 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
cancel()
return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "")
default:
obj, err := c.downloadObject(fetchCtx, objectKey, downloader, c.derivedStorageBucket)
obj, err := c.downloadObject(fetchCtx, objectKey, downloader, c.derivedStorageDataCenter)
cancel() // Ensure cancel is called to release resources
if err == nil {
if i > 0 {
@ -368,9 +370,10 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "")
}
func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, bucket *string) (ente.EmbeddingObject, error) {
func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, dc string) (ente.EmbeddingObject, error) {
var obj ente.EmbeddingObject
buff := &aws.WriteAtBuffer{}
bucket := c.S3Config.GetBucket(dc)
_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
Bucket: bucket,
Key: &objectKey,
@ -387,16 +390,16 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, downl
// download the embedding object from hot bucket and upload to embeddings bucket
func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) {
if c.derivedStorageBucket == c.S3Config.GetHotBucket() {
return nil, stacktrace.Propagate(errors.New("embedding bucket and hot bucket are same"), "")
if c.derivedStorageDataCenter == c.S3Config.GetHotDataCenter() {
return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "")
}
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotBucket())
obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotDataCenter())
if err != nil {
return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
}
go func() {
_, err = c.uploadObject(obj, objectKey)
_, err = c.uploadObject(obj, objectKey, c.derivedStorageDataCenter)
if err != nil {
log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", err)
}