Przeglądaj źródła

refactor(server): partner core (#2678)

* refactor(server): partner core

* refactor(server): partner access check
Jason Rasmussen 2 lat temu
rodzic
commit
6ce35d47f5

+ 5 - 3
server/apps/immich/src/api-v1/asset/asset.service.spec.ts

@@ -9,9 +9,9 @@ import { AssetCountByUserIdResponseDto } from './response-dto/asset-count-by-use
 import { DownloadService } from '../../modules/download/download.service';
 import { AlbumRepository, IAlbumRepository } from '../album/album-repository';
 import {
+  IAccessRepository,
   ICryptoRepository,
   IJobRepository,
-  IPartnerRepository,
   ISharedLinkRepository,
   IStorageRepository,
   JobName,
@@ -20,6 +20,7 @@ import {
   assetEntityStub,
   authStub,
   fileStub,
+  newAccessRepositoryMock,
   newCryptoRepositoryMock,
   newJobRepositoryMock,
   newSharedLinkRepositoryMock,
@@ -131,10 +132,10 @@ const _getArchivedAssetsCountByUserId = (): AssetCountByUserIdResponseDto => {
 describe('AssetService', () => {
   let sut: AssetService;
   let a: Repository<AssetEntity>; // TO BE DELETED AFTER FINISHED REFACTORING
+  let accessMock: jest.Mocked<IAccessRepository>;
   let assetRepositoryMock: jest.Mocked<IAssetRepository>;
   let albumRepositoryMock: jest.Mocked<IAlbumRepository>;
   let downloadServiceMock: jest.Mocked<Partial<DownloadService>>;
-  let partnerRepositoryMock: jest.Mocked<IPartnerRepository>;
   let sharedLinkRepositoryMock: jest.Mocked<ISharedLinkRepository>;
   let cryptoMock: jest.Mocked<ICryptoRepository>;
   let jobMock: jest.Mocked<IJobRepository>;
@@ -173,12 +174,14 @@ describe('AssetService', () => {
       downloadArchive: jest.fn(),
     };
 
+    accessMock = newAccessRepositoryMock();
     sharedLinkRepositoryMock = newSharedLinkRepositoryMock();
     jobMock = newJobRepositoryMock();
     cryptoMock = newCryptoRepositoryMock();
     storageMock = newStorageRepositoryMock();
 
     sut = new AssetService(
+      accessMock,
       assetRepositoryMock,
       albumRepositoryMock,
       a,
@@ -187,7 +190,6 @@ describe('AssetService', () => {
       jobMock,
       cryptoMock,
       storageMock,
-      partnerRepositoryMock,
     );
 
     when(assetRepositoryMock.get)

+ 5 - 7
server/apps/immich/src/api-v1/asset/asset.service.ts

@@ -25,12 +25,12 @@ import { CuratedObjectsResponseDto } from './response-dto/curated-objects-respon
 import {
   AssetResponseDto,
   getLivePhotoMotionFilename,
+  IAccessRepository,
   ImmichReadStream,
   IStorageRepository,
   JobName,
   mapAsset,
   mapAssetWithoutExif,
-  PartnerCore,
 } from '@app/domain';
 import { CreateAssetDto, UploadFile } from './dto/create-asset.dto';
 import { DeleteAssetResponseDto, DeleteAssetStatusEnum } from './response-dto/delete-asset-response.dto';
@@ -55,7 +55,6 @@ import { DownloadService } from '../../modules/download/download.service';
 import { DownloadDto } from './dto/download-library.dto';
 import { IAlbumRepository } from '../album/album-repository';
 import { SharedLinkCore } from '@app/domain';
-import { IPartnerRepository } from '@app/domain';
 import { ISharedLinkRepository } from '@app/domain';
 import { DownloadFilesDto } from './dto/download-files.dto';
 import { CreateAssetsShareLinkDto } from './dto/create-asset-shared-link.dto';
@@ -82,9 +81,9 @@ export class AssetService {
   readonly logger = new Logger(AssetService.name);
   private shareCore: SharedLinkCore;
   private assetCore: AssetCore;
-  private partnerCore: PartnerCore;
 
   constructor(
+    @Inject(IAccessRepository) private accessRepository: IAccessRepository,
     @Inject(IAssetRepository) private _assetRepository: IAssetRepository,
     @Inject(IAlbumRepository) private _albumRepository: IAlbumRepository,
     @InjectRepository(AssetEntity)
@@ -94,11 +93,9 @@ export class AssetService {
     @Inject(IJobRepository) private jobRepository: IJobRepository,
     @Inject(ICryptoRepository) cryptoRepository: ICryptoRepository,
     @Inject(IStorageRepository) private storageRepository: IStorageRepository,
-    @Inject(IPartnerRepository) private partnerRepository: IPartnerRepository,
   ) {
     this.assetCore = new AssetCore(_assetRepository, jobRepository);
     this.shareCore = new SharedLinkCore(sharedLinkRepository, cryptoRepository);
-    this.partnerCore = new PartnerCore(partnerRepository);
   }
 
   public async uploadFile(
@@ -581,7 +578,7 @@ export class AssetService {
         }
 
         // Step 3: Check if any partner owns the asset
-        const canAccess = await this.partnerCore.hasAssetAccess(assetId, authUser.id);
+        const canAccess = await this.accessRepository.hasPartnerAssetAccess(authUser.id, assetId);
         if (canAccess) {
           continue;
         }
@@ -601,7 +598,8 @@ export class AssetService {
 
   private async checkUserAccess(authUser: AuthUserDto, userId: string) {
     // Check if userId shares assets with authUser
-    if (!(await this.partnerCore.get({ sharedById: userId, sharedWithId: authUser.id }))) {
+    const canAccess = await this.accessRepository.hasPartnerAccess(authUser.id, userId);
+    if (!canAccess) {
       throw new ForbiddenException();
     }
   }

+ 6 - 0
server/libs/domain/src/access/access.repository.ts

@@ -0,0 +1,6 @@
+export const IAccessRepository = 'IAccessRepository';
+
+export interface IAccessRepository {
+  hasPartnerAccess(userId: string, partnerId: string): Promise<boolean>;
+  hasPartnerAssetAccess(userId: string, assetId: string): Promise<boolean>;
+}

+ 1 - 0
server/libs/domain/src/access/index.ts

@@ -0,0 +1 @@
+export * from './access.repository';

+ 2 - 1
server/libs/domain/src/index.ts

@@ -1,3 +1,4 @@
+export * from './access';
 export * from './album';
 export * from './api-key';
 export * from './asset';
@@ -13,10 +14,10 @@ export * from './job';
 export * from './media';
 export * from './metadata';
 export * from './oauth';
+export * from './partner';
 export * from './person';
 export * from './search';
 export * from './server-info';
-export * from './partner';
 export * from './shared-link';
 export * from './smart-info';
 export * from './storage';

+ 0 - 1
server/libs/domain/src/partner/index.ts

@@ -1,3 +1,2 @@
-export * from './partner.core';
 export * from './partner.repository';
 export * from './partner.service';

+ 0 - 33
server/libs/domain/src/partner/partner.core.ts

@@ -1,33 +0,0 @@
-import { PartnerEntity } from '@app/infra/entities';
-import { IPartnerRepository, PartnerIds } from './partner.repository';
-
-export enum PartnerDirection {
-  SharedBy = 'shared-by',
-  SharedWith = 'shared-with',
-}
-
-export class PartnerCore {
-  constructor(private repository: IPartnerRepository) {}
-
-  async getAll(userId: string, direction: PartnerDirection): Promise<PartnerEntity[]> {
-    const partners = await this.repository.getAll(userId);
-    const key = direction === PartnerDirection.SharedBy ? 'sharedById' : 'sharedWithId';
-    return partners.filter((partner) => partner[key] === userId);
-  }
-
-  get(ids: PartnerIds): Promise<PartnerEntity | null> {
-    return this.repository.get(ids);
-  }
-
-  async create(ids: PartnerIds): Promise<PartnerEntity> {
-    return this.repository.create(ids);
-  }
-
-  async remove(ids: PartnerIds): Promise<void> {
-    await this.repository.remove(ids as PartnerEntity);
-  }
-
-  hasAssetAccess(assetId: string, userId: string): Promise<boolean> {
-    return this.repository.hasAssetAccess(assetId, userId);
-  }
-}

+ 5 - 1
server/libs/domain/src/partner/partner.repository.ts

@@ -5,6 +5,11 @@ export interface PartnerIds {
   sharedWithId: string;
 }
 
+export enum PartnerDirection {
+  SharedBy = 'shared-by',
+  SharedWith = 'shared-with',
+}
+
 export const IPartnerRepository = 'IPartnerRepository';
 
 export interface IPartnerRepository {
@@ -12,5 +17,4 @@ export interface IPartnerRepository {
   get(partner: PartnerIds): Promise<PartnerEntity | null>;
   create(partner: PartnerIds): Promise<PartnerEntity>;
   remove(entity: PartnerEntity): Promise<void>;
-  hasAssetAccess(assetId: string, userId: string): Promise<boolean>;
 }

+ 1 - 2
server/libs/domain/src/partner/partner.service.spec.ts

@@ -1,7 +1,6 @@
 import { BadRequestException } from '@nestjs/common';
 import { authStub, newPartnerRepositoryMock, partnerStub } from '../../test';
-import { PartnerDirection } from './partner.core';
-import { IPartnerRepository } from './partner.repository';
+import { IPartnerRepository, PartnerDirection } from './partner.repository';
 import { PartnerService } from './partner.service';
 
 const responseDto = {

+ 9 - 12
server/libs/domain/src/partner/partner.service.ts

@@ -1,41 +1,38 @@
 import { PartnerEntity } from '@app/infra/entities';
 import { BadRequestException, Inject, Injectable } from '@nestjs/common';
 import { AuthUserDto } from '../auth';
-import { IPartnerRepository, PartnerCore, PartnerDirection, PartnerIds } from '../partner';
+import { IPartnerRepository, PartnerDirection, PartnerIds } from '../partner';
 import { mapUser, UserResponseDto } from '../user';
 
 @Injectable()
 export class PartnerService {
-  private partnerCore: PartnerCore;
-
-  constructor(@Inject(IPartnerRepository) partnerRepository: IPartnerRepository) {
-    this.partnerCore = new PartnerCore(partnerRepository);
-  }
+  constructor(@Inject(IPartnerRepository) private repository: IPartnerRepository) {}
 
   async create(authUser: AuthUserDto, sharedWithId: string): Promise<UserResponseDto> {
     const partnerId: PartnerIds = { sharedById: authUser.id, sharedWithId };
-    const exists = await this.partnerCore.get(partnerId);
+    const exists = await this.repository.get(partnerId);
     if (exists) {
       throw new BadRequestException(`Partner already exists`);
     }
 
-    const partner = await this.partnerCore.create(partnerId);
+    const partner = await this.repository.create(partnerId);
     return this.map(partner, PartnerDirection.SharedBy);
   }
 
   async remove(authUser: AuthUserDto, sharedWithId: string): Promise<void> {
     const partnerId: PartnerIds = { sharedById: authUser.id, sharedWithId };
-    const partner = await this.partnerCore.get(partnerId);
+    const partner = await this.repository.get(partnerId);
     if (!partner) {
       throw new BadRequestException('Partner not found');
     }
 
-    await this.partnerCore.remove(partner);
+    await this.repository.remove(partner);
   }
 
   async getAll(authUser: AuthUserDto, direction: PartnerDirection): Promise<UserResponseDto[]> {
-    const partners = await this.partnerCore.getAll(authUser.id, direction);
-    return partners.map((partner) => this.map(partner, direction));
+    const partners = await this.repository.getAll(authUser.id);
+    const key = direction === PartnerDirection.SharedBy ? 'sharedById' : 'sharedWithId';
+    return partners.filter((partner) => partner[key] === authUser.id).map((partner) => this.map(partner, direction));
   }
 
   private map(partner: PartnerEntity, direction: PartnerDirection): UserResponseDto {

+ 8 - 0
server/libs/domain/test/access.repository.mock.ts

@@ -0,0 +1,8 @@
+import { IAccessRepository } from '../src';
+
+export const newAccessRepositoryMock = (): jest.Mocked<IAccessRepository> => {
+  return {
+    hasPartnerAccess: jest.fn(),
+    hasPartnerAssetAccess: jest.fn(),
+  };
+};

+ 1 - 0
server/libs/domain/test/index.ts

@@ -1,3 +1,4 @@
+export * from './access.repository.mock';
 export * from './album.repository.mock';
 export * from './api-key.repository.mock';
 export * from './asset.repository.mock';

+ 0 - 1
server/libs/domain/test/partner.repository.mock.ts

@@ -6,6 +6,5 @@ export const newPartnerRepositoryMock = (): jest.Mocked<IPartnerRepository> => {
     remove: jest.fn(),
     getAll: jest.fn(),
     get: jest.fn(),
-    hasAssetAccess: jest.fn(),
   };
 };

+ 3 - 0
server/libs/infra/src/infra.module.ts

@@ -1,4 +1,5 @@
 import {
+  IAccessRepository,
   IAlbumRepository,
   IAssetRepository,
   ICommunicationRepository,
@@ -30,6 +31,7 @@ import { databaseConfig } from './database.config';
 import { databaseEntities } from './entities';
 import { bullConfig, bullQueues } from './infra.config';
 import {
+  AccessRepository,
   AlbumRepository,
   APIKeyRepository,
   AssetRepository,
@@ -53,6 +55,7 @@ import {
 } from './repositories';
 
 const providers: Provider[] = [
+  { provide: IAccessRepository, useClass: AccessRepository },
   { provide: IAlbumRepository, useClass: AlbumRepository },
   { provide: IAssetRepository, useClass: AssetRepository },
   { provide: ICommunicationRepository, useClass: CommunicationRepository },

+ 38 - 0
server/libs/infra/src/repositories/access.repository.ts

@@ -0,0 +1,38 @@
+import { IAccessRepository } from '@app/domain';
+import { InjectRepository } from '@nestjs/typeorm';
+import { Repository } from 'typeorm';
+import { PartnerEntity } from '../entities';
+
+export class AccessRepository implements IAccessRepository {
+  constructor(@InjectRepository(PartnerEntity) private partnerRepository: Repository<PartnerEntity>) {}
+
+  hasPartnerAccess(userId: string, partnerId: string): Promise<boolean> {
+    return this.partnerRepository.exist({
+      where: {
+        sharedWithId: userId,
+        sharedById: partnerId,
+      },
+    });
+  }
+
+  hasPartnerAssetAccess(userId: string, assetId: string): Promise<boolean> {
+    return this.partnerRepository.exist({
+      where: {
+        sharedWith: {
+          id: userId,
+        },
+        sharedBy: {
+          assets: {
+            id: assetId,
+          },
+        },
+      },
+      relations: {
+        sharedWith: true,
+        sharedBy: {
+          assets: true,
+        },
+      },
+    });
+  }
+}

+ 1 - 0
server/libs/infra/src/repositories/index.ts

@@ -1,3 +1,4 @@
+export * from './access.repository';
 export * from './album.repository';
 export * from './api-key.repository';
 export * from './asset.repository';

+ 0 - 23
server/libs/infra/src/repositories/partner.repository.ts

@@ -24,27 +24,4 @@ export class PartnerRepository implements IPartnerRepository {
   async remove(entity: PartnerEntity): Promise<void> {
     await this.repository.remove(entity);
   }
-
-  async hasAssetAccess(assetId: string, userId: string): Promise<boolean> {
-    const count = await this.repository.count({
-      where: {
-        sharedWith: {
-          id: userId,
-        },
-        sharedBy: {
-          assets: {
-            id: assetId,
-          },
-        },
-      },
-      relations: {
-        sharedWith: true,
-        sharedBy: {
-          assets: true,
-        },
-      },
-    });
-
-    return count == 1;
-  }
 }