Skip to content

Instantly share code, notes, and snippets.

@chnsh
Last active July 11, 2018 13:01
Show Gist options
  • Save chnsh/c4c8d3b513b01d95519a21e81b420355 to your computer and use it in GitHub Desktop.
Save chnsh/c4c8d3b513b01d95519a21e81b420355 to your computer and use it in GitHub Desktop.
Common Graph algorithms in Python
from collections import deque
import heapq
import unittest
import warnings
class Graph:
UNDIRECTED = 0
DIRECTED = 1
class Vertex:
def __init__(self, element):
self.element = element
def __lt__(self, other):
return other is None or self.weight < other.weight
def __gt__(self, other):
return other is not None and self.weight > other.weight
def __eq__(self, other):
try:
return self.element == other.element
except AttributeError:
try:
# int
return self.element == other
except:
# NoneType
return False
def __hash__(self):
return hash(self.element)
def __repr__(self):
return str(self.element)
def __str__(self):
return self.__repr__()
def __init__(self, graph_type=UNDIRECTED):
self.graph_type = graph_type
self.adjacency_list = dict()
self._nodes = dict()
def _make_vertex(self, vertex):
if vertex not in self._nodes:
self._nodes[vertex] = self.Vertex(vertex)
return self._nodes[vertex]
def add_edge_to_graph(self, u, v, **args):
"""
Adds an edge in the graph
"""
u = self._make_vertex(u)
v = self._make_vertex(v)
if u not in self.adjacency_list:
self.adjacency_list[u] = dict()
if v not in self.adjacency_list:
self.adjacency_list[v] = dict()
self.adjacency_list[u][v] = args
if self.graph_type == Graph.UNDIRECTED:
self.adjacency_list[v][u] = args
class GraphTraversal:
def __init__(self, graph_type=Graph.UNDIRECTED):
self.graph = Graph(graph_type)
def _process_vertex(self, vertex):
print("Processed {}!".format(vertex))
def _process_edge(self, u, v):
print("Found edge {} to {}".format(u, v))
def _process_vertex_early(self, v):
print("Traversing {}".format(v))
# def _vertex_already_visited(self, u, v):
# print("{} already visited, do not need to visit from {}".format(v, u))
class BFSTraversal(GraphTraversal):
def traverse(self, source):
"""
"""
source = self.graph._make_vertex(source) if not isinstance(source, self.graph.Vertex) else source
if source.element not in self.graph.adjacency_list:
raise KeyError("Invalid source vertex, not in graph: {}".format(source))
for vertex in self.graph.adjacency_list:
if vertex!=source:
vertex.color = "WHITE"
vertex.distance = float("inf")
vertex.parent = None
source.color = "GRAY"
source.distance = 0
source.parent = None
q = deque()
q.append(source)
while len(q) > 0:
u = q.popleft()
for v in self.graph.adjacency_list[u].keys():
if v.color != "BLACK":
self._process_edge(u, v) # If it is black, u must have been visited from v when v was processed
if v.color == "WHITE":
v.color = "GRAY"
v.distance = u.distance + 1
v.parent = u
q.append(v)
self._process_vertex(u)
u.color = "BLACK"
def shortest_path(self, source, destination, path_list=None):
if not isinstance(source, self.graph.Vertex):
if source not in self.graph.adjacency_list:
raise Key("Source {} not in adjacency list".format(source))
source = self.graph._make_vertex(source)
if not isinstance(destination, self.graph.Vertex):
if destination not in self.graph.adjacency_list:
raise KeyError("Destination {} not in adjacency list".format(destination))
destination = self.graph._make_vertex(destination)
if source == destination:
return path_list
else:
if path_list is None:
path_list = []
path_list.append(destination.element)
return self.shortest_path(source, destination.parent, path_list)
class BiPartiteDetectionTraversal(BFSTraversal):
def __init__(self):
super().__init__()
self.is_bipartite = True
def traverse(self, source):
source = self.graph._make_vertex(source)
source.label = "LEFT"
super().traverse(source)
def _process_edge(self, u, v):
if hasattr(v, 'label') and v.label == u.label:
print("Not a bipartite graph due to {} and {}".format(u, v))
self.is_bipartite=False
else:
v.label = "RIGHT" if u.label == "LEFT" else "LEFT"
class ConnectedComponentsTraversal(BFSTraversal):
def components(self):
components = []
for vertex in self.graph.adjacency_list:
if not hasattr(vertex, 'color') or vertex.color!='BLACK':
print("Found 1 component {}".format(vertex))
self.traverse(vertex)
components.append(vertex)
return components
class DFSTraversal(GraphTraversal):
def __init__(self, graph_type=Graph.UNDIRECTED):
super().__init__(graph_type)
self.time = 0
def traverse(self, source):
source = self.graph._make_vertex(source) if not isinstance(source, self.graph.Vertex) else source
if source not in self.graph.adjacency_list:
raise KeyError("Invalid source vertex, not in graph: {}".format(source))
source.entry_time = self.time
self.time += 1
source.discovered = True
source.parent = None if not hasattr(source,'parent') else source.parent
self._process_vertex_early(source)
for v in self.graph.adjacency_list[source].keys():
if not hasattr(v, 'discovered') or not v.discovered:
v.parent = source
self._process_edge(source, v)
self.traverse(v)
elif not (hasattr(v, 'processed') and v.processed) and source.parent!=v or self.graph.graph_type == Graph.DIRECTED:
self._process_edge(source, v)
self._process_vertex(source)
source.exit_time = self.time
self.time += 1
source.processed = True
def _classify_edge(self, x, y):
if x == y.parent:
return 'TREE'
elif hasattr(y, 'discovered') and y.discovered and x.parent != y and not (hasattr(y, 'processed') and y.processed):
return 'BACK'
elif hasattr(y, 'processed') and y.processed and y.entry_time > x.entry_time:
return 'FORWARD'
elif hasattr(y, 'processed') and y.processed and y.entry_time < x.entry_time:
return 'CROSS'
else:
warnings.warn("Unclassified edge {}, {}".format(x, y))
def _process_edge(self, u, v):
print("Found {} edge {} to {}".format(self._classify_edge(u,v), u, v))
class CycleDetectionTraversal(DFSTraversal):
def __init__(self):
super().__init__()
self.cycle_exists = False
def _process_edge(self, u, v):
if self._classify_edge(u,v) == 'BACK':
self.cycle_exists = True
class ArticulationVertexTraversal(DFSTraversal):
def _process_edge(self, u, v):
edge = self._classify_edge(u, v)
if edge == 'TREE':
u.out_degree+=1
else:
if v.entry_time < u.oldest_ancestor.entry_time:
u.oldest_ancestor = v
def _process_vertex(self, v):
if v.parent is None:
"""
Root cut node, self-explainatory
"""
if v.out_degree > 1:
v.articulation_vertex = True
print("Found {} as root cut node!".format(v))
return
if v.parent.parent is not None and v.oldest_ancestor == v.parent:
"""
If v's parent is the root and it has 2 children, it is a root cut node
If it has just 1 child then it isn't an articulation vertex if reachable vertex is the root node since it does not give rise to disconnections!
"""
print("Found {} as parent cut node!".format(v.parent))
v.parent.articulation_vertex = True
if v.oldest_ancestor == v:
"""
If the only way to reach the vertex is through it's parent, then the parent is a cut-node, plus the node itself is a cutnode if it is not a leaf
"""
print("Found {} bridge cut node parent!".format(v.parent))
v.parent.articulation_vertex = True
if v.out_degree > 0:
print("Found {} bridge cut node!".format(v))
v.articulation_vertex = True
if v.oldest_ancestor.entry_time < v.parent.oldest_ancestor.entry_time:
"""
If the child can access an older ancestor, so can the parent! In fact, this is how we get to a parent cut node. Nothing in _process_edge can lead to it
"""
v.parent.oldest_ancestor = v.oldest_ancestor
def _process_vertex_early(self, v):
v.oldest_ancestor = v
v.out_degree = 0
class TopologicalSortTraversal(DFSTraversal):
def __init__(self):
super().__init__(graph_type=Graph.DIRECTED)
self.stack = list()
def disconnected_traversal(self):
for v in self.graph.adjacency_list:
if not hasattr(v, 'discovered') or not v.discovered:
self.traverse(v)
return list(map(lambda v: v.element, reversed(self.stack))) # This is merely to run the test
def _process_edge(self, u, v):
super()._process_edge(u,v)
edge = self._classify_edge(u, v)
assert edge != 'BACK'
def _process_vertex(self, v):
self.stack.append(v)
class KahnTopologicalSort:
def __init__(self):
self.graph = Graph(graph_type=Graph.DIRECTED)
def topological_sort(self):
"""
The idea is that we pick a node with 0 in degree and add it to the topological order, then we reduce it's neighbors degree by 1
"""
stack = list()
topological_order = list()
for vertex in self.graph.adjacency_list:
for neighbour in self.graph.adjacency_list[vertex].keys():
if not hasattr(neighbour, 'in_count'):
neighbour.in_count = 0
neighbour.in_count += 1
for vertex in self.graph.adjacency_list:
if not (hasattr(vertex, 'in_count') and vertex.in_count):
stack.append(vertex)
while len(stack) > 0:
vertex = stack.pop()
neighbours = self.graph.adjacency_list[vertex].keys()
topological_order.append(vertex)
for neighbour in neighbours:
if hasattr(neighbour, 'in_count'):
neighbour.in_count -= 1
if not neighbour.in_count:
stack.append(neighbour)
if len(topological_order)!= len(self.graph.adjacency_list):
print("Topological order not possible")
return None
return list(map(lambda v: v.element, topological_order)) # This is merely to run the test
class StronglyConnectedComponentsTraversal(DFSTraversal):
def __init__(self):
super().__init__(graph_type=Graph.DIRECTED)
self.stack = list()
self.scc_number = 0
def strongly_connected_components(self):
for vertex in self.graph.adjacency_list:
if not (hasattr(vertex, 'discovered') and vertex.discovered):
self.traverse(vertex)
def _process_vertex_early(self, v):
super()._process_vertex_early(v)
v.oldest_ancestor = v
v.strongly_connected_component = None
self.stack.append(v)
def _process_edge(self, x, y):
super()._process_edge(x, y)
edge = self._classify_edge(x, y)
if edge == 'BACK' or edge == 'CROSS' and not y.strongly_connected_component:
if y.entry_time < x.oldest_ancestor.entry_time:
x.oldest_ancestor = y
def _process_vertex(self, v):
super()._process_vertex(v)
if v.oldest_ancestor == v:
self._pop_components(v)
else:
if v.oldest_ancestor.entry_time < v.parent.oldest_ancestor.entry_time:
v.parent.oldest_ancestor = v.oldest_ancestor
def _pop_components(self, v):
print("Component found for {}".format(v))
self.scc_number+=1
while True:
u = self.stack.pop()
u.strongly_connected_component = self.scc_number
if u == v:
break
class PrimMST(GraphTraversal):
def __init__(self):
super().__init__()
self.heap = list()
self.entries = dict()
self.total_weight = 0
def traverse(self, v):
"""
CLRS based priority queue implementation
"""
v = self.graph._make_vertex(v)
for vertex in self.graph.adjacency_list:
vertex.parent = None
vertex.weight = float("inf")
vertex.in_heap = True
self._add_vertex(vertex)
v.weight = 0
self._add_vertex(v)
while self.heap:
try:
min_vertex = self._pop_vertex()
except KeyError:
# Priority queue is empty now
return
min_vertex.in_heap = False
self.total_weight += min_vertex.weight
for neighbour, attrs in self.graph.adjacency_list[min_vertex].items():
weight = attrs['weight']
if neighbour.in_heap and attrs['weight'] < neighbour.weight:
neighbour.parent = min_vertex
neighbour.weight = attrs['weight']
self._add_vertex(neighbour)
def _add_vertex(self, vertex):
if vertex in self.entries:
self._remove_entry(vertex)
entry = [vertex.weight, vertex]
self.entries[vertex] = entry
heapq.heappush(self.heap, entry)
def _remove_entry(self, vertex):
entry = self.entries.pop(vertex)
entry[1] = None
def _pop_vertex(self):
while self.heap:
weight, vertex = heapq.heappop(self.heap)
if vertex is not None:
del self.entries[vertex]
return vertex
raise KeyError("pop from an empty priority queue")
class GraphTest(unittest.TestCase):
def setUp(self):
self.bfs = BFSTraversal()
g = self.bfs.graph # Todo refactor tests to merely have graph here and not bfs
g.add_edge_to_graph(1, 5)
g.add_edge_to_graph(1, 2)
g.add_edge_to_graph(2, 5)
g.add_edge_to_graph(2, 4)
g.add_edge_to_graph(2, 3)
g.add_edge_to_graph(3, 4)
g.add_edge_to_graph(4, 5)
def test_shortest_path(self):
self.bfs.traverse(1)
self.assertEqual(self.bfs.shortest_path(1, 3), [3, 2])
self.assertEqual(self.bfs.shortest_path(2, 3), [3])
self.assertEqual(self.bfs.shortest_path(1, 4), [4, 5])
def test_is_biPartite(self):
bipartite = BiPartiteDetectionTraversal()
bipartite.graph = self.bfs.graph
bipartite.traverse(1)
self.assertFalse(bipartite.is_bipartite)
bipartite = BiPartiteDetectionTraversal()
g = bipartite.graph
g.add_edge_to_graph(1, 5)
g.add_edge_to_graph(1, 2)
# g.add_edge_to_graph(2, 5)
g.add_edge_to_graph(2, 4)
# g.add_edge_to_graph(2, 3)
g.add_edge_to_graph(3, 4)
g.add_edge_to_graph(4, 5)
bipartite.traverse(1)
self.assertTrue(bipartite.is_bipartite)
def test_connected_components(self):
disconnected = ConnectedComponentsTraversal()
disconnected.graph = self.bfs.graph
self.assertEqual(len(disconnected.components()), 1)
def test_disconnected_components(self):
disconnected = ConnectedComponentsTraversal()
disconnected.graph = self.bfs.graph
disconnected.graph.add_edge_to_graph(7, 9)
self.assertEqual(2, len(disconnected.components()))
def test_dfs(self):
dfs = DFSTraversal()
dfs.graph = self.bfs.graph
dfs.traverse(1)
def test_articulation_vertex_traversal(self):
av = ArticulationVertexTraversal()
g = Graph()
g.add_edge_to_graph(1, 5)
g.add_edge_to_graph(1, 2)
g.add_edge_to_graph(2, 5)
g.add_edge_to_graph(2, 4)
g.add_edge_to_graph(2, 3)
g.add_edge_to_graph(4, 5)
av.graph = g
av.traverse(1)
self.assertTrue(g._make_vertex(2).articulation_vertex)
def test_articulation_vertex_traversal_2(self):
av = ArticulationVertexTraversal()
g = Graph()
g.add_edge_to_graph(1, 2)
av.graph = g
av.traverse(1)
self.assertTrue(g._make_vertex(1).articulation_vertex)
def test_cycle(self):
c = CycleDetectionTraversal()
c.graph = self.bfs.graph
c.traverse(1)
self.assertTrue(c.cycle_exists)
def test_topological_sort(self):
t = TopologicalSortTraversal()
g = t.graph
g.add_edge_to_graph(1, 2)
g.add_edge_to_graph(5, 1)
g.add_edge_to_graph(5, 0)
g.add_edge_to_graph(2, 3)
g.add_edge_to_graph(4, 0)
g.add_edge_to_graph(4, 3)
topological_traversal = t.disconnected_traversal()
self.assertTrue(topological_traversal.index(5) < topological_traversal.index(0))
self.assertTrue(topological_traversal.index(4) < topological_traversal.index(0))
self.assertTrue(topological_traversal.index(4) < topological_traversal.index(3))
self.assertTrue(topological_traversal.index(2) < topological_traversal.index(3))
def test_kahn_topological_sort(self):
k = KahnTopologicalSort()
g = k.graph
g.add_edge_to_graph(1, 2)
g.add_edge_to_graph(5, 1)
g.add_edge_to_graph(5, 0)
g.add_edge_to_graph(2, 3)
g.add_edge_to_graph(4, 0)
g.add_edge_to_graph(4, 3)
topological_traversal = k.topological_sort()
self.assertTrue(topological_traversal.index(5) < topological_traversal.index(0))
self.assertTrue(topological_traversal.index(4) < topological_traversal.index(0))
self.assertTrue(topological_traversal.index(4) < topological_traversal.index(3))
self.assertTrue(topological_traversal.index(2) < topological_traversal.index(3))
def test_kahn_topological_sort_on_non_dag(self):
k = KahnTopologicalSort()
g = k.graph
g.add_edge_to_graph(1, 2)
g.add_edge_to_graph(2, 3)
g.add_edge_to_graph(3, 1)
topological_traversal = k.topological_sort()
self.assertFalse(topological_traversal)
def test_dfs_topological_sort_on_non_dag(self):
t = TopologicalSortTraversal()
g = t.graph
g.add_edge_to_graph(1, 2)
g.add_edge_to_graph(2, 3)
g.add_edge_to_graph(3, 1)
self.assertRaises(AssertionError, t.disconnected_traversal)
def test_strongly_connected_components(self):
s = StronglyConnectedComponentsTraversal()
g = s.graph
g.add_edge_to_graph(1, 2)
g.add_edge_to_graph(2, 3)
g.add_edge_to_graph(3, 4)
g.add_edge_to_graph(4, 1)
g.add_edge_to_graph(1, 5)
g.add_edge_to_graph(5, 4)
s.strongly_connected_components()
self.assertEqual(1, s.scc_number)
s = StronglyConnectedComponentsTraversal()
g = s.graph
g.add_edge_to_graph(1, 2)
g.add_edge_to_graph(2, 3)
g.add_edge_to_graph(2, 4)
g.add_edge_to_graph(3, 1)
g.add_edge_to_graph(4, 1)
g.add_edge_to_graph(2, 5)
g.add_edge_to_graph(4, 6)
g.add_edge_to_graph(4, 8)
g.add_edge_to_graph(8, 6)
g.add_edge_to_graph(6, 7)
g.add_edge_to_graph(7, 5)
g.add_edge_to_graph(5, 6)
s.strongly_connected_components()
self.assertEqual(3, s.scc_number)
def test_prim_mst(self):
p = PrimMST()
g = p.graph
g.add_edge_to_graph(1, 2, weight=5)
g.add_edge_to_graph(2, 3, weight=7)
g.add_edge_to_graph(3, 4, weight=5)
g.add_edge_to_graph(4, 5, weight=2)
g.add_edge_to_graph(5, 3, weight=2)
g.add_edge_to_graph(3, 7, weight=4)
g.add_edge_to_graph(7, 5, weight=3)
g.add_edge_to_graph(5, 6, weight=7)
g.add_edge_to_graph(6, 7, weight=4)
g.add_edge_to_graph(6, 1, weight=12)
g.add_edge_to_graph(1, 7, weight=7)
g.add_edge_to_graph(2, 7, weight=9)
p.traverse(1)
self.assertEqual(23, p.total_weight)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment