diff --git a/backend/src/auth/auth.controller.ts b/backend/src/auth/auth.controller.ts index 4bc197d..20c876f 100644 --- a/backend/src/auth/auth.controller.ts +++ b/backend/src/auth/auth.controller.ts @@ -21,6 +21,7 @@ import { AuthRegisterDTO } from "./dto/authRegister.dto"; import { AuthSignInDTO } from "./dto/authSignIn.dto"; import { AuthSignInTotpDTO } from "./dto/authSignInTotp.dto"; import { EnableTotpDTO } from "./dto/enableTotp.dto"; +import { TokenDTO } from "./dto/token.dto"; import { UpdatePasswordDTO } from "./dto/updatePassword.dto"; import { VerifyTotpDTO } from "./dto/verifyTotp.dto"; import { JwtGuard } from "./guard/jwt.guard"; @@ -45,8 +46,8 @@ export class AuthController { response = this.addTokensToResponse( response, - result.accessToken, - result.refreshToken + result.refreshToken, + result.accessToken ); return result; @@ -64,8 +65,8 @@ export class AuthController { if (result.accessToken && result.refreshToken) { response = this.addTokensToResponse( response, - result.accessToken, - result.refreshToken + result.refreshToken, + result.accessToken ); } @@ -83,17 +84,28 @@ export class AuthController { response = this.addTokensToResponse( response, - result.accessToken, - result.refreshToken + result.refreshToken, + result.accessToken ); - return result; + return new TokenDTO().from(result); } @Patch("password") @UseGuards(JwtGuard) - async updatePassword(@GetUser() user: User, @Body() dto: UpdatePasswordDTO) { - await this.authService.updatePassword(user, dto.oldPassword, dto.password); + async updatePassword( + @GetUser() user: User, + @Res({ passthrough: true }) response: Response, + @Body() dto: UpdatePasswordDTO + ) { + const result = await this.authService.updatePassword( + user, + dto.oldPassword, + dto.password + ); + + response = this.addTokensToResponse(response, result.refreshToken); + return new TokenDTO().from(result); } @Post("token") @@ -108,7 +120,7 @@ export class AuthController { request.cookies.refresh_token ); response.cookie("access_token", accessToken); - return { accessToken }; + return new TokenDTO().from({ accessToken }); } @Post("signOut") @@ -146,15 +158,16 @@ export class AuthController { private addTokensToResponse( response: Response, - accessToken: string, - refreshToken: string + refreshToken?: string, + accessToken?: string ) { - response.cookie("access_token", accessToken); - response.cookie("refresh_token", refreshToken, { - path: "/api/auth/token", - httpOnly: true, - maxAge: 1000 * 60 * 60 * 24 * 30 * 3, - }); + if (accessToken) response.cookie("access_token", accessToken); + if (refreshToken) + response.cookie("refresh_token", refreshToken, { + path: "/api/auth/token", + httpOnly: true, + maxAge: 1000 * 60 * 60 * 24 * 30 * 3, + }); return response; } diff --git a/backend/src/auth/auth.service.ts b/backend/src/auth/auth.service.ts index e309589..8254ab9 100644 --- a/backend/src/auth/auth.service.ts +++ b/backend/src/auth/auth.service.ts @@ -87,10 +87,16 @@ export class AuthService { const hash = await argon.hash(newPassword); - return await this.prisma.user.update({ + await this.prisma.refreshToken.deleteMany({ + where: { userId: user.id }, + }); + + await this.prisma.user.update({ where: { id: user.id }, data: { password: hash }, }); + + return this.createRefreshToken(user.id); } async createAccessToken(user: User, refreshTokenId: string) { @@ -112,7 +118,12 @@ export class AuthService { refreshTokenId: string; }; - await this.prisma.refreshToken.delete({ where: { id: refreshTokenId } }); + await this.prisma.refreshToken + .delete({ where: { id: refreshTokenId } }) + .catch((e) => { + // Ignore error if refresh token doesn't exist + if (e.code != "P2025") throw e; + }); } async refreshAccessToken(refreshToken: string) { diff --git a/backend/src/auth/dto/token.dto.ts b/backend/src/auth/dto/token.dto.ts new file mode 100644 index 0000000..8810a1e --- /dev/null +++ b/backend/src/auth/dto/token.dto.ts @@ -0,0 +1,15 @@ +import { Expose, plainToClass } from "class-transformer"; + +export class TokenDTO { + @Expose() + accessToken: string; + + @Expose() + refreshToken: string; + + from(partial: Partial) { + return plainToClass(TokenDTO, partial, { + excludeExtraneousValues: true, + }); + } +}