from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, Optional

from twitch_airflow_components.operators.sagemaker_processing import SageMakerProcessingOperator

from conductor.docker import CONTAINER_ROOT
from conductor.internal.dag_utils import deep_update
from conductor.operators.operator_iface import IConfiguredOperator, TaskWrapper
from conductor.utils.naming import timestamp_name


@dataclass
class ProcessingEntrypoint:
    entrypoint: Callable
    kwargs: Dict[str, Any]


class ConfiguredSageMakerProcessingOperator(IConfiguredOperator):
    def generate_tasks(
        self,
        *,
        task_id: str,
        entrypoint: Callable,
        config: Optional[Dict[str, Any]] = None,
        **kwargs,
    ) -> TaskWrapper[SageMakerProcessingOperator, None]:
        env_name = self.project_resources.env_name
        env_config = self.project_resources.env
        self.task_container_params[task_id] = ProcessingEntrypoint(
            entrypoint=entrypoint, kwargs=kwargs
        )
        # Base configuration assuming Conductor preset values.
        base_config: Dict[str, Any] = {
            "ProcessingResources": {
                "ClusterConfig": {
                    "InstanceCount": 1,
                    "InstanceType": "ml.t3.medium",
                    "VolumeSizeInGB": 1,
                }
            },
            "ProcessingJobName": timestamp_name(
                f"{self.project_resources.name}-{task_id}-{self.dag_resources.name}", 63
            ),
            "AppSpecification": {
                "ImageUri": self.project_resources.ecr_url(),
                "ContainerEntrypoint": ["process"],
                "ContainerArguments": [
                    self.dag_resources.name,
                    task_id,
                ],
            },
            "Environment": {
                "CONDUCTOR_ENV": env_name,
                "CONDUCTOR_GIT_BRANCH": self.project_resources.branch,
                "CONDUCTOR_COMMIT_HASH": self.project_resources.commit_hash,
                "CDK_OUTDIR": str(Path(CONTAINER_ROOT, "cdk")),
                "AWS_DEFAULT_REGION": env_config.default_region,
            },
            "RoleArn": self.project_resources.sagemaker_execution_role(),
        }
        if env_config.vpc is not None:
            network_config = {
                "VpcConfig": {
                    "SecurityGroupIds": env_config.vpc.security_groups,
                    "Subnets": env_config.vpc.subnets,
                }
            }
            base_config["NetworkConfig"] = network_config
        # Overwrite the base config with custom values.
        if config is not None:
            deep_update(base_config, config)
        return TaskWrapper[SageMakerProcessingOperator, None](
            SageMakerProcessingOperator(
                config=base_config, task_id=task_id, aws_conn_id=None, dag=self.dag
            ),
            None,
        )
