Last active
April 30, 2024 07:54
-
-
Save catwell/191f589d66927ea340a2c6636bc52738 to your computer and use it in GitHub Desktop.
Lua nested coroutines
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- mini scheduler | |
local tasks = {} -- execution delays in ticks, indexed by the coroutine | |
local delay_mt = {} | |
local function is_delay(x) | |
return type(x) == "table" and getmetatable(x) == delay_mt | |
end | |
local function make_delay(t) | |
return setmetatable({t=t}, delay_mt) | |
end | |
local function sleep(delay) -- delay in ticks | |
if type(delay) ~= "number" then | |
error("bad argument #1 for 'sleep' expected number, got: " .. type(delay), 2) | |
end | |
delay = math.max(0, math.floor(delay)) | |
coroutine.yield(make_delay(delay)) | |
end | |
local _old_resume = coroutine.resume | |
local function _resume(co, ...) | |
local r = {_old_resume(co, ...)} | |
if r[1] and is_delay(r[2]) then | |
return _resume(co, table.unpack(r, 2)) | |
end | |
return table.unpack(r) | |
end | |
coroutine.resume = _resume | |
local function add_task(func, delay) | |
if type(func) ~= "function" then | |
error("bad argument #1 for 'add_task' expected function, got: " .. type(func), 2) | |
end | |
if type(delay) ~= "number" and type(delay) ~= "nil" then | |
error("bad argument #2 for 'add_task' expected number, got: " .. type(delay), 2) | |
end | |
local coro = coroutine.create(func) | |
tasks[coro] = math.max(0, math.floor(delay or 0)) | |
return true | |
end | |
function run_scheduler(init_func) | |
print "starting scheduler initialization" | |
assert(init_func()) | |
print "starting scheduler main loop" | |
while next(tasks) do | |
local old_tasks = tasks | |
tasks = {} | |
for coro, delay in pairs(old_tasks) do | |
if delay > 0 then | |
-- we're not up to run yet | |
tasks[coro] = delay - 1 | |
else | |
-- we're up! | |
local success, new_delay = _old_resume(coro) | |
if not success then | |
print("error resuming coroutine: ", new_delay) | |
end | |
if new_delay ~= nil and not is_delay(new_delay) then | |
print("not a delay: ", new_delay) | |
end | |
if coroutine.status(coro) ~= "dead" then | |
tasks[coro] = new_delay.t | |
end | |
end | |
end | |
end | |
print "exiting" | |
return true | |
end | |
-- end of mini scheduler code | |
-- Try and run the scheduler with 2 interleaving tasks | |
print("\nStarting first test using non-nested coroutines\n") | |
run_scheduler(function() | |
-- Function to create a Fibonacci producer coroutine | |
local number_ticks = 0 | |
function fibonacciProducer(n) | |
-- Coroutine function | |
local function fibCoroutine() | |
local a, b = 0, 1 | |
for i = 1, n do | |
coroutine.yield(a) -- Yield the current Fibonacci number | |
a, b = b, a + b -- Update to the next Fibonacci number | |
end | |
end | |
return coroutine.create(fibCoroutine) -- Return a new coroutine based on fibCoroutine | |
end | |
-- schedule task counting 1 to 10, on each tick printing fibonacci numbers | |
add_task(function() | |
local i = 0 | |
local fibo = fibonacciProducer(100) | |
while i < 10 do | |
i = i + 1 | |
print("fibonacci:", coroutine.resume(fibo)) -- do work | |
sleep(number_ticks) -- yield to allow other threads to run | |
end | |
end, number_ticks) | |
-- schedule task counting A to C, on every 10 ticks | |
local char_ticks = 2 | |
add_task(function() | |
local i = 0 | |
while i < 3 do | |
i = i + 1 | |
print("character:", string.char(64+i)) -- do work | |
sleep(char_ticks) -- wait till we're up again | |
end | |
end, char_ticks) | |
return true | |
end) | |
--[[ Result is: | |
starting scheduler initialization | |
starting scheduler main loop | |
fibonacci: true 0 | |
fibonacci: true 1 | |
character: A | |
fibonacci: true 1 | |
fibonacci: true 2 | |
fibonacci: true 3 | |
character: B | |
fibonacci: true 5 | |
fibonacci: true 8 | |
fibonacci: true 13 | |
character: C | |
fibonacci: true 21 | |
fibonacci: true 34 | |
exiting | |
]] | |
-- Now try again, but make the coroutines nested. So sleep from the fibonacci | |
-- coroutine instead. | |
print("\nStarting second test using nested coroutines\n") | |
run_scheduler(function() | |
-- Function to create a Fibonacci producer coroutine | |
local number_ticks = 0 | |
function fibonacciProducer(n) | |
-- Coroutine function | |
local function fibCoroutine() | |
local a, b = 0, 1 | |
for i = 1, n do | |
coroutine.yield(a) -- Yield the current Fibonacci number | |
a, b = b, a + b -- Update to the next Fibonacci number | |
sleep(number_ticks) -- yield to allow other threads to run <-- SLEEP WAS MOVED HERE | |
end | |
end | |
return coroutine.create(fibCoroutine) -- Return a new coroutine based on fibCoroutine | |
end | |
-- schedule task counting 1 to 10, on each tick | |
add_task(function() | |
local i = 0 | |
local fibo = fibonacciProducer(100) | |
while i < 10 do | |
i = i + 1 | |
print("fibonacci:", coroutine.resume(fibo)) -- do work | |
end | |
end, number_ticks) | |
-- schedule task counting A to C, on every 10 ticks | |
local char_ticks = 2 | |
add_task(function() | |
local i = 0 | |
while i < 3 do | |
i = i + 1 | |
print("character:", string.char(64+i)) -- do work | |
sleep(char_ticks) -- wait till we're up again | |
end | |
end, char_ticks) | |
return true | |
end) | |
--[[ Result is: | |
Starting second test using nested coroutines | |
starting scheduler initialization | |
starting scheduler main loop | |
fibonacci: true 0 | |
fibonacci: true 1 | |
fibonacci: true 1 | |
fibonacci: true 2 | |
fibonacci: true 3 | |
fibonacci: true 5 | |
fibonacci: true 8 | |
fibonacci: true 13 | |
fibonacci: true 21 | |
fibonacci: true 34 | |
character: A | |
character: B | |
character: C | |
exiting | |
]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment