Skip to content

Instantly share code, notes, and snippets.

@Taragolis
Created December 23, 2022 16:15
Show Gist options
  • Save Taragolis/d4ade059c1966a773a4ee7a1d29800d9 to your computer and use it in GitHub Desktop.
Save Taragolis/d4ade059c1966a773a4ee7a1d29800d9 to your computer and use it in GitHub Desktop.
PoC: Consistency in boto3-based Hooks
from __future__ import annotations
import json
import warnings
from typing import Generic, TypeVar, TYPE_CHECKING, Union
import boto3
from botocore.config import Config
from airflow.compat.functools import cached_property
from airflow.hooks.base import BaseHook
from airflow.utils.helpers import exactly_one
from airflow.exceptions import AirflowException
Boto3Client = TypeVar("Boto3Client", bound=boto3.client)
Boto3Resource = TypeVar("Boto3Resource", bound=boto3.resource)
class _BaseBoto3Hook(BaseHook):
conn_name_attr = "aws_conn_id"
default_conn_name = "aws_default"
conn_type = "aws"
hook_name = "Amazon Web Services"
def __init__(
self,
*,
aws_conn_id: str | None = default_conn_name,
region_name: str | None = None,
api_version: str | None = None,
endpoint_url: str | None = None,
verify: bool | str | None = None,
use_ssl: bool | None = None,
config: Config | None = None,
legacy_mode: bool = True,
):
super().__init__()
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.api_version = api_version
self.endpoint_url = endpoint_url
self.verify = verify
self.use_ssl = use_ssl
self.config = config
self._legacy_mode = legacy_mode
@cached_property
def session(self) -> boto3.session.Session:
"""Get the underlying boto3.session.Session (cached)."""
return boto3.session.Session(region_name=self.region_name)
@property
def client(self) -> Boto3Client:
"""Get the underlying boto3.client"""
raise NotImplementedError()
@property
def resource(self) -> Boto3Resource:
"""Get the underlying boto3.resource"""
raise NotImplementedError()
@cached_property
def conn(self):
"""
Get the "connection object" based on hook mode and previous implementation (cached).
If legacy_mode is False then this property return boto3.session.Session
Otherwise return object from `_legacy_mode_conn` property.
"""
if not self._legacy_mode:
return self.session
warnings.warn(
f"`conn` property accessed in legacy mode for {self}, "
f"please consider use one of appropriate properties directly: `session`, `client`, `resource`.",
FutureWarning,
stacklevel=3,
)
return self._legacy_mode_conn
@property
def _legacy_mode_conn(self):
"""Implementation """
raise NotImplementedError(f"Legacy mode for `conn` property not implemented in {self}.")
def get_conn(self):
return self.conn
def test_connection(self):
"""
Tests the AWS connection by call AWS STS (Security Token Service) GetCallerIdentity API.
.. seealso::
https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html
"""
try:
conn_info = self.session.client("sts").get_caller_identity()
metadata = conn_info.pop("ResponseMetadata", {})
if metadata.get("HTTPStatusCode") != 200:
try:
return False, json.dumps(metadata)
except TypeError:
return False, str(metadata)
conn_info["credentials_method"] = self.session.get_credentials().method
conn_info["region_name"] = self.session.region_name
return True, ", ".join(f"{k}={v!r}" for k, v in conn_info.items())
except Exception as e:
return False, str(f"{type(e).__name__!r} error occurred while testing connection: {e}")
def __repr__(self):
return f"{type(self).__module__}.{type(self).__name__}"
class Boto3ResourceHook(_BaseBoto3Hook, Generic[Boto3Client, Boto3Resource]):
def __init__(self, service_name: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.service_name = service_name
@cached_property
def resource(self) -> Boto3Resource:
return self.session.resource( # type: ignore
service_name=self.service_name,
api_version=self.endpoint_url,
verify=self.verify,
use_ssl=self.use_ssl,
config=self.config,
)
@cached_property
def client(self) -> Boto3Client:
return self.resource.meta.client
@cached_property
def conn(self) -> Union[Boto3Resource]:
return self.resource
@property
def _legacy_mode_conn(self):
return self.resource
def __repr__(self):
return f"{super().__repr__()}[service_name={self.service_name!r}]"
class Boto3ClientHook(_BaseBoto3Hook, Generic[Boto3Client]):
def __init__(self, service_name: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.service_name = service_name
@cached_property
def client(self) -> Boto3Client:
return self.session.client( # type: ignore
service_name=self.service_name,
api_version=self.endpoint_url,
verify=self.verify,
use_ssl=self.use_ssl,
config=self.config,
)
@property
def resource(self):
msg = f"{self} does not provide `boto3.resource(service_name={self.service_name!r})`."
if self.service_name in self.session.get_available_resources():
msg += (
f" Please consider to use {Boto3ResourceHook.__module__}."
f"{Boto3ResourceHook.__name__}(service_name={self.service_name!r}) instead."
)
raise AirflowException(msg)
@property
def _legacy_mode_conn(self):
return self.client
def __repr__(self):
return f"{super().__repr__()}[service_name={self.service_name!r}]"
class AwsBaseHook(_BaseBoto3Hook):
def __init__(
self,
aws_conn_id: str | None = _BaseBoto3Hook.default_conn_name,
verify: bool | str | None = None,
region_name: str | None = None,
client_type: str | None = None,
resource_type: str | None = None,
config: Config | None = None,
):
self.client_type = client_type
self.resource_type = resource_type
warnings.warn(
f"This {type(self).__module__}.{type(self).__name__} hook deprecated and will be removed "
f"in a future releases. Please use {Boto3ClientHook.__module__}.{Boto3ClientHook.__name__} or "
f"{Boto3ResourceHook.__module__}.{Boto3ResourceHook.__name__} instead.",
DeprecationWarning,
stacklevel=3
)
if not exactly_one(client_type, resource_type):
raise ValueError(
f"Either client_type={client_type!r} or resource_type={resource_type!r} must be provided, not both."
)
self._service_name: str = self.client_type or self.resource_type
self._is_resource = bool(resource_type)
if TYPE_CHECKING:
assert isinstance(self._service_name, str)
super().__init__(
aws_conn_id=aws_conn_id, region_name=region_name, verify=verify, config=config, legacy_mode=True
)
@property
def _legacy_conn(self):
if not self._is_resource:
return self.session.client( # type: ignore
service_name=self._service_name,
api_version=self.endpoint_url,
verify=self.verify,
use_ssl=self.use_ssl,
config=self.config,
)
return self.session.resource( # type: ignore
service_name=self._service_name,
api_version=self.endpoint_url,
verify=self.verify,
use_ssl=self.use_ssl,
config=self.config,
)
@property
def client(self) -> Boto3Client:
raise AirflowException(
f"This property disabled for {type(self).__module__}.{type(self).__name__}. "
f"Consider to use {Boto3ClientHook.__module__}.{Boto3ClientHook.__name__} and subclasses."
)
@property
def resource(self) -> Boto3Resource:
raise AirflowException(
f"This property disabled for {type(self).__module__}.{type(self).__name__}. "
f"Consider to use {Boto3ResourceHook.__module__}.{Boto3ResourceHook.__name__} and subclasses."
)
def __repr__(self):
return f"{super().__repr__()}[client_type={self.client_type!r}, resource_type={self.resource_type!r}]"
class AwsGenericHook(AwsBaseHook):
...
if TYPE_CHECKING:
from mypy_boto3_s3.client import S3Client
from mypy_boto3_s3.service_resource import S3ServiceResource
from mypy_boto3_ec2.client import EC2Client
from mypy_boto3_ec2.service_resource import EC2ServiceResource
from mypy_boto3_batch.client import BatchClient
class S3Hook(Boto3ResourceHook["S3Client", "S3ServiceResource"]):
def __init__(self, **kwargs):
super().__init__(service_name="s3", **kwargs)
class EC2Hook(Boto3ResourceHook["EC2Client", "EC2ServiceResource"]):
def __init__(self, **kwargs):
super().__init__(service_name="ec2", **kwargs)
class BatchHook(Boto3ClientHook["BatchClient"]):
def __init__(self, **kwargs):
super().__init__(service_name="batch", **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment