Skip to content

Instantly share code, notes, and snippets.

@mathnathan
Created March 15, 2016 20:39
Show Gist options
  • Save mathnathan/e77414e67bd47439c57c to your computer and use it in GitHub Desktop.
Save mathnathan/e77414e67bd47439c57c to your computer and use it in GitHub Desktop.
polygroups visualization: activate
license: gpl-3.0

Prototyping Neural Network algorithm visualization

simulation.json

{
	"polygroup": [{
		"5": [ 					// timestep this node was activated
      							// each tuple describes incoming activations
			[2, 3],				// [node index, timestep when this node fired]
			[1, 4]				// 
		]
	}, {}, {					
  									// this node (index 2) fired at times 2 and 6
		"2": [
			[0, 2],
			[2, 2]
		],
		"6": [
			[0, 6],
			[1, 4]
		]
	}],
	"init": [1, 3, 1],
	"numTimeSteps": 6,
	"network": [
		[0, 2, 3],
		[1, 4, 2],
		[1, 3, 1]
	]
}

NOTES // when a node fires, what would be super cool is in the connection // between two nodes, that node can increase or decrease in size, can // flash different color whether plasticity is increasing or decreasing

Also using @sxywu's trick of adding links as invisible nodes and centering them between source and target, this keeps the graph from forming straight lines in some configurations.

Using @erkal's polygon links.


forked from mbostock's block: Collision Detection

forked from enjalot's block: rigid links

forked from enjalot's block: rigid links

forked from mathnathan's block: polygroups visualization

forked from enjalot's block: polygroups visualization

forked from enjalot's block: polygroups visualization: activate

