-
-
Save Taragolis/d4ade059c1966a773a4ee7a1d29800d9 to your computer and use it in GitHub Desktop.
PoC: Consistency in boto3-based Hooks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import json | |
import warnings | |
from typing import Generic, TypeVar, TYPE_CHECKING, Union | |
import boto3 | |
from botocore.config import Config | |
from airflow.compat.functools import cached_property | |
from airflow.hooks.base import BaseHook | |
from airflow.utils.helpers import exactly_one | |
from airflow.exceptions import AirflowException | |
Boto3Client = TypeVar("Boto3Client", bound=boto3.client) | |
Boto3Resource = TypeVar("Boto3Resource", bound=boto3.resource) | |
class _BaseBoto3Hook(BaseHook): | |
conn_name_attr = "aws_conn_id" | |
default_conn_name = "aws_default" | |
conn_type = "aws" | |
hook_name = "Amazon Web Services" | |
def __init__( | |
self, | |
*, | |
aws_conn_id: str | None = default_conn_name, | |
region_name: str | None = None, | |
api_version: str | None = None, | |
endpoint_url: str | None = None, | |
verify: bool | str | None = None, | |
use_ssl: bool | None = None, | |
config: Config | None = None, | |
legacy_mode: bool = True, | |
): | |
super().__init__() | |
self.aws_conn_id = aws_conn_id | |
self.region_name = region_name | |
self.api_version = api_version | |
self.endpoint_url = endpoint_url | |
self.verify = verify | |
self.use_ssl = use_ssl | |
self.config = config | |
self._legacy_mode = legacy_mode | |
@cached_property | |
def session(self) -> boto3.session.Session: | |
"""Get the underlying boto3.session.Session (cached).""" | |
return boto3.session.Session(region_name=self.region_name) | |
@property | |
def client(self) -> Boto3Client: | |
"""Get the underlying boto3.client""" | |
raise NotImplementedError() | |
@property | |
def resource(self) -> Boto3Resource: | |
"""Get the underlying boto3.resource""" | |
raise NotImplementedError() | |
@cached_property | |
def conn(self): | |
""" | |
Get the "connection object" based on hook mode and previous implementation (cached). | |
If legacy_mode is False then this property return boto3.session.Session | |
Otherwise return object from `_legacy_mode_conn` property. | |
""" | |
if not self._legacy_mode: | |
return self.session | |
warnings.warn( | |
f"`conn` property accessed in legacy mode for {self}, " | |
f"please consider use one of appropriate properties directly: `session`, `client`, `resource`.", | |
FutureWarning, | |
stacklevel=3, | |
) | |
return self._legacy_mode_conn | |
@property | |
def _legacy_mode_conn(self): | |
"""Implementation """ | |
raise NotImplementedError(f"Legacy mode for `conn` property not implemented in {self}.") | |
def get_conn(self): | |
return self.conn | |
def test_connection(self): | |
""" | |
Tests the AWS connection by call AWS STS (Security Token Service) GetCallerIdentity API. | |
.. seealso:: | |
https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html | |
""" | |
try: | |
conn_info = self.session.client("sts").get_caller_identity() | |
metadata = conn_info.pop("ResponseMetadata", {}) | |
if metadata.get("HTTPStatusCode") != 200: | |
try: | |
return False, json.dumps(metadata) | |
except TypeError: | |
return False, str(metadata) | |
conn_info["credentials_method"] = self.session.get_credentials().method | |
conn_info["region_name"] = self.session.region_name | |
return True, ", ".join(f"{k}={v!r}" for k, v in conn_info.items()) | |
except Exception as e: | |
return False, str(f"{type(e).__name__!r} error occurred while testing connection: {e}") | |
def __repr__(self): | |
return f"{type(self).__module__}.{type(self).__name__}" | |
class Boto3ResourceHook(_BaseBoto3Hook, Generic[Boto3Client, Boto3Resource]): | |
def __init__(self, service_name: str, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.service_name = service_name | |
@cached_property | |
def resource(self) -> Boto3Resource: | |
return self.session.resource( # type: ignore | |
service_name=self.service_name, | |
api_version=self.endpoint_url, | |
verify=self.verify, | |
use_ssl=self.use_ssl, | |
config=self.config, | |
) | |
@cached_property | |
def client(self) -> Boto3Client: | |
return self.resource.meta.client | |
@cached_property | |
def conn(self) -> Union[Boto3Resource]: | |
return self.resource | |
@property | |
def _legacy_mode_conn(self): | |
return self.resource | |
def __repr__(self): | |
return f"{super().__repr__()}[service_name={self.service_name!r}]" | |
class Boto3ClientHook(_BaseBoto3Hook, Generic[Boto3Client]): | |
def __init__(self, service_name: str, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.service_name = service_name | |
@cached_property | |
def client(self) -> Boto3Client: | |
return self.session.client( # type: ignore | |
service_name=self.service_name, | |
api_version=self.endpoint_url, | |
verify=self.verify, | |
use_ssl=self.use_ssl, | |
config=self.config, | |
) | |
@property | |
def resource(self): | |
msg = f"{self} does not provide `boto3.resource(service_name={self.service_name!r})`." | |
if self.service_name in self.session.get_available_resources(): | |
msg += ( | |
f" Please consider to use {Boto3ResourceHook.__module__}." | |
f"{Boto3ResourceHook.__name__}(service_name={self.service_name!r}) instead." | |
) | |
raise AirflowException(msg) | |
@property | |
def _legacy_mode_conn(self): | |
return self.client | |
def __repr__(self): | |
return f"{super().__repr__()}[service_name={self.service_name!r}]" | |
class AwsBaseHook(_BaseBoto3Hook): | |
def __init__( | |
self, | |
aws_conn_id: str | None = _BaseBoto3Hook.default_conn_name, | |
verify: bool | str | None = None, | |
region_name: str | None = None, | |
client_type: str | None = None, | |
resource_type: str | None = None, | |
config: Config | None = None, | |
): | |
self.client_type = client_type | |
self.resource_type = resource_type | |
warnings.warn( | |
f"This {type(self).__module__}.{type(self).__name__} hook deprecated and will be removed " | |
f"in a future releases. Please use {Boto3ClientHook.__module__}.{Boto3ClientHook.__name__} or " | |
f"{Boto3ResourceHook.__module__}.{Boto3ResourceHook.__name__} instead.", | |
DeprecationWarning, | |
stacklevel=3 | |
) | |
if not exactly_one(client_type, resource_type): | |
raise ValueError( | |
f"Either client_type={client_type!r} or resource_type={resource_type!r} must be provided, not both." | |
) | |
self._service_name: str = self.client_type or self.resource_type | |
self._is_resource = bool(resource_type) | |
if TYPE_CHECKING: | |
assert isinstance(self._service_name, str) | |
super().__init__( | |
aws_conn_id=aws_conn_id, region_name=region_name, verify=verify, config=config, legacy_mode=True | |
) | |
@property | |
def _legacy_conn(self): | |
if not self._is_resource: | |
return self.session.client( # type: ignore | |
service_name=self._service_name, | |
api_version=self.endpoint_url, | |
verify=self.verify, | |
use_ssl=self.use_ssl, | |
config=self.config, | |
) | |
return self.session.resource( # type: ignore | |
service_name=self._service_name, | |
api_version=self.endpoint_url, | |
verify=self.verify, | |
use_ssl=self.use_ssl, | |
config=self.config, | |
) | |
@property | |
def client(self) -> Boto3Client: | |
raise AirflowException( | |
f"This property disabled for {type(self).__module__}.{type(self).__name__}. " | |
f"Consider to use {Boto3ClientHook.__module__}.{Boto3ClientHook.__name__} and subclasses." | |
) | |
@property | |
def resource(self) -> Boto3Resource: | |
raise AirflowException( | |
f"This property disabled for {type(self).__module__}.{type(self).__name__}. " | |
f"Consider to use {Boto3ResourceHook.__module__}.{Boto3ResourceHook.__name__} and subclasses." | |
) | |
def __repr__(self): | |
return f"{super().__repr__()}[client_type={self.client_type!r}, resource_type={self.resource_type!r}]" | |
class AwsGenericHook(AwsBaseHook): | |
... | |
if TYPE_CHECKING: | |
from mypy_boto3_s3.client import S3Client | |
from mypy_boto3_s3.service_resource import S3ServiceResource | |
from mypy_boto3_ec2.client import EC2Client | |
from mypy_boto3_ec2.service_resource import EC2ServiceResource | |
from mypy_boto3_batch.client import BatchClient | |
class S3Hook(Boto3ResourceHook["S3Client", "S3ServiceResource"]): | |
def __init__(self, **kwargs): | |
super().__init__(service_name="s3", **kwargs) | |
class EC2Hook(Boto3ResourceHook["EC2Client", "EC2ServiceResource"]): | |
def __init__(self, **kwargs): | |
super().__init__(service_name="ec2", **kwargs) | |
class BatchHook(Boto3ClientHook["BatchClient"]): | |
def __init__(self, **kwargs): | |
super().__init__(service_name="batch", **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment