[Airflow] 코드 리팩토링 2. TaskFactory

DAGFactory를 만들려다 보니 자연스럽게 TaskFactory에 대한 필요성을 느꼈습니다. Medium에서 찾아본 DAGFactory에 대한 포스트들도 대부분 TaskFactory도 같이 만들던데, 직접 구현해보니 의식의 흐름과 같이 TaskFactory가 필요함을 알 수 있었습니다.

코드 : https://github.com/eprj453/airflow-factory


TaskFactory

Task는 DAG의 구성요소입니다. Task의 종류가 전부 다르기 때문에 다양한 형태의 Task를 집어넣어도 결과물인 DAG를 만들어내는 DAGFactory를 만들었습니다.

그렇다면 Task는 어떤 식으로 만들어야할까요? Task는 Operator라는 Airflow가 제공해주는 클래스의 인스턴스입니다. Task와 Operator가 같은 것이냐?하고 누군가 물어본다면 저는 Task는 논리적인 단위, Operator는 물리적이고 직접적으로 구현한 객체라고 답할 것 같습니다. 관점에 따라 같다고 볼 수도 있겠으나 다른 개념으로 보는 것이 맞다고 생각합니다.

어쨌든 Task는 Operator 클래스로 만듭니다. 여기서 문제는? Operator 클래스 또한 그 종류가 매우 많다는 것입니다.

https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/operators/index.html


모든 Operator를 DAG마다 선언해준다면 아래와 같이 작성해야 합니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
task_1 = PythonOperator(
	arg_1=....,
  arg_2=.....
  ...
  ...
)

task_2 = BashOperator(
	arg_1=....,
  arg_2=.....
  ...
  ...
)

task_3 = MyCustomOperator(
	arg_1=....,
  arg_2=.....
  ...
  ...
)
...

dag = DAGFactory(tasks = [task_1, task_2, task_3])

이렇게 인스턴스를 쫙 나열하면 일어날 수 있는 문제는 무엇이 있을까요?

  • airflow 버전을 업데이트 하려고 한다. 그런데 모든 Operator마다 필수 인자가 하나씩 생긴다면?
  • 모든 DAG 스크립트마다 필요한 Operator 클래스를 전부 import 해야한다. 만약 모듈 경로를 수정해야 한다면?

물론 하나의 .py 파일에 여러 개의 DAG 인스턴스를 만드는 것도 가능합니다. 그러나 DAG는 사용자가 생각하는 Job의 단위가 되고 이를 하나의 .py 파일에 여러 개 선언하는 것은 새로운 Job을 생성하거나 관리하는 측면에서 좋은 방법은 아닌 것 같습니다.

하나의 .py 파일에 DAG 인스턴스 여러 개를 생성한다면, 위와 같은 인스턴스 직접 선언 방식은 코드의 규모가 커질수록 엄청난 재앙을 불러올 수 있습니다.

TaskFactory가 필요하다는 공감대는 충분히 느꼈으니, 이제 만들어보겠습니다. TaskFactory가 Operator 클래스의 인스턴스를 만들려면 필수로 가지고 알아야 하는 정보는 다음과 같습니다.

  • Operator 클래스 종류
  • Argument 정보

이에 TaskFactory가 받을 정보를 담고 있는 AirflowTask 클래스를 만들었습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class AirflowTask:
    """
    class AirflowTask

    Task의 구성 요소인 Operator, base_args를 인자로 받아 클래스로 가져갑니다.
    Task object로 미리 만들어놓지 않는 이유는 Task object에 DAG를 나중에 지정할 수 없기 때문입니다.

    이렇게 만들어진 AirflowTask Object는 dag_factory에 전달되며,
    DAGFactory의 클래스 메서드인 _add_task_to_dag 메서드에서 비로소 Task Object로 DAG에 할당됩니다.

    """
    def __init__(self, operator, base_args):
        self.operator = operator
        self.base_args = base_args

물론 클래스로 명시해 하나의 객체로 가져가는게 개발자의 오류도 줄일 수 있고 직관적이기도 하지만, 이렇게밖에 할 수 없는 이유가 있습니다. Operator 인스턴스는 필수 인자로 DAG 인스턴스를 가지고 있어야 하기 때문입니다.

말로는 조금 어려우니 그림을 보겠습니다. 제가 최초에 원했던 workflow는 아래와 같습니다.

image

여기서 문제가 생깁니다. TaskFactory는 Task를 만들어서 DAGFactory에 넘겨줘야 하는데, Task를 만들기 위해서는 아직 참조하지도 못하는 DAG 인스턴스를 갖고 있어야 합니다. 그렇기 때문에 Operator의 종류와 인자만을 따로 AirflowTask로 묶어 DAGFactory로 전달한 뒤, Task와 DAG 생성 모두 DAGFactory에서 하게 됩니다.

결국 코드는 이런 형태로 동작하게 됩니다.

image


Operator 선택은 Enum class를 사용하려고 했으나, Enum의 반환 객체로 클래스를 선택할 수 없어 부득이하게 OperatorSelector 클래스를 하나 더 만들었습니다. class method를 사용해 Enum과 유사한 형태로 사용할 수 있도록 했습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class OperatorSelector:
    @classmethod
    def PYTHON(cls):
        return PythonOperator

    @classmethod
    def BASH(cls):
        return BashOperator

    @classmethod
    def S3_TO_REDSHIFT(cls):
        return S3ToRedshiftOperator
	
 
operator_1 = OperatorSelector.PYTHON()
operator_2 = OperatorSelector.S3_TO_REDSHIFT()
...
...

이제, TaskFactory에 아래와 같이 인스턴스를 전달해주면 TaskFactory는 DAGFactory를 통해 DAG 인스턴스를 만들 것입니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

task_1 = AirflowTask(
    operator=OperatorSelector.PYTHON(),
    base_args=PythonOperatorArgs(
        task_id="print_hello",
        python_callable=custom_function.print_hello,
        op_kwargs={
            "name": "jack"
        },
    )
)

task_2 = AirflowTask(
    operator=OperatorSelector.PYTHON(),
    base_args=PythonOperatorArgs(
        task_id="print_hello",
        python_callable=custom_function.print_hello,
        op_kwargs={
            "name": "sam"
        },
    )
)

...
...
...


TaskFactory 코드 전문

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

class TaskFactory:
    """
    class TaskFactory
    DAG Instance와 AirflowTask Object 혹은 list(AirflowTask Object)로
    1 unit의 Task / Tasks를 만드는 클래스입니다.

    """
    @classmethod
    def generate_task(
            cls,
            dag: DAG,
            airflow_tasks: Union[AirflowTask, List[AirflowTask]]) \
            -> Union[BaseOperator, List[BaseOperator]]:

        if type(airflow_tasks) != list:  # 단일 Task
            task = airflow_tasks.operator(
                dag=dag,
                **airflow_tasks.base_args.to_args()
            )
            print(task.task_id)
            return task

        else:
            tasks = []
            for airflow_task in airflow_tasks:
                task = airflow_task.operator(
                    dag=dag,
                    **airflow_task.base_args.to_args()
                )
                tasks.append(task)
            return tasks

먼 길을 돌아 DAGFactory로 다시 돌아왔습니다. 이제 DAGFactory를 완성할 차례입니다.


DAGFactory

DAGFactory의 동작과정은 다음과 같습니다.

  • 필요한 것 : DAG 인스턴스 정보, Task 정보 (AirflowTask)

  • 동작 과정 :

    1. DAG 인스턴스 생성
    2. Task 생성
    3. DAG 인스턴스 반환


create_dag()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
@classmethod
    def _create_dag(cls,
                    dag_id: str,
                    default_args: Dict[str, Any],
                    schedule_interval: str,
                    **kwargs,
                    ) -> DAG:
        essential_default_args = {'start_date', 'retry'}

        if essential_default_args < set(default_args.keys()):
            logging.warning(f"""
            default_args 필수 인자가 없습니다. default 값으로 생성됩니다.
            {essential_default_args.difference(set(default_args.keys()))}
            """)

        DEFAULT_ARGS = {
            'owner': 'datadev-parksw2',
            'depends_on_past': False,  # 이전 task가 성공해야만 다음 task가 실행된다.
            'start_date': datetime(2021, 1, 1, 9, 0, 0),
            'email_on_failure': False,
            'email_on_retry': False,
            'retry': 1,
            'retry_delay': timedelta(minutes=5),
        }

        DEFAULT_ARGS.update(default_args)

        DAG_ARGS = {
            'default_args': DEFAULT_ARGS,
            'schedule_interval': schedule_interval,
            "max_active_runs": 10,
            "catchup": True
        }

        DAG_ARGS.update(kwargs)

        dag = DAG(dag_id=dag_id, **DAG_ARGS)
        return dag

필수 인자는 DEFAULT_ARGS로 명시해놓고 사용자로부터 받은 정보만 업데이트합니다. 이를 기반으로 dag 인스턴스를 생성한 뒤 반환합니다.

add_task_to_dags()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
@classmethod
    def _add_tasks_to_dag(cls, dag: DAG, tasks: List[AirflowTask]) -> DAG:
        # task generate in here
        dependency_tasks = []
        n = len(dependency_tasks)
        for task in tasks:
            t = TaskFactory.generate_task(dag=dag, airflow_tasks=task)

            if dependency_tasks:
                upstream_task = dependency_tasks[-1]
                upstream_task.set_downstream(t)
                # upstream_task >> t
            dependency_tasks.append(t)

        return dag

TaskFactory에서 Task를 만든 뒤 DAG 인스턴스를 반환합니다.

generate_dags()

위 2가지 메서드가 private하게 실행되도록 해주는 public 메서드입니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
@classmethod
    def generate_dag(cls,
                     dag_id: str,
                     default_args: Dict[str, Any],
                     schedule_interval: str,
                     tasks: Union[List[AirflowTask]],
                     ) -> DAG:

        """


        """

        dag = cls._create_dag(
            dag_id=dag_id,
            default_args=default_args,
            schedule_interval=schedule_interval,
        )
        dag = cls._add_tasks_to_dag(dag=dag, tasks=tasks)
        return dag


DAGFactory 코드 전문

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from airflow import DAG

from datetime import datetime, timedelta
import logging
from typing import Any, Dict, Union, List

from components.airflow_task import AirflowTask
from commons.factory.task_factory import TaskFactory


class DAGFactory:

    @classmethod
    def _create_dag(cls,
                    dag_id: str,
                    default_args: Dict[str, Any],
                    schedule_interval: str,
                    **kwargs,
                    ) -> DAG:
        essential_default_args = {'start_date', 'retry'}

        if essential_default_args < set(default_args.keys()):
            logging.warning(f"""
            default_args 필수 인자가 없습니다. default 값으로 생성됩니다.
            {essential_default_args.difference(set(default_args.keys()))}
            """)

        DEFAULT_ARGS = {
            'owner': 'datadev-parksw2',
            'depends_on_past': False,  # 이전 task가 성공해야만 다음 task가 실행된다.
            'start_date': datetime(2021, 1, 1, 9, 0, 0),
            'email_on_failure': False,
            'email_on_retry': False,
            'retry': 1,
            'retry_delay': timedelta(minutes=5),
        }

        DEFAULT_ARGS.update(default_args)

        DAG_ARGS = {
            'default_args': DEFAULT_ARGS,
            'schedule_interval': schedule_interval,
            "max_active_runs": 10,
            "catchup": True
        }

        DAG_ARGS.update(kwargs)

        dag = DAG(dag_id=dag_id, **DAG_ARGS)
        return dag

    @classmethod
    def _add_tasks_to_dag(cls, dag: DAG, tasks: List[AirflowTask]) -> DAG:
        # task generate in here
        dependency_tasks = []
        n = len(dependency_tasks)
        for task in tasks:
            t = TaskFactory.generate_task(dag=dag, airflow_tasks=task)

            if dependency_tasks:
                upstream_task = dependency_tasks[-1]
                upstream_task.set_downstream(t)
                # upstream_task >> t
            dependency_tasks.append(t)

        return dag

    @classmethod
    def generate_dag(cls,
                     dag_id: str,
                     default_args: Dict[str, Any],
                     schedule_interval: str,
                     tasks: Union[List[AirflowTask]],
                     ) -> DAG:

        dag = cls._create_dag(
            dag_id=dag_id,
            default_args=default_args,
            schedule_interval=schedule_interval,
        )
        dag = cls._add_tasks_to_dag(dag=dag, tasks=tasks)
        return dag

이 외에도 각 인스턴스가 사용하는 기본 인자도 클래스로 구현해놓기는 했으나 생략하겠습니다.


결론

회사에 입사한 이후로 업무 담당은 혼자서 해왔기 때문에 혼자 일을 하는 것은 어느정도 적응이 되었지만, 스스로 짰던 코드를 다시 들춰보고 리팩토링하는 건 이번이 처음이었습니다. 여태까지 했던 어떤 작업보다 방향성을 잡기 어려웠고, 작동시키는데에 급급해 과거의 내가 짜놓았던 코드를 마주하자니 부끄럽기도 했습니다.

다른 팀 개발자들을 통해 비슷하게라도 코드리뷰를 많이 진행하고 싶었는데, 다른 팀 귀찮게 하는 것도 하루이틀이고.. 재택근무가 장기화되면서 코드리뷰를 거의 못했습니다. 아직은 스스로에게 무엇이 부족한지조차 파악하기 어려운 1년차 병아리 개발자라서 그 점이 제일 아쉽습니다.

참고

https://medium.com/towards-data-science/how-to-build-a-dag-factory-on-airflow-9a19ab84084c

https://towardsdatascience.com/data-engineers-shouldnt-write-airflow-dags-b885d57737ce