from typing import Optional

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
from botocore.exceptions import ClientError


class SageMakerMonitoringScheduleOperator(SageMakerBaseOperator):
    def __init__(
        self,
        *,
        config: dict,
        wait_for_completion: bool = True,
        print_log: bool = True,
        check_interval: int = 30,
        max_ingestion_time: Optional[int] = None,
        operation: str = "create",  # If name found, this will automatically be changed to "update"
        action_if_job_exists: str = "increment",
        **kwargs,
    ):
        super().__init__(config=config, **kwargs)

        if action_if_job_exists not in ("increment", "fail"):
            raise AirflowException(
                "Argument action_if_job_exists accepts only 'increment' and 'fail'. "
                f"Provided value: '{action_if_job_exists}'."
            )
        self.action_if_job_exists = action_if_job_exists
        self.wait_for_completion = wait_for_completion
        self.print_log = print_log
        self.check_interval = check_interval
        self.max_ingestion_time = max_ingestion_time
        self.operation = operation.lower()

        if operation not in ["create", "update"]:
            raise AirflowException(
                "Argument operation accepts only 'create' and 'update'. "
                f"Provided value: '{operation}'."
            )

        self._create_integer_fields()

    def _create_integer_fields(self) -> None:
        """Set fields which should be casted to integers."""
        self.integer_fields = [
            [
                "MonitoringScheduleConfig",
                "MonitoringJobDefinition",
                "MonitoringResources",
                "ClusterConfig",
                "InstanceCount",
            ],
            [
                "MonitoringScheduleConfig",
                "MonitoringJobDefinition",
                "MonitoringResources",
                "ClusterConfig",
                "VolumeSizeInGB",
            ],
        ]

        if (
            "StoppingCondition"
            in self.config["MonitoringScheduleConfig"]["MonitoringJobDefinition"]
        ):
            self.integer_fields += [
                [
                    "MonitoringScheduleConfig",
                    "MonitoringJobDefinition",
                    "StoppingCondition",
                    "MaxRuntimeInSeconds",
                ]
            ]

    def expand_role(self) -> None:
        if "RoleArn" in self.config:
            hook = AwsBaseHook(self.aws_conn_id, client_type="iam")
            self.config["RoleArn"] = hook.expand_role(self.config["RoleArn"])

    def execute(self, context) -> dict:
        self.preprocess_config()

        if self.operation == "create":
            sagemaker_operation = self.create_monitoring_job
            log_str = "Creating"
        elif self.operation == "update":
            sagemaker_operation = self.update_monitoring_job
            log_str = "Updating"

        if self.operation not in ["create", "update"]:
            raise AirflowException(
                "Argument operation accepts only 'create' and 'update'. "
                f"Provided value: '{self.operation}'."
            )

        self.log.info(
            "%s SageMaker Model Monitoring Job %s.",
            log_str,
            self.config["MonitoringScheduleName"],
        )

        try:
            response = sagemaker_operation(
                self.config,
                wait_for_completion=self.wait_for_completion,
                check_interval=self.check_interval,
                max_ingestion_time=self.max_ingestion_time,
            )
        except ClientError:  # Botocore throws a ClientError if the model monitor is already created
            self.log.info("Attempt failed: %s SageMaker Model Monitoring Job.", log_str)
            if self.operation == "update":
                raise

            self.operation = "update"
            log_str = "Updating"
            sagemaker_operation = self.update_monitoring_job
            self.log.info(
                "%s SageMaker Model Monitoring Job %s.",
                log_str,
                self.config["MonitoringScheduleName"],
            )
            response = sagemaker_operation(
                self.config,
                wait_for_completion=self.wait_for_completion,
                check_interval=self.check_interval,
                max_ingestion_time=self.max_ingestion_time,
            )

        if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
            raise AirflowException(
                f"SageMaker Model Monitor scheduling failed: {response}"
            )
        return {
            "MonitorScheduling": self.describe_monitoring_job(
                self.config["MonitoringScheduleName"]
            )
        }

    def create_monitoring_job(
        self,
        config: dict,
        wait_for_completion: bool = True,
        check_interval: int = 30,
        max_ingestion_time: Optional[int] = None,
    ):
        response = self.hook.get_conn().create_monitoring_schedule(**config)
        if wait_for_completion:
            self.hook.check_status(
                job_name=config["MonitoringScheduleName"],
                key="MonitoringScheduleStatus",
                describe_function=self.describe_monitoring_job,
                check_interval=check_interval,
                max_ingestion_time=max_ingestion_time,
                non_terminal_states={
                    "Pending"
                },  # possible states: Pending, Failed, Scheduled, or Stopped
            )
        return response

    def update_monitoring_job(
        self,
        config: dict,
        wait_for_completion: bool = True,
        check_interval: int = 30,
        max_ingestion_time: Optional[int] = None,
    ):
        response = self.hook.get_conn().update_monitoring_schedule(**config)
        if wait_for_completion:
            self.hook.check_status(
                job_name=config["MonitoringScheduleName"],
                key="MonitoringScheduleStatus",
                describe_function=self.describe_monitoring_job,
                check_interval=check_interval,
                max_ingestion_time=max_ingestion_time,
                non_terminal_states={
                    "Pending"
                },  # possible states: Pending, Failed, Scheduled, or Stopped
            )
        return response

    def describe_monitoring_job(self, name: str) -> dict:
        return self.hook.get_conn().describe_monitoring_schedule(
            MonitoringScheduleName=name
        )
