Skip to content

Instantly share code, notes, and snippets.

@njsmith
Last active August 6, 2021 17:54
Show Gist options
  • Save njsmith/60114a31cfa03bc63be6dd57e421b5f3 to your computer and use it in GitHub Desktop.
Save njsmith/60114a31cfa03bc63be6dd57e421b5f3 to your computer and use it in GitHub Desktop.
import trio
import functools
class ReplicatedBroadcastFailed(Exception):
pass
class ReplicatedBroadcast:
def __init__(self, nursery, async_iterable):
self._nursery = nursery
self._async_iter = async_iterable.__aiter__()
self._actively_pulling = False
self._history = []
self._failed = None
self._complete = False
self._new_data = trio.Event()
async def subscribe(self):
next_idx = 0
while next_idx < len(self._history) and not self._complete:
if next_idx < len(self._history):
yield self._history[next_idx]
next_idx += 1
else:
if not self._actively_pulling:
self._actively_pulling = True
self._nursery.start_soon(self._pull_next_item)
await self._new_data.wait()
if self._failed is not None:
raise ReplicatedBroadcastFailed from self._failed
def _pull_next_item(self):
assert self._actively_pulling
assert not self._complete
try:
self._history.append(await self._async_iter.__anext__())
except StopAsyncIteration:
self._complete = True
except Exception as exc:
self._failed = exc
self._complete = True
finally:
self._actively_pulling = False
self._new_data.set()
self._new_data = trio.Event()
def cached_async_gen(inner):
@functools.lru_cache
def broadcast_factory(*args, **kwargs):
# XX need to get this nursery from somewhere...
return ReplicatedBroadcast(nursery, inner(*args, **kwargs))
@functools.wraps(inner)
def wrapper(*args, **kwargs):
return broadcast_factory(*args, **kwargs).subscribe()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment