Last active
October 7, 2021 20:44
-
-
Save a7ul/d453239dca846fe36a79b02e5ab0b177 to your computer and use it in GitHub Desktop.
Airflow operator to create kubernetes jobs in a separate GKE cluster via Airflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# # Installation | |
# 1. Copy this gkejoboperator.py in your dag folder. | |
# | |
# 2. For this custom operator to be used in the composer env, we need to install these python modules | |
# | |
# // requirements.txt | |
# kubernetes==11.0.0 | |
# pyyaml==5.3.1 | |
# | |
# 3. Add a connection in airflow for the operator to use for connecting to gke cluster. | |
# | |
# In Airflow: | |
# Go to Admin -> Connections -> Create | |
# | |
# Enter the following details: | |
# | |
# Conn Id: my-gke-connection | |
# Conn Type: google cloud platform | |
# Project Id: mygcpproject | |
# Keyfile Json: Contents of the service-account.json with access to target k8s cluster | |
# | |
# | |
# -- Now you are ready to use the operator in your dags | |
# | |
# # ----------------------------------- | |
# # Usage in DAG files | |
# # ----------------------------------- | |
# | |
# | |
# from gkejoboperator import GKEJobOperator | |
# | |
# dag = DAG( | |
# f'my_dag', | |
# default_args=default_args, | |
# catchup=False, | |
# max_active_runs=1, | |
# schedule_interval="0 15 * * 1-5", | |
# ) | |
# | |
# my_job = GKEJobOperator( | |
# dag=dag, | |
# task_id='my_task_id', | |
# gcp_conn_id='my-gke-connection', | |
# cluster_name='mygkeclustername', | |
# project_id='mygcpproject', | |
# location='europe-west1', | |
# job_yaml=''' | |
# apiVersion: batch/v1 | |
# kind: Job | |
# metadata: | |
# labels: | |
# app: myapp | |
# variant: somejob | |
# spec: | |
# ttlSecondsAfterFinished: 120 | |
# backoffLimit: 0 | |
# activeDeadlineSeconds: 3600 | |
# template: | |
# metadata: | |
# labels: | |
# app: myapp | |
# variant: somejob | |
# spec: | |
# restartPolicy: "Never" | |
# containers: | |
# - name: my-pod | |
# image: eu.gcr.io/mygcpproject/myapp:latest | |
# imagePullPolicy: Always | |
# command: | |
# - "node" | |
# - "/app/dist/scripts/somescript.js" | |
# envFrom: | |
# - secretRef: | |
# name: my-secrets | |
# - configMapRef: | |
# name: my-config | |
# ''', | |
# ) | |
import logging | |
import os | |
import json | |
import re | |
import subprocess | |
import tempfile | |
import time | |
import unicodedata | |
from typing import Optional | |
import yaml | |
from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook | |
from airflow.models import BaseOperator | |
from airflow.utils.decorators import apply_defaults | |
from kubernetes import client, config | |
from kubernetes.client.rest import ApiException as K8sApiException | |
from urllib3.exceptions import ProtocolError | |
def slugify(text): | |
""" | |
Converts to lowercase, removes non-word characters (alphanumerics and | |
underscores) and converts spaces to hyphens. | |
""" | |
text = unicodedata.normalize("NFKD", text).lower() | |
return re.sub(r"[\W_]+", "-", text) | |
def serialize_labels(label_obj): | |
labels = [] | |
for key in label_obj: | |
labels.append(f"{key}={label_obj[key]}") | |
return ",".join(labels) | |
class GKEJobOperator(BaseOperator): | |
""" | |
Executes a task in a Kubernetes Job in the specified Google Kubernetes Engine cluster | |
This Operator assumes that the system has gcloud installed and has configured a | |
connection id with a service account that has access to the GKE cluster you want to connect to. | |
The **minimum** required to define a cluster to create are the variables | |
``task_id``,``gcp_conn_id``,``cluster_name``,``job_yaml`` | |
The **optional** arguments include | |
``project_id``, ``location`` | |
Note that project_id is required if it is not specified in gcp_conn_id | |
:param gcp_conn_id: The google cloud connection id to use. This allows for users to specify a service account. | |
:type gcp_conn_id: str | |
:param cluster_name: The name of the Google Kubernetes Engine cluster the job should be spawned in | |
:type cluster_name: str | |
:param job_yaml: The complete job yaml file to deploy on the cluster. | |
:type job_yaml: str | |
:param project_id: GCP project id | |
:type project_id: str | |
:param location: The location of the cluster like europe-west1-b | |
:type location: str | |
""" | |
template_fields = ["job_yaml"] | |
template_fields_renderers = {"job_yaml": "yaml"} | |
@apply_defaults | |
def __init__(self, | |
gcp_conn_id: str, | |
cluster_name: str, | |
job_yaml: str, | |
project_id: Optional[str] = None, | |
location: Optional[str] = None, | |
*args, | |
**kwargs): | |
super().__init__(*args, **kwargs) | |
# internal | |
self.k8s_job_metadata = None # Assigned after job is created in k8s | |
self.k8s_job_labels = {"dag_id": self.dag_id, "task_id": self.task_id} | |
self.k8s_client = None # Assigned after `execute` is called | |
self.k8s_namespace: Optional[str] = None # Assigned after `execute` is called | |
# external | |
self.gcp_conn_id = gcp_conn_id | |
self.cluster_name = cluster_name | |
self.job_yaml = job_yaml | |
self.project_id = project_id | |
self.location = "europe-west1-b" if location is None else location | |
def execute(self, context): | |
job_definition = self._parse_job_yaml(self.job_yaml) | |
self.k8s_namespace = job_definition["metadata"]["namespace"] | |
self.k8s_client = self._get_k8s_client() | |
k8s_batch_client = self.k8s_client.BatchV1Api() | |
job_response = k8s_batch_client.create_namespaced_job(self.k8s_namespace, job_definition) | |
logging.info("Job created in K8s:") | |
logging.info(f"{job_response.metadata}") | |
self.k8s_job_metadata = job_response.metadata | |
job_status = self._wait_for_job_to_end() | |
self._get_job_logs() | |
self._cleanup() | |
if job_status is None or job_status.succeeded is None: | |
# Mark task as failed | |
raise Exception("The job failed to complete") | |
def on_kill(self): | |
""" | |
Called when the task is killed, either by making it as failed or success manually. | |
""" | |
logging.info("The DAG job was killed!") | |
self._cleanup() | |
return super().on_kill() | |
def _cleanup(self): | |
logging.info("Cleaning up...") | |
try: | |
k8s_batch_client = self.k8s_client.BatchV1Api() | |
job_name = self.k8s_job_metadata.name | |
delete_options_current_job = self.k8s_client.V1DeleteOptions( | |
propagation_policy="Foreground", | |
grace_period_seconds=0 | |
) | |
# Delete current job | |
k8s_batch_client.delete_namespaced_job(name=job_name, body=delete_options_current_job, | |
namespace=self.k8s_namespace) | |
# Also cleanup any stale finished successful jobs that might exist of this dag+task. | |
# If any failed ones cleanup them manually! | |
label_selector = serialize_labels(self.k8s_job_labels) | |
field_selector = "status.successful=1" | |
k8s_batch_client.delete_collection_namespaced_job(namespace=self.k8s_namespace, | |
field_selector=field_selector, | |
grace_period_seconds=0, | |
propagation_policy="Foreground", | |
label_selector=label_selector) | |
except K8sApiException as e: | |
logging.warning(f"Error while cleaning up {e}") | |
def _parse_job_yaml(self, raw_yaml: str): | |
job_yaml = yaml.safe_load(raw_yaml) | |
job_yaml.setdefault("metadata", {}) | |
# We do not want to set the job name, | |
# since it may cause conflicts when this operator is used across multiple dags / tasks. | |
# So, we use generateName instead. | |
job_name_prefix = job_yaml["metadata"].pop("name", | |
slugify(f"job-{self.dag_id}-{self.task_id}")) | |
job_yaml["metadata"].setdefault("generateName", f"{job_name_prefix}-") | |
job_yaml["metadata"].setdefault("labels", {}) | |
job_yaml["metadata"]["labels"].update(self.k8s_job_labels) | |
job_yaml["metadata"].setdefault("namespace", "default") | |
job_yaml["metadata"].setdefault("finalizers", []) | |
if "foregroundDeletion" not in set(job_yaml["metadata"]["finalizers"]): | |
job_yaml["metadata"]["finalizers"].append("foregroundDeletion") | |
job_yaml["metadata"].setdefault("labels", {}) | |
job_yaml.setdefault("spec", {}) | |
# This is still an alpha feature hence unavailable in gke clusters by default at the moment | |
job_yaml["spec"].setdefault("ttlSecondsAfterFinished", 120) | |
job_yaml["spec"].setdefault("backoffLimit", 0) | |
# Default deadline for a job is 5 hours | |
job_yaml["spec"].setdefault("activeDeadlineSeconds", 60 * 60 * 5) | |
job_yaml["spec"].setdefault("template", {}) | |
job_yaml["spec"]["template"].setdefault("metadata", {}) | |
job_yaml["spec"]["template"]["metadata"].setdefault("labels", {}) | |
job_yaml["spec"]["template"]["metadata"]["labels"].update(self.k8s_job_labels) | |
job_yaml["spec"]["template"].setdefault("spec", {}) | |
job_yaml["spec"]["template"]["spec"].setdefault("restartPolicy", "Never") | |
logging.info("Job YAML:") | |
logging.info(json.dumps(job_yaml, indent=2)) | |
return job_yaml | |
def _get_job_logs(self): | |
job_name = self.k8s_job_metadata.name | |
k8s_core_client = self.k8s_client.CoreV1Api() | |
job_label_selector = f"job-name={job_name}" | |
try: | |
pod_response = k8s_core_client.list_namespaced_pod(namespace=self.k8s_namespace, | |
label_selector=job_label_selector) | |
for item in pod_response.items: | |
pod_name = item.metadata.name | |
try: | |
# For whatever reason the response returns only the first few characters unless | |
# the call is for `_return_http_data_only=True, _preload_content=False` | |
pod_log_response = k8s_core_client.read_namespaced_pod_log( | |
name=pod_name, | |
namespace=self.k8s_namespace, | |
_return_http_data_only=True, | |
_preload_content=False, | |
timestamps=True | |
) | |
pod_log = pod_log_response.data.decode("utf-8") | |
logging.info(f"Logs for {pod_name}:") | |
logging.info(pod_log) | |
except K8sApiException: | |
logging.warning(f"Exception when reading log for {pod_name}") | |
except K8sApiException as e: | |
logging.warning(f"Found exception while listing pod for the job {e}") | |
def _wait_for_job_to_end(self): | |
k8s_batch_client = self.k8s_client.BatchV1Api() | |
job_name = self.k8s_job_metadata.name | |
job_status = None | |
logging.info("Waiting for the job to finish...") | |
try: | |
while True: | |
try: | |
job = k8s_batch_client.read_namespaced_job(namespace=self.k8s_namespace, | |
name=job_name) | |
job_status = job.status | |
if job.status.active is None and job.status.start_time is not None: | |
logging.info(f"Job status for K8s job {job_name}: {job.status}") | |
break | |
except ProtocolError: | |
logging.warning("Ignoring ProtocolError and Continuing...") | |
time.sleep(5) | |
except K8sApiException as e: | |
logging.warning(f"Error while reading status {e}") | |
return job_status | |
def _get_k8s_client(self): | |
gcp = GoogleCloudBaseHook(gcp_conn_id=self.gcp_conn_id) | |
gcp_service_account_path = gcp._get_field("key_path", False) | |
gcp_service_account_json = "" | |
if gcp_service_account_path: | |
with open(gcp_service_account_path) as f: | |
gcp_service_account_json = f.read() | |
else: | |
gcp_service_account_json = gcp._get_field("keyfile_dict", False) | |
self.project_id = gcp.project_id if self.project_id is None else self.project_id | |
with tempfile.NamedTemporaryFile("w+", suffix=".json", | |
encoding="utf8") as gcloud_service_account_file: | |
gcloud_service_account_file.write(gcp_service_account_json) | |
gcloud_service_account_file.seek(0) | |
with tempfile.NamedTemporaryFile("w+") as kube_config_file: | |
kube_config_file.seek(0) | |
custom_env = os.environ.copy() | |
custom_env["KUBECONFIG"] = kube_config_file.name | |
custom_env["GOOGLE_APPLICATION_CREDENTIALS"] = gcloud_service_account_file.name | |
subprocess.check_call( | |
["gcloud", "auth", "activate-service-account", | |
"--key-file", gcloud_service_account_file.name] | |
) | |
subprocess.check_call( | |
["gcloud", "container", "clusters", "get-credentials", | |
self.cluster_name, | |
"--region", self.location, | |
"--project", self.project_id | |
], | |
env=custom_env | |
) | |
# Tell `GKEJobOperator` kubectl api instance where the config file is located | |
config.load_kube_config(kube_config_file.name) | |
return client | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment