Created
September 2, 2022 21:45
-
-
Save Seanny123/7a3066a84698ee097c31a53f2525bd55 to your computer and use it in GitHub Desktop.
An example of nested fanout using Ray Workflows.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
An example of nested fanout using Ray Workflows. | |
The full workflow creates several batches given a single input, then each of those batches is fanned-out and evaluated. | |
""" | |
import time | |
import ray | |
from ray import workflow | |
@ray.remote | |
def get_range(start: int) -> list[int]: | |
""" | |
Generate a batch. | |
""" | |
time.sleep(2) | |
return list(range(start, start + 5)) | |
@ray.remote | |
def get_square(data: float) -> float: | |
""" | |
A simple operation to test fanout. | |
""" | |
time.sleep(2) | |
return data**2 | |
@ray.remote | |
def range_and_square(data): | |
""" | |
Generate batches, then fanout from each batch | |
Starting node of the whole operation as invoked in the `workflow.run` call below. | |
""" | |
# Batch operation | |
my_ranges = [get_range.bind(start) for start in data] | |
# Fanout to simple operations | |
return workflow.continuation(expand_ranges.bind(my_ranges)) | |
@ray.remote | |
def expand_ranges(my_ranges): | |
""" | |
Ray Workflows do not allow multiple `.bind` operations to occur simultaenously, | |
so we need to chain functions in order to allow the `workflow.continuation` | |
call in `range_and_square` to construct the full DAG before operation. | |
""" | |
expanded = [] | |
for my_range in my_ranges: | |
# proof ray.get is non-blocking, which is why this function does not create a bottleneck | |
print("get range!") | |
expanded.append(ray.get(my_range)) | |
return workflow.continuation(finalize_range.bind(expanded)) | |
@ray.remote | |
def finalize_range(my_ranges): | |
squares = [] | |
for my_range in my_ranges: | |
for item in my_range: | |
squares.append(get_square.bind(item)) | |
return workflow.continuation(finalize_results.bind(squares)) | |
@ray.remote | |
def finalize_results(squares): | |
return [ray.get(s) for s in squares] | |
if __name__ == "__main__": | |
values = list(range(10)) | |
result = workflow.run(range_and_square.bind(values)) | |
# Sleep to allow stderr messages from Ray to flow past | |
# Before printing the output | |
time.sleep(1) | |
print(f"{result=}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment