Created
June 28, 2024 19:54
-
-
Save peterbe/ee26c0e6c37691aea32d9fddbea0e463 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function defaultTokenizer(text: string) { | |
//remove punctuation from text - remove anything that isn't a word char or a space | |
var rgxPunctuation = /[^(a-zA-ZA-Яa-я0-9_)+\s]/g; | |
var sanitized = text.replace(rgxPunctuation, " "); | |
return sanitized.split(/\s+/); | |
} | |
/** | |
* Naive-Bayes Classifier | |
* | |
* This is a naive-bayes classifier that uses Laplace Smoothing. | |
* | |
* Takes an (optional) options object containing: | |
* - `tokenizer` => custom tokenization function | |
* | |
*/ | |
type Options = { | |
tokenizer?: (text: string) => string[]; | |
}; | |
type TokenizerFunction = (text: string) => string[]; | |
export class Naivebayes { | |
tokenizer: TokenizerFunction; | |
vocabulary: { [key: string]: boolean }; // XXX change to Set | |
vocabularySize: number; // XXX Shouldn't be needed | |
totalDocuments: number; | |
docCount: { [key: string]: number }; | |
wordCount: { [key: string]: number }; | |
wordFrequencyCount: { [key: string]: { [key: string]: number } }; | |
categories: { [key: string]: boolean }; | |
constructor({ tokenizer = defaultTokenizer }: Options = {}) { | |
this.tokenizer = tokenizer; | |
//initialize our vocabulary and its size | |
this.vocabulary = {}; | |
this.vocabularySize = 0; | |
//number of documents we have learned from | |
this.totalDocuments = 0; | |
//document frequency table for each of our categories | |
//=> for each category, how often were documents mapped to it | |
this.docCount = {}; | |
//for each category, how many words total were mapped to it | |
this.wordCount = {}; | |
//word frequency table for each category | |
//=> for each category, how frequent was a given word mapped to it | |
this.wordFrequencyCount = {}; | |
//hashmap of our category names | |
this.categories = {}; | |
// // set options object | |
} | |
initializeCategory(categoryName: string) { | |
if (!this.categories[categoryName]) { | |
this.docCount[categoryName] = 0; | |
this.wordCount[categoryName] = 0; | |
this.wordFrequencyCount[categoryName] = {}; | |
this.categories[categoryName] = true; | |
} | |
return this; | |
} | |
learn(text: string, category: string) { | |
this.initializeCategory(category); | |
this.docCount[category]++; | |
this.totalDocuments++; | |
const tokens = this.tokenizer(text); | |
const frequencyTable = this.frequencyTable(tokens); | |
const self = this; | |
Object.keys(frequencyTable).forEach((token) => { | |
if (!this.vocabulary[token]) { | |
self.vocabulary[token] = true; | |
self.vocabularySize++; | |
} | |
const frequencyInText = frequencyTable[token]; | |
if (!self.wordFrequencyCount[category][token]) | |
self.wordFrequencyCount[category][token] = frequencyInText; | |
else self.wordFrequencyCount[category][token] += frequencyInText; | |
self.wordCount[category] += frequencyInText; | |
}); | |
} | |
frequencyTable(tokens: string[]) { | |
const frequencyTable = Object.create(null); | |
tokens.forEach(function (token) { | |
if (!frequencyTable[token]) frequencyTable[token] = 1; | |
else frequencyTable[token]++; | |
}); | |
return frequencyTable; | |
} | |
categorize(text: string): string { | |
var self = this, | |
maxProbability = -Infinity, | |
chosenCategory = ""; | |
var tokens = self.tokenizer(text); | |
var frequencyTable = self.frequencyTable(tokens); | |
Object.keys(self.categories).forEach(function (category) { | |
var categoryProbability = self.docCount[category] / self.totalDocuments; | |
var logProbability = Math.log(categoryProbability); | |
Object.keys(frequencyTable).forEach(function (token) { | |
var frequencyInText = frequencyTable[token]; | |
var tokenProbability = self.tokenProbability(token, category); | |
logProbability += frequencyInText * Math.log(tokenProbability); | |
}); | |
if (logProbability > maxProbability) { | |
maxProbability = logProbability; | |
chosenCategory = category; | |
} | |
}); | |
return chosenCategory; | |
} | |
tokenProbability(token: string, category: string) { | |
var wordFrequencyCount = this.wordFrequencyCount[category][token] || 0; | |
var wordCount = this.wordCount[category]; | |
return (wordFrequencyCount + 1) / (wordCount + this.vocabularySize); | |
} | |
toJSON() { | |
throw new Error("Method not implemented."); | |
} | |
fromJSON(jsonStr: string) { | |
throw new Error("Method not implemented."); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* | |
* Run it like this: | |
* | |
* bun run test.ts ../docs-internal-data/comments/bayes-training.json sample.json bad | |
* | |
* or | |
* | |
* bun run test.ts ../docs-internal-data/comments/bayes-training.json sample.json good | |
* | |
* The sample.json file needs to a JSON array of objects with a "comment" and "rating" field. | |
* You can get these from Hydro. | |
*/ | |
import { readFileSync } from "fs"; | |
import util from "node:util"; | |
import chalk from "chalk"; | |
import { Naivebayes } from "./naive-bayes"; | |
const args = process.argv.slice(2); | |
const trainedFilePath = args[0] as string; | |
const sampleFilePath = args[1] as string; | |
const print = args[2] === "bad" ? "bad" : "good"; | |
type Training = { | |
rating: string; | |
comment: string; | |
}; | |
type Sample = { | |
comment: string; | |
rating: number; | |
}; | |
async function main() { | |
const trained = JSON.parse(readFileSync(trainedFilePath, "utf8"))[ | |
"comments" | |
] as Training[]; | |
console.log("TRAINED", trained.length); | |
const samples = JSON.parse(readFileSync(sampleFilePath, "utf8")) as Sample[]; | |
console.log("SAMPLES", samples.length); | |
const goodSamples = samples.filter((s) => s.rating >= 0.9); | |
console.log("GOOD SAMPLES", goodSamples.length); | |
const bayes = new Naivebayes(); | |
// console.log(bayes); | |
// console.log(trained); | |
const countTraining: Record<string, number> = {}; | |
for (const training of trained) { | |
// bayes.learn(key, trained[key]); | |
bayes.learn(training.comment, training.rating); | |
countTraining[training.rating] = countTraining[training.rating] | |
? countTraining[training.rating] + 1 | |
: 1; | |
} | |
console.log("COUNT TRAINING:", countTraining); | |
// return; | |
const count: Record<string, number> = {}; | |
console.log(`Printing all the ${chalk.bold(print)} ones`); | |
for (const sample of goodSamples) { | |
// console.log(sample.comment); | |
const category = bayes.categorize(sample.comment); | |
count[category] = count[category] ? count[category] + 1 : 1; | |
if (category === print) { | |
console.log(chalk.grey(util.inspect(sample.comment))); | |
console.log(`CATEGORY: ${category}`); | |
console.log("\n"); | |
} | |
} | |
console.log("COUNT:", count); | |
} | |
main(); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment