import pytest
import jwt
from flask_jwt_extended import create_access_token, decode_token
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.orm import scoped_session

from app.extensions.utils.time_helper import (
    get_jwt_access_expired_time_delta,
    get_jwt_refresh_expired_time_delta,
)
from app.persistence.model import BlacklistModel, JwtModel
from core.domains.authentication.dto.auth_dto import GetBlacklistDto
from core.domains.authentication.repository.auth_repository import (
    AuthenticationRepository,
)
from core.domains.user.repository.user_repository import UserRepository
from tests.seeder.conftest import make_random_today_date
from tests.seeder.factory import make_custom_jwt


def test_create_blacklist_when_get_blacklist_dto(session: scoped_session, create_user):
    """
        given : Blacklist DTO (user_id, access_token)
        when : 유저 로그 아웃
        then : Blacklist 생성 및 DB 저장
    """
    access_token = create_access_token(create_user.id)

    blacklist_dto = GetBlacklistDto(user_id=create_user.id, access_token=access_token)

    AuthenticationRepository().create_blacklist(dto=blacklist_dto)

    blacklist = session.query(BlacklistModel).filter_by(user_id=create_user.id).first()

    assert blacklist.user_id == create_user.id
    assert blacklist.access_token is not None


def test_update_token_when_get_user_id(
    db: SQLAlchemy, session: scoped_session, create_user
):
    """
        given : user, JWT
        when : 로그인 로직 -> 기존 유저 재로그인
        then : update token
    """
    UserRepository().create_jwt(create_user.id)
    token_before = session.query(JwtModel).filter_by(user_id=create_user.id).first()

    # create new session
    connection = db.engine.connect()
    options = dict(bind=connection, binds={})

    session_2 = db.create_scoped_session(options=options)

    # update token
    AuthenticationRepository().update_jwt(create_user.id)
    # query from new session
    token_after = session_2.query(JwtModel).filter_by(user_id=create_user.id).first()

    assert token_before.user_id == create_user.id
    assert token_before.user_id == token_after.user_id
    assert token_before.id == token_after.id

    # 업데이트된 전, 후 토큰 비교
    assert token_before.access_token != token_after.access_token
    assert token_before.refresh_token != token_after.refresh_token
    assert token_before.access_expired_at != token_after.access_expired_at
    assert token_before.refresh_expired_at != token_after.refresh_expired_at

    # session 2 remove
    session_2.remove()


def test_verify_token_when_get_valid_token_then_decode_success(
    session: scoped_session, create_user
):
    """
        given : JWT
        when : Valid Datetime
        then : Success decode
    """
    UserRepository().create_jwt(create_user.id)
    token_info = session.query(JwtModel).filter_by(user_id=create_user.id).first()

    decoded_access = decode_token(token_info.access_token)
    decoded_refresh = decode_token(token_info.refresh_token)
    assert decoded_access.get("identity") == create_user.id
    assert decoded_access.get("type") == "access"
    assert decoded_refresh.get("identity") == create_user.id
    assert decoded_refresh.get("type") == "refresh"


def test_verify_access_token_when_get_invalid_token_then_decode_fail(
    session: scoped_session, create_user
):
    """
        given : JWT (access_token)
        when : Invalid Datetime
        then : Fail decode
    """
    yesterday = make_random_today_date(1, 0)

    invalid_access_token = make_custom_jwt(
        create_user.id,
        now=yesterday,
        token_type="access",
        delta=get_jwt_access_expired_time_delta(),
    )

    with pytest.raises(jwt.ExpiredSignatureError):
        decode_token(invalid_access_token)


def test_verify_refresh_token_when_get_invalid_token_then_decode_fail(
    session: scoped_session, create_user
):
    """
        given : JWT (refresh_token)
        when : Invalid Datetime
        then : Fail decode
    """
    more_then_two_weeks_ago = make_random_today_date(15, 0)

    invalid_refresh_token = make_custom_jwt(
        create_user.id,
        now=more_then_two_weeks_ago,
        token_type="refresh",
        delta=get_jwt_refresh_expired_time_delta(),
    )

    with pytest.raises(jwt.ExpiredSignatureError):
        decode_token(invalid_refresh_token)
