Skip to content

Instantly share code, notes, and snippets.

@neubig
Created February 3, 2012 02:24
Show Gist options
  • Save neubig/1727256 to your computer and use it in GitHub Desktop.
Save neubig/1727256 to your computer and use it in GitHub Desktop.
A program to test a multinomial sampler
#include <vector>
#include <iostream>
#include <cstdlib>
#include <cmath>
using namespace std;
int SampleMultinomial(const vector<double> & distribution) {
double value = (double)rand()/RAND_MAX;
for(int i = 0; i < distribution.size(); i++)
if( (value -= distribution[i]) <= 0 )
return i;
cerr << "Overflow in SampleMultinomial()" << endl;
return -1;
}
bool TestLawOfLargeNumbers(const vector<double> distribution) {
// Call SampleMultinomial C times, and count the number of times we get
// each result
int C = 1000000;
vector<int> counts(distribution.size(), 0);
for(int i = 0; i < C; i++) {
int x = SampleMultinomial(distribution);
counts[x]++;
}
// Check to make sure that the value of the probabilities match the true
// probabilities of the distribution within a confidence threshold
double confidence = 0.005;
bool passed = true;
for(int i = 0; i < distribution.size(); i++) {
double estimated_probability = (double)counts[i]/C;
if(abs(distribution[i] - estimated_probability) > confidence) {
cerr << "FAILED: Difference between at " << i << " ("
<< abs(distribution[i] - estimated_probability)
<< ") is more than confidence " << confidence << endl;
passed = false;
}
}
return passed;
}
bool TestPseudoRandom(const vector<double> distribution) {
srand( 1234 );
const int C = 10;
int expected[C] = { 0, 0, 0, 0, 1, 0, 0, 0, 2, 2 };
bool passed = true;
for(int i = 0; i < C; i++) {
int x = SampleMultinomial(distribution);
if(x != expected[i]) {
cerr << "FAILED: Actual value at " << i << " (" << x
<< ") not equal to expected (" << expected[i] << ")" << endl;
passed = false;
}
}
return passed;
}
int main() {
srand( time(NULL) );
vector<double> my_distribution(3);
my_distribution[0] = 0.5;
my_distribution[1] = 0.3;
my_distribution[2] = 0.2;
cerr << "TestLawOfLargeNumbers: "
<< (TestLawOfLargeNumbers(my_distribution)?"passed":"FAILED") <<endl;
cerr << "TestPseudoRandom: "
<< (TestPseudoRandom(my_distribution)?"passed":"FAILED") <<endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment