from abc import ABC, abstractmethod
from typing import Any, Dict

import jwt
import requests
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from jwt import DecodeError
from requests import Response

from app.extensions.utils.log_helper import logger_
from core.domains.oauth.enum.oauth_enum import (
    OAuthKakaoEnum,
    OAuthGoogleEnum,
    OAuthAppleEnum,
)
from core.exceptions import (
    InvalidRequestException,
    NotFoundException,
    TokenValidationErrorException,
)

logger = logger_.getLogger(__name__)


class OAuthBase:
    def __init__(self, auth_header: str | None):
        self._access_token: str = self._get_oauth_token(auth_header)
        self._base_headers: dict = {
            "Content-Type": "application/x-www-form-urlencoded;charset=utf-8",
            "Cache-Control": "no-cache",
        }

    @property
    def access_token(self):
        return self._access_token

    def _get_oauth_token(self, auth_header: str | None) -> str:
        prefix = "Bearer"

        if not auth_header:
            raise NotFoundException(msg="Not found Authorization in Header")

        bearer, _, token = auth_header.partition(" ")

        if bearer != prefix:
            raise InvalidRequestException(message="Invalid Token format")

        if not token:
            raise NotFoundException(msg="Not found Token in Header")

        return token

    def request_validation(self) -> Response:
        """
            각 플랫폼마다 개별 구현하기
        """
        pass


class OAuthKakao(OAuthBase):
    def __init__(self, auth_header: str):
        super().__init__(auth_header)
        self._base_headers.update({"Authorization": "Bearer " + self._access_token})
        self._url = (
            OAuthKakaoEnum.API_BASE_URL.value + OAuthKakaoEnum.USER_INFO_END_POINT.value
        )

    def request_validation(self) -> Response:
        return requests.get(url=self._url, headers=self._base_headers)


class OAuthGoogle(OAuthBase):
    def __init__(self, auth_header: str):
        super().__init__(auth_header)
        self._base_headers.update({"Authorization": "Bearer " + self._access_token})
        self._url: str = OAuthGoogleEnum.USER_INFO_URL.value

    def request_validation(self) -> Response:
        return requests.get(url=self._url, headers=self._base_headers)


class OAuthApple(OAuthBase):
    def __init__(self, auth_header: str):
        super().__init__(auth_header)
        self._url: str = OAuthAppleEnum.FIREBASE_AUTH_URL.value

    def request_validation(self) -> str:
        # Get Firebase Apple auth public keys
        response: Response = self._get_firebase_auth_keys()

        try:
            response.raise_for_status()
        except Exception:
            raise InvalidRequestException(
                message=f"Failed get auth public keys from Firebase, error:{response.json()}"
            )

        public_keys: dict = response.json()

        # Find Apple correct auth public key
        try:
            firebase_token_header = jwt.get_unverified_header(self._access_token)
        except DecodeError:
            raise InvalidRequestException(f"Invalid JWT format, not enough segment")

        if not (
            firebase_token_header.get("kid")
            or firebase_token_header.get("alg")
            or firebase_token_header.get("typ")
        ):
            raise InvalidRequestException(
                f"Invalid auth public key header, Not Firebase Apple's id_token"
            )

        firebase_cert_key = public_keys.get(firebase_token_header.get("kid"))

        if not firebase_cert_key:
            raise NotFoundException(
                msg=f"Not found correct Firebase cert key, Failed Apple OAuth Login"
            )

        decoded_token = self._get_decoded_firebase_token(
            token=self._access_token,
            cert=firebase_cert_key,
            algorithm=firebase_token_header.get("alg"),
        )
        try:
            apple_sub = decoded_token.get("firebase")["identities"]["apple.com"][0]
        except Exception:
            raise InvalidRequestException(f"Invalid firebase token, Can not decoded")

        if not apple_sub:
            raise TokenValidationErrorException(
                msg=f"Not found Apple_sub in decoded token, Failed Apple OAuth Login"
            )
        return apple_sub

    def _get_firebase_auth_keys(self) -> Response:
        return requests.get(url=self._url, headers=self._base_headers)

    def _get_decoded_firebase_token(
        self, token: str, cert: str, algorithm: str
    ) -> Dict[str, Any]:
        cert_to_bytes = cert.encode("utf-8")
        public_key = x509.load_pem_x509_certificate(
            data=cert_to_bytes, backend=default_backend()
        ).public_key()
        try:
            decoded_token = jwt.decode(
                jwt=token,
                key=public_key,
                algorithms=[algorithm],
                audience=OAuthAppleEnum.FIREBASE_AUDIENCE.value,
            )
            return decoded_token
        except Exception as e:
            logger.error(f"[OAuthApple][get_decoded_firebase_token] error : {e}")
            raise InvalidRequestException(message=e)
