Skip to content

Instantly share code, notes, and snippets.

@HacKanCuBa
Last active March 17, 2024 07:04
Show Gist options
  • Save HacKanCuBa/9fceabaeb6417ed6280c2e7a48981420 to your computer and use it in GitHub Desktop.
Save HacKanCuBa/9fceabaeb6417ed6280c2e7a48981420 to your computer and use it in GitHub Desktop.
Cache handy helpers
"""Handy cache helpers.
These are not yet production ready, as I haven't toroughly tested them, but close.
---
Mandatory license blah blah:
Copyright (C) 2023 HacKan (https://hackan.net)
This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
file, You can obtain one at https://mozilla.org/MPL/2.0/.
"""
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Sequence
import redis.asyncio as redis
class AsyncCacheBackend(ABC):
@abstractmethod
async def get(self, key: str | bytes) -> bytes:
...
@abstractmethod
async def set(self, key: str | bytes, value: bytes, ttl: int | None = None) -> None:
...
@abstractmethod
async def delete(self, key: str | bytes) -> None:
...
@abstractmethod
async def keys(self) -> tuple[bytes, ...]:
...
class RedisAsyncCache(AsyncCacheBackend):
def __init__(self, conn: redis.Redis):
self._redis = conn
async def get(self, key: str | bytes) -> bytes | None:
return await self._redis.get(key)
async def set(self, key: str | bytes, value: bytes, ttl: int | None = None) -> None:
await self._redis.set(key, value, ex=ttl)
async def delete(self, key: str | bytes) -> None:
await self._redis.delete(key)
async def keys(self) -> tuple[bytes, ...]:
keys: list[bytes] = await self._redis.keys()
return tuple(keys)
class InMemoryAsyncCache(AsyncCacheBackend):
def __init__(self):
self._cache = {}
@staticmethod
def _build_key(key: str | bytes) -> bytes:
if isinstance(key, str):
return key.encode()
return key
async def get(self, key: str | bytes) -> bytes | None:
return self._cache.get(self._build_key(key))
async def set(self, key: str | bytes, value: bytes, _: int | None = None) -> None:
self._cache[self._build_key(key)] = value
async def delete(self, key: str | bytes) -> None:
try:
del self._cache[self._build_key(key)]
except KeyError:
pass
async def keys(self) -> tuple[bytes, ...]:
# noinspection PyTypeChecker
return tuple(self._cache.keys())
class AsyncCache:
def __init__(self, backend: AsyncCacheBackend, /) -> None:
self._cache = backend
async def get(self, key: str | bytes) -> bytes:
value = await self._cache.get(key)
if value is None:
raise KeyError(key)
return value
async def set(self, key: str | bytes, value: bytes, ttl: int | None = None) -> None:
await self._cache.set(key, value, ttl)
async def delete(self, key: str | bytes) -> None:
await self._cache.delete(key)
async def keys(self) -> tuple[bytes, ...]:
return await self._cache.keys()
class AsyncCachePassthrough:
def __init__(self, backends: Sequence[AsyncCacheBackend], /) -> None:
self._caches = tuple(backends)
async def get(self, key: str | bytes) -> bytes:
value = None
to_set: set[AsyncCacheBackend] = set()
for cache in self._caches:
value = await cache.get(key)
if value is not None:
# This assumes that every other cache after this one already has the value
# Not doing so would be quite slow
break
if value is None:
raise KeyError(key)
for cache in to_set:
await cache.set(key, value)
return value
async def set(self, key: str | bytes, value: bytes) -> None:
for cache in self._caches:
await cache.set(key, value)
async def delete(self, key: str | bytes) -> None:
for cache in self._caches:
await cache.delete(key)
async def synchronize(self) -> None:
if len(self._caches) == 1:
return
# This will assume that if two caches have the same key, they also have the same value
keys: defaultdict[bytes, set[AsyncCacheBackend]] = defaultdict(set)
for cache in self._caches:
for key in await cache.keys():
keys[key].add(cache)
for cache in self._caches:
for key in keys:
if cache in keys[key]:
continue
other = next(iter(keys[key]))
value = await other.get(key)
assert value is not None
await cache.set(key, value)
@asynccontextmanager
async def async_redis_connection(host: str, *, port: int = 6379, db: int = 0, **kwargs: Any) -> AsyncGenerator[redis.Redis, None]:
params = {
"auto_close_connection_pool": True,
}
params.update(kwargs)
conn: redis.Redis = redis.Redis(host=host, port=port, db=db, **params) # type: ignore[arg-type]
try:
yield conn
finally:
await conn.close()
@asynccontextmanager
async def async_cache(host: str | None = None, **kwargs: Any) -> AsyncGenerator[AsyncCache, None]:
if not host:
cache = AsyncCache(InMemoryAsyncCache())
yield cache
return
async with async_redis_connection(host, **kwargs) as conn:
cache = AsyncCache(RedisAsyncCache(conn))
yield cache
@asynccontextmanager
async def async_cache_passthrough(
host: str | None = None,
*,
use_in_memory: bool = True,
**kwargs: Any,
) -> AsyncGenerator[AsyncCachePassthrough | None, None]:
if host:
async with async_redis_connection(host, **kwargs) as conn:
if use_in_memory:
backends = [InMemoryAsyncCache(), RedisAsyncCache(conn)]
else:
backends = [RedisAsyncCache(conn)]
cache = AsyncCachePassthrough(backends)
elif use_in_memory:
cache = AsyncCachePassthrough([InMemoryAsyncCache()])
else:
cache = None
yield cache
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment