Skip to content

Instantly share code, notes, and snippets.

@peterbe
Created June 28, 2024 19:54
Show Gist options
  • Save peterbe/ee26c0e6c37691aea32d9fddbea0e463 to your computer and use it in GitHub Desktop.
Save peterbe/ee26c0e6c37691aea32d9fddbea0e463 to your computer and use it in GitHub Desktop.
function defaultTokenizer(text: string) {
//remove punctuation from text - remove anything that isn't a word char or a space
var rgxPunctuation = /[^(a-zA-ZA-Яa-я0-9_)+\s]/g;
var sanitized = text.replace(rgxPunctuation, " ");
return sanitized.split(/\s+/);
}
/**
* Naive-Bayes Classifier
*
* This is a naive-bayes classifier that uses Laplace Smoothing.
*
* Takes an (optional) options object containing:
* - `tokenizer` => custom tokenization function
*
*/
type Options = {
tokenizer?: (text: string) => string[];
};
type TokenizerFunction = (text: string) => string[];
export class Naivebayes {
tokenizer: TokenizerFunction;
vocabulary: { [key: string]: boolean }; // XXX change to Set
vocabularySize: number; // XXX Shouldn't be needed
totalDocuments: number;
docCount: { [key: string]: number };
wordCount: { [key: string]: number };
wordFrequencyCount: { [key: string]: { [key: string]: number } };
categories: { [key: string]: boolean };
constructor({ tokenizer = defaultTokenizer }: Options = {}) {
this.tokenizer = tokenizer;
//initialize our vocabulary and its size
this.vocabulary = {};
this.vocabularySize = 0;
//number of documents we have learned from
this.totalDocuments = 0;
//document frequency table for each of our categories
//=> for each category, how often were documents mapped to it
this.docCount = {};
//for each category, how many words total were mapped to it
this.wordCount = {};
//word frequency table for each category
//=> for each category, how frequent was a given word mapped to it
this.wordFrequencyCount = {};
//hashmap of our category names
this.categories = {};
// // set options object
}
initializeCategory(categoryName: string) {
if (!this.categories[categoryName]) {
this.docCount[categoryName] = 0;
this.wordCount[categoryName] = 0;
this.wordFrequencyCount[categoryName] = {};
this.categories[categoryName] = true;
}
return this;
}
learn(text: string, category: string) {
this.initializeCategory(category);
this.docCount[category]++;
this.totalDocuments++;
const tokens = this.tokenizer(text);
const frequencyTable = this.frequencyTable(tokens);
const self = this;
Object.keys(frequencyTable).forEach((token) => {
if (!this.vocabulary[token]) {
self.vocabulary[token] = true;
self.vocabularySize++;
}
const frequencyInText = frequencyTable[token];
if (!self.wordFrequencyCount[category][token])
self.wordFrequencyCount[category][token] = frequencyInText;
else self.wordFrequencyCount[category][token] += frequencyInText;
self.wordCount[category] += frequencyInText;
});
}
frequencyTable(tokens: string[]) {
const frequencyTable = Object.create(null);
tokens.forEach(function (token) {
if (!frequencyTable[token]) frequencyTable[token] = 1;
else frequencyTable[token]++;
});
return frequencyTable;
}
categorize(text: string): string {
var self = this,
maxProbability = -Infinity,
chosenCategory = "";
var tokens = self.tokenizer(text);
var frequencyTable = self.frequencyTable(tokens);
Object.keys(self.categories).forEach(function (category) {
var categoryProbability = self.docCount[category] / self.totalDocuments;
var logProbability = Math.log(categoryProbability);
Object.keys(frequencyTable).forEach(function (token) {
var frequencyInText = frequencyTable[token];
var tokenProbability = self.tokenProbability(token, category);
logProbability += frequencyInText * Math.log(tokenProbability);
});
if (logProbability > maxProbability) {
maxProbability = logProbability;
chosenCategory = category;
}
});
return chosenCategory;
}
tokenProbability(token: string, category: string) {
var wordFrequencyCount = this.wordFrequencyCount[category][token] || 0;
var wordCount = this.wordCount[category];
return (wordFrequencyCount + 1) / (wordCount + this.vocabularySize);
}
toJSON() {
throw new Error("Method not implemented.");
}
fromJSON(jsonStr: string) {
throw new Error("Method not implemented.");
}
}
/**
*
* Run it like this:
*
* bun run test.ts ../docs-internal-data/comments/bayes-training.json sample.json bad
*
* or
*
* bun run test.ts ../docs-internal-data/comments/bayes-training.json sample.json good
*
* The sample.json file needs to a JSON array of objects with a "comment" and "rating" field.
* You can get these from Hydro.
*/
import { readFileSync } from "fs";
import util from "node:util";
import chalk from "chalk";
import { Naivebayes } from "./naive-bayes";
const args = process.argv.slice(2);
const trainedFilePath = args[0] as string;
const sampleFilePath = args[1] as string;
const print = args[2] === "bad" ? "bad" : "good";
type Training = {
rating: string;
comment: string;
};
type Sample = {
comment: string;
rating: number;
};
async function main() {
const trained = JSON.parse(readFileSync(trainedFilePath, "utf8"))[
"comments"
] as Training[];
console.log("TRAINED", trained.length);
const samples = JSON.parse(readFileSync(sampleFilePath, "utf8")) as Sample[];
console.log("SAMPLES", samples.length);
const goodSamples = samples.filter((s) => s.rating >= 0.9);
console.log("GOOD SAMPLES", goodSamples.length);
const bayes = new Naivebayes();
// console.log(bayes);
// console.log(trained);
const countTraining: Record<string, number> = {};
for (const training of trained) {
// bayes.learn(key, trained[key]);
bayes.learn(training.comment, training.rating);
countTraining[training.rating] = countTraining[training.rating]
? countTraining[training.rating] + 1
: 1;
}
console.log("COUNT TRAINING:", countTraining);
// return;
const count: Record<string, number> = {};
console.log(`Printing all the ${chalk.bold(print)} ones`);
for (const sample of goodSamples) {
// console.log(sample.comment);
const category = bayes.categorize(sample.comment);
count[category] = count[category] ? count[category] + 1 : 1;
if (category === print) {
console.log(chalk.grey(util.inspect(sample.comment)));
console.log(`CATEGORY: ${category}`);
console.log("\n");
}
}
console.log("COUNT:", count);
}
main();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment