Browse Source

feat(server): link via profile.sub (#1055)

Jason Rasmussen 2 years ago
parent
commit
99854e90be

+ 1 - 0
server/apps/immich/src/api-v1/auth/auth.service.spec.ts

@@ -39,6 +39,7 @@ describe('AuthService', () => {
     userRepositoryMock = {
     userRepositoryMock = {
       get: jest.fn(),
       get: jest.fn(),
       getAdmin: jest.fn(),
       getAdmin: jest.fn(),
+      getByOAuthId: jest.fn(),
       getByEmail: jest.fn(),
       getByEmail: jest.fn(),
       getList: jest.fn(),
       getList: jest.fn(),
       create: jest.fn(),
       create: jest.fn(),

+ 24 - 1
server/apps/immich/src/api-v1/oauth/oauth.service.spec.ts

@@ -20,12 +20,14 @@ const mockConfig = (config: Partial<OAuthConfig>) => {
 };
 };
 
 
 const email = 'user@immich.com';
 const email = 'user@immich.com';
+const sub = 'my-auth-user-sub';
 
 
 const user = {
 const user = {
   id: 'user',
   id: 'user',
   email,
   email,
   firstName: 'user',
   firstName: 'user',
   lastName: 'imimch',
   lastName: 'imimch',
+  oauthId: '',
 } as UserEntity;
 } as UserEntity;
 
 
 const loginResponse = {
 const loginResponse = {
@@ -53,13 +55,14 @@ describe('OAuthService', () => {
         authorizationUrl: jest.fn().mockReturnValue('http://authorization-url'),
         authorizationUrl: jest.fn().mockReturnValue('http://authorization-url'),
         callbackParams: jest.fn().mockReturnValue({ state: 'state' }),
         callbackParams: jest.fn().mockReturnValue({ state: 'state' }),
         callback: jest.fn().mockReturnValue({ access_token: 'access-token' }),
         callback: jest.fn().mockReturnValue({ access_token: 'access-token' }),
-        userinfo: jest.fn().mockResolvedValue({ email }),
+        userinfo: jest.fn().mockResolvedValue({ sub, email }),
       }),
       }),
     } as any);
     } as any);
 
 
     userRepositoryMock = {
     userRepositoryMock = {
       get: jest.fn(),
       get: jest.fn(),
       getAdmin: jest.fn(),
       getAdmin: jest.fn(),
+      getByOAuthId: jest.fn(),
       getByEmail: jest.fn(),
       getByEmail: jest.fn(),
       getList: jest.fn(),
       getList: jest.fn(),
       create: jest.fn(),
       create: jest.fn(),
@@ -132,6 +135,26 @@ describe('OAuthService', () => {
       expect(userRepositoryMock.getByEmail).toHaveBeenCalledTimes(1);
       expect(userRepositoryMock.getByEmail).toHaveBeenCalledTimes(1);
     });
     });
 
 
+    it('should link an existing user', async () => {
+      configServiceMock.get.mockImplementation(
+        mockConfig({
+          OAUTH_ENABLED: true,
+          OAUTH_AUTO_REGISTER: false,
+        }),
+      );
+      sut = new OAuthService(immichJwtServiceMock, configServiceMock, userRepositoryMock);
+      jest.spyOn(sut['logger'], 'debug').mockImplementation(() => null);
+      jest.spyOn(sut['logger'], 'warn').mockImplementation(() => null);
+      userRepositoryMock.getByEmail.mockResolvedValue(user);
+      userRepositoryMock.update.mockResolvedValue(user);
+      immichJwtServiceMock.createLoginResponse.mockResolvedValue(loginResponse);
+
+      await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' })).resolves.toEqual(loginResponse);
+
+      expect(userRepositoryMock.getByEmail).toHaveBeenCalledTimes(1);
+      expect(userRepositoryMock.update).toHaveBeenCalledWith(user.id, { oauthId: sub });
+    });
+
     it('should allow auto registering by default', async () => {
     it('should allow auto registering by default', async () => {
       configServiceMock.get.mockImplementation(mockConfig({ OAUTH_ENABLED: true }));
       configServiceMock.get.mockImplementation(mockConfig({ OAUTH_ENABLED: true }));
       sut = new OAuthService(immichJwtServiceMock, configServiceMock, userRepositoryMock);
       sut = new OAuthService(immichJwtServiceMock, configServiceMock, userRepositoryMock);

+ 12 - 2
server/apps/immich/src/api-v1/oauth/oauth.service.ts

@@ -63,8 +63,17 @@ export class OAuthService {
     const profile = await client.userinfo<OAuthProfile>(tokens.access_token || '');
     const profile = await client.userinfo<OAuthProfile>(tokens.access_token || '');
 
 
     this.logger.debug(`Logging in with OAuth: ${JSON.stringify(profile)}`);
     this.logger.debug(`Logging in with OAuth: ${JSON.stringify(profile)}`);
-    let user = await this.userRepository.getByEmail(profile.email);
+    let user = await this.userRepository.getByOAuthId(profile.sub);
 
 
+    // link existing user
+    if (!user) {
+      const emailUser = await this.userRepository.getByEmail(profile.email);
+      if (emailUser) {
+        user = await this.userRepository.update(emailUser.id, { oauthId: profile.sub });
+      }
+    }
+
+    // register new user
     if (!user) {
     if (!user) {
       if (!this.autoRegister) {
       if (!this.autoRegister) {
         this.logger.warn(
         this.logger.warn(
@@ -73,11 +82,12 @@ export class OAuthService {
         throw new BadRequestException(`User does not exist and auto registering is disabled.`);
         throw new BadRequestException(`User does not exist and auto registering is disabled.`);
       }
       }
 
 
-      this.logger.log(`Registering new user: ${profile.email}`);
+      this.logger.log(`Registering new user: ${profile.email}/${profile.sub}`);
       user = await this.userRepository.create({
       user = await this.userRepository.create({
         firstName: profile.given_name || '',
         firstName: profile.given_name || '',
         lastName: profile.family_name || '',
         lastName: profile.family_name || '',
         email: profile.email,
         email: profile.email,
+        oauthId: profile.sub,
       });
       });
     }
     }
 
 

+ 5 - 0
server/apps/immich/src/api-v1/user/user-repository.ts

@@ -8,6 +8,7 @@ export interface IUserRepository {
   get(id: string, withDeleted?: boolean): Promise<UserEntity | null>;
   get(id: string, withDeleted?: boolean): Promise<UserEntity | null>;
   getAdmin(): Promise<UserEntity | null>;
   getAdmin(): Promise<UserEntity | null>;
   getByEmail(email: string, withPassword?: boolean): Promise<UserEntity | null>;
   getByEmail(email: string, withPassword?: boolean): Promise<UserEntity | null>;
+  getByOAuthId(oauthId: string): Promise<UserEntity | null>;
   getList(filter?: { excludeId?: string }): Promise<UserEntity[]>;
   getList(filter?: { excludeId?: string }): Promise<UserEntity[]>;
   create(user: Partial<UserEntity>): Promise<UserEntity>;
   create(user: Partial<UserEntity>): Promise<UserEntity>;
   update(id: string, user: Partial<UserEntity>): Promise<UserEntity>;
   update(id: string, user: Partial<UserEntity>): Promise<UserEntity>;
@@ -41,6 +42,10 @@ export class UserRepository implements IUserRepository {
     return builder.getOne();
     return builder.getOne();
   }
   }
 
 
+  public async getByOAuthId(oauthId: string): Promise<UserEntity | null> {
+    return this.userRepository.findOne({ where: { oauthId } });
+  }
+
   public async getList({ excludeId }: { excludeId?: string } = {}): Promise<UserEntity[]> {
   public async getList({ excludeId }: { excludeId?: string } = {}): Promise<UserEntity[]> {
     if (!excludeId) {
     if (!excludeId) {
       return this.userRepository.find(); // TODO: this should also be ordered the same as below
       return this.userRepository.find(); // TODO: this should also be ordered the same as below

+ 3 - 0
server/apps/immich/src/api-v1/user/user.service.spec.ts

@@ -27,6 +27,7 @@ describe('UserService', () => {
     firstName: 'admin_first_name',
     firstName: 'admin_first_name',
     lastName: 'admin_last_name',
     lastName: 'admin_last_name',
     isAdmin: true,
     isAdmin: true,
+    oauthId: '',
     shouldChangePassword: false,
     shouldChangePassword: false,
     profileImagePath: '',
     profileImagePath: '',
     createdAt: '2021-01-01',
     createdAt: '2021-01-01',
@@ -40,6 +41,7 @@ describe('UserService', () => {
     firstName: 'immich_first_name',
     firstName: 'immich_first_name',
     lastName: 'immich_last_name',
     lastName: 'immich_last_name',
     isAdmin: false,
     isAdmin: false,
+    oauthId: '',
     shouldChangePassword: false,
     shouldChangePassword: false,
     profileImagePath: '',
     profileImagePath: '',
     createdAt: '2021-01-01',
     createdAt: '2021-01-01',
@@ -53,6 +55,7 @@ describe('UserService', () => {
     firstName: 'updated_immich_first_name',
     firstName: 'updated_immich_first_name',
     lastName: 'updated_immich_last_name',
     lastName: 'updated_immich_last_name',
     isAdmin: false,
     isAdmin: false,
+    oauthId: '',
     shouldChangePassword: true,
     shouldChangePassword: true,
     profileImagePath: '',
     profileImagePath: '',
     createdAt: '2021-01-01',
     createdAt: '2021-01-01',

+ 1 - 0
server/apps/immich/src/modules/immich-jwt/immich-jwt.service.spec.ts

@@ -52,6 +52,7 @@ describe('ImmichJwtService', () => {
         email: 'test@immich.com',
         email: 'test@immich.com',
         password: 'changeme',
         password: 'changeme',
         salt: '123',
         salt: '123',
+        oauthId: '',
         profileImagePath: '',
         profileImagePath: '',
         shouldChangePassword: false,
         shouldChangePassword: false,
         createdAt: 'today',
         createdAt: 'today',

+ 1 - 0
server/apps/immich/test/test-utils.ts

@@ -20,6 +20,7 @@ export function newUserRepositoryMock(): jest.Mocked<IUserRepository> {
     get: jest.fn(),
     get: jest.fn(),
     getAdmin: jest.fn(),
     getAdmin: jest.fn(),
     getByEmail: jest.fn(),
     getByEmail: jest.fn(),
+    getByOAuthId: jest.fn(),
     getList: jest.fn(),
     getList: jest.fn(),
     create: jest.fn(),
     create: jest.fn(),
     update: jest.fn(),
     update: jest.fn(),

+ 3 - 0
server/libs/database/src/entities/user.entity.ts

@@ -23,6 +23,9 @@ export class UserEntity {
   @Column({ default: '', select: false })
   @Column({ default: '', select: false })
   salt?: string;
   salt?: string;
 
 
+  @Column({ default: '', select: false })
+  oauthId!: string;
+
   @Column({ default: '' })
   @Column({ default: '' })
   profileImagePath!: string;
   profileImagePath!: string;
 
 

+ 14 - 0
server/libs/database/src/migrations/1670104716264-OAuthId.ts

@@ -0,0 +1,14 @@
+import { MigrationInterface, QueryRunner } from "typeorm";
+
+export class OAuthId1670104716264 implements MigrationInterface {
+    name = 'OAuthId1670104716264'
+
+    public async up(queryRunner: QueryRunner): Promise<void> {
+        await queryRunner.query(`ALTER TABLE "users" ADD "oauthId" character varying NOT NULL DEFAULT ''`);
+    }
+
+    public async down(queryRunner: QueryRunner): Promise<void> {
+        await queryRunner.query(`ALTER TABLE "users" DROP COLUMN "oauthId"`);
+    }
+
+}