Created
August 23, 2017 20:43
-
-
Save CalebFenton/a129333dabc1cc346b0874407f92b568 to your computer and use it in GitHub Desktop.
Markov Chain implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gnu.trove.map.TCharDoubleMap; | |
import gnu.trove.map.hash.TCharDoubleHashMap; | |
import java.io.Serializable; | |
import java.util.HashMap; | |
import java.util.Map; | |
import java.util.regex.Pattern; | |
public class MarkovChain implements Serializable { | |
private static final long serialVersionUID = 986958034001823764L; | |
private final Pattern splitPattern; | |
private final int nGramSize; | |
private final Map<String, TCharDoubleMap> chain; | |
private char[] acceptedChars; | |
private char[] filteredChars; | |
private double unknownProbability; | |
public MarkovChain() { | |
this("", 2); | |
} | |
public MarkovChain(String splitRegex, int nGramSize) { | |
if (splitRegex == null) { | |
splitPattern = null; | |
} else { | |
splitPattern = Pattern.compile(splitRegex); | |
} | |
this.nGramSize = nGramSize; | |
chain = new HashMap<>(); | |
acceptedChars = new char[0]; | |
filteredChars = new char[0]; | |
} | |
public MarkovChain(int nGramSize) { | |
this(null, nGramSize); | |
} | |
public void setAcceptedChars(char[] acceptedChars) { | |
this.acceptedChars = acceptedChars; | |
} | |
public void setFilteredChars(char[] filteredChars) { | |
this.filteredChars = filteredChars; | |
} | |
public void update(String input) { | |
if (splitPattern == null) { | |
updateToken(input); | |
} else { | |
String[] tokens = splitPattern.split(input); | |
update(tokens); | |
} | |
} | |
public void update(String[] tokens) { | |
for (String token : tokens) { | |
updateToken(token); | |
} | |
} | |
private void updateToken(String token) { | |
for (int i = 0; i < token.length() - nGramSize + 1; i++) { | |
String nGram = token.substring(i, i + nGramSize); | |
TCharDoubleMap unitToCount = chain.get(nGram); | |
if (unitToCount == null) { | |
unitToCount = new TCharDoubleHashMap(); | |
chain.put(nGram, unitToCount); | |
} | |
char nextUnit; | |
if (i + nGramSize + 1 > token.length()) { | |
nextUnit = '\0'; | |
} else { | |
nextUnit = token.charAt(i + nGramSize); | |
} | |
unitToCount.adjustOrPutValue(nextUnit, 1, 1); | |
} | |
} | |
public void finish() { | |
double minLogProbability = Double.MAX_VALUE; | |
for (TCharDoubleMap unitToCount : chain.values()) { | |
int countSum = 0; | |
for (double value : unitToCount.values()) { | |
countSum += value; | |
} | |
for (char unit : unitToCount.keys()) { | |
double count = unitToCount.get(unit); | |
double logProbability = Math.log(count / countSum); | |
if (logProbability < minLogProbability) { | |
minLogProbability = logProbability; | |
} | |
unitToCount.put(unit, logProbability); | |
} | |
} | |
// Use this probability when we see a new ngram -> unit pair | |
// Without this, the probability might be 0 for an input | |
unknownProbability = Math.min(minLogProbability, Math.log(0.5)); | |
} | |
/** | |
* | |
* @param input | |
* @return probability or nan if tokens are empty | |
*/ | |
public double getProbability(String input) { | |
if (splitPattern == null) { | |
return getProbabilityToken(input); | |
} else { | |
String[] tokens = splitPattern.split(input); | |
return getProbability(tokens); | |
} | |
} | |
public double getProbability(String[] tokens) { | |
double probability = 0; | |
for (String token : tokens) { | |
probability += getProbabilityToken(token); | |
} | |
return probability / tokens.length; | |
} | |
private double getProbabilityToken(String token) { | |
double logProbabilitySum = 0; | |
int transitionCount = 0; | |
for (int i = 0; i < token.length() - nGramSize + 1; i++) { | |
String nGram = token.substring(i, i + nGramSize); | |
char nextUnit; | |
if (i + nGramSize + 1 > token.length()) { | |
nextUnit = '\0'; | |
} else { | |
nextUnit = token.charAt(i + nGramSize); | |
} | |
double tokenProbability = unknownProbability; | |
TCharDoubleMap unitToCount = chain.get(nGram); | |
if (unitToCount != null && unitToCount.containsKey(nextUnit)) { | |
tokenProbability = unitToCount.get(nextUnit); | |
} | |
logProbabilitySum += tokenProbability; | |
transitionCount += 1; | |
} | |
transitionCount = Math.max(transitionCount, 1); | |
return Math.exp(logProbabilitySum / transitionCount); | |
} | |
public String asString() { | |
StringBuilder sb = new StringBuilder(); | |
for (Map.Entry<String, TCharDoubleMap> entry : chain.entrySet()) { | |
String nGram = entry.getKey(); | |
sb.append('"').append(nGram).append("\":{"); | |
TCharDoubleMap unitToProbability = entry.getValue(); | |
for (char unit : unitToProbability.keys()) { | |
double count = unitToProbability.get(unit); | |
sb.append('\'').append(unit).append("'=").append(count).append(", "); | |
} | |
if (unitToProbability.size() > 0) { | |
sb.setLength(sb.length() - 2); | |
} | |
sb.append("}\n"); | |
} | |
if (sb.length() > 1) { | |
sb.setLength(sb.length() - 1); | |
} | |
return sb.toString(); | |
} | |
public double getUnknownProbability() { | |
return unknownProbability; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment