import unittest
from unittest import mock

import psycopg2.extras
import pytest
from airflow.models import Connection

from twitch_airflow_components.hooks.postgres import PostgresHook


class TestPostgresHookConn(unittest.TestCase):
    def setUp(self):
        super().setUp()

        self.connection = Connection(
            login="login", password="password", host="host", schema="schema"
        )

        class UnitTestPostgresHook(PostgresHook):
            conn_name_attr = "test_conn_id"

        self.db_hook = UnitTestPostgresHook()
        self.db_hook.get_connection = mock.Mock()
        self.db_hook.get_connection.return_value = self.connection

    @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
    def test_get_conn_non_default_id(self, mock_connect):
        self.db_hook.test_conn_id = (
            "non_default"  # pylint: disable=attribute-defined-outside-init
        )
        self.db_hook.get_conn()
        mock_connect.assert_called_once_with(
            user="login", password="password", host="host", dbname="schema", port=None
        )
        self.db_hook.get_connection.assert_called_once_with("non_default")

    @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
    def test_get_conn(self, mock_connect):
        self.db_hook.get_conn()
        mock_connect.assert_called_once_with(
            user="login", password="password", host="host", dbname="schema", port=None
        )

    @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
    def test_get_conn_cursor(self, mock_connect):
        self.connection.extra = '{"cursor": "dictcursor"}'
        self.db_hook.get_conn()
        mock_connect.assert_called_once_with(
            cursor_factory=psycopg2.extras.DictCursor,
            user="login",
            password="password",
            host="host",
            dbname="schema",
            port=None,
        )

    @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
    def test_get_conn_with_invalid_cursor(self, mock_connect):
        self.connection.extra = '{"cursor": "mycursor"}'
        with pytest.raises(ValueError):
            self.db_hook.get_conn()

    @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
    def test_get_conn_from_connection(self, mock_connect):
        conn = Connection(
            login="login-conn", password="password-conn", host="host", schema="schema"
        )
        hook = PostgresHook(connection=conn)
        hook.get_conn()
        mock_connect.assert_called_once_with(
            user="login-conn",
            password="password-conn",
            host="host",
            dbname="schema",
            port=None,
        )

    @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
    def test_get_conn_from_connection_with_schema(self, mock_connect):
        conn = Connection(
            login="login-conn", password="password-conn", host="host", schema="schema"
        )
        hook = PostgresHook(connection=conn, schema="schema-override")
        hook.get_conn()
        mock_connect.assert_called_once_with(
            user="login-conn",
            password="password-conn",
            host="host",
            dbname="schema-override",
            port=None,
        )

    @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
    @mock.patch(
        "airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_client_type"
    )
    def test_get_conn_rds_iam_postgres(self, mock_client, mock_connect):
        self.connection.extra = '{"iam":true}'
        mock_client.return_value.generate_db_auth_token.return_value = "aws_token"
        self.db_hook.get_conn()
        mock_connect.assert_called_once_with(
            user="login", password="aws_token", host="host", dbname="schema", port=5432
        )

    @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
    def test_get_conn_extra(self, mock_connect):
        self.connection.extra = '{"connect_timeout": 3}'
        self.db_hook.get_conn()
        mock_connect.assert_called_once_with(
            user="login",
            password="password",
            host="host",
            dbname="schema",
            port=None,
            connect_timeout=3,
        )

    @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
    @mock.patch(
        "airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_client_type"
    )
    def test_get_conn_rds_iam_redshift(self, mock_client, mock_connect):
        self.connection.extra = '{"iam":true, "redshift":true, "cluster-identifier": "different-identifier"}'
        self.connection.host = (
            "cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com"
        )
        login = f"IAM:{self.connection.login}"
        mock_client.return_value.get_cluster_credentials.return_value = {
            "DbPassword": "aws_token",
            "DbUser": login,
        }
        self.db_hook.get_conn()
        get_cluster_credentials_call = mock.call(
            DbUser=self.connection.login,
            DbName=self.connection.schema,
            ClusterIdentifier="different-identifier",
            AutoCreate=False,
        )
        mock_client.return_value.get_cluster_credentials.assert_has_calls(
            [get_cluster_credentials_call]
        )
        mock_connect.assert_called_once_with(
            user=login,
            password="aws_token",
            host=self.connection.host,
            dbname="schema",
            port=5439,
        )
        # Verify that the connection object has not been mutated.
        self.db_hook.get_conn()
        mock_client.return_value.get_cluster_credentials.assert_has_calls(
            [get_cluster_credentials_call, get_cluster_credentials_call]
        )