<!DOCTYPE html>
<meta charset="utf-8">
<body>
<script src="//d3js.org/d3.v3.min.js"></script>
<style>
#force {
width: 48%;
height: 100%;
float: left;
border: 1px solid #eee;
box-shadow: 3px 3px 5px #d6d6d6;
}
#raster {
width: 50%;
height: 100%;
float: right;
border: 1px solid #eee;
box-shadow: 3px 3px 5px #d6d6d6;
}
svg {
width: 100%;
height: 100%;
}
/* RASTER PLOT */
rect.step {
stroke: #444;
stroke-width: 1;
fill: #eee;
opacity: 0.2;
}
text.step {
font-family: Courier new, fixed-width;
font-size: 10px;
}
line.activation {
opacity: 0.1;
stroke: #111;
}
line.active {
opacity: 1;
stroke-width: 4px;
stroke: #444;
}
circle.activation {
fill: #444;
}
.legend {
opacity: 0.5;
}
/* FORCE LAYOUT */
.node {
stroke: #fff;
stroke-width: 4;
}
.link {
stroke: #fff;
stroke-width: 2;
pointer-events: none;
}
</style>
<body>
<div id="force"></div>
<div id="raster"></div>
</body>
<script>
var width = 500;
var height = 500;
var color = d3.scale.category10();
d3.json("simulation.json", function(err, data) {
// network is the array describing the delays between neurons
var N = data.network.length;
var nt = data.numTimeSteps + 4; // hardcoded, but should be max delay
var startRaster = 50;
var endRaster = 450;
var headLength = 20;
var headWidth = 14;
var linkWidth = 8;
// fill in the plasticity for time steps when it doesn't change
// so its easier to look up by time step
var plasticity = []
d3.range(nt).forEach(function(t) {
if(data.plasticity[t]) {
plasticity[t] = data.plasticity[t];
} else {
plasticity[t] = plasticity[t-1];
}
})
// nodes for the force layout
var nodes = d3.range(N).map(function(i) {
var x = width / 2 + Math.random() * 10;
var y = height / 2 + Math.random() * 10;
return {
index: i,
radius: 15,
x: x,
y: y,
px: x,
py: y,
};
})
// links for the force layout
var links = [];
data.network.forEach(function(row, i) {
// row is incoming (target)
row.forEach(function(delay, j) {
// j is source
if(i === j) {
// figure out how do represent self-referential links
} else {
var p = plasticity[0][j][i];
links.push({
source: nodes[j], target: nodes[i],
delay: delay, type: "normal",
linkWidth: linkWidth * p,
headLength: headLength,
headWidth: headWidth * p
})
}
})
})
// setup data structures used for showing neural activity
var spikes = []; // when a neuron fires
var activations = []; // the subset of activations that lead to spikes
var allActivations = []; // all activations (when a neuron spikes it sends activations to all linked neurons)
data.init.forEach(function(time, index) {
spikes.push({
index: index,
time: time,
initial: true
})
addActivationsForSpike(time, index)
})
// create the time steps and the historical data for spikes and activations
var steps = d3.range(nt).map(function (time) {
data.polygroup.forEach(function(node, index) {
var sources = node[time];
if(sources && sources.length) {
spikes.push({
time: time,
index: index
})
// loop over the sources of the spikes to get the activations
sources.forEach(function(source) {
// we want to render each spike
activations.push({
target: index,
source: source[0],
sourceTime: source[1],
time: time
})
})
addActivationsForSpike(time, index)
}
})
return {
time: time,
}
})
function addActivationsForSpike(time, index) {
// loop over all connections for this neuron to create an activation
//var row = data.network[index]
// row is incoming (target)
data.network.forEach(function(row, i) {
var delay = row[index]
// j is source
if(index === i) {
// figure out how do represent self-referential links
} else {
allActivations.push({
sourceTime: time,
targetTime: time + delay,
source: index, target: i, // could reference the nodes instead
delay: delay,
type: "normal",
})
}
})
}
var xscale = d3.scale.ordinal()
.domain(d3.range(nt))
.rangeBands([startRaster, endRaster])
var yscale = d3.scale.ordinal()
.domain(d3.range(N))
.rangePoints([30, 300])
//---------------------------------------------------------------
// RASTER PLOT
//---------------------------------------------------------------
renderRaster();
function renderRaster() {
var rastersvg = d3.select("#raster").append("svg")
var radius = 10;
console.log("spikes", spikes)
console.log("activations", activations)
var stepRects = rastersvg.selectAll("rect.step")
.data(steps)
stepRects.enter().append("rect").classed("step", true)
stepRects.attr({
x: function(d,i) { return xscale(d.time) },
y: 10,
width: xscale.rangeBand(),
height: 400
})
var stepLabels = rastersvg.selectAll("text.step")
.data(steps)
stepLabels.enter().append("text").classed("step", true)
.text(function(d) { return d.time })
.attr({
x: function(d,i) { return xscale(d.time) + 5 },
y: 405
})
var legendCircles = rastersvg.selectAll("circle.legend")
.data(d3.range(N))
.enter()
.append("circle").classed("legend", true)
.attr({
cx: startRaster - 2 * radius,
cy: function(d) { return yscale(d) },
fill: function(d) { return color(d) },
r: radius
})
var legendLines = rastersvg.selectAll("line.legend")
.data(d3.range(N))
.enter()
.append("line").classed("legend", true)
.attr({
x1: startRaster - radius,
y1: function(d) { return yscale(d) },
x2: endRaster,
y2: function(d) { return yscale(d) },
stroke: function(d) { return color(d) },
})
var legendText = rastersvg.selectAll("text.legend")
.data(d3.range(N))
.enter()
.append("text").classed("legend", true)
.text(function(d) { return d })
.attr({
x: startRaster - 2 * radius,
y: function(d) { return yscale(d) + 20 },
fill: function(d) { return color(d) },
"alignment-baseline": "middle",
"text-anchor":"middle"
})
var activationLines = rastersvg.selectAll("line.activation")
.data(activations)
activationLines.enter().append("line").classed("activation", true)
activationLines.attr({
x1: function(d) { return xscale(d.sourceTime)},
y1: function(d) { return yscale(d.source)},
x2: function(d) { return xscale(d.time)},
y2: function(d) { return yscale(d.target)},
})
var activeLines = rastersvg.selectAll("line.active")
.data(activations)
activeLines.enter().append("line").classed("active", true)
activeLines.attr({
x1: function(d) { return xscale(d.sourceTime)},
y1: function(d) { return yscale(d.source)},
x2: function(d) { return xscale(d.sourceTime)},
y2: function(d) { return yscale(d.source)},
})
var activationDots = rastersvg.selectAll("circle.activation")
.data(activations)
activationDots.enter().append("circle").classed("activation", true)
activationDots.attr({
cx: function(d) { return xscale(d.sourceTime)},
cy: function(d) { return yscale(d.source)},
r: 5,
opacity: 0
})
var spikeCircles = rastersvg.selectAll("circle.spike")
.data(spikes)
spikeCircles.enter().append("circle").classed("spike", true)
spikeCircles.attr({
cx: function(d) { return xscale(d.time) },
cy: function(d) { return yscale(d.index) },
r: radius,
//fill: function(d) { return color(d.index) },
fill: "#eee",
opacity: 0.1,
stroke: function(d) {
//return d.initial ? "#111": "#eee"
return "#444"
}
})
}
//---------------------------------------------------------------
// FORCE DIAGRAM
//---------------------------------------------------------------
renderForce();
function renderForce() {
var delayScale = d3.scale.linear()
.domain(d3.extent(links, function(d) { return d.delay }))
.range([100, 300])
var force = d3.layout.force()
.gravity(0.05)
.charge(function(d) {
if(d.type === "rigid") {
return -100
}
return -30
})
.nodes(nodes.concat(links))
.size([width, height]);
force.start();
var forcesvg = d3.select("#force").append("svg")
var svgLinks = forcesvg.selectAll(".link")
.data(links)
.enter().append("polygon").classed("link", true)
var svgNodes = forcesvg.selectAll("circle")
.data(nodes)
.enter().append("circle").classed("node", true)
.attr("r", function(d) { return d.radius; })
.style("fill", function(d, i) { return color(d.index); })
.call(force.drag)
//console.log(links)
var activeRects = forcesvg.selectAll("rect.activation")
.data(allActivations)
activeRects.enter().append("rect").classed("activation", true)
force.on("tick", function(e) {
var nodes = force.nodes()
force.alpha(0.1)
// collision detection
var q = d3.geom.quadtree(nodes);
nodes.forEach(function(node) {
q.visit(collide(node))
})
// strongly coupling certain nodes
links.forEach(function(link) {
var source = link.source;
var target = link.target;
/*
if(link.type == "normal") {
var x = source.x - target.x;
var y = source.y - target.y;
var dist = Math.sqrt(x * x + y * y);
//var r = source.radius + target.radius;
var r = delayScale(link.delay)
if (dist < r) {
// this condition adapted from collision detection
dist = (dist - r) / dist * .5; //don't quite understand this
source.x -= x *= dist;
source.y -= y *= dist;
target.x += x;
target.y += y;
} else {
dist = -0.01; // not sure how to do the opposive of above
source.x += x *= dist;
source.y += y *= dist;
target.x -= x;
target.y -= y;
}
}
*/
// move the invisible link nodes
var x1 = link.source.x,
x2 = link.target.x,
y1 = link.source.y,
y2 = link.target.y,
slope = (y2 - y1) / (x2 - x1);
link.x = (x2 + x1)/ 2;
link.y = (x2 - x1) * slope / 2 + y1;
})
forcesvg.selectAll("circle.node")
.attr({
cx: function(d) {return d.x },
cy: function(d) { return d.y },
r: function(d) { return d.radius }
})
svgLinks.attr("points", calculatePolygon);
});
}
controller();
function controller() {
var rastersvg = d3.select("#raster svg")
// the travelling dots and lines that show an activation en route to a neuron spike
var activationDots = rastersvg.selectAll("circle.activation")
var activeLines = rastersvg.selectAll("line.active")
// colored circles signifying when a neuron spikes
var spikeCircles = rastersvg.selectAll("circle.spike")
var forcesvg = d3.select("#force svg")
var neuronCircles = forcesvg.selectAll("circle.node")
var neuronLinks = forcesvg.selectAll(".link")
var activationRects = forcesvg.selectAll("rect.activation")
var stepRects = rastersvg.selectAll("rect.step")
stepRects.on("mouseover", function(d) {
//console.log("step", d);
rasterSpike(d);
forceSpike(d);
})
.on("mouseout", function(d) {
})
function rasterSpike(step) {
spikeCircles.filter(function(f) { return f.time === step.time })
.transition()
.attr({
r: 15
})
.each("end", function(c) {
d3.select(this).transition()
.attr({
r: 10
})
})
}
function forceSpike(step) {
var spiked = [];
spikes.forEach(function(d) {
if(d.time === step.time) {
spiked.push(d.index)
}
})
neuronCircles.filter(function(f) { return spiked.indexOf(f.index) >= 0 })
.transition()
.attr({
r: function(d) { return d.radius * 1.5}
})
.each("end", function(c) {
d3.select(this).transition()
.attr({
r: function(d) { return d.radius }
})
})
}
rastersvg.on("mousemove", function() {
var x = d3.mouse(this)[0];
timeTravel(x);
})
function timeTravel(x) {
var inRaster = x - startRaster;
var rasterWidth = endRaster - startRaster;
if(inRaster < rasterWidth && inRaster > 0) {
var continuous = nt * inRaster / rasterWidth;
var step = Math.floor(continuous);
//console.log("step", step, continuous, inRaster, x)
spikes.forEach(function(d) {
var circles = spikeCircles.filter(function(f) { return f == d})
if(continuous >= d.time) {
circles.style({
opacity: 1,
fill: function(d) { return color(d.index )}
})
} else {
circles.style({
opacity: 0.1,
fill: "#eee"
})
}
})
activations.forEach(function(d) {
var dots = activationDots.filter(function(f) { return f == d })
var lines = activeLines.filter(function(f) { return f == d })
// source x and y
var sx = xscale(d.sourceTime);
var sy = yscale(d.source);
// a number between 0 and 1, how far along we are in this time step, in continuous time
var fraction = (continuous - d.sourceTime) / (d.time - d.sourceTime);
if(fraction > 1) fraction = 1;
// target x and y
var tx = sx + (xscale(d.time) - sx) * fraction;
var ty = sy + (yscale(d.target) - sy) * fraction;
if(d.sourceTime < continuous && continuous < d.time) {
dots.attr({
cx: tx,
cy: ty
})
.style("opacity", 1)
lines.attr({
x2: tx,
y2: ty
})
} else {
dots.style("opacity", 0)
if(d.sourceTime < continuous) {
lines.attr({
x2: tx,
y2: ty
})
} else {
lines.attr({
x2: sx,
y2: sy
})
}
}
})
// update the widths of the links based on plasticity over time
neuronLinks.attr({
points: function(d) {
var p = plasticity[step][d.source.index][d.target.index];
d.headWidth = headWidth * p
d.linkWidth = linkWidth * p
return calculatePolygon(d)
}
})
// update the activation pulses on the force graph
allActivations.forEach(function(d) {
var rects = activationRects.filter(function(f) { return f == d })
var fraction = (continuous - d.sourceTime) / (d.targetTime - d.sourceTime);
if(fraction > 1) fraction = 1;
var p = plasticity[d.sourceTime][d.source][d.target];
var sx = nodes[d.source].x;
var sy = nodes[d.source].y;
var tx = sx + (nodes[d.target].x - sx) * fraction;
var ty = sy + (nodes[d.target].y - sy) * fraction;
if(d.sourceTime < continuous && continuous < d.targetTime) {
//console.log(d, sx, sy, tx, ty, nodes[d.target])
rects.attr({
x: tx,
y: ty,
width: headWidth * p,
height: headWidth * p
})
.style("opacity", 1)
} else {
rects.style("opacity", 0)
}
})
}
}
}
function collide(node) {
var r = node.radius + 16,
nx1 = node.x - r,
nx2 = node.x + r,
ny1 = node.y - r,
ny2 = node.y + r;
return function(quad, x1, y1, x2, y2) {
if (quad.point && (quad.point !== node)) {
var x = node.x - quad.point.x,
y = node.y - quad.point.y,
dist = Math.sqrt(x * x + y * y),
r = node.radius + quad.point.radius;
if (dist < r) {
dist = (dist - r) / dist * .5;
node.x -= x *= dist;
node.y -= y *= dist;
quad.point.x += x;
quad.point.y += y;
}
}
return x1 > nx2 || x2 < nx1 || y1 > ny2 || y2 < ny1;
};
}
// taken from http://blockbuilder.org/erkal/9746652
function calculatePolygon(d) {
var p2 = d.source,
w = diff(d.target, p2),
wl = length(w),
v1 = scale(w, (wl - d.target.radius) / wl),
p1 = sum(p2, v1),
v2 = scale(rotate90(w), d.linkWidth / length(w)),
p3 = sum(p2, v2),
v1l = length(v1),
v3 = scale(v1, (v1l - d.headLength) / v1l),
p4 = sum(p3, v3),
v2l = length(v2),
v4 = scale(v2, d.headWidth / v2l),
p5 = sum(p4, v4);
return pr(p1) +" "+ pr(p2) +" "+ pr(p3) +" "+ pr(p4) +" "+ pr(p5);
function length(v) {return Math.sqrt(v.x * v.x + v.y * v.y)}
function diff(v, w) {return {x: v.x - w.x, y: v.y - w.y}}
function sum(v, w) {return {x: v.x + w.x, y: v.y + w.y}}
function scale(v, f) {return {x: f * v.x, y: f * v.y}}
function rotate90(v) {return {x: v.y, y: -v.x}} // clockwise
function pr(v) {return v.x +","+ v.y}
}
})
</script>
{"polygroup": [{"8": [[4, 7], [2, 6], [3, 4]], "3": [[0, 2], [4, 2]]}, {"4": [[1, 2], [3, 3]]}, {"9": [[3, 6], [4, 5]], "6": [[3, 3], [4, 2]]}, {"4": [[0, 2], [1, 2]], "6": [[4, 5], [1, 4]]}, {"5": [[1, 4], [3, 4]], "7": [[3, 6], [2, 3]]}], "init": [2, 2, 3, 3, 2], "numTimeSteps": 9, "network": [[1, 3, 2, 4, 1], [1, 2, 0, 1, 0], [0, 0, 0, 3, 4], [2, 2, 4, 0, 1], [0, 1, 4, 1, 0]], "plasticity": {"0": [[0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5]], "9": [[0.5, 0.30000000000000004, 0.5, 0.5, 0.7], [0.4, 0.6, 0.4, 0.6, 0.4], [0.30000000000000004, 0.30000000000000004, 0.30000000000000004, 0.7, 0.7], [0.5, 0.7, 0.30000000000000004, 0.30000000000000004, 0.5], [0.30000000000000004, 0.5, 0.5, 0.7, 0.30000000000000004]], "6": [[0.6, 0.4, 0.4, 0.4, 0.6], [0.4, 0.6, 0.4, 0.6, 0.4], [0.4, 0.4, 0.4, 0.6, 0.6], [0.5, 0.7, 0.30000000000000004, 0.30000000000000004, 0.5], [0.4, 0.6, 0.4, 0.6, 0.4]]}}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment