Last active
February 19, 2023 07:45
-
-
Save DavidKatz-il/e2caf17285f8ef2d4dd6e70beb8186b0 to your computer and use it in GitHub Desktop.
API client for airflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from typing import Any, List, Optional, Dict, Tuple, Union | |
from urllib.parse import urlencode | |
import urllib3 | |
import json | |
LIMIT = 10_000 # the default on airflow api is 100 | |
@dataclass | |
class Configuration: | |
""" | |
Configuration to use with an airflow api client | |
""" | |
host: str | |
username: str | |
password: str | |
def auth_settings(self) -> dict: | |
""" | |
Get auth settings dict for airflow api client | |
:param dag_id: the dag_id | |
:return: dict of auth settings | |
""" | |
auth = {} | |
if self.username is not None and self.password is not None: | |
auth["Basic"] = { | |
"type": "basic", | |
"in": "header", | |
"key": "Authorization", | |
"value": urllib3.util.make_headers( | |
basic_auth=self.username + ":" + self.password | |
).get("authorization"), | |
} | |
return auth | |
class AirflowClientAPI: | |
""" | |
API client for airflow | |
""" | |
def __init__(self, configuration: Configuration): | |
""" | |
Initialize an AirflowClientAPI instance | |
:param configuration: airflow configurations | |
""" | |
self.configuration = configuration | |
self.pool_manager = urllib3.PoolManager() | |
self.default_headers = { | |
"Content-Type": "application/json", | |
"User-Agent": "OpenAPI-Generator/2.4.0/python", | |
} | |
def _request( | |
self, | |
method, | |
url, | |
headers=None, | |
query_params=None, | |
fields=None, | |
body=None, | |
auth="Basic", | |
) -> Any: | |
headers = headers or {} | |
fields = fields or {} | |
if query_params: | |
query_params = self.__parameters_to_tuples(query_params) | |
url += "?" + urlencode(query_params) | |
body = json.dumps(body) if body else body | |
auth_setting = self.configuration.auth_settings().get(auth) | |
headers[auth_setting["key"]] = auth_setting["value"] | |
headers.update(self.default_headers) | |
response = self.pool_manager.request( | |
method, | |
f"{self.configuration.host}/{url}", | |
fields=fields, | |
headers=headers, | |
body=body, | |
) | |
if response.status != 200: | |
raise Exception(f"request failed with status: {response.status}") | |
response_data = json.loads(response.data) | |
return response_data | |
@staticmethod | |
def __parameters_to_tuples( | |
params: Dict[str, Union[int, str, list, tuple]] | |
) -> List[Tuple]: | |
new_params = [] | |
for k, v in params.items(): | |
if isinstance(v, (list, tuple)): | |
new_params.extend((k, value) for value in v) | |
elif isinstance(v, (int, str)): | |
new_params.append((k, v)) | |
else: | |
raise Exception( | |
f"key: {k} has a type: {type(v)} that is not supported." | |
) | |
return new_params | |
def _get_dags( | |
self, | |
dag_id_pattern: Optional[str] = None, | |
tags: Optional[List[str]] = None, | |
only_active: Optional[bool] = True, | |
) -> List[Dict]: | |
""" | |
Get all dags info | |
:param dag_id_pattern: the dag_id pattern | |
:param tags: tags to filter on | |
:param only_active: filter on only_active or not | |
:return: all dags info | |
""" | |
query_params = {"only_active": only_active} | |
if dag_id_pattern: | |
query_params["dag_id_pattern"] = dag_id_pattern | |
if tags: | |
query_params["tags"] = tags | |
return self._request("GET", "dags", query_params=query_params)["dags"] | |
def get_all_dag_ids( | |
self, | |
dag_id_pattern: Optional[str] = None, | |
tags: Optional[List[str]] = None, | |
only_active: Optional[bool] = True, | |
) -> List[str]: | |
""" | |
Get all dag_id's | |
:param dag_id_pattern: the dag_id pattern | |
:param tags: tags to filter on | |
:return: all matched dag_id's | |
""" | |
dags = self._get_dags(dag_id_pattern, tags, only_active) | |
dag_ids = [dag["dag_id"] for dag in dags] | |
return dag_ids | |
def _get_dag_runs(self, dag_id: str) -> List[Dict]: | |
""" | |
Get all dag runs of a given dag_id | |
:param dag_id: the dag_id | |
:return: all dag runs | |
""" | |
return self._request("GET", f"dags/{dag_id}/dagRuns")["dag_runs"] | |
def get_last_dag_run_id(self, dag_id: str) -> str: | |
""" | |
Get the last dag_run_id of a given dag_id | |
:param dag_id: the dag_id | |
:return: the dag_run_id | |
""" | |
dag_runs = self._get_dag_runs(dag_id) | |
last_dag_run_id = max(dag_runs, key=lambda dag: dag["start_date"])["dag_run_id"] | |
return last_dag_run_id | |
def get_task_instances( | |
self, dag_id: str, dag_run_id: str, state: Optional[List[str]] = None | |
) -> List[Dict]: | |
""" | |
Get a list of task instances | |
:param dag_id: the dag_id | |
:param dag_run_id: the dag_run_id | |
:param state: list of state to filter | |
:return: list of task instances | |
""" | |
query_params = {"limit": LIMIT} | |
if state: | |
query_params["state"] = state | |
response_data = self._request( | |
"GET", | |
f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances", | |
query_params=query_params, | |
) | |
return response_data["task_instances"] | |
def unpause_dags( | |
self, dag_id_pattern: str, tags: Optional[List[str]] = None | |
) -> dict: | |
""" | |
Unpause a dag | |
:param dag_id_pattern: dag_id pattern | |
:param tags: list of tags to filter | |
:return: the request response data | |
""" | |
body = {"is_paused": False} | |
query_params = {"update_mask": "is_paused", "dag_id_pattern": dag_id_pattern} | |
if tags: | |
query_params["tags"] = tags | |
return self._request("PATCH", "dags", query_params=query_params, body=body) | |
def trigger_dag( | |
self, dag_id: str, conf: Optional[Dict] = None | |
) -> Dict[str, Union[int, List[Dict]]]: | |
""" | |
Trigger a dag | |
:param dag_id: the dag_id | |
:param conf: conf to run the dag with | |
:return: the request response data | |
""" | |
body = None | |
if conf: | |
body = {"conf": conf} | |
return self._request("POST", f"dags/{dag_id}/dagRuns", body=body) | |
def clear_tasks( | |
self, | |
dag_id: str, | |
dag_run_id: str, | |
task_ids: List[str], | |
include_downstream: Optional[bool] = True, | |
) -> Dict[str, List[Dict]]: | |
""" | |
Clear the state of all the tasks for a specific dag run | |
:param dag_id: the dag_id | |
:param dag_run_id: the dag_run_id | |
:param task_ids: list of task_id to clear | |
:param include_downstream: to include downstream default is True | |
:return: the request response data | |
""" | |
body = { | |
"dry_run": False, | |
"dag_run_id": dag_run_id, | |
"task_ids": task_ids, | |
"only_failed": True, | |
"include_downstream": include_downstream, | |
} | |
return self._request("POST", f"dags/{dag_id}/clearTaskInstances", body=body) | |
def extract_task_id(self, ld: List[Dict]) -> List[str]: | |
""" | |
Extract the task ids from a list of dicts with the key "task_id" | |
:param ld: list of dicts | |
:return: list of task_id | |
""" | |
return self.__get_values_from_list_of_dicts(ld, "task_id") | |
@staticmethod | |
def __get_values_from_list_of_dicts(ld: List[Dict], key: str) -> List[str]: | |
""" | |
Get all key values from a list of dicts | |
:param ld: list of dicts | |
:param key: key name | |
:return: list of the values | |
""" | |
return list(d[key] for d in ld) | |
def main(): | |
configuration = Configuration( | |
host="http://<<URL>>:8080/api/v1", | |
username="<<USERNAME>>", | |
password="<<PASSWORD>>", | |
) | |
airflow_client_api = AirflowClientAPI(configuration) | |
dag_id = "<<DAG-ID>>" | |
dag_run_id = "<<DAG-RUN-ID>>" | |
# unpause a dag | |
airflow_client_api.unpause_dags(dag_id) | |
# trigger a dag with conf | |
airflow_client_api.trigger_dag(dag_id, conf={"key": "value"}) | |
# clear failed tasks with downstream | |
failed_tasks = airflow_client_api.get_task_instances( | |
dag_id, dag_run_id, state=["failed"] | |
) | |
airflow_client_api.clear_tasks( | |
dag_id, dag_run_id, airflow_client_api.extract_task_id(failed_tasks) | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment