Browse Source

feat(server): Per user asset access control (#993)

* Limit asset access to owner

* Check public albums for asset

* Clean up

* Fix test

* Rename repository method

* Simplify control flow

* Revert "Simplify control flow"

This reverts commit 7bc3cbf687aa28cf1cd2674b2bfff6da03024199.

* Revert Makefile change
Matthias Rupp 2 years ago
parent
commit
e8bbad6772

+ 1 - 1
Makefile

@@ -26,4 +26,4 @@ prod-scale:
 	docker-compose -f ./docker/docker-compose.yml up --build -V --scale immich-server=3 --scale immich-microservices=3 --remove-orphans
 
 api:
-	cd ./server && npm run api:generate
+	cd ./server && npm run api:generate

+ 14 - 0
server/apps/immich/src/api-v1/album/album-repository.ts

@@ -25,6 +25,7 @@ export interface IAlbumRepository {
   updateAlbum(album: AlbumEntity, updateAlbumDto: UpdateAlbumDto): Promise<AlbumEntity>;
   getListByAssetId(userId: string, assetId: string): Promise<AlbumEntity[]>;
   getCountByUserId(userId: string): Promise<AlbumCountResponseDto>;
+  getSharedWithUserAlbumCount(userId: string, assetId: string): Promise<number>;
 }
 
 export const ALBUM_REPOSITORY = 'ALBUM_REPOSITORY';
@@ -283,4 +284,17 @@ export class AlbumRepository implements IAlbumRepository {
 
     return this.albumRepository.save(album);
   }
+
+  async getSharedWithUserAlbumCount(userId: string, assetId: string): Promise<number> {
+    const result = await this
+        .userAlbumRepository
+        .createQueryBuilder('usa')
+        .select('count(aa)', 'count')
+        .innerJoin('asset_album', 'aa', 'aa.albumId = usa.albumId')
+        .where('aa.assetId = :assetId', { assetId })
+        .andWhere('usa.sharedUserId = :userId', { userId })
+        .getRawOne();
+
+    return result.count;
+  }
 }

+ 2 - 0
server/apps/immich/src/api-v1/album/album.service.spec.ts

@@ -123,6 +123,7 @@ describe('Album service', () => {
       updateAlbum: jest.fn(),
       getListByAssetId: jest.fn(),
       getCountByUserId: jest.fn(),
+      getSharedWithUserAlbumCount: jest.fn(),
     };
 
     assetRepositoryMock = {
@@ -142,6 +143,7 @@ describe('Album service', () => {
       getAssetWithNoThumbnail: jest.fn(),
       getAssetWithNoSmartInfo: jest.fn(),
       getExistingAssets: jest.fn(),
+      countByIdAndUser: jest.fn(),
     };
 
     downloadServiceMock = {

+ 10 - 0
server/apps/immich/src/api-v1/asset/asset-repository.ts

@@ -43,6 +43,7 @@ export interface IAssetRepository {
     userId: string,
     checkDuplicateAssetDto: CheckExistingAssetsDto,
   ): Promise<CheckExistingAssetsResponseDto>;
+  countByIdAndUser(assetId: string, userId: string): Promise<number>;
 }
 
 export const ASSET_REPOSITORY = 'ASSET_REPOSITORY';
@@ -343,4 +344,13 @@ export class AssetRepository implements IAssetRepository {
     });
     return new CheckExistingAssetsResponseDto(existingAssets.map((a) => a.deviceAssetId));
   }
+
+  async countByIdAndUser(assetId: string, userId: string): Promise<number> {
+    return await this.assetRepository.count({
+        where: {
+          id: assetId,
+          userId
+      }
+    });
+  }
 }

+ 16 - 6
server/apps/immich/src/api-v1/asset/asset.controller.ts

@@ -21,7 +21,7 @@ import { FileFieldsInterceptor } from '@nestjs/platform-express';
 import { assetUploadOption } from '../../config/asset-upload.config';
 import { AuthUserDto, GetAuthUser } from '../../decorators/auth-user.decorator';
 import { ServeFileDto } from './dto/serve-file.dto';
-import { Response as Res} from 'express';
+import { Response as Res } from 'express';
 import { BackgroundTaskService } from '../../modules/background-task/background-task.service';
 import { DeleteAssetDto } from './dto/delete-asset.dto';
 import { SearchAssetDto } from './dto/search-asset.dto';
@@ -86,10 +86,12 @@ export class AssetController {
 
   @Get('/download/:assetId')
   async downloadFile(
+    @GetAuthUser() authUser: AuthUserDto,
     @Response({ passthrough: true }) res: Res,
     @Query(new ValidationPipe({ transform: true })) query: ServeFileDto,
     @Param('assetId') assetId: string,
   ): Promise<any> {
+    await this.assetService.checkAssetsAccess(authUser, [assetId]);
     return this.assetService.downloadFile(query, assetId, res);
   }
 
@@ -110,22 +112,26 @@ export class AssetController {
   @Get('/file/:assetId')
   @Header('Cache-Control', 'max-age=3600')
   async serveFile(
+    @GetAuthUser() authUser: AuthUserDto,
     @Headers() headers: Record<string, string>,
     @Response({ passthrough: true }) res: Res,
     @Query(new ValidationPipe({ transform: true })) query: ServeFileDto,
     @Param('assetId') assetId: string,
   ): Promise<any> {
+    await this.assetService.checkAssetsAccess(authUser, [assetId]);
     return this.assetService.serveFile(assetId, query, res, headers);
   }
 
   @Get('/thumbnail/:assetId')
   @Header('Cache-Control', 'max-age=3600')
   async getAssetThumbnail(
+    @GetAuthUser() authUser: AuthUserDto,
     @Headers() headers: Record<string, string>,
     @Response({ passthrough: true }) res: Res,
     @Param('assetId') assetId: string,
     @Query(new ValidationPipe({ transform: true })) query: GetAssetThumbnailDto,
   ): Promise<any> {
+    await this.assetService.checkAssetsAccess(authUser, [assetId]);
     return this.assetService.getAssetThumbnail(assetId, query, res, headers);
   }
 
@@ -203,7 +209,8 @@ export class AssetController {
     @GetAuthUser() authUser: AuthUserDto,
     @Param('assetId') assetId: string,
   ): Promise<AssetResponseDto> {
-    return await this.assetService.getAssetById(authUser, assetId);
+    await this.assetService.checkAssetsAccess(authUser, [assetId]);
+    return await this.assetService.getAssetById(assetId);
   }
 
   /**
@@ -215,7 +222,8 @@ export class AssetController {
     @Param('assetId') assetId: string,
     @Body() dto: UpdateAssetDto,
   ): Promise<AssetResponseDto> {
-    return await this.assetService.updateAssetById(authUser, assetId, dto);
+    await this.assetService.checkAssetsAccess(authUser, [assetId], true);
+    return await this.assetService.updateAssetById(assetId, dto);
   }
 
   @Delete('/')
@@ -223,17 +231,19 @@ export class AssetController {
     @GetAuthUser() authUser: AuthUserDto,
     @Body(ValidationPipe) assetIds: DeleteAssetDto,
   ): Promise<DeleteAssetResponseDto[]> {
+    await this.assetService.checkAssetsAccess(authUser, assetIds.ids, true);
+
     const deleteAssetList: AssetResponseDto[] = [];
 
     for (const id of assetIds.ids) {
-      const assets = await this.assetService.getAssetById(authUser, id);
+      const assets = await this.assetService.getAssetById(id);
       if (!assets) {
         continue;
       }
       deleteAssetList.push(assets);
 
       if (assets.livePhotoVideoId) {
-        const livePhotoVideo = await this.assetService.getAssetById(authUser, assets.livePhotoVideoId);
+        const livePhotoVideo = await this.assetService.getAssetById(assets.livePhotoVideoId);
         if (livePhotoVideo) {
           deleteAssetList.push(livePhotoVideo);
           assetIds.ids = [...assetIds.ids, livePhotoVideo.id];
@@ -241,7 +251,7 @@ export class AssetController {
       }
     }
 
-    const result = await this.assetService.deleteAssetById(authUser, assetIds);
+    const result = await this.assetService.deleteAssetById(assetIds);
 
     result.forEach((res) => {
       deleteAssetList.filter((a) => a.id == res.id && res.status == DeleteAssetStatusEnum.SUCCESS);

+ 10 - 1
server/apps/immich/src/api-v1/asset/asset.module.ts

@@ -10,13 +10,18 @@ import { CommunicationModule } from '../communication/communication.module';
 import { QueueNameEnum } from '@app/job/constants/queue-name.constant';
 import { AssetRepository, ASSET_REPOSITORY } from './asset-repository';
 import { DownloadModule } from '../../modules/download/download.module';
+import { ALBUM_REPOSITORY, AlbumRepository } from '../album/album-repository';
+import { AlbumEntity } from '@app/database/entities/album.entity';
+import { UserAlbumEntity } from '@app/database/entities/user-album.entity';
+import { UserEntity } from '@app/database/entities/user.entity';
+import { AssetAlbumEntity } from '@app/database/entities/asset-album.entity';
 
 @Module({
   imports: [
     CommunicationModule,
     BackgroundTaskModule,
     DownloadModule,
-    TypeOrmModule.forFeature([AssetEntity]),
+    TypeOrmModule.forFeature([AssetEntity, AlbumEntity, UserAlbumEntity, UserEntity, AssetAlbumEntity]),
     BullModule.registerQueue({
       name: QueueNameEnum.ASSET_UPLOADED,
       defaultJobOptions: {
@@ -42,6 +47,10 @@ import { DownloadModule } from '../../modules/download/download.module';
       provide: ASSET_REPOSITORY,
       useClass: AssetRepository,
     },
+    {
+      provide: ALBUM_REPOSITORY,
+      useClass: AlbumRepository,
+    },
   ],
   exports: [AssetService],
 })

+ 4 - 0
server/apps/immich/src/api-v1/asset/asset.service.spec.ts

@@ -11,11 +11,13 @@ import { DownloadService } from '../../modules/download/download.service';
 import { BackgroundTaskService } from '../../modules/background-task/background-task.service';
 import { IAssetUploadedJob, IVideoTranscodeJob } from '@app/job';
 import { Queue } from 'bull';
+import { IAlbumRepository } from "../album/album-repository";
 
 describe('AssetService', () => {
   let sui: AssetService;
   let a: Repository<AssetEntity>; // TO BE DELETED AFTER FINISHED REFACTORING
   let assetRepositoryMock: jest.Mocked<IAssetRepository>;
+  let albumRepositoryMock: jest.Mocked<IAlbumRepository>;
   let downloadServiceMock: jest.Mocked<Partial<DownloadService>>;
   let backgroundTaskServiceMock: jest.Mocked<BackgroundTaskService>;
   let assetUploadedQueueMock: jest.Mocked<Queue<IAssetUploadedJob>>;
@@ -122,6 +124,7 @@ describe('AssetService', () => {
       getAssetWithNoThumbnail: jest.fn(),
       getAssetWithNoSmartInfo: jest.fn(),
       getExistingAssets: jest.fn(),
+      countByIdAndUser: jest.fn(),
     };
 
     downloadServiceMock = {
@@ -130,6 +133,7 @@ describe('AssetService', () => {
 
     sui = new AssetService(
       assetRepositoryMock,
+      albumRepositoryMock,
       a,
       backgroundTaskServiceMock,
       assetUploadedQueueMock,

+ 27 - 8
server/apps/immich/src/api-v1/asset/asset.service.ts

@@ -54,6 +54,7 @@ import { InjectQueue } from '@nestjs/bull';
 import { Queue } from 'bull';
 import { DownloadService } from '../../modules/download/download.service';
 import { DownloadDto } from './dto/download-library.dto';
+import { ALBUM_REPOSITORY, IAlbumRepository } from '../album/album-repository';
 
 const fileInfo = promisify(stat);
 
@@ -63,6 +64,9 @@ export class AssetService {
     @Inject(ASSET_REPOSITORY)
     private _assetRepository: IAssetRepository,
 
+    @Inject(ALBUM_REPOSITORY)
+    private _albumRepository: IAlbumRepository,
+
     @InjectRepository(AssetEntity)
     private assetRepository: Repository<AssetEntity>,
 
@@ -221,22 +225,18 @@ export class AssetService {
     return assets.map((asset) => mapAsset(asset));
   }
 
-  public async getAssetById(authUser: AuthUserDto, assetId: string): Promise<AssetResponseDto> {
+  public async getAssetById(assetId: string): Promise<AssetResponseDto> {
     const asset = await this._assetRepository.getById(assetId);
 
     return mapAsset(asset);
   }
 
-  public async updateAssetById(authUser: AuthUserDto, assetId: string, dto: UpdateAssetDto): Promise<AssetResponseDto> {
+  public async updateAssetById(assetId: string, dto: UpdateAssetDto): Promise<AssetResponseDto> {
     const asset = await this._assetRepository.getById(assetId);
     if (!asset) {
       throw new BadRequestException('Asset not found');
     }
 
-    if (authUser.id !== asset.userId) {
-      throw new ForbiddenException('Not the owner');
-    }
-
     const updatedAsset = await this._assetRepository.update(asset, dto);
 
     return mapAsset(updatedAsset);
@@ -496,14 +496,13 @@ export class AssetService {
     }
   }
 
-  public async deleteAssetById(authUser: AuthUserDto, assetIds: DeleteAssetDto): Promise<DeleteAssetResponseDto[]> {
+  public async deleteAssetById(assetIds: DeleteAssetDto): Promise<DeleteAssetResponseDto[]> {
     const result: DeleteAssetResponseDto[] = [];
 
     const target = assetIds.ids;
     for (const assetId of target) {
       const res = await this.assetRepository.delete({
         id: assetId,
-        userId: authUser.id,
       });
 
       if (res.affected) {
@@ -642,6 +641,26 @@ export class AssetService {
   getAssetCountByUserId(authUser: AuthUserDto): Promise<AssetCountByUserIdResponseDto> {
     return this._assetRepository.getAssetCountByUserId(authUser.id);
   }
+
+  async checkAssetsAccess(authUser: AuthUserDto, assetIds: string[], mustBeOwner = false) {
+    for (const assetId of assetIds) {
+      // Step 1: Check if user owns asset
+      if ((await this._assetRepository.countByIdAndUser(assetId, authUser.id)) == 1) {
+        continue;
+      }
+
+      // Avoid additional checks if ownership is required
+      if (!mustBeOwner) {
+        // Step 2: Check if asset is part of an album shared with me
+        if ((await this._albumRepository.getSharedWithUserAlbumCount(authUser.id, assetId)) > 0) {
+          continue;
+        }
+
+        //TODO: Step 3: Check if asset is part of a public album
+      }
+      throw new ForbiddenException();
+    }
+  }
 }
 
 async function processETag(path: string, res: Res, headers: Record<string, string>): Promise<boolean> {