Last active
July 11, 2018 13:01
-
-
Save chnsh/c4c8d3b513b01d95519a21e81b420355 to your computer and use it in GitHub Desktop.
Common Graph algorithms in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import deque | |
import heapq | |
import unittest | |
import warnings | |
class Graph: | |
UNDIRECTED = 0 | |
DIRECTED = 1 | |
class Vertex: | |
def __init__(self, element): | |
self.element = element | |
def __lt__(self, other): | |
return other is None or self.weight < other.weight | |
def __gt__(self, other): | |
return other is not None and self.weight > other.weight | |
def __eq__(self, other): | |
try: | |
return self.element == other.element | |
except AttributeError: | |
try: | |
# int | |
return self.element == other | |
except: | |
# NoneType | |
return False | |
def __hash__(self): | |
return hash(self.element) | |
def __repr__(self): | |
return str(self.element) | |
def __str__(self): | |
return self.__repr__() | |
def __init__(self, graph_type=UNDIRECTED): | |
self.graph_type = graph_type | |
self.adjacency_list = dict() | |
self._nodes = dict() | |
def _make_vertex(self, vertex): | |
if vertex not in self._nodes: | |
self._nodes[vertex] = self.Vertex(vertex) | |
return self._nodes[vertex] | |
def add_edge_to_graph(self, u, v, **args): | |
""" | |
Adds an edge in the graph | |
""" | |
u = self._make_vertex(u) | |
v = self._make_vertex(v) | |
if u not in self.adjacency_list: | |
self.adjacency_list[u] = dict() | |
if v not in self.adjacency_list: | |
self.adjacency_list[v] = dict() | |
self.adjacency_list[u][v] = args | |
if self.graph_type == Graph.UNDIRECTED: | |
self.adjacency_list[v][u] = args | |
class GraphTraversal: | |
def __init__(self, graph_type=Graph.UNDIRECTED): | |
self.graph = Graph(graph_type) | |
def _process_vertex(self, vertex): | |
print("Processed {}!".format(vertex)) | |
def _process_edge(self, u, v): | |
print("Found edge {} to {}".format(u, v)) | |
def _process_vertex_early(self, v): | |
print("Traversing {}".format(v)) | |
# def _vertex_already_visited(self, u, v): | |
# print("{} already visited, do not need to visit from {}".format(v, u)) | |
class BFSTraversal(GraphTraversal): | |
def traverse(self, source): | |
""" | |
""" | |
source = self.graph._make_vertex(source) if not isinstance(source, self.graph.Vertex) else source | |
if source.element not in self.graph.adjacency_list: | |
raise KeyError("Invalid source vertex, not in graph: {}".format(source)) | |
for vertex in self.graph.adjacency_list: | |
if vertex!=source: | |
vertex.color = "WHITE" | |
vertex.distance = float("inf") | |
vertex.parent = None | |
source.color = "GRAY" | |
source.distance = 0 | |
source.parent = None | |
q = deque() | |
q.append(source) | |
while len(q) > 0: | |
u = q.popleft() | |
for v in self.graph.adjacency_list[u].keys(): | |
if v.color != "BLACK": | |
self._process_edge(u, v) # If it is black, u must have been visited from v when v was processed | |
if v.color == "WHITE": | |
v.color = "GRAY" | |
v.distance = u.distance + 1 | |
v.parent = u | |
q.append(v) | |
self._process_vertex(u) | |
u.color = "BLACK" | |
def shortest_path(self, source, destination, path_list=None): | |
if not isinstance(source, self.graph.Vertex): | |
if source not in self.graph.adjacency_list: | |
raise Key("Source {} not in adjacency list".format(source)) | |
source = self.graph._make_vertex(source) | |
if not isinstance(destination, self.graph.Vertex): | |
if destination not in self.graph.adjacency_list: | |
raise KeyError("Destination {} not in adjacency list".format(destination)) | |
destination = self.graph._make_vertex(destination) | |
if source == destination: | |
return path_list | |
else: | |
if path_list is None: | |
path_list = [] | |
path_list.append(destination.element) | |
return self.shortest_path(source, destination.parent, path_list) | |
class BiPartiteDetectionTraversal(BFSTraversal): | |
def __init__(self): | |
super().__init__() | |
self.is_bipartite = True | |
def traverse(self, source): | |
source = self.graph._make_vertex(source) | |
source.label = "LEFT" | |
super().traverse(source) | |
def _process_edge(self, u, v): | |
if hasattr(v, 'label') and v.label == u.label: | |
print("Not a bipartite graph due to {} and {}".format(u, v)) | |
self.is_bipartite=False | |
else: | |
v.label = "RIGHT" if u.label == "LEFT" else "LEFT" | |
class ConnectedComponentsTraversal(BFSTraversal): | |
def components(self): | |
components = [] | |
for vertex in self.graph.adjacency_list: | |
if not hasattr(vertex, 'color') or vertex.color!='BLACK': | |
print("Found 1 component {}".format(vertex)) | |
self.traverse(vertex) | |
components.append(vertex) | |
return components | |
class DFSTraversal(GraphTraversal): | |
def __init__(self, graph_type=Graph.UNDIRECTED): | |
super().__init__(graph_type) | |
self.time = 0 | |
def traverse(self, source): | |
source = self.graph._make_vertex(source) if not isinstance(source, self.graph.Vertex) else source | |
if source not in self.graph.adjacency_list: | |
raise KeyError("Invalid source vertex, not in graph: {}".format(source)) | |
source.entry_time = self.time | |
self.time += 1 | |
source.discovered = True | |
source.parent = None if not hasattr(source,'parent') else source.parent | |
self._process_vertex_early(source) | |
for v in self.graph.adjacency_list[source].keys(): | |
if not hasattr(v, 'discovered') or not v.discovered: | |
v.parent = source | |
self._process_edge(source, v) | |
self.traverse(v) | |
elif not (hasattr(v, 'processed') and v.processed) and source.parent!=v or self.graph.graph_type == Graph.DIRECTED: | |
self._process_edge(source, v) | |
self._process_vertex(source) | |
source.exit_time = self.time | |
self.time += 1 | |
source.processed = True | |
def _classify_edge(self, x, y): | |
if x == y.parent: | |
return 'TREE' | |
elif hasattr(y, 'discovered') and y.discovered and x.parent != y and not (hasattr(y, 'processed') and y.processed): | |
return 'BACK' | |
elif hasattr(y, 'processed') and y.processed and y.entry_time > x.entry_time: | |
return 'FORWARD' | |
elif hasattr(y, 'processed') and y.processed and y.entry_time < x.entry_time: | |
return 'CROSS' | |
else: | |
warnings.warn("Unclassified edge {}, {}".format(x, y)) | |
def _process_edge(self, u, v): | |
print("Found {} edge {} to {}".format(self._classify_edge(u,v), u, v)) | |
class CycleDetectionTraversal(DFSTraversal): | |
def __init__(self): | |
super().__init__() | |
self.cycle_exists = False | |
def _process_edge(self, u, v): | |
if self._classify_edge(u,v) == 'BACK': | |
self.cycle_exists = True | |
class ArticulationVertexTraversal(DFSTraversal): | |
def _process_edge(self, u, v): | |
edge = self._classify_edge(u, v) | |
if edge == 'TREE': | |
u.out_degree+=1 | |
else: | |
if v.entry_time < u.oldest_ancestor.entry_time: | |
u.oldest_ancestor = v | |
def _process_vertex(self, v): | |
if v.parent is None: | |
""" | |
Root cut node, self-explainatory | |
""" | |
if v.out_degree > 1: | |
v.articulation_vertex = True | |
print("Found {} as root cut node!".format(v)) | |
return | |
if v.parent.parent is not None and v.oldest_ancestor == v.parent: | |
""" | |
If v's parent is the root and it has 2 children, it is a root cut node | |
If it has just 1 child then it isn't an articulation vertex if reachable vertex is the root node since it does not give rise to disconnections! | |
""" | |
print("Found {} as parent cut node!".format(v.parent)) | |
v.parent.articulation_vertex = True | |
if v.oldest_ancestor == v: | |
""" | |
If the only way to reach the vertex is through it's parent, then the parent is a cut-node, plus the node itself is a cutnode if it is not a leaf | |
""" | |
print("Found {} bridge cut node parent!".format(v.parent)) | |
v.parent.articulation_vertex = True | |
if v.out_degree > 0: | |
print("Found {} bridge cut node!".format(v)) | |
v.articulation_vertex = True | |
if v.oldest_ancestor.entry_time < v.parent.oldest_ancestor.entry_time: | |
""" | |
If the child can access an older ancestor, so can the parent! In fact, this is how we get to a parent cut node. Nothing in _process_edge can lead to it | |
""" | |
v.parent.oldest_ancestor = v.oldest_ancestor | |
def _process_vertex_early(self, v): | |
v.oldest_ancestor = v | |
v.out_degree = 0 | |
class TopologicalSortTraversal(DFSTraversal): | |
def __init__(self): | |
super().__init__(graph_type=Graph.DIRECTED) | |
self.stack = list() | |
def disconnected_traversal(self): | |
for v in self.graph.adjacency_list: | |
if not hasattr(v, 'discovered') or not v.discovered: | |
self.traverse(v) | |
return list(map(lambda v: v.element, reversed(self.stack))) # This is merely to run the test | |
def _process_edge(self, u, v): | |
super()._process_edge(u,v) | |
edge = self._classify_edge(u, v) | |
assert edge != 'BACK' | |
def _process_vertex(self, v): | |
self.stack.append(v) | |
class KahnTopologicalSort: | |
def __init__(self): | |
self.graph = Graph(graph_type=Graph.DIRECTED) | |
def topological_sort(self): | |
""" | |
The idea is that we pick a node with 0 in degree and add it to the topological order, then we reduce it's neighbors degree by 1 | |
""" | |
stack = list() | |
topological_order = list() | |
for vertex in self.graph.adjacency_list: | |
for neighbour in self.graph.adjacency_list[vertex].keys(): | |
if not hasattr(neighbour, 'in_count'): | |
neighbour.in_count = 0 | |
neighbour.in_count += 1 | |
for vertex in self.graph.adjacency_list: | |
if not (hasattr(vertex, 'in_count') and vertex.in_count): | |
stack.append(vertex) | |
while len(stack) > 0: | |
vertex = stack.pop() | |
neighbours = self.graph.adjacency_list[vertex].keys() | |
topological_order.append(vertex) | |
for neighbour in neighbours: | |
if hasattr(neighbour, 'in_count'): | |
neighbour.in_count -= 1 | |
if not neighbour.in_count: | |
stack.append(neighbour) | |
if len(topological_order)!= len(self.graph.adjacency_list): | |
print("Topological order not possible") | |
return None | |
return list(map(lambda v: v.element, topological_order)) # This is merely to run the test | |
class StronglyConnectedComponentsTraversal(DFSTraversal): | |
def __init__(self): | |
super().__init__(graph_type=Graph.DIRECTED) | |
self.stack = list() | |
self.scc_number = 0 | |
def strongly_connected_components(self): | |
for vertex in self.graph.adjacency_list: | |
if not (hasattr(vertex, 'discovered') and vertex.discovered): | |
self.traverse(vertex) | |
def _process_vertex_early(self, v): | |
super()._process_vertex_early(v) | |
v.oldest_ancestor = v | |
v.strongly_connected_component = None | |
self.stack.append(v) | |
def _process_edge(self, x, y): | |
super()._process_edge(x, y) | |
edge = self._classify_edge(x, y) | |
if edge == 'BACK' or edge == 'CROSS' and not y.strongly_connected_component: | |
if y.entry_time < x.oldest_ancestor.entry_time: | |
x.oldest_ancestor = y | |
def _process_vertex(self, v): | |
super()._process_vertex(v) | |
if v.oldest_ancestor == v: | |
self._pop_components(v) | |
else: | |
if v.oldest_ancestor.entry_time < v.parent.oldest_ancestor.entry_time: | |
v.parent.oldest_ancestor = v.oldest_ancestor | |
def _pop_components(self, v): | |
print("Component found for {}".format(v)) | |
self.scc_number+=1 | |
while True: | |
u = self.stack.pop() | |
u.strongly_connected_component = self.scc_number | |
if u == v: | |
break | |
class PrimMST(GraphTraversal): | |
def __init__(self): | |
super().__init__() | |
self.heap = list() | |
self.entries = dict() | |
self.total_weight = 0 | |
def traverse(self, v): | |
""" | |
CLRS based priority queue implementation | |
""" | |
v = self.graph._make_vertex(v) | |
for vertex in self.graph.adjacency_list: | |
vertex.parent = None | |
vertex.weight = float("inf") | |
vertex.in_heap = True | |
self._add_vertex(vertex) | |
v.weight = 0 | |
self._add_vertex(v) | |
while self.heap: | |
try: | |
min_vertex = self._pop_vertex() | |
except KeyError: | |
# Priority queue is empty now | |
return | |
min_vertex.in_heap = False | |
self.total_weight += min_vertex.weight | |
for neighbour, attrs in self.graph.adjacency_list[min_vertex].items(): | |
weight = attrs['weight'] | |
if neighbour.in_heap and attrs['weight'] < neighbour.weight: | |
neighbour.parent = min_vertex | |
neighbour.weight = attrs['weight'] | |
self._add_vertex(neighbour) | |
def _add_vertex(self, vertex): | |
if vertex in self.entries: | |
self._remove_entry(vertex) | |
entry = [vertex.weight, vertex] | |
self.entries[vertex] = entry | |
heapq.heappush(self.heap, entry) | |
def _remove_entry(self, vertex): | |
entry = self.entries.pop(vertex) | |
entry[1] = None | |
def _pop_vertex(self): | |
while self.heap: | |
weight, vertex = heapq.heappop(self.heap) | |
if vertex is not None: | |
del self.entries[vertex] | |
return vertex | |
raise KeyError("pop from an empty priority queue") | |
class GraphTest(unittest.TestCase): | |
def setUp(self): | |
self.bfs = BFSTraversal() | |
g = self.bfs.graph # Todo refactor tests to merely have graph here and not bfs | |
g.add_edge_to_graph(1, 5) | |
g.add_edge_to_graph(1, 2) | |
g.add_edge_to_graph(2, 5) | |
g.add_edge_to_graph(2, 4) | |
g.add_edge_to_graph(2, 3) | |
g.add_edge_to_graph(3, 4) | |
g.add_edge_to_graph(4, 5) | |
def test_shortest_path(self): | |
self.bfs.traverse(1) | |
self.assertEqual(self.bfs.shortest_path(1, 3), [3, 2]) | |
self.assertEqual(self.bfs.shortest_path(2, 3), [3]) | |
self.assertEqual(self.bfs.shortest_path(1, 4), [4, 5]) | |
def test_is_biPartite(self): | |
bipartite = BiPartiteDetectionTraversal() | |
bipartite.graph = self.bfs.graph | |
bipartite.traverse(1) | |
self.assertFalse(bipartite.is_bipartite) | |
bipartite = BiPartiteDetectionTraversal() | |
g = bipartite.graph | |
g.add_edge_to_graph(1, 5) | |
g.add_edge_to_graph(1, 2) | |
# g.add_edge_to_graph(2, 5) | |
g.add_edge_to_graph(2, 4) | |
# g.add_edge_to_graph(2, 3) | |
g.add_edge_to_graph(3, 4) | |
g.add_edge_to_graph(4, 5) | |
bipartite.traverse(1) | |
self.assertTrue(bipartite.is_bipartite) | |
def test_connected_components(self): | |
disconnected = ConnectedComponentsTraversal() | |
disconnected.graph = self.bfs.graph | |
self.assertEqual(len(disconnected.components()), 1) | |
def test_disconnected_components(self): | |
disconnected = ConnectedComponentsTraversal() | |
disconnected.graph = self.bfs.graph | |
disconnected.graph.add_edge_to_graph(7, 9) | |
self.assertEqual(2, len(disconnected.components())) | |
def test_dfs(self): | |
dfs = DFSTraversal() | |
dfs.graph = self.bfs.graph | |
dfs.traverse(1) | |
def test_articulation_vertex_traversal(self): | |
av = ArticulationVertexTraversal() | |
g = Graph() | |
g.add_edge_to_graph(1, 5) | |
g.add_edge_to_graph(1, 2) | |
g.add_edge_to_graph(2, 5) | |
g.add_edge_to_graph(2, 4) | |
g.add_edge_to_graph(2, 3) | |
g.add_edge_to_graph(4, 5) | |
av.graph = g | |
av.traverse(1) | |
self.assertTrue(g._make_vertex(2).articulation_vertex) | |
def test_articulation_vertex_traversal_2(self): | |
av = ArticulationVertexTraversal() | |
g = Graph() | |
g.add_edge_to_graph(1, 2) | |
av.graph = g | |
av.traverse(1) | |
self.assertTrue(g._make_vertex(1).articulation_vertex) | |
def test_cycle(self): | |
c = CycleDetectionTraversal() | |
c.graph = self.bfs.graph | |
c.traverse(1) | |
self.assertTrue(c.cycle_exists) | |
def test_topological_sort(self): | |
t = TopologicalSortTraversal() | |
g = t.graph | |
g.add_edge_to_graph(1, 2) | |
g.add_edge_to_graph(5, 1) | |
g.add_edge_to_graph(5, 0) | |
g.add_edge_to_graph(2, 3) | |
g.add_edge_to_graph(4, 0) | |
g.add_edge_to_graph(4, 3) | |
topological_traversal = t.disconnected_traversal() | |
self.assertTrue(topological_traversal.index(5) < topological_traversal.index(0)) | |
self.assertTrue(topological_traversal.index(4) < topological_traversal.index(0)) | |
self.assertTrue(topological_traversal.index(4) < topological_traversal.index(3)) | |
self.assertTrue(topological_traversal.index(2) < topological_traversal.index(3)) | |
def test_kahn_topological_sort(self): | |
k = KahnTopologicalSort() | |
g = k.graph | |
g.add_edge_to_graph(1, 2) | |
g.add_edge_to_graph(5, 1) | |
g.add_edge_to_graph(5, 0) | |
g.add_edge_to_graph(2, 3) | |
g.add_edge_to_graph(4, 0) | |
g.add_edge_to_graph(4, 3) | |
topological_traversal = k.topological_sort() | |
self.assertTrue(topological_traversal.index(5) < topological_traversal.index(0)) | |
self.assertTrue(topological_traversal.index(4) < topological_traversal.index(0)) | |
self.assertTrue(topological_traversal.index(4) < topological_traversal.index(3)) | |
self.assertTrue(topological_traversal.index(2) < topological_traversal.index(3)) | |
def test_kahn_topological_sort_on_non_dag(self): | |
k = KahnTopologicalSort() | |
g = k.graph | |
g.add_edge_to_graph(1, 2) | |
g.add_edge_to_graph(2, 3) | |
g.add_edge_to_graph(3, 1) | |
topological_traversal = k.topological_sort() | |
self.assertFalse(topological_traversal) | |
def test_dfs_topological_sort_on_non_dag(self): | |
t = TopologicalSortTraversal() | |
g = t.graph | |
g.add_edge_to_graph(1, 2) | |
g.add_edge_to_graph(2, 3) | |
g.add_edge_to_graph(3, 1) | |
self.assertRaises(AssertionError, t.disconnected_traversal) | |
def test_strongly_connected_components(self): | |
s = StronglyConnectedComponentsTraversal() | |
g = s.graph | |
g.add_edge_to_graph(1, 2) | |
g.add_edge_to_graph(2, 3) | |
g.add_edge_to_graph(3, 4) | |
g.add_edge_to_graph(4, 1) | |
g.add_edge_to_graph(1, 5) | |
g.add_edge_to_graph(5, 4) | |
s.strongly_connected_components() | |
self.assertEqual(1, s.scc_number) | |
s = StronglyConnectedComponentsTraversal() | |
g = s.graph | |
g.add_edge_to_graph(1, 2) | |
g.add_edge_to_graph(2, 3) | |
g.add_edge_to_graph(2, 4) | |
g.add_edge_to_graph(3, 1) | |
g.add_edge_to_graph(4, 1) | |
g.add_edge_to_graph(2, 5) | |
g.add_edge_to_graph(4, 6) | |
g.add_edge_to_graph(4, 8) | |
g.add_edge_to_graph(8, 6) | |
g.add_edge_to_graph(6, 7) | |
g.add_edge_to_graph(7, 5) | |
g.add_edge_to_graph(5, 6) | |
s.strongly_connected_components() | |
self.assertEqual(3, s.scc_number) | |
def test_prim_mst(self): | |
p = PrimMST() | |
g = p.graph | |
g.add_edge_to_graph(1, 2, weight=5) | |
g.add_edge_to_graph(2, 3, weight=7) | |
g.add_edge_to_graph(3, 4, weight=5) | |
g.add_edge_to_graph(4, 5, weight=2) | |
g.add_edge_to_graph(5, 3, weight=2) | |
g.add_edge_to_graph(3, 7, weight=4) | |
g.add_edge_to_graph(7, 5, weight=3) | |
g.add_edge_to_graph(5, 6, weight=7) | |
g.add_edge_to_graph(6, 7, weight=4) | |
g.add_edge_to_graph(6, 1, weight=12) | |
g.add_edge_to_graph(1, 7, weight=7) | |
g.add_edge_to_graph(2, 7, weight=9) | |
p.traverse(1) | |
self.assertEqual(23, p.total_weight) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment