from abc import ABC
from typing import Any, Dict, Generic, Sequence, TypeVar, Union

from airflow.models.dag import DAG
from airflow.models.taskmixin import TaskMixin
from conductor_cdk import DAGResources

from conductor.config import ProjectResources


class IConfiguredOperator(ABC):

    project_resources: ProjectResources
    dag_resources: DAGResources
    task_container_params: Dict[str, Any]
    dag: DAG

    def __init__(
        self,
        project_resources: ProjectResources,
        dag_resources: DAGResources,
        task_container_params: Dict[str, Any],
        dag: DAG,
    ) -> None:
        assert dag_resources.dag_id() == dag.dag_id
        self.project_resources = project_resources
        self.dag_resources = dag_resources
        self.task_container_params = task_container_params
        self.dag = dag


TTask = TypeVar("TTask", bound=TaskMixin)
TOutputs = TypeVar("TOutputs")


class TaskWrapper(ABC, TaskMixin, Generic[TTask, TOutputs]):
    """
    TaskWrapper is a pass-through wrapper for task objects that can be extended by IConfiguredOperators
    to pass return values and other metadata to subsequent operators. This allows us to type check return values
    that do not exist on the underlying task.
    """

    task: TTask
    outputs: TOutputs

    def __init__(self, task: TTask, outputs: TOutputs):
        self.task = task
        self.outputs = outputs

    @property
    def roots(self):
        """Should return list of root operator List[BaseOperator]"""
        return self.task.roots

    @property
    def leaves(self):
        """Should return list of leaf operator List[BaseOperator]"""
        return self.task.leaves

    def set_upstream(self, other: Union[TaskMixin, Sequence[TaskMixin]]) -> None:
        """Set a task or a task list to be directly upstream from the current task."""
        self.task.set_upstream(other)

    def set_downstream(self, other: Union[TaskMixin, Sequence[TaskMixin]]) -> None:
        """Set a task or a task list to be directly downstream from the current task."""
        self.task.set_downstream(other)
