FastAPI - JWT 인증 구현하기

프로필 사진mingke

FastAPI Logo

목차

소개

오늘은 며칠전 JWT인증을 만들어 달라는 요청을 받아서 구현했던 내용을 한 번 정리해보려고 합니다.

간단하게 JWT (JSON Web Token)를 설명하면 웹서비스 인증 시스템에 널리 사용되는 컴팩트한 토큰입니다. 헤더, 페이로드, 시그니처 세부분으로 나뉘어져 구성되어 있습니다.

FastAPI 공식 문서에 소개된 것을 조금 변형하여 구현해보도록 하겠습니다.

JWT 구현

  • jose 라이브러리를 사용합니다.
  • 설치 python-jose , jose 설치하면 안됩니다.
pip install python-jose

Encoder 구현

  • 우선 구현할 인코더 추상클래스를 만들어서 메소드 정의부터 해봅니다.
  • encoding할 데이터와 키 만료시간을 입력받아 인코딩합니다.
    • TIMEZONE은 서울로 했습니다. TIMEZONE은 알아서
  • 아주 간단하게 Encoder가 완료되었습니다.
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from zoneinfo import ZoneInfo
from jose import jwt, JWTError
 
class AbstractJWTEncoder(ABC):
    """
    JWT 인코더 추상클래스
    encode 메소드를 구현
 
    :param data: JWT에 담을 데이터
    :param expires_delta: JWT 만료 시간
    :param secret_key: JWT 암호화 키
    :param algorithm: JWT 암호화 알고리즘
    """
 
    @abstractmethod
    def encode(
        self, data: dict, expires_delta: int, secret_key: str, algorithm: str
    ) -> str:
        pass
 
class JWTEncoder(AbstractJWTEncoder):
    def encode(
        self, data: dict, expires_delta: int, secret_key: str, algorithm: str
    ) -> str:
        to_encode = data.copy()
        expire = datetime.now(ZoneInfo("Asia/Seoul")) + timedelta(minutes=expires_delta)
        to_encode.update({"exp": expire})
        return jwt.encode(to_encode, secret_key, algorithm=algorithm)

Decoder구현

  • 인코더와 마찬가지로 디코더 추상클래스를 만들어봅니다.
  • encoding된 JWT와 인코딩할 때 사용했던 동일한 secret_key와 알고리즘을 입력받아 디코딩합니다.
class AbstractJWTDecoder(ABC):
    """
    JWT 디코더 추상클래스
    decode 메소드를 구현
 
    :param token: JWT 토큰
    :param secret_key: JWT 암호화 키
    :param algorithm: JWT 암호화 알고리즘
    """
 
    @abstractmethod
    def decode(self, token: str, secret_key: str, algorithm: str) -> dict | None:
        pass
 
class JWTDecoder(AbstractJWTDecoder):
    def decode(self, token: str, secret_key: str, algorithm: str) -> dict | None:
        try:
            return jwt.decode(token, secret_key, algorithms=[algorithm])
        except JWTError:
            return None

JWT Service 구현

  • 서비스 객체에 access token, refresh token 생성 및 token 검사 기능 등을 구현해줍니다.
  • 인코더와 디코더를 주입해서 기능을 완성합니다.
class JWTService:
    """
    JWT 로그인시 access token, refresh token을 생성하는 로직
    """
 
    def __init__(
        self,
        encoder: JWTEncoder,
        decoder: JWTDecoder,
        algorithm: str = None,
        secret_key: str = None,
        access_token_expire_time: int = None,
        refresh_token_expire_time: int = None,
    ):
        self.encoder = encoder
        self.decoder = decoder
        self.algorithm = algorithm
        self.secret_key = secret_key
        self.access_token_expire_time = access_token_expire_time
        self.refresh_token_expire_time = refresh_token_expire_time
 
    def create_access_token(self, data: dict) -> str:
        return self._create_token(data, self.access_token_expire_time)
 
    def create_refresh_token(self, data: dict) -> str:
        return self._create_token(data, self.refresh_token_expire_time)
 
    def _create_token(self, data: dict, expires_delta: int) -> str:
        return self.encoder.encode(data, expires_delta, self.secret_key, self.algorithm)
 
    def check_token_expired(self, token: str) -> dict | None:
        payload = self.decoder.decode(token, self.secret_key, self.algorithm)
 
        now = datetime.timestamp(datetime.now(ZoneInfo("Asia/Seoul")))
        if payload and payload["exp"] < now:
            return None
 
        return payload

로그인 사용 예시

  • 생성한 jwt_service로 access_token과 refresh_token을 생성하여 클라이언트에 전달합니다.
jwt_service = JWTService(JWTEncoder(), JWTDecoder())
 
async def perform_login(user: Annotated[User, Depends(validate_user)]) -> User:
    data = {"id": str(user.id)}
    access_token = jwt_service.create_access_token(data)
    refresh_token = jwt_service.create_refresh_token(data)
    return {
        "access_token": access_token,
        "refresh_token": refresh_token,
    }

인증 예시

  • 공식 문서의 get_current_user 구현을 변형했습니다. 전체적인 로직은 비슷합니다.
    • Authorization 헤더를 가져와서 검증
    • 유효한 토큰인지 검증하고 디코딩
    • 디코딩한 데이터를 이용해 유저 정보 가져오기
    • 중간에 문제가 있으면 예외 발생
    • GetCurrentUser dependency만들어서 route에 주입
async def validate_token(request: Request) -> str | None:
    authorization = request.headers.get("Authorization")
    scheme, param = get_authorization_scheme_param(authorization)
    if not authorization or scheme.lower() != "bearer":
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Token is invalid",
            headers={"WWW-Authenticate": "Bearer"},
            code="INVALID_TOKEN",
        )
    return param
 
class JWTAuthentication:
    async def __call__(
        self,
        db: DB,
        token: Annotated[str, Depends(validate_token)],
    ):
        try:
            valid_payload = jwt_service.check_token_expired(token)
            if valid_payload:
                id = valid_payload.get("id")
            else:
                raise HTTPException(
                    status_code=status.HTTP_401_UNAUTHORIZED,
                    detail="Expired or changed token.",
                    headers={"WWW-Authenticate": "Bearer"},
                    code="EXPIRED_TOKEN",
                )
        except HTTPException as e:
            raise e
        else:
            user = await user_repository.get_user_by_id(db, id)
            if user is None:
                raise HTTPException(
                    status_code=status.HTTP_401_UNAUTHORIZED,
                    detail="Non-existent user",
                    headers={"WWW-Authenticate": "Bearer"},
                    code="INVALID_USER",
                )
            else:
                return user
 
get_current_user = JWTAuthentication()
GetCurrentUser = Annotated[User, Depends(get_current_user)]

마무리

JWT 구현 방법은 다양하게 있겠지만 전략패턴으로 한 번 구현해봤습니다. 이상으로 포스팅을 마치겠습니다.

Loading...