from dataclasses import dataclass
from typing import Any, Dict, Optional

from twitch_airflow_components.operators.sagemaker_model_monitoring_schedule import SageMakerMonitoringScheduleOperator

from conductor.internal.dag_utils import deep_update
from conductor.operators.operator_iface import IConfiguredOperator, TaskWrapper


@dataclass
class SageMakerMonitoringScheduleOutputs:
    schedule_name: str
    report_s3_url: str


class ConfiguredSageMakerMonitoringScheduleOperator(IConfiguredOperator):
    def generate_tasks(
        self,
        *,
        task_id: str,
        endpoint_name: str,
        statistics_json_url: str,
        constraints_json_url: str,
        instance_count: int = 1,
        instance_type: str = "ml.m5.large",
        volume_size_in_gb: int = 30,
        schedule_expression: str = "cron(0 * ? * * *)",
        report_s3_url: Optional[str] = None,
        schedule_name: Optional[str] = None,
        config: Optional[Dict[str, Any]] = None,
    ) -> TaskWrapper[
        SageMakerMonitoringScheduleOperator, SageMakerMonitoringScheduleOutputs
    ]:
        if schedule_name is None:
            # schedule_name is stable for repeated dag runs.
            # This means subsequent runs of this operator will update the monitor created in the first run.
            schedule_name_prefix = f"{endpoint_name}-{task_id}"
            schedule_name_suffix = "-default-monitor"
            schedule_name_prefix_max_len = 63 - len(schedule_name_suffix)
            if len(schedule_name_prefix) > schedule_name_prefix_max_len:
                schedule_name_prefix = schedule_name_prefix[
                    0:schedule_name_prefix_max_len
                ]

            schedule_name = schedule_name_prefix + schedule_name_suffix

        if report_s3_url is None:
            report_s3_url = self.dag_resources.s3_url_for_path([task_id])

        base_config: Dict[str, Any] = {
            "MonitoringScheduleName": schedule_name,
            "MonitoringScheduleConfig": {
                "ScheduleConfig": {
                    "ScheduleExpression": schedule_expression,
                },
                "MonitoringJobDefinition": {
                    "BaselineConfig": {
                        "ConstraintsResource": {
                            "S3Uri": constraints_json_url,
                        },
                        "StatisticsResource": {
                            "S3Uri": statistics_json_url,
                        },
                    },
                    "MonitoringInputs": [
                        {
                            "EndpointInput": {
                                "EndpointName": endpoint_name,
                                "LocalPath": "/opt/ml/processing/input/dataset",
                            }
                        },
                    ],
                    "MonitoringOutputConfig": {
                        "MonitoringOutputs": [
                            {
                                "S3Output": {
                                    "S3Uri": report_s3_url,
                                    "LocalPath": "/opt/ml/processing/output/dataset",
                                }
                            },
                        ],
                    },
                    "MonitoringResources": {
                        "ClusterConfig": {
                            "InstanceCount": instance_count,
                            "InstanceType": instance_type,
                            "VolumeSizeInGB": volume_size_in_gb,
                        }
                    },
                    "MonitoringAppSpecification": {
                        # default container for monitoring job
                        "ImageUri": "159807026194.dkr.ecr.us-west-2.amazonaws.com/sagemaker-model-monitor-analyzer",
                    },
                    "RoleArn": self.project_resources.sagemaker_execution_role(),
                },
                "MonitoringType": "DataQuality",
            },
        }

        if config is not None:
            deep_update(base_config, config)

        response_s3_url = (
            "{{ti.xcom_pull(task_ids='"
            + task_id
            + "')['MonitorScheduling']['MonitoringScheduleConfig']['MonitoringJobDefinition']"
            + "['MonitoringOutputConfig']['MonitoringOutputs'][0]['S3Output']['S3Uri']}} "
        )

        return TaskWrapper[
            SageMakerMonitoringScheduleOperator, SageMakerMonitoringScheduleOutputs
        ](
            SageMakerMonitoringScheduleOperator(
                config=base_config,
                task_id=task_id,
                aws_conn_id=None,
                dag=self.dag,
            ),
            SageMakerMonitoringScheduleOutputs(
                schedule_name=schedule_name, report_s3_url=response_s3_url
            ),
        )
