Skip to content

Instantly share code, notes, and snippets.

@TWiStErRob
Last active December 31, 2015 01:36
Show Gist options
  • Save TWiStErRob/b1a97e992b2eabbaffa1 to your computer and use it in GitHub Desktop.
Save TWiStErRob/b1a97e992b2eabbaffa1 to your computer and use it in GitHub Desktop.
Programmatic version of Tushar Roy's Held-Karp TSP video (https://www.youtube.com/watch?v=-JjA4BLQyqE)
package problems.tushar;
import java.util.*;
import java.util.function.Predicate;
import java.util.stream.*;
import static problems.tushar.TravelingSalesmanHeldKarpX.Algo.*;
import static problems.tushar.TravelingSalesmanHeldKarpX.Utils.*;
/**
* @author Tushar Roy (11/17/2015) original
* @author Róbert (TWiStErRob) Papp (12/28/2015) generic Graph
* @author Róbert (TWiStErRob) Papp (12/31/2015) verbose logging
*
* Help Karp method of finding tour of traveling salesman.
*
* Time complexity - O(2^n * n^2)
* Space complexity - O(2^n)
*
* https://en.wikipedia.org/wiki/Held%E2%80%93Karp_algorithm
* https://github.com/mission-peace/interview/blob/master/src/com/interview/graph/TravelingSalesmanHeldKarp.java
*/
public class TravelingSalesmanHeldKarpX {
public interface Graph<T> extends Iterable<T> {
double distance(T from, T to);
T startVertex();
default Stream<T> stream() {
return StreamSupport.stream(spliterator(), false);
}
}
public static class Algo<T extends Comparable<T>> {
public static final double INF = Double.POSITIVE_INFINITY;
private Graph<T> graph;
private Map<Index, Double> minCosts;
private Map<Index, T> parents;
private T start;
public void run(Graph<T> graph) {
this.graph = graph;
// TreeMap O(logn) is more convenient for debugging, use HashMap O(1) to reduce time complexity
this.minCosts = new TreeMap<>();
this.parents = new TreeMap<>();
this.start = graph.startVertex();
Set<T> otherVertices = Collections.unmodifiableSet(graph
.stream()
.filter(Predicate.isEqual(start).negate())
.collect(Collectors.toSet())
);
List<Set<T>> allSets = generateSubSets(otherVertices);
//Collections.sort(allSets, new SizeComparator()); // good enough for algorithm
Collections.sort(allSets, new LexiSizeComparator<>()); // better for debugging, but really slow
System.out.println(
allSets.stream().map(Utils::setToString).collect(Collectors.joining(", ", "Subsets: ", "")));
for (Set<T> set : allSets) {
System.out.printf("Checking subset: %s%n", setToString(set));
for (T current : otherVertices) {
if (set.contains(current)) {
System.out.printf("%s = (%s) = skip, because we already visited %s if we came through %s%n",
indexToString(current, set), pathToString(start, set, current), current, set);
continue;
}
findMin(current, set);
}
System.out.println();
}
System.out.println("Gathering solution:");
Index solution = new Index(start, otherVertices);
findMin(solution);
System.out.println("All intermediate costs:");
minCosts.forEach((i, c) -> System.out.printf("\t%s = %.0f (%s)%n", i, c, indexPathToString(start, i)));
String tourString = getTour()
.stream()
.map(Object::toString)
.collect(Collectors.joining("->"));
System.out.printf("Min cost of TSP tour %s is %.0f%n", tourString, minCosts.get(solution));
}
private void findMin(T current, Set<T> set) {
findMin(new Index(current, set));
}
private void findMin(Index index) {
System.out.printf("%s = (%s) = min { // starting from %s to reach %s via %s%n",
index, indexPathToString(start, index), start, index.target, setToString(index.via));
double minCost;
T minPrevVertex = start;
if (index.via.isEmpty()) {
minCost = graph.distance(start, index.target);
System.out.printf("\tdist(%s,%s) = %.0f%n", start, index.target, minCost);
} else {
minCost = INF;
Set<T> prevSet = new HashSet<>(index.via); // copy to prevent exception, and reuse in all iterations
for (T prevVertex : index.via) {
// try to finish the path to index.target with prevVertex
double cost = getCost(index, prevSet, prevVertex);
if (cost < minCost) {
minCost = cost;
minPrevVertex = prevVertex;
}
}
}
if (minCost != INF) {
System.out.printf("} = %.0f via %s // reaching %s from %s | finishing with %s->%s%n",
minCost, minPrevVertex, index.target, minPrevVertex, minPrevVertex, index.target);
} else {
System.out.printf("} = no path%n");
}
minCosts.put(index, minCost);
parents.put(index, minPrevVertex);
}
private double getCost(Index index, Set<T> prevSet, T prevVertex) {
try {
prevSet.remove(prevVertex);
double c = minCosts.get(new Index(prevVertex, prevSet)); // start->{index.via}->prevVertex
double d = graph.distance(prevVertex, index.target); // prevVertex->index.target
System.out.printf("\tcost(%s) + cost(%s->%s) = cost(%s) + dist(%s,%s) = %.0f + %.0f = %.0f%n",
pathToString(start, prevSet, prevVertex), prevVertex, index.target,
indexToString(prevVertex, prevSet), prevVertex, index.target,
c, d, c + d);
return c + d; // 0->{index.via}->prevVertex + prevVertex->index.target
} finally {
prevSet.add(prevVertex);
}
}
public List<T> getTour() {
Set<T> set = new HashSet<>();
for (T vertex : graph) {
set.add(vertex);
}
T vertex = graph.startVertex();
System.out.printf("%nFinishing the tour at %s.%n", vertex);
Deque<T> stack = new ArrayDeque<>();
while (vertex != null) {
stack.push(vertex);
System.out.printf("How did we reach %s?%n", vertex);
set.remove(vertex);
Index parentKey = new Index(vertex, set);
T parent = parents.get(parentKey);
if (parent != null) {
System.out.printf("\tWe came through vertices: %s%n", setToString(set));
System.out.printf("\tWe finished at %s: step %s->%s costs %.0f.%n",
parent, parent, vertex, graph.distance(parent, vertex));
System.out.printf("\tMinimal cost for all paths through these to reach %s: %.0f.%n",
vertex, minCosts.get(parentKey));
} else {
System.out.printf("\tWe started there.%n");
}
vertex = parent;
}
List<T> result = new ArrayList<>(stack);
System.out.printf("Read the steps backwards:%n");
double sum = 0;
for (int i = 1; i < result.size(); i++) {
T prev = result.get(i - 1);
T curr = result.get(i);
double d = graph.distance(prev, curr);
sum += d;
System.out.printf("\t%s->%s (costs %.0f), totaling %.0f%n", prev, curr, d, sum);
}
return result;
}
private class Index implements Comparable<Index> {
T target;
Set<T> via;
Index(T vertex, Set<T> via) {
this.target = vertex;
this.via = via;
}
@Override public int compareTo(Index o) {
int sizeOrder = new LexiSizeComparator<T>().compare(via, o.via);
if (sizeOrder != 0) {
return sizeOrder;
}
int targetOrder = this.target.compareTo(o.target);
if (targetOrder != 0) {
return targetOrder;
}
return 0;
}
@Override public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
@SuppressWarnings("unchecked") Index index = (Index)o;
if (target != index.target) {
return false;
}
if (via == null) {
return index.via == null;
}
return via.equals(index.via);
}
@Override public int hashCode() {
int result = target.hashCode();
result = 31 * result + (via != null? via.hashCode() : 0);
return result;
}
@Override public String toString() {
return indexToString(target, via);
}
}
private String indexToString(T vertex, Set<T> set) {
return String.format("[%s,%s]", vertex, setToString(set));
}
private String indexPathToString(T from, Index index) {
return pathToString(from, index.via, index.target);
}
}
public static void main(String... args) {
// From video, solution: 0->1->3->2->0 costing 21
new Algo<Character>().run(new CharacterGraph(new double[][] {
{0, 1, 15, 6},
{2, 0, 7, 3},
{9, 6, 0, 12},
{10, 4, 8, 0},
}, '0'));
// From comments, solution: A->C->E->D->B->A costing 21
new Algo<Character>().run(new CharacterGraph(new double[][] {
{0, 3, 3, 1, INF},
{3, 0, 8, 5, INF},
{3, 8, 0, 1, 6},
{1, 5, 1, 0, 4},
{INF, INF, 6, 4, 0},
}, 'A'));
}
private static class CharacterGraph implements Graph<Character> {
private final double[][] distance;
private final char base;
CharacterGraph(double[][] distance, char base) {
this.distance = distance;
this.base = base;
}
@Override public Character startVertex() {
return base;
}
@Override public double distance(Character from, Character to) {
return distance[getIndex(from)][getIndex(to)];
}
@Override public Iterator<Character> iterator() {
return IntStream.range(0, getVertexCount()).mapToObj(this::getVertex).iterator();
}
private int getVertexCount() {
return distance.length;
}
private char getVertex(int index) {
return (char)(base + index);
}
private int getIndex(Character vertex) {
return vertex - base;
}
}
static class Utils {
@SuppressWarnings({"unchecked", "rawtypes"})
static <T extends Comparable<T>> List<Set<T>> generateSubSets(Collection<T> items) {
T[] input = items.toArray((T[])new Comparable[items.size()]);
List<Set<T>> subSets = new ArrayList<>();
T[] result = (T[])new Comparable[input.length];
generateSubSets(input, 0, 0, subSets, result);
return subSets;
}
static <T> void generateSubSets(T[] input, int start, int pos, List<Set<T>> allSets, T[] result) {
if (pos == input.length) {
return; // don't include the set: {input}
}
Set<T> set = new TreeSet<>(Arrays.asList(result).subList(0, pos));
allSets.add(set);
for (int i = start; i < input.length; i++) {
result[pos] = input[i];
generateSubSets(input, i + 1, pos + 1, allSets, result);
}
}
static <T> String setToString(Set<T> set) {
StringBuilder builder = new StringBuilder();
if (set.isEmpty()) {
builder.append('∅');
} else {
builder.append('{');
for (T t : set) {
builder.append(t).append(',');
}
builder.setCharAt(builder.length() - 1, '}');
}
return builder.toString();
}
static <T> String pathToString(T from, Set<T> via, T to) {
return String.format("%s->%s->%s", from, setToString(via), to);
}
}
private static class SizeComparator implements Comparator<Collection<?>> {
@Override public int compare(Collection<?> o1, Collection<?> o2) {
return Integer.compare(o1.size(), o2.size());
}
}
private static class LexiSizeComparator<T extends Comparable<T>> implements Comparator<Collection<T>> {
private final Comparator<Iterable<T>> lexiComparator = new LexiComparator<T>();
@Override public int compare(Collection<T> o1, Collection<T> o2) {
int sizeOrder = Integer.compare(o1.size(), o2.size());
if (sizeOrder != 0) {
return sizeOrder;
}
return lexiComparator.compare(o1, o2);
}
}
private static class LexiComparator<T extends Comparable<T>> implements Comparator<Iterable<T>> {
private final Comparator<T> nullSafe = Comparator.nullsLast(Comparator.naturalOrder());
@Override public int compare(Iterable<T> o1, Iterable<T> o2) {
Iterator<T> it1 = o1.iterator();
Iterator<T> it2 = o2.iterator();
while (it1.hasNext() && it2.hasNext()) {
int result = nullSafe.compare(it1.next(), it2.next());
if (result != 0) {
return result;
}
}
// items in both iterators were equal up to this point,
// but there could be more in one of them:
if (it1.hasNext()) {
return +1; // it1 has more elements -> o1 > o2
}
if (it2.hasNext()) {
return -1; // it2 has more elements -> o1 < o2
}
return 0; // o1 and o2 has the same length and same elements
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment