Last active
December 31, 2015 01:36
-
-
Save TWiStErRob/b1a97e992b2eabbaffa1 to your computer and use it in GitHub Desktop.
Programmatic version of Tushar Roy's Held-Karp TSP video (https://www.youtube.com/watch?v=-JjA4BLQyqE)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package problems.tushar; | |
import java.util.*; | |
import java.util.function.Predicate; | |
import java.util.stream.*; | |
import static problems.tushar.TravelingSalesmanHeldKarpX.Algo.*; | |
import static problems.tushar.TravelingSalesmanHeldKarpX.Utils.*; | |
/** | |
* @author Tushar Roy (11/17/2015) original | |
* @author Róbert (TWiStErRob) Papp (12/28/2015) generic Graph | |
* @author Róbert (TWiStErRob) Papp (12/31/2015) verbose logging | |
* | |
* Help Karp method of finding tour of traveling salesman. | |
* | |
* Time complexity - O(2^n * n^2) | |
* Space complexity - O(2^n) | |
* | |
* https://en.wikipedia.org/wiki/Held%E2%80%93Karp_algorithm | |
* https://github.com/mission-peace/interview/blob/master/src/com/interview/graph/TravelingSalesmanHeldKarp.java | |
*/ | |
public class TravelingSalesmanHeldKarpX { | |
public interface Graph<T> extends Iterable<T> { | |
double distance(T from, T to); | |
T startVertex(); | |
default Stream<T> stream() { | |
return StreamSupport.stream(spliterator(), false); | |
} | |
} | |
public static class Algo<T extends Comparable<T>> { | |
public static final double INF = Double.POSITIVE_INFINITY; | |
private Graph<T> graph; | |
private Map<Index, Double> minCosts; | |
private Map<Index, T> parents; | |
private T start; | |
public void run(Graph<T> graph) { | |
this.graph = graph; | |
// TreeMap O(logn) is more convenient for debugging, use HashMap O(1) to reduce time complexity | |
this.minCosts = new TreeMap<>(); | |
this.parents = new TreeMap<>(); | |
this.start = graph.startVertex(); | |
Set<T> otherVertices = Collections.unmodifiableSet(graph | |
.stream() | |
.filter(Predicate.isEqual(start).negate()) | |
.collect(Collectors.toSet()) | |
); | |
List<Set<T>> allSets = generateSubSets(otherVertices); | |
//Collections.sort(allSets, new SizeComparator()); // good enough for algorithm | |
Collections.sort(allSets, new LexiSizeComparator<>()); // better for debugging, but really slow | |
System.out.println( | |
allSets.stream().map(Utils::setToString).collect(Collectors.joining(", ", "Subsets: ", ""))); | |
for (Set<T> set : allSets) { | |
System.out.printf("Checking subset: %s%n", setToString(set)); | |
for (T current : otherVertices) { | |
if (set.contains(current)) { | |
System.out.printf("%s = (%s) = skip, because we already visited %s if we came through %s%n", | |
indexToString(current, set), pathToString(start, set, current), current, set); | |
continue; | |
} | |
findMin(current, set); | |
} | |
System.out.println(); | |
} | |
System.out.println("Gathering solution:"); | |
Index solution = new Index(start, otherVertices); | |
findMin(solution); | |
System.out.println("All intermediate costs:"); | |
minCosts.forEach((i, c) -> System.out.printf("\t%s = %.0f (%s)%n", i, c, indexPathToString(start, i))); | |
String tourString = getTour() | |
.stream() | |
.map(Object::toString) | |
.collect(Collectors.joining("->")); | |
System.out.printf("Min cost of TSP tour %s is %.0f%n", tourString, minCosts.get(solution)); | |
} | |
private void findMin(T current, Set<T> set) { | |
findMin(new Index(current, set)); | |
} | |
private void findMin(Index index) { | |
System.out.printf("%s = (%s) = min { // starting from %s to reach %s via %s%n", | |
index, indexPathToString(start, index), start, index.target, setToString(index.via)); | |
double minCost; | |
T minPrevVertex = start; | |
if (index.via.isEmpty()) { | |
minCost = graph.distance(start, index.target); | |
System.out.printf("\tdist(%s,%s) = %.0f%n", start, index.target, minCost); | |
} else { | |
minCost = INF; | |
Set<T> prevSet = new HashSet<>(index.via); // copy to prevent exception, and reuse in all iterations | |
for (T prevVertex : index.via) { | |
// try to finish the path to index.target with prevVertex | |
double cost = getCost(index, prevSet, prevVertex); | |
if (cost < minCost) { | |
minCost = cost; | |
minPrevVertex = prevVertex; | |
} | |
} | |
} | |
if (minCost != INF) { | |
System.out.printf("} = %.0f via %s // reaching %s from %s | finishing with %s->%s%n", | |
minCost, minPrevVertex, index.target, minPrevVertex, minPrevVertex, index.target); | |
} else { | |
System.out.printf("} = no path%n"); | |
} | |
minCosts.put(index, minCost); | |
parents.put(index, minPrevVertex); | |
} | |
private double getCost(Index index, Set<T> prevSet, T prevVertex) { | |
try { | |
prevSet.remove(prevVertex); | |
double c = minCosts.get(new Index(prevVertex, prevSet)); // start->{index.via}->prevVertex | |
double d = graph.distance(prevVertex, index.target); // prevVertex->index.target | |
System.out.printf("\tcost(%s) + cost(%s->%s) = cost(%s) + dist(%s,%s) = %.0f + %.0f = %.0f%n", | |
pathToString(start, prevSet, prevVertex), prevVertex, index.target, | |
indexToString(prevVertex, prevSet), prevVertex, index.target, | |
c, d, c + d); | |
return c + d; // 0->{index.via}->prevVertex + prevVertex->index.target | |
} finally { | |
prevSet.add(prevVertex); | |
} | |
} | |
public List<T> getTour() { | |
Set<T> set = new HashSet<>(); | |
for (T vertex : graph) { | |
set.add(vertex); | |
} | |
T vertex = graph.startVertex(); | |
System.out.printf("%nFinishing the tour at %s.%n", vertex); | |
Deque<T> stack = new ArrayDeque<>(); | |
while (vertex != null) { | |
stack.push(vertex); | |
System.out.printf("How did we reach %s?%n", vertex); | |
set.remove(vertex); | |
Index parentKey = new Index(vertex, set); | |
T parent = parents.get(parentKey); | |
if (parent != null) { | |
System.out.printf("\tWe came through vertices: %s%n", setToString(set)); | |
System.out.printf("\tWe finished at %s: step %s->%s costs %.0f.%n", | |
parent, parent, vertex, graph.distance(parent, vertex)); | |
System.out.printf("\tMinimal cost for all paths through these to reach %s: %.0f.%n", | |
vertex, minCosts.get(parentKey)); | |
} else { | |
System.out.printf("\tWe started there.%n"); | |
} | |
vertex = parent; | |
} | |
List<T> result = new ArrayList<>(stack); | |
System.out.printf("Read the steps backwards:%n"); | |
double sum = 0; | |
for (int i = 1; i < result.size(); i++) { | |
T prev = result.get(i - 1); | |
T curr = result.get(i); | |
double d = graph.distance(prev, curr); | |
sum += d; | |
System.out.printf("\t%s->%s (costs %.0f), totaling %.0f%n", prev, curr, d, sum); | |
} | |
return result; | |
} | |
private class Index implements Comparable<Index> { | |
T target; | |
Set<T> via; | |
Index(T vertex, Set<T> via) { | |
this.target = vertex; | |
this.via = via; | |
} | |
@Override public int compareTo(Index o) { | |
int sizeOrder = new LexiSizeComparator<T>().compare(via, o.via); | |
if (sizeOrder != 0) { | |
return sizeOrder; | |
} | |
int targetOrder = this.target.compareTo(o.target); | |
if (targetOrder != 0) { | |
return targetOrder; | |
} | |
return 0; | |
} | |
@Override public boolean equals(Object o) { | |
if (this == o) { | |
return true; | |
} | |
if (o == null || getClass() != o.getClass()) { | |
return false; | |
} | |
@SuppressWarnings("unchecked") Index index = (Index)o; | |
if (target != index.target) { | |
return false; | |
} | |
if (via == null) { | |
return index.via == null; | |
} | |
return via.equals(index.via); | |
} | |
@Override public int hashCode() { | |
int result = target.hashCode(); | |
result = 31 * result + (via != null? via.hashCode() : 0); | |
return result; | |
} | |
@Override public String toString() { | |
return indexToString(target, via); | |
} | |
} | |
private String indexToString(T vertex, Set<T> set) { | |
return String.format("[%s,%s]", vertex, setToString(set)); | |
} | |
private String indexPathToString(T from, Index index) { | |
return pathToString(from, index.via, index.target); | |
} | |
} | |
public static void main(String... args) { | |
// From video, solution: 0->1->3->2->0 costing 21 | |
new Algo<Character>().run(new CharacterGraph(new double[][] { | |
{0, 1, 15, 6}, | |
{2, 0, 7, 3}, | |
{9, 6, 0, 12}, | |
{10, 4, 8, 0}, | |
}, '0')); | |
// From comments, solution: A->C->E->D->B->A costing 21 | |
new Algo<Character>().run(new CharacterGraph(new double[][] { | |
{0, 3, 3, 1, INF}, | |
{3, 0, 8, 5, INF}, | |
{3, 8, 0, 1, 6}, | |
{1, 5, 1, 0, 4}, | |
{INF, INF, 6, 4, 0}, | |
}, 'A')); | |
} | |
private static class CharacterGraph implements Graph<Character> { | |
private final double[][] distance; | |
private final char base; | |
CharacterGraph(double[][] distance, char base) { | |
this.distance = distance; | |
this.base = base; | |
} | |
@Override public Character startVertex() { | |
return base; | |
} | |
@Override public double distance(Character from, Character to) { | |
return distance[getIndex(from)][getIndex(to)]; | |
} | |
@Override public Iterator<Character> iterator() { | |
return IntStream.range(0, getVertexCount()).mapToObj(this::getVertex).iterator(); | |
} | |
private int getVertexCount() { | |
return distance.length; | |
} | |
private char getVertex(int index) { | |
return (char)(base + index); | |
} | |
private int getIndex(Character vertex) { | |
return vertex - base; | |
} | |
} | |
static class Utils { | |
@SuppressWarnings({"unchecked", "rawtypes"}) | |
static <T extends Comparable<T>> List<Set<T>> generateSubSets(Collection<T> items) { | |
T[] input = items.toArray((T[])new Comparable[items.size()]); | |
List<Set<T>> subSets = new ArrayList<>(); | |
T[] result = (T[])new Comparable[input.length]; | |
generateSubSets(input, 0, 0, subSets, result); | |
return subSets; | |
} | |
static <T> void generateSubSets(T[] input, int start, int pos, List<Set<T>> allSets, T[] result) { | |
if (pos == input.length) { | |
return; // don't include the set: {input} | |
} | |
Set<T> set = new TreeSet<>(Arrays.asList(result).subList(0, pos)); | |
allSets.add(set); | |
for (int i = start; i < input.length; i++) { | |
result[pos] = input[i]; | |
generateSubSets(input, i + 1, pos + 1, allSets, result); | |
} | |
} | |
static <T> String setToString(Set<T> set) { | |
StringBuilder builder = new StringBuilder(); | |
if (set.isEmpty()) { | |
builder.append('∅'); | |
} else { | |
builder.append('{'); | |
for (T t : set) { | |
builder.append(t).append(','); | |
} | |
builder.setCharAt(builder.length() - 1, '}'); | |
} | |
return builder.toString(); | |
} | |
static <T> String pathToString(T from, Set<T> via, T to) { | |
return String.format("%s->%s->%s", from, setToString(via), to); | |
} | |
} | |
private static class SizeComparator implements Comparator<Collection<?>> { | |
@Override public int compare(Collection<?> o1, Collection<?> o2) { | |
return Integer.compare(o1.size(), o2.size()); | |
} | |
} | |
private static class LexiSizeComparator<T extends Comparable<T>> implements Comparator<Collection<T>> { | |
private final Comparator<Iterable<T>> lexiComparator = new LexiComparator<T>(); | |
@Override public int compare(Collection<T> o1, Collection<T> o2) { | |
int sizeOrder = Integer.compare(o1.size(), o2.size()); | |
if (sizeOrder != 0) { | |
return sizeOrder; | |
} | |
return lexiComparator.compare(o1, o2); | |
} | |
} | |
private static class LexiComparator<T extends Comparable<T>> implements Comparator<Iterable<T>> { | |
private final Comparator<T> nullSafe = Comparator.nullsLast(Comparator.naturalOrder()); | |
@Override public int compare(Iterable<T> o1, Iterable<T> o2) { | |
Iterator<T> it1 = o1.iterator(); | |
Iterator<T> it2 = o2.iterator(); | |
while (it1.hasNext() && it2.hasNext()) { | |
int result = nullSafe.compare(it1.next(), it2.next()); | |
if (result != 0) { | |
return result; | |
} | |
} | |
// items in both iterators were equal up to this point, | |
// but there could be more in one of them: | |
if (it1.hasNext()) { | |
return +1; // it1 has more elements -> o1 > o2 | |
} | |
if (it2.hasNext()) { | |
return -1; // it2 has more elements -> o1 < o2 | |
} | |
return 0; // o1 and o2 has the same length and same elements | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment