from dataclasses import dataclass
from typing import Any, Dict, Optional

from airflow.providers.amazon.aws.operators.sagemaker_transform import SageMakerTransformOperator

from conductor.internal.dag_utils import deep_update
from conductor.operators.operator_iface import IConfiguredOperator, TaskWrapper
from conductor.utils.naming import timestamp_name


@dataclass
class SageMakerTransformOutputs:
    s3_url: str


class ConfiguredSageMakerTransformOperator(IConfiguredOperator):
    def generate_tasks(
        self,
        *,
        task_id: str,
        model_name: str,
        config: Optional[Dict[str, Any]] = None,
    ) -> TaskWrapper[SageMakerTransformOperator, SageMakerTransformOutputs]:
        output_s3_url = self.dag_resources.s3_url_for_path([task_id])
        base_config: Dict[str, Any] = {
            "TransformJobName": timestamp_name(
                f"{task_id}-{self.dag_resources.name}", 63
            ),
            "ModelName": model_name,
            "TransformResources": {"InstanceType": "ml.m4.xlarge", "InstanceCount": 1},
            "TransformOutput": {
                "S3OutputPath": output_s3_url,
                "AssembleWith": "Line",
                "Accept": "text/csv",
            },
            "MaxPayloadInMB": 4,
        }
        if config is not None:
            deep_update(base_config, config)
        return TaskWrapper[SageMakerTransformOperator, SageMakerTransformOutputs](
            SageMakerTransformOperator(
                config=base_config, task_id=task_id, aws_conn_id=None, dag=self.dag
            ),
            SageMakerTransformOutputs(output_s3_url),
        )
