file.go 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793
  1. package repo
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "strconv"
  8. "strings"
  9. "github.com/ente-io/stacktrace"
  10. log "github.com/sirupsen/logrus"
  11. "github.com/ente-io/museum/ente"
  12. "github.com/ente-io/museum/pkg/utils/s3config"
  13. "github.com/ente-io/museum/pkg/utils/time"
  14. "github.com/lib/pq"
  15. )
  16. // FileRepository is an implementation of the FileRepo that
  17. // persists and retrieves data from disk.
  18. type FileRepository struct {
  19. DB *sql.DB
  20. S3Config *s3config.S3Config
  21. QueueRepo *QueueRepository
  22. ObjectRepo *ObjectRepository
  23. ObjectCleanupRepo *ObjectCleanupRepository
  24. ObjectCopiesRepo *ObjectCopiesRepository
  25. UsageRepo *UsageRepository
  26. }
  27. // Create creates an entry in the database for the given file
  28. func (repo *FileRepository) Create(
  29. file ente.File,
  30. fileSize int64,
  31. thumbnailSize int64,
  32. usageDiff int64,
  33. collectionOwnerID int64,
  34. app ente.App,
  35. ) (ente.File, int64, error) {
  36. hotDC := repo.S3Config.GetHotDataCenter()
  37. dcsForNewEntry := pq.StringArray{hotDC}
  38. ctx := context.Background()
  39. tx, err := repo.DB.BeginTx(ctx, nil)
  40. if err != nil {
  41. return file, -1, stacktrace.Propagate(err, "")
  42. }
  43. if file.OwnerID != collectionOwnerID {
  44. return file, -1, stacktrace.Propagate(errors.New("both file and collection should belong to same owner"), "")
  45. }
  46. var fileID int64
  47. err = tx.QueryRowContext(ctx, `INSERT INTO files
  48. (owner_id, encrypted_metadata,
  49. file_decryption_header, thumbnail_decryption_header, metadata_decryption_header,
  50. magic_metadata, pub_magic_metadata, info, updation_time)
  51. VALUES($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING file_id`,
  52. file.OwnerID, file.Metadata.EncryptedData, file.File.DecryptionHeader,
  53. file.Thumbnail.DecryptionHeader, file.Metadata.DecryptionHeader,
  54. file.MagicMetadata, file.PubicMagicMetadata, file.Info,
  55. file.UpdationTime).Scan(&fileID)
  56. if err != nil {
  57. tx.Rollback()
  58. return file, -1, stacktrace.Propagate(err, "")
  59. }
  60. file.ID = fileID
  61. _, err = tx.ExecContext(ctx, `INSERT INTO collection_files
  62. (collection_id, file_id, encrypted_key, key_decryption_nonce, is_deleted, updation_time, c_owner_id, f_owner_id)
  63. VALUES($1, $2, $3, $4, $5, $6, $7, $8)`, file.CollectionID, file.ID,
  64. file.EncryptedKey, file.KeyDecryptionNonce, false, file.UpdationTime, file.OwnerID, collectionOwnerID)
  65. if err != nil {
  66. tx.Rollback()
  67. return file, -1, stacktrace.Propagate(err, "")
  68. }
  69. _, err = tx.ExecContext(ctx, `UPDATE collections SET updation_time = $1
  70. WHERE collection_id = $2`, file.UpdationTime, file.CollectionID)
  71. if err != nil {
  72. tx.Rollback()
  73. return file, -1, stacktrace.Propagate(err, "")
  74. }
  75. _, err = tx.ExecContext(ctx, `INSERT INTO object_keys(file_id, o_type, object_key, size, datacenters)
  76. VALUES($1, $2, $3, $4, $5)`, fileID, ente.FILE, file.File.ObjectKey, fileSize, dcsForNewEntry)
  77. if err != nil {
  78. tx.Rollback()
  79. if err.Error() == "pq: duplicate key value violates unique constraint \"object_keys_object_key_key\"" {
  80. return file, -1, ente.ErrDuplicateFileObjectFound
  81. }
  82. return file, -1, stacktrace.Propagate(err, "")
  83. }
  84. _, err = tx.ExecContext(ctx, `INSERT INTO object_keys(file_id, o_type, object_key, size, datacenters)
  85. VALUES($1, $2, $3, $4, $5)`, fileID, ente.THUMBNAIL, file.Thumbnail.ObjectKey, thumbnailSize, dcsForNewEntry)
  86. if err != nil {
  87. tx.Rollback()
  88. if err.Error() == "pq: duplicate key value violates unique constraint \"object_keys_object_key_key\"" {
  89. return file, -1, ente.ErrDuplicateThumbnailObjectFound
  90. }
  91. return file, -1, stacktrace.Propagate(err, "")
  92. }
  93. err = repo.ObjectCleanupRepo.RemoveTempObjectKey(ctx, tx, file.File.ObjectKey, hotDC)
  94. if err != nil {
  95. tx.Rollback()
  96. return file, -1, stacktrace.Propagate(err, "")
  97. }
  98. err = repo.ObjectCleanupRepo.RemoveTempObjectKey(ctx, tx, file.Thumbnail.ObjectKey, hotDC)
  99. if err != nil {
  100. tx.Rollback()
  101. return file, -1, stacktrace.Propagate(err, "")
  102. }
  103. usage, err := repo.updateUsage(ctx, tx, file.OwnerID, usageDiff)
  104. if err != nil {
  105. tx.Rollback()
  106. return file, -1, stacktrace.Propagate(err, "")
  107. }
  108. err = repo.markAsNeedingReplication(ctx, tx, file, hotDC)
  109. if err != nil {
  110. tx.Rollback()
  111. return file, -1, stacktrace.Propagate(err, "")
  112. }
  113. err = tx.Commit()
  114. if err != nil {
  115. return file, -1, stacktrace.Propagate(err, "")
  116. }
  117. return file, usage, stacktrace.Propagate(err, "")
  118. }
  119. // markAsNeedingReplication inserts new entries in object_copies, setting the
  120. // current hot DC as the source copy.
  121. //
  122. // The higher layer above us (file controller) would've already checked that the
  123. // object exists in the current hot DC (See `c.sizeOf` in file controller). This
  124. // would cover cases where the client fetched presigned upload URLs for say
  125. // hotDC1, but by the time they connected to museum, museum switched to using
  126. // hotDC2. So then when museum would try to fetch the file size from hotDC2, the
  127. // object won't be found there, and the upload would fail (which is the
  128. // behaviour we want, since hot DC swaps are not a frequent/expected operation,
  129. // we just wish to guarantee correctness if they do happen).
  130. func (repo *FileRepository) markAsNeedingReplication(ctx context.Context, tx *sql.Tx, file ente.File, hotDC string) error {
  131. if hotDC == repo.S3Config.GetHotBackblazeDC() {
  132. err := repo.ObjectCopiesRepo.CreateNewB2Object(ctx, tx, file.File.ObjectKey, true, true)
  133. if err != nil {
  134. return stacktrace.Propagate(err, "")
  135. }
  136. err = repo.ObjectCopiesRepo.CreateNewB2Object(ctx, tx, file.Thumbnail.ObjectKey, true, false)
  137. return stacktrace.Propagate(err, "")
  138. } else if hotDC == repo.S3Config.GetHotWasabiDC() {
  139. err := repo.ObjectCopiesRepo.CreateNewWasabiObject(ctx, tx, file.File.ObjectKey, true, true)
  140. if err != nil {
  141. return stacktrace.Propagate(err, "")
  142. }
  143. err = repo.ObjectCopiesRepo.CreateNewWasabiObject(ctx, tx, file.Thumbnail.ObjectKey, true, false)
  144. return stacktrace.Propagate(err, "")
  145. } else {
  146. // Bail out if we're trying to add a new entry for a file but the
  147. // primary hot DC is not one of the known types.
  148. err := fmt.Errorf("only B2 and Wasabi DCs can be used for as the primary hot storage; instead, it was %s", hotDC)
  149. return stacktrace.Propagate(err, "")
  150. }
  151. }
  152. // See markAsNeedingReplication - this variant is for updating only thumbnails.
  153. func (repo *FileRepository) markThumbnailAsNeedingReplication(ctx context.Context, tx *sql.Tx, thumbnailObjectKey string, hotDC string) error {
  154. if hotDC == repo.S3Config.GetHotBackblazeDC() {
  155. err := repo.ObjectCopiesRepo.CreateNewB2Object(ctx, tx, thumbnailObjectKey, true, false)
  156. return stacktrace.Propagate(err, "")
  157. } else if hotDC == repo.S3Config.GetHotWasabiDC() {
  158. err := repo.ObjectCopiesRepo.CreateNewWasabiObject(ctx, tx, thumbnailObjectKey, true, false)
  159. return stacktrace.Propagate(err, "")
  160. } else {
  161. // Bail out if we're trying to add a new entry for a file but the
  162. // primary hot DC is not one of the known types.
  163. err := fmt.Errorf("only B2 and Wasabi DCs can be used for as the primary hot storage; instead, it was %s", hotDC)
  164. return stacktrace.Propagate(err, "")
  165. }
  166. }
  167. // ResetNeedsReplication resets the replication status for an existing file
  168. func (repo *FileRepository) ResetNeedsReplication(file ente.File, hotDC string) error {
  169. if hotDC == repo.S3Config.GetHotBackblazeDC() {
  170. err := repo.ObjectCopiesRepo.ResetNeedsWasabiReplication(file.File.ObjectKey)
  171. if err != nil {
  172. return stacktrace.Propagate(err, "")
  173. }
  174. err = repo.ObjectCopiesRepo.ResetNeedsScalewayReplication(file.File.ObjectKey)
  175. if err != nil {
  176. return stacktrace.Propagate(err, "")
  177. }
  178. err = repo.ObjectCopiesRepo.ResetNeedsWasabiReplication(file.Thumbnail.ObjectKey)
  179. return stacktrace.Propagate(err, "")
  180. } else if hotDC == repo.S3Config.GetHotWasabiDC() {
  181. err := repo.ObjectCopiesRepo.ResetNeedsB2Replication(file.File.ObjectKey)
  182. if err != nil {
  183. return stacktrace.Propagate(err, "")
  184. }
  185. err = repo.ObjectCopiesRepo.ResetNeedsScalewayReplication(file.File.ObjectKey)
  186. if err != nil {
  187. return stacktrace.Propagate(err, "")
  188. }
  189. err = repo.ObjectCopiesRepo.ResetNeedsB2Replication(file.Thumbnail.ObjectKey)
  190. return stacktrace.Propagate(err, "")
  191. } else {
  192. // Bail out if we're trying to update the replication flags but the
  193. // primary hot DC is not one of the known types.
  194. err := fmt.Errorf("only B2 and Wasabi DCs can be used for as the primary hot storage; instead, it was %s", hotDC)
  195. return stacktrace.Propagate(err, "")
  196. }
  197. }
  198. // Update updates the entry in the database for the given file
  199. func (repo *FileRepository) Update(file ente.File, fileSize int64, thumbnailSize int64, usageDiff int64, oldObjects []string, isDuplicateRequest bool) error {
  200. hotDC := repo.S3Config.GetHotDataCenter()
  201. dcsForNewEntry := pq.StringArray{hotDC}
  202. ctx := context.Background()
  203. tx, err := repo.DB.BeginTx(ctx, nil)
  204. if err != nil {
  205. return stacktrace.Propagate(err, "")
  206. }
  207. _, err = tx.ExecContext(ctx, `UPDATE files SET encrypted_metadata = $1,
  208. file_decryption_header = $2, thumbnail_decryption_header = $3,
  209. metadata_decryption_header = $4, updation_time = $5 , info = $6 WHERE file_id = $7`,
  210. file.Metadata.EncryptedData, file.File.DecryptionHeader,
  211. file.Thumbnail.DecryptionHeader, file.Metadata.DecryptionHeader,
  212. file.UpdationTime, file.Info, file.ID)
  213. if err != nil {
  214. tx.Rollback()
  215. return stacktrace.Propagate(err, "")
  216. }
  217. updatedRows, err := tx.QueryContext(ctx, `UPDATE collection_files
  218. SET updation_time = $1 WHERE file_id = $2 RETURNING collection_id`, file.UpdationTime,
  219. file.ID)
  220. if err != nil {
  221. tx.Rollback()
  222. return stacktrace.Propagate(err, "")
  223. }
  224. defer updatedRows.Close()
  225. updatedCIDs := make([]int64, 0)
  226. for updatedRows.Next() {
  227. var cID int64
  228. err := updatedRows.Scan(&cID)
  229. if err != nil {
  230. return stacktrace.Propagate(err, "")
  231. }
  232. updatedCIDs = append(updatedCIDs, cID)
  233. }
  234. _, err = tx.ExecContext(ctx, `UPDATE collections SET updation_time = $1
  235. WHERE collection_id = ANY($2)`, file.UpdationTime, pq.Array(updatedCIDs))
  236. if err != nil {
  237. tx.Rollback()
  238. return stacktrace.Propagate(err, "")
  239. }
  240. _, err = tx.ExecContext(ctx, `DELETE FROM object_copies WHERE object_key = ANY($1)`,
  241. pq.Array(oldObjects))
  242. if err != nil {
  243. tx.Rollback()
  244. return stacktrace.Propagate(err, "")
  245. }
  246. _, err = tx.ExecContext(ctx, `UPDATE object_keys
  247. SET object_key = $1, size = $2, datacenters = $3 WHERE file_id = $4 AND o_type = $5`,
  248. file.File.ObjectKey, fileSize, dcsForNewEntry, file.ID, ente.FILE)
  249. if err != nil {
  250. tx.Rollback()
  251. return stacktrace.Propagate(err, "")
  252. }
  253. _, err = tx.ExecContext(ctx, `UPDATE object_keys
  254. SET object_key = $1, size = $2, datacenters = $3 WHERE file_id = $4 AND o_type = $5`,
  255. file.Thumbnail.ObjectKey, thumbnailSize, dcsForNewEntry, file.ID, ente.THUMBNAIL)
  256. if err != nil {
  257. tx.Rollback()
  258. return stacktrace.Propagate(err, "")
  259. }
  260. _, err = repo.updateUsage(ctx, tx, file.OwnerID, usageDiff)
  261. if err != nil {
  262. tx.Rollback()
  263. return stacktrace.Propagate(err, "")
  264. }
  265. err = repo.ObjectCleanupRepo.RemoveTempObjectKey(ctx, tx, file.File.ObjectKey, hotDC)
  266. if err != nil {
  267. tx.Rollback()
  268. return stacktrace.Propagate(err, "")
  269. }
  270. err = repo.ObjectCleanupRepo.RemoveTempObjectKey(ctx, tx, file.Thumbnail.ObjectKey, hotDC)
  271. if err != nil {
  272. tx.Rollback()
  273. return stacktrace.Propagate(err, "")
  274. }
  275. if isDuplicateRequest {
  276. // Skip markAsNeedingReplication for duplicate requests, it'd fail with
  277. // pq: duplicate key value violates unique constraint \"object_copies_pkey\"
  278. // and render our transaction uncommittable
  279. log.Infof("Skipping update of object_copies for a duplicate request to update file %d", file.ID)
  280. } else {
  281. err = repo.markAsNeedingReplication(ctx, tx, file, hotDC)
  282. if err != nil {
  283. tx.Rollback()
  284. return stacktrace.Propagate(err, "")
  285. }
  286. }
  287. err = repo.QueueRepo.AddItems(ctx, tx, OutdatedObjectsQueue, oldObjects)
  288. if err != nil {
  289. tx.Rollback()
  290. return stacktrace.Propagate(err, "")
  291. }
  292. err = tx.Commit()
  293. return stacktrace.Propagate(err, "")
  294. }
  295. // UpdateMagicAttributes updates the magic attributes for the list of files and update collection_files & collection
  296. // which have this file.
  297. func (repo *FileRepository) UpdateMagicAttributes(ctx context.Context, fileUpdates []ente.UpdateMagicMetadata, isPublicMetadata bool) error {
  298. updationTime := time.Microseconds()
  299. tx, err := repo.DB.BeginTx(ctx, nil)
  300. if err != nil {
  301. return stacktrace.Propagate(err, "")
  302. }
  303. fileIDs := make([]int64, 0)
  304. for _, update := range fileUpdates {
  305. update.MagicMetadata.Version = update.MagicMetadata.Version + 1
  306. fileIDs = append(fileIDs, update.ID)
  307. if isPublicMetadata {
  308. _, err = tx.ExecContext(ctx, `UPDATE files SET pub_magic_metadata = $1, updation_time = $2 WHERE file_id = $3`,
  309. update.MagicMetadata, updationTime, update.ID)
  310. } else {
  311. _, err = tx.ExecContext(ctx, `UPDATE files SET magic_metadata = $1, updation_time = $2 WHERE file_id = $3`,
  312. update.MagicMetadata, updationTime, update.ID)
  313. }
  314. if err != nil {
  315. if rollbackErr := tx.Rollback(); rollbackErr != nil {
  316. log.WithError(rollbackErr).Error("transaction rollback failed")
  317. return stacktrace.Propagate(rollbackErr, "")
  318. }
  319. return stacktrace.Propagate(err, "")
  320. }
  321. }
  322. // todo: full table scan, need to add index (for discussion: add user_id and idx {user_id, file_id}).
  323. updatedRows, err := tx.QueryContext(ctx, `UPDATE collection_files
  324. SET updation_time = $1 WHERE file_id = ANY($2) AND is_deleted= false RETURNING collection_id`, updationTime,
  325. pq.Array(fileIDs))
  326. if err != nil {
  327. if rollbackErr := tx.Rollback(); rollbackErr != nil {
  328. log.WithError(rollbackErr).Error("transaction rollback failed")
  329. return stacktrace.Propagate(rollbackErr, "")
  330. }
  331. return stacktrace.Propagate(err, "")
  332. }
  333. defer updatedRows.Close()
  334. updatedCIDs := make([]int64, 0)
  335. for updatedRows.Next() {
  336. var cID int64
  337. err := updatedRows.Scan(&cID)
  338. if err != nil {
  339. return stacktrace.Propagate(err, "")
  340. }
  341. updatedCIDs = append(updatedCIDs, cID)
  342. }
  343. _, err = tx.ExecContext(ctx, `UPDATE collections SET updation_time = $1
  344. WHERE collection_id = ANY($2)`, updationTime, pq.Array(updatedCIDs))
  345. if err != nil {
  346. if rollbackErr := tx.Rollback(); rollbackErr != nil {
  347. log.WithError(rollbackErr).Error("transaction rollback failed")
  348. return stacktrace.Propagate(rollbackErr, "")
  349. }
  350. return stacktrace.Propagate(err, "")
  351. }
  352. return tx.Commit()
  353. }
  354. // Update updates the entry in the database for the given file
  355. func (repo *FileRepository) UpdateThumbnail(ctx context.Context, fileID int64, userID int64, thumbnail ente.FileAttributes, thumbnailSize int64, usageDiff int64, oldThumbnailObject *string) error {
  356. hotDC := repo.S3Config.GetHotDataCenter()
  357. dcsForNewEntry := pq.StringArray{hotDC}
  358. tx, err := repo.DB.BeginTx(ctx, nil)
  359. if err != nil {
  360. return stacktrace.Propagate(err, "")
  361. }
  362. updationTime := time.Microseconds()
  363. _, err = tx.ExecContext(ctx, `UPDATE files SET
  364. thumbnail_decryption_header = $1,
  365. updation_time = $2 WHERE file_id = $3`,
  366. thumbnail.DecryptionHeader,
  367. updationTime, fileID)
  368. if err != nil {
  369. tx.Rollback()
  370. return stacktrace.Propagate(err, "")
  371. }
  372. updatedRows, err := tx.QueryContext(ctx, `UPDATE collection_files
  373. SET updation_time = $1 WHERE file_id = $2 RETURNING collection_id`, updationTime,
  374. fileID)
  375. if err != nil {
  376. tx.Rollback()
  377. return stacktrace.Propagate(err, "")
  378. }
  379. defer updatedRows.Close()
  380. updatedCIDs := make([]int64, 0)
  381. for updatedRows.Next() {
  382. var cID int64
  383. err := updatedRows.Scan(&cID)
  384. if err != nil {
  385. return stacktrace.Propagate(err, "")
  386. }
  387. updatedCIDs = append(updatedCIDs, cID)
  388. }
  389. _, err = tx.ExecContext(ctx, `UPDATE collections SET updation_time = $1
  390. WHERE collection_id = ANY($2)`, updationTime, pq.Array(updatedCIDs))
  391. if err != nil {
  392. tx.Rollback()
  393. return stacktrace.Propagate(err, "")
  394. }
  395. if oldThumbnailObject != nil {
  396. _, err = tx.ExecContext(ctx, `DELETE FROM object_copies WHERE object_key = $1`,
  397. *oldThumbnailObject)
  398. if err != nil {
  399. tx.Rollback()
  400. return stacktrace.Propagate(err, "")
  401. }
  402. }
  403. _, err = tx.ExecContext(ctx, `UPDATE object_keys
  404. SET object_key = $1, size = $2, datacenters = $3 WHERE file_id = $4 AND o_type = $5`,
  405. thumbnail.ObjectKey, thumbnailSize, dcsForNewEntry, fileID, ente.THUMBNAIL)
  406. if err != nil {
  407. tx.Rollback()
  408. return stacktrace.Propagate(err, "")
  409. }
  410. _, err = repo.updateUsage(ctx, tx, userID, usageDiff)
  411. if err != nil {
  412. tx.Rollback()
  413. return stacktrace.Propagate(err, "")
  414. }
  415. err = repo.ObjectCleanupRepo.RemoveTempObjectKey(ctx, tx, thumbnail.ObjectKey, hotDC)
  416. if err != nil {
  417. tx.Rollback()
  418. return stacktrace.Propagate(err, "")
  419. }
  420. err = repo.markThumbnailAsNeedingReplication(ctx, tx, thumbnail.ObjectKey, hotDC)
  421. if err != nil {
  422. return stacktrace.Propagate(err, "")
  423. }
  424. if oldThumbnailObject != nil {
  425. err = repo.QueueRepo.AddItems(ctx, tx, OutdatedObjectsQueue, []string{*oldThumbnailObject})
  426. if err != nil {
  427. tx.Rollback()
  428. return stacktrace.Propagate(err, "")
  429. }
  430. }
  431. err = tx.Commit()
  432. return stacktrace.Propagate(err, "")
  433. }
  434. // GetOwnerID returns the ownerID for a file
  435. func (repo *FileRepository) GetOwnerID(fileID int64) (int64, error) {
  436. row := repo.DB.QueryRow(`SELECT owner_id FROM files WHERE file_id = $1`,
  437. fileID)
  438. var ownerID int64
  439. err := row.Scan(&ownerID)
  440. return ownerID, stacktrace.Propagate(err, "failed to get file owner")
  441. }
  442. // GetOwnerToFileCountMap will return a map of ownerId & number of files owned by that owner
  443. func (repo *FileRepository) GetOwnerToFileCountMap(ctx context.Context, fileIDs []int64) (map[int64]int64, error) {
  444. rows, err := repo.DB.QueryContext(ctx, `SELECT owner_id, count(*) FROM files WHERE file_id = ANY($1) group by owner_id`,
  445. pq.Array(fileIDs))
  446. if err != nil {
  447. return nil, stacktrace.Propagate(err, "")
  448. }
  449. defer rows.Close()
  450. result := make(map[int64]int64, 0)
  451. for rows.Next() {
  452. var ownerID, count int64
  453. if err = rows.Scan(&ownerID, &count); err != nil {
  454. return nil, stacktrace.Propagate(err, "")
  455. }
  456. result[ownerID] = count
  457. }
  458. return result, nil
  459. }
  460. // GetOwnerToFileIDsMap will return a map of ownerId & number of files owned by that owner
  461. func (repo *FileRepository) GetOwnerToFileIDsMap(ctx context.Context, fileIDs []int64) (map[int64][]int64, error) {
  462. rows, err := repo.DB.QueryContext(ctx, `SELECT owner_id, file_id FROM files WHERE file_id = ANY($1)`,
  463. pq.Array(fileIDs))
  464. if err != nil {
  465. return nil, stacktrace.Propagate(err, "")
  466. }
  467. defer rows.Close()
  468. result := make(map[int64][]int64, 0)
  469. for rows.Next() {
  470. var ownerID, fileID int64
  471. if err = rows.Scan(&ownerID, &fileID); err != nil {
  472. return nil, stacktrace.Propagate(err, "")
  473. }
  474. if ownerFileIDs, ok := result[ownerID]; ok {
  475. result[ownerID] = append(ownerFileIDs, fileID)
  476. } else {
  477. result[ownerID] = []int64{fileID}
  478. }
  479. }
  480. return result, nil
  481. }
  482. func (repo *FileRepository) VerifyFileOwner(ctx context.Context, fileIDs []int64, ownerID int64, logger *log.Entry) error {
  483. countMap, err := repo.GetOwnerToFileCountMap(ctx, fileIDs)
  484. if err != nil {
  485. return stacktrace.Propagate(err, "failed to get owners info")
  486. }
  487. logger = logger.WithFields(log.Fields{
  488. "owner_id": ownerID,
  489. "file_ids": fileIDs,
  490. "owners_map": countMap,
  491. })
  492. if len(countMap) == 0 {
  493. logger.Error("all fileIDs are invalid")
  494. return stacktrace.Propagate(ente.ErrBadRequest, "")
  495. }
  496. if len(countMap) > 1 {
  497. logger.Error("files are owned by multiple users")
  498. return stacktrace.Propagate(ente.ErrPermissionDenied, "")
  499. }
  500. if filesOwned, ok := countMap[ownerID]; ok {
  501. if filesOwned != int64(len(fileIDs)) {
  502. logger.WithField("file_owned", filesOwned).Error("failed to find all fileIDs")
  503. return stacktrace.Propagate(ente.ErrBadRequest, "")
  504. }
  505. return nil
  506. } else {
  507. logger.Error("user is not an owner of any file")
  508. return stacktrace.Propagate(ente.ErrPermissionDenied, "")
  509. }
  510. }
  511. // GetOwnerAndMagicMetadata returns the ownerID and magicMetadata for given file id
  512. func (repo *FileRepository) GetOwnerAndMagicMetadata(fileID int64, publicMetadata bool) (int64, *ente.MagicMetadata, error) {
  513. var row *sql.Row
  514. if publicMetadata {
  515. row = repo.DB.QueryRow(`SELECT owner_id, pub_magic_metadata FROM files WHERE file_id = $1`,
  516. fileID)
  517. } else {
  518. row = repo.DB.QueryRow(`SELECT owner_id, magic_metadata FROM files WHERE file_id = $1`,
  519. fileID)
  520. }
  521. var ownerID int64
  522. var magicMetadata *ente.MagicMetadata
  523. err := row.Scan(&ownerID, &magicMetadata)
  524. return ownerID, magicMetadata, stacktrace.Propagate(err, "")
  525. }
  526. // GetSize returns the size of files indicated by fileIDs that are owned by the given userID.
  527. func (repo *FileRepository) GetSize(userID int64, fileIDs []int64) (int64, error) {
  528. row := repo.DB.QueryRow(`
  529. SELECT COALESCE(SUM(size), 0) FROM object_keys WHERE o_type = 'file' AND is_deleted = false AND file_id = ANY(SELECT file_id FROM files WHERE (file_id = ANY($1) AND owner_id = $2))`,
  530. pq.Array(fileIDs), userID)
  531. var size int64
  532. err := row.Scan(&size)
  533. if err != nil {
  534. return -1, stacktrace.Propagate(err, "")
  535. }
  536. return size, nil
  537. }
  538. // GetFileCountForUser returns the total number of files in the system for a given user.
  539. func (repo *FileRepository) GetFileCountForUser(userID int64, app ente.App) (int64, error) {
  540. row := repo.DB.QueryRow(`SELECT count(distinct files.file_id)
  541. FROM collection_files
  542. JOIN collections c on c.owner_id = $1 and c.collection_id = collection_files.collection_id
  543. JOIN files ON
  544. files.owner_id = $1 AND files.file_id = collection_files.file_id
  545. WHERE (c.app = $2 AND collection_files.is_deleted = false);`, userID, app)
  546. var fileCount int64
  547. err := row.Scan(&fileCount)
  548. if err != nil {
  549. return -1, stacktrace.Propagate(err, "")
  550. }
  551. return fileCount, nil
  552. }
  553. func (repo *FileRepository) GetFileAttributesFromObjectKey(objectKey string) (ente.File, error) {
  554. s3ObjectKeys, err := repo.ObjectRepo.GetAllFileObjectsByObjectKey(objectKey)
  555. if err != nil {
  556. return ente.File{}, stacktrace.Propagate(err, "")
  557. }
  558. if len(s3ObjectKeys) != 2 {
  559. return ente.File{}, stacktrace.Propagate(fmt.Errorf("unexpected file count: %d", len(s3ObjectKeys)), "")
  560. }
  561. var file ente.File
  562. file.ID = s3ObjectKeys[0].FileID // all file IDs should be same as per query in GetAllFileObjectsByObjectKey
  563. row := repo.DB.QueryRow(`SELECT owner_id, file_decryption_header, thumbnail_decryption_header, metadata_decryption_header, encrypted_metadata FROM files WHERE file_id = $1`, file.ID)
  564. err = row.Scan(&file.OwnerID,
  565. &file.File.DecryptionHeader, &file.Thumbnail.DecryptionHeader,
  566. &file.Metadata.DecryptionHeader,
  567. &file.Metadata.EncryptedData)
  568. if err != nil {
  569. return ente.File{}, err
  570. }
  571. for _, object := range s3ObjectKeys {
  572. if object.Type == ente.FILE {
  573. file.File.ObjectKey = object.ObjectKey
  574. file.File.Size = object.FileSize
  575. } else if object.Type == ente.THUMBNAIL {
  576. file.Thumbnail.ObjectKey = object.ObjectKey
  577. file.Thumbnail.Size = object.FileSize
  578. } else {
  579. err = fmt.Errorf("unexpted file type %s", object.Type)
  580. return ente.File{}, stacktrace.Propagate(err, "")
  581. }
  582. }
  583. return file, nil
  584. }
  585. func (repo *FileRepository) GetFileAttributesForCopy(fileIDs []int64) ([]ente.File, error) {
  586. result := make([]ente.File, 0)
  587. rows, err := repo.DB.Query(`SELECT file_id, owner_id, file_decryption_header, thumbnail_decryption_header, metadata_decryption_header, encrypted_metadata, pub_magic_metadata FROM files WHERE file_id = ANY($1)`, pq.Array(fileIDs))
  588. if err != nil {
  589. return nil, stacktrace.Propagate(err, "")
  590. }
  591. defer rows.Close()
  592. for rows.Next() {
  593. var file ente.File
  594. err := rows.Scan(&file.ID, &file.OwnerID, &file.File.DecryptionHeader, &file.Thumbnail.DecryptionHeader, &file.Metadata.DecryptionHeader, &file.Metadata.EncryptedData, &file.PubicMagicMetadata)
  595. if err != nil {
  596. return nil, stacktrace.Propagate(err, "")
  597. }
  598. result = append(result, file)
  599. }
  600. return result, nil
  601. }
  602. // GetUsage gets the Storage usage of a user
  603. // Deprecated: GetUsage is deprecated, use UsageRepository.GetUsage
  604. func (repo *FileRepository) GetUsage(userID int64) (int64, error) {
  605. return repo.UsageRepo.GetUsage(userID)
  606. }
  607. func (repo *FileRepository) DropFilesMetadata(ctx context.Context, fileIDs []int64) error {
  608. // ensure that the fileIDs are not present in object_keys
  609. rows, err := repo.DB.QueryContext(ctx, `SELECT distinct(file_id) FROM object_keys WHERE file_id = ANY($1)`, pq.Array(fileIDs))
  610. if err != nil {
  611. return stacktrace.Propagate(err, "")
  612. }
  613. defer rows.Close()
  614. fileIdsNotDeleted := make([]int64, 0)
  615. for rows.Next() {
  616. var fileID int64
  617. err := rows.Scan(&fileID)
  618. if err != nil {
  619. return stacktrace.Propagate(err, "")
  620. }
  621. fileIdsNotDeleted = append(fileIdsNotDeleted, fileID)
  622. }
  623. if len(fileIdsNotDeleted) > 0 {
  624. return stacktrace.Propagate(fmt.Errorf("fileIDs %v are still present in object_keys", fileIdsNotDeleted), "")
  625. }
  626. _, err = repo.DB.ExecContext(ctx, `
  627. UPDATE files SET encrypted_metadata = '-',
  628. metadata_decryption_header = '-',
  629. file_decryption_header = '-',
  630. thumbnail_decryption_header = '-',
  631. magic_metadata = NULL,
  632. pub_magic_metadata = NULL,
  633. info = NULL
  634. where file_id = ANY($1)`, pq.Array(fileIDs))
  635. return stacktrace.Propagate(err, "")
  636. }
  637. // GetDuplicateFiles returns the list of files for a user that are of the same size
  638. func (repo *FileRepository) GetDuplicateFiles(userID int64) ([]ente.DuplicateFiles, error) {
  639. rows, err := repo.DB.Query(`SELECT string_agg(o.file_id::character varying, ','), o.size FROM object_keys o JOIN files f ON f.file_id = o.file_id
  640. WHERE f.owner_id = $1 AND o.o_type = 'file' AND o.is_deleted = false
  641. GROUP BY size
  642. HAVING count(*) > 1;`, userID)
  643. if err != nil {
  644. return nil, stacktrace.Propagate(err, "")
  645. }
  646. defer rows.Close()
  647. result := make([]ente.DuplicateFiles, 0)
  648. for rows.Next() {
  649. var res string
  650. var size int64
  651. err := rows.Scan(&res, &size)
  652. if err != nil {
  653. return result, stacktrace.Propagate(err, "")
  654. }
  655. fileIDStrs := strings.Split(res, ",")
  656. fileIDs := make([]int64, 0)
  657. for _, fileIDStr := range fileIDStrs {
  658. fileID, err := strconv.ParseInt(fileIDStr, 10, 64)
  659. if err != nil {
  660. return result, stacktrace.Propagate(err, "")
  661. }
  662. fileIDs = append(fileIDs, fileID)
  663. }
  664. result = append(result, ente.DuplicateFiles{FileIDs: fileIDs, Size: size})
  665. }
  666. return result, nil
  667. }
  668. func (repo *FileRepository) GetLargeThumbnailFiles(userID int64, threshold int64) ([]int64, error) {
  669. rows, err := repo.DB.Query(`
  670. SELECT file_id FROM object_keys WHERE o_type = 'thumbnail' AND is_deleted = false AND size >= $2 AND file_id = ANY(SELECT file_id FROM files WHERE owner_id = $1)`,
  671. userID, threshold)
  672. if err != nil {
  673. return nil, stacktrace.Propagate(err, "")
  674. }
  675. defer rows.Close()
  676. result := make([]int64, 0)
  677. for rows.Next() {
  678. var fileID int64
  679. err := rows.Scan(&fileID)
  680. if err != nil {
  681. return result, stacktrace.Propagate(err, "")
  682. }
  683. result = append(result, fileID)
  684. }
  685. return result, nil
  686. }
  687. func (repo *FileRepository) GetTotalFileCount() (int64, error) {
  688. // 9,522,438 is the magic number that accommodates the bumping up of fileIDs
  689. // Doing this magic instead of count(*) since it's faster
  690. row := repo.DB.QueryRow(`select (select max(file_id) from files) - (select 9522438)`)
  691. var count int64
  692. err := row.Scan(&count)
  693. return count, stacktrace.Propagate(err, "")
  694. }
  695. func convertRowsToFiles(rows *sql.Rows) ([]ente.File, error) {
  696. defer rows.Close()
  697. files := make([]ente.File, 0)
  698. for rows.Next() {
  699. var (
  700. file ente.File
  701. updationTime float64
  702. )
  703. err := rows.Scan(&file.ID, &file.OwnerID, &file.CollectionID, &file.CollectionOwnerID,
  704. &file.EncryptedKey, &file.KeyDecryptionNonce,
  705. &file.File.DecryptionHeader, &file.Thumbnail.DecryptionHeader,
  706. &file.Metadata.DecryptionHeader,
  707. &file.Metadata.EncryptedData, &file.MagicMetadata, &file.PubicMagicMetadata,
  708. &file.Info, &file.IsDeleted, &updationTime)
  709. if err != nil {
  710. return files, stacktrace.Propagate(err, "")
  711. }
  712. file.UpdationTime = int64(updationTime)
  713. files = append(files, file)
  714. }
  715. return files, nil
  716. }
  717. // scheduleDeletion added a list of files's object ids to delete queue for deletion from datastore
  718. func (repo *FileRepository) scheduleDeletion(ctx context.Context, tx *sql.Tx, fileIDs []int64, userID int64) error {
  719. diff := int64(0)
  720. objectsToBeDeleted, err := repo.ObjectRepo.MarkObjectsAsDeletedForFileIDs(ctx, tx, fileIDs)
  721. if err != nil {
  722. return stacktrace.Propagate(err, "file object deletion failed for fileIDs: %v", fileIDs)
  723. }
  724. totalObjectSize := int64(0)
  725. for _, object := range objectsToBeDeleted {
  726. totalObjectSize += object.FileSize
  727. }
  728. diff = diff - (totalObjectSize)
  729. _, err = repo.updateUsage(ctx, tx, userID, diff)
  730. return stacktrace.Propagate(err, "")
  731. }
  732. // updateUsage updates the storage usage of a user and returns the updated value
  733. func (repo *FileRepository) updateUsage(ctx context.Context, tx *sql.Tx, userID int64, diff int64) (int64, error) {
  734. row := tx.QueryRowContext(ctx, `SELECT storage_consumed FROM usage WHERE user_id = $1 FOR UPDATE`, userID)
  735. var usage int64
  736. err := row.Scan(&usage)
  737. if err != nil {
  738. if errors.Is(err, sql.ErrNoRows) {
  739. usage = 0
  740. } else {
  741. return -1, stacktrace.Propagate(err, "")
  742. }
  743. }
  744. newUsage := usage + diff
  745. _, err = tx.ExecContext(ctx, `INSERT INTO usage (user_id, storage_consumed)
  746. VALUES ($1, $2)
  747. ON CONFLICT (user_id) DO UPDATE
  748. SET storage_consumed = $2`,
  749. userID, newUsage)
  750. if err != nil {
  751. return -1, stacktrace.Propagate(err, "")
  752. }
  753. return newUsage, nil
  754. }