from dataclasses import dataclass
from typing import Any, Dict, Optional, Type

from twitch_airflow_components.operators.sagemaker_training import SageMakerTrainingOperator

from conductor.internal.dag_utils import deep_update
from conductor.operators.operator_iface import IConfiguredOperator, TaskWrapper
from conductor.types.model import Model
from conductor.utils.naming import timestamp_name


@dataclass
class SageMakerTrainingOutputs:
    s3_url: str


class ConfiguredSageMakerTrainingOperator(IConfiguredOperator):
    def generate_tasks(
        self,
        *,
        task_id: str,
        model_cls: Type[Model],
        config: Optional[Dict[str, Any]] = None,
    ) -> TaskWrapper[SageMakerTrainingOperator, SageMakerTrainingOutputs]:
        env_name = self.project_resources.env_name
        env_config = self.project_resources.env
        model_output_s3_prefix = self.dag_resources.s3_url_for_path([task_id])
        job_name = timestamp_name(
            f"{self.project_resources.name}-{task_id}-{self.dag_resources.name}", 63
        )
        base_config: Dict[str, Any] = {
            "TrainingJobName": job_name,
            "ResourceConfig": {
                "InstanceCount": 1,
                "InstanceType": "ml.t3.medium",
                "VolumeSizeInGB": 1,
            },
            "AlgorithmSpecification": {
                "TrainingImage": self.project_resources.ecr_url(),
                "TrainingInputMode": "File",
            },
            "RoleArn": self.project_resources.sagemaker_execution_role(),
            "OutputDataConfig": {"S3OutputPath": model_output_s3_prefix},
            "Environment": {
                "MODEL_CLS_MODULE": model_cls.__module__,
                "MODEL_CLS_NAME": model_cls.__name__,
                "CONDUCTOR_ENV": env_name,
                "CONDUCTOR_GIT_BRANCH": self.project_resources.branch,
                "CONDUCTOR_COMMIT_HASH": self.project_resources.commit_hash,
                "AWS_DEFAULT_REGION": env_config.default_region,
            },
            "StoppingCondition": {"MaxRuntimeInSeconds": 86400},  # One day.
            "EnableNetworkIsolation": False,
        }
        output_s3_url = (
            "{{ti.xcom_pull(task_ids='"
            + task_id
            + "')['Training']['ModelArtifacts']['S3ModelArtifacts']}}"
        )
        if env_config.vpc is not None:
            base_config["VpcConfig"] = {
                "SecurityGroupIds": env_config.vpc.security_groups,
                "Subnets": env_config.vpc.subnets,
            }
        if config is not None:
            deep_update(base_config, config)
        return TaskWrapper[SageMakerTrainingOperator, SageMakerTrainingOutputs](
            SageMakerTrainingOperator(
                config=base_config,
                task_id=task_id,
                aws_conn_id=None,
                dag=self.dag,
            ),
            SageMakerTrainingOutputs(s3_url=output_s3_url),
        )
