Last active
June 9, 2021 03:01
-
-
Save hanslovsky/f52d006118c468d3277d6f240f784949 to your computer and use it in GitHub Desktop.
Example use case for using Jep to generate numpy arrays in CacheLoader
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env kscript | |
// requires kscript: https://github.com/holgerbrandl/kscript | |
// install jep native libraries with | |
// python -m pip install jep | |
// When using Python interpreter in a a non-standard location, set PYTHONHOME appropriately. | |
@file:MavenRepository("scijava", "https://maven.scijava.org/content/groups/public") | |
@file:DependsOn("net.imglib2:imglib2-cache:1.0.0-beta-16") | |
@file:DependsOn("net.imglib2:imglib2:5.12.0") | |
@file:DependsOn("black.ninia:jep:3.9.1") | |
@file:DependsOn("sc.fiji:bigdataviewer-vistools:1.0.0-beta-28") | |
@file:DependsOn("sc.fiji:bigdataviewer-core:10.2.0") | |
import bdv.util.BdvFunctions | |
import bdv.util.BdvOptions | |
import bdv.util.volatiles.SharedQueue | |
import bdv.util.volatiles.VolatileViews | |
import java.lang.System | |
import jep.DirectNDArray | |
import jep.SharedInterpreter | |
import net.imglib2.cache.CacheLoader | |
import net.imglib2.cache.img.CachedCellImg | |
import net.imglib2.cache.ref.GuardedStrongRefLoaderCache | |
import net.imglib2.cache.ref.SoftRefLoaderCache | |
import net.imglib2.img.basictypeaccess.volatiles.VolatileDoubleAccess | |
import net.imglib2.img.cell.Cell | |
import net.imglib2.img.cell.CellGrid | |
import net.imglib2.type.numeric.real.DoubleType | |
import net.imglib2.util.Intervals | |
import java.nio.ByteBuffer | |
import java.nio.DoubleBuffer | |
import java.util.concurrent.BlockingQueue | |
import java.util.concurrent.CountDownLatch | |
import java.util.concurrent.LinkedBlockingDeque | |
import java.util.concurrent.TimeUnit | |
class DoubleBufferAccess(private val buf: DoubleBuffer) : VolatileDoubleAccess { | |
override fun getValue(index: Int) = buf[index] | |
override fun setValue(index: Int, value: Double) { | |
buf.put(index, value) | |
} | |
override fun isValid() = true | |
companion object { | |
val empty get() = DoubleBufferAccess(ByteBuffer.allocate(0).asDoubleBuffer()) | |
} | |
} | |
class Task(val buf: DoubleBuffer, val index: Long, val min: LongArray, val max: LongArray, val dim: IntArray, val code: String, val blockName: String? = null) { | |
private val latch = CountDownLatch(1) | |
fun complete() = latch.countDown() | |
fun awaitCompletion() = latch.await() | |
} | |
class Worker( | |
queue: BlockingQueue<Task>, | |
init: String? = null, | |
name: String? = null) { | |
private var closed = false | |
private val pythonReady = CountDownLatch(1) | |
private val workerThread = Thread { | |
val python = try { | |
SharedInterpreter().also { it.initialize() } | |
} finally { | |
pythonReady.countDown() | |
} | |
init?.let { python.exec(it) } | |
while (!closed) { | |
queue.poll(10, TimeUnit.MILLISECONDS)?.let { task -> | |
try { | |
require(task.buf.isDirect) | |
python.set("_buf", DirectNDArray(task.buf, *task.dim.reversedArray())) | |
python.set("_index", task.index) | |
python.set("_min", task.min.reversedArray()) | |
python.set("_max", task.max.reversedArray()) | |
python.set("_dim", task.dim.reversedArray()) | |
python.exec("${task.blockName ?: "block"} = Block(_buf, _index, _min, _max, _dim)") | |
python.exec(task.code) | |
} catch (e: Exception) { | |
e.printStackTrace() | |
} finally { | |
task.complete() | |
} | |
} | |
} | |
} | |
init { | |
workerThread.isDaemon = true | |
name?.let { workerThread.setName(it) } | |
workerThread.start() | |
pythonReady.await() | |
} | |
fun close() { | |
closed = true | |
} | |
companion object { | |
fun SharedInterpreter.initialize() { | |
exec( | |
""" | |
from dataclasses import dataclass | |
import numpy as np | |
@dataclass | |
class Block: | |
data: np.ndarray | |
index: int | |
min: tuple | |
max: tuple | |
dim: tuple | |
""".trimIndent()) | |
} | |
} | |
} | |
class WorkerQueue(numWorkers: Int, init: String? = null) { | |
private val queue = LinkedBlockingDeque<Task>() | |
private val workers = Array(numWorkers) { Worker(queue, init, "Python-$it") } | |
fun submitAndAwaitCompletion(buf: DoubleBuffer, index: Long, min: LongArray, max: LongArray, dim: IntArray, code: String, blockName: String? = null) { | |
val task = Task(buf, index, min, max, dim, code, blockName) | |
queue.add(task) | |
task.awaitCompletion() | |
} | |
fun close() { | |
workers.forEach { it.close() } | |
} | |
} | |
class JepyterCacheLoader( | |
private val grid: CellGrid, | |
numWorkers: Int, | |
private val code: String, | |
init: String? = null, | |
private val blockName: String? = null) : CacheLoader<Long, Cell<DoubleBufferAccess>> { | |
private val workerQueue = WorkerQueue(numWorkers, init) | |
override fun get(key: Long): Cell<DoubleBufferAccess> { | |
grid.getCellDimension(1, 2L) | |
val min = LongArray(grid.nDim) { grid.getCellMin(it, key) } | |
val dim = IntArray(grid.nDim) | |
grid.getCellDimensions(key, min, dim) | |
val max = LongArray(grid.nDim) { min[it] + dim[it] - 1 } | |
val buf = ByteBuffer.allocateDirect(8 * Intervals.numElements(*dim).toInt()).asDoubleBuffer() | |
workerQueue.submitAndAwaitCompletion(buf, key, min, max, dim, code, blockName) | |
return Cell(dim, min, DoubleBufferAccess(buf)) | |
} | |
companion object { | |
private val CellGrid.nDim get() = numDimensions() | |
} | |
} | |
val dims = longArrayOf(300, 400, 500) | |
val bs = intArrayOf(30, 40, 50) | |
val grid = CellGrid(dims, bs) | |
val loader = JepyterCacheLoader( | |
grid, | |
3, | |
code = """ | |
block.data[...] = np.mod(np.arange(block.data.size), 255).reshape(block.data.shape) | |
# add 50 milliseconds delay to visualize how blocks are generated on demand | |
time.sleep(0.05) | |
""".trimIndent(), | |
init = "import numpy as np; import time" | |
) | |
// Soft ref cache will not work because native memory will not be added to heap. | |
// Use cache with hard limit on size instead to make sure that unused memory gets freed. | |
val cache = GuardedStrongRefLoaderCache<Long, Cell<DoubleBufferAccess>>(30).withLoader(loader) | |
val img = CachedCellImg(grid, DoubleType(), cache, DoubleBufferAccess.empty) | |
val bdv = BdvFunctions.show( | |
VolatileViews.wrapAsVolatile(img, SharedQueue(10, 1)), | |
"numpy", | |
BdvOptions.options().numRenderingThreads(10)) | |
bdv.setDisplayRange(0.0, 255.0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env kscript | |
// requires kscript: https://github.com/holgerbrandl/kscript | |
// install jep native libraries and tensorflow, stardist dependencies with | |
// python -m pip install jep stardist tensorflow | |
// When using Python interpreter in a a non-standard location, set PYTHONHOME appropriately. | |
@file:MavenRepository("scijava", "https://maven.scijava.org/content/groups/public") | |
@file:DependsOn("net.imglib2:imglib2-cache:1.0.0-beta-16") | |
@file:DependsOn("net.imglib2:imglib2:5.12.0") | |
@file:DependsOn("black.ninia:jep:3.9.1") | |
@file:DependsOn("sc.fiji:bigdataviewer-vistools:1.0.0-beta-28") | |
@file:DependsOn("sc.fiji:bigdataviewer-core:10.2.0") | |
import bdv.util.BdvFunctions | |
import bdv.util.BdvOptions | |
import bdv.util.volatiles.SharedQueue | |
import bdv.util.volatiles.VolatileViews | |
import java.lang.System | |
import jep.DirectNDArray | |
import jep.SharedInterpreter | |
import net.imglib2.cache.CacheLoader | |
import net.imglib2.cache.img.CachedCellImg | |
import net.imglib2.cache.ref.GuardedStrongRefLoaderCache | |
import net.imglib2.cache.ref.SoftRefLoaderCache | |
import net.imglib2.img.basictypeaccess.volatiles.VolatileDoubleAccess | |
import net.imglib2.img.cell.Cell | |
import net.imglib2.img.cell.CellGrid | |
import net.imglib2.type.numeric.real.DoubleType | |
import net.imglib2.util.Intervals | |
import java.nio.ByteBuffer | |
import java.nio.DoubleBuffer | |
import java.util.concurrent.BlockingQueue | |
import java.util.concurrent.CountDownLatch | |
import java.util.concurrent.LinkedBlockingDeque | |
import java.util.concurrent.TimeUnit | |
class DoubleBufferAccess(private val buf: DoubleBuffer) : VolatileDoubleAccess { | |
override fun getValue(index: Int) = buf[index] | |
override fun setValue(index: Int, value: Double) { | |
buf.put(index, value) | |
} | |
override fun isValid() = true | |
companion object { | |
val empty get() = DoubleBufferAccess(ByteBuffer.allocate(0).asDoubleBuffer()) | |
} | |
} | |
class Task(val buf: DoubleBuffer, val index: Long, val min: LongArray, val max: LongArray, val dim: IntArray, val code: String, val blockName: String? = null) { | |
private val latch = CountDownLatch(1) | |
fun complete() = latch.countDown() | |
fun awaitCompletion() = latch.await() | |
} | |
class Worker( | |
queue: BlockingQueue<Task>, | |
init: String? = null, | |
name: String? = null) { | |
private var closed = false | |
private val pythonReady = CountDownLatch(1) | |
private val workerThread = Thread { | |
val python = try { | |
SharedInterpreter().also { it.initialize() } | |
} finally { | |
pythonReady.countDown() | |
} | |
init?.let { python.exec(it) } | |
while (!closed) { | |
queue.poll(10, TimeUnit.MILLISECONDS)?.let { task -> | |
try { | |
require(task.buf.isDirect) | |
python.set("_buf", DirectNDArray(task.buf, *task.dim.reversedArray())) | |
python.set("_index", task.index) | |
python.set("_min", task.min.reversedArray()) | |
python.set("_max", task.max.reversedArray()) | |
python.set("_dim", task.dim.reversedArray()) | |
python.exec("${task.blockName ?: "block"} = Block(_buf, _index, _min, _max, _dim)") | |
python.exec(task.code) | |
} catch (e: Exception) { | |
e.printStackTrace() | |
} finally { | |
task.complete() | |
} | |
} | |
} | |
} | |
init { | |
workerThread.isDaemon = true | |
name?.let { workerThread.setName(it) } | |
workerThread.start() | |
pythonReady.await() | |
} | |
fun close() { | |
closed = true | |
} | |
companion object { | |
fun SharedInterpreter.initialize() { | |
exec( | |
""" | |
from dataclasses import dataclass | |
import numpy as np | |
@dataclass | |
class Block: | |
data: np.ndarray | |
index: int | |
min: tuple | |
max: tuple | |
dim: tuple | |
""".trimIndent()) | |
} | |
} | |
} | |
class WorkerQueue(numWorkers: Int, init: String? = null) { | |
private val queue = LinkedBlockingDeque<Task>() | |
private val workers = Array(numWorkers) { Worker(queue, init, "Python-$it") } | |
fun submitAndAwaitCompletion(buf: DoubleBuffer, index: Long, min: LongArray, max: LongArray, dim: IntArray, code: String, blockName: String? = null) { | |
val task = Task(buf, index, min, max, dim, code, blockName) | |
queue.add(task) | |
task.awaitCompletion() | |
} | |
fun close() { | |
workers.forEach { it.close() } | |
} | |
} | |
class JepyterCacheLoader( | |
private val grid: CellGrid, | |
numWorkers: Int, | |
private val code: String, | |
init: String? = null, | |
private val blockName: String? = null) : CacheLoader<Long, Cell<DoubleBufferAccess>> { | |
private val workerQueue = WorkerQueue(numWorkers, init) | |
override fun get(key: Long): Cell<DoubleBufferAccess> { | |
grid.getCellDimension(1, 2L) | |
val min = LongArray(grid.nDim) { grid.getCellMin(it, key) } | |
val dim = IntArray(grid.nDim) | |
grid.getCellDimensions(key, min, dim) | |
val max = LongArray(grid.nDim) { min[it] + dim[it] - 1 } | |
val buf = ByteBuffer.allocateDirect(8 * Intervals.numElements(*dim).toInt()).asDoubleBuffer() | |
workerQueue.submitAndAwaitCompletion(buf, key, min, max, dim, code, blockName) | |
return Cell(dim, min, DoubleBufferAccess(buf)) | |
} | |
companion object { | |
private val CellGrid.nDim get() = numDimensions() | |
} | |
} | |
val initBlock = """ | |
from stardist.data import test_image_nuclei_2d | |
from stardist.models import StarDist2D | |
from stardist.plot import render_label | |
from csbdeep.utils import normalize | |
img = test_image_nuclei_2d() | |
model = StarDist2D.from_pretrained('2D_versatile_fluo') | |
""".trimIndent() | |
val code = """ | |
halo = 10 | |
offsets = tuple(min(m, halo) for m in block.min) | |
print(f'{offsets=}') | |
slicing = tuple(slice(m - o, M+1 + halo) for o, m, M in zip(offsets, block.min, block.max)) | |
# slicing = tuple(slice(m, M+1) for m, M in zip(block.min, block.max)) | |
labels, _ = model.predict_instances(normalize(img[slicing])) | |
block.data[...] = labels[tuple(slice(o, o + s) for o, s in zip(offsets, block.data.shape))] # labels | |
""".trimIndent() | |
val dims = longArrayOf(512, 512) | |
val bs = intArrayOf(80, 90) | |
val grid = CellGrid(dims, bs) | |
val loader = JepyterCacheLoader( | |
grid, | |
3, | |
code = code, | |
init = initBlock | |
) | |
// Soft ref cache will not work because native memory will not be added to heap. | |
// Use cache with hard limit on size instead to make sure that unused memory gets freed. | |
val cache = GuardedStrongRefLoaderCache<Long, Cell<DoubleBufferAccess>>(30).withLoader(loader) | |
val img = CachedCellImg(grid, DoubleType(), cache, DoubleBufferAccess.empty) | |
val bdv = BdvFunctions.show( | |
VolatileViews.wrapAsVolatile(img, SharedQueue(10, 1)), | |
"numpy", | |
BdvOptions.options().numRenderingThreads(10).is2D()) | |
bdv.setDisplayRange(0.0, 10.0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment