Created
September 1, 2014 09:04
-
-
Save syhw/a51cd8e45f0fed2b49f0 to your computer and use it in GitHub Desktop.
Bandits problem solved with naive algorithms, epsilon-greedy bandit, UCB1, and Bayesian bandits.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.stats import bernoulli | |
N_EXPS = 200 # number of experiences to conduct TODO test 10k or 20k EXPS | |
N_BAGS = 10 # number of bags | |
N_DRAWS = 100 # number of draws | |
SMOOTHER = 1.E-2 # shrinkage parameter | |
EPSILON_BANDIT = 0.1 # epsilon-greedy bandit epsilon | |
EPSILON_NUM = 1.E-9 # numerical epsilon | |
bags_p = np.random.random((N_EXPS, N_BAGS)) # the probability of the Bernoulli | |
# for N_EXPS with N_BAGS | |
mean_s_1 = 0. | |
mean_s_2_m = 0. | |
mean_s_2_s = 0. | |
mean_s_3_m = 0. | |
mean_s_3_s = 0. | |
mean_s_4_m = 0. | |
mean_s_4_s = 0. | |
mean_s_5 = 0. | |
mean_s_5_e0 = 0. | |
mean_s_6 = 0. | |
mean_s_7 = 0. | |
for i in xrange(bags_p.shape[0]): | |
# strategy 1 = take 10 in each of the 10 bags | |
b_1 = [bernoulli.rvs(p, size=N_BAGS) for p in bags_p[i]] | |
mean_s_1 += np.sum(b_1) | |
# strategy 2 = Laplace law of succession estimation with one sample (stupid) | |
b_2 = [bernoulli.rvs(p) for p in bags_p[i]] | |
estimator = np.array([e+SMOOTHER for e in b_2]) | |
estimator /= np.sum(estimator) | |
mean_s_2_m += np.sum(b_2) | |
mean_s_2_s += np.sum(b_2) | |
# still 90 draws: | |
# maximal picking (_m) | |
mean_s_2_m += np.sum(bernoulli.rvs(bags_p[i, np.argmax(estimator)], | |
size=N_DRAWS-N_BAGS)) | |
# sampling (_s) | |
mean_s_2_s += np.sum([bernoulli.rvs(np.random.choice(bags_p[i], | |
p=estimator)) for _ in xrange(N_DRAWS-N_BAGS)]) | |
# strategy 3 = Laplace law of succession estimation with three samples | |
b_3 = [bernoulli.rvs(p, size=3) for p in bags_p[i]] | |
estimator = [np.sum(e)+SMOOTHER for e in b_3] | |
estimator = np.array(estimator) | |
estimator /= np.sum(estimator) | |
mean_s_3_m += np.sum(b_3) | |
mean_s_3_s += np.sum(b_3) | |
# still 70 draws: | |
# maximal picking (_m) | |
mean_s_3_m += np.sum(bernoulli.rvs(bags_p[i, np.argmax(estimator)], | |
size=N_DRAWS-3*N_BAGS)) | |
# sampling (_s) | |
mean_s_3_s += np.sum([bernoulli.rvs(np.random.choice(bags_p[i], | |
p=estimator)) for _ in xrange(N_DRAWS-3*N_BAGS)]) | |
# strategy 4 = Laplace law of succession estimation with five samples | |
b_4 = [bernoulli.rvs(p, size=5) for p in bags_p[i]] | |
estimator = [np.sum(e)+SMOOTHER for e in b_4] | |
estimator /= np.sum(estimator) | |
mean_s_4_m += np.sum(b_4) | |
mean_s_4_s += np.sum(b_4) | |
# still 50 draws: | |
# maximal picking (_m) | |
mean_s_4_m += np.sum(bernoulli.rvs(bags_p[i, np.argmax(estimator)], | |
size=N_DRAWS-5*N_BAGS)) | |
# sampling (_s) | |
mean_s_4_s += np.sum([bernoulli.rvs(np.random.choice(bags_p[i], | |
p=estimator)) for _ in xrange(N_DRAWS-5*N_BAGS)]) | |
# strategy 5 = epsilon-greedy bandit | |
b_5 = b_2 | |
mean_s_5 += np.sum(b_5) | |
estimator = np.array([e+EPSILON_NUM for e in b_5]) | |
for _ in xrange(N_DRAWS-N_BAGS): | |
if np.random.random() < EPSILON_BANDIT: | |
random_ind = np.random.randint(0, N_BAGS) | |
random_sample = bernoulli.rvs(bags_p[i, random_ind]) | |
estimator[random_ind] += random_sample | |
mean_s_5 += random_sample | |
else: | |
max_ind = np.argmax(estimator) | |
max_sample = bernoulli.rvs(bags_p[i, max_ind]) | |
estimator[max_ind] += max_sample | |
mean_s_5 += max_sample | |
# strategy 5 = epsilon-greedy bandit with epsilon = 0 | |
b_5 = b_2 | |
mean_s_5_e0 += np.sum(b_5) | |
estimator = np.array([e+EPSILON_NUM for e in b_5]) | |
for _ in xrange(N_DRAWS-N_BAGS): | |
max_ind = np.argmax(estimator) | |
max_sample = bernoulli.rvs(bags_p[i, max_ind]) | |
estimator[max_ind] += max_sample | |
mean_s_5_e0 += max_sample | |
# strategy 6 = UCB1 bandit | |
b_6 = b_2 | |
mean_s_6 += np.sum(b_6) | |
estimator = np.array([e+EPSILON_NUM for e in b_5]) | |
counts = np.ones(N_BAGS) | |
for c in xrange(N_DRAWS-N_BAGS): | |
best_estimator = [e + np.sqrt(2*np.log(c+N_BAGS+1.)) / | |
(counts[ii] ** 1.5) for ii, e in enumerate(estimator)] | |
max_ind = np.argmax(best_estimator) | |
max_sample = bernoulli.rvs(bags_p[i, max_ind]) | |
estimator[max_ind] += max_sample | |
counts[max_ind] += 1 | |
mean_s_6 += max_sample | |
# strategy 7 = Bayesian bandit | |
b_7 = b_2 | |
mean_s_7 += np.sum(b_7) # we could even start without pulling each once! | |
beta_a = np.ones(N_BAGS) | |
beta_b = np.ones(N_BAGS) | |
for ii, val in enumerate(b_7): | |
beta_a[ii] += val | |
beta_b[ii] += 1 - val | |
for c in xrange(N_DRAWS - N_BAGS): | |
#best_estimator = [np.mean(np.random.beta(beta_a[ii], beta_b[ii], | |
# size=100)) for ii in xrange(N_BAGS)] | |
best_estimator = [beta_a[ii]/(beta_a[ii] + beta_b[ii]) for ii | |
in xrange(N_BAGS)] | |
max_ind = np.argmax(best_estimator) | |
max_sample = bernoulli.rvs(bags_p[i, max_ind]) | |
beta_a[max_ind] += max_sample | |
beta_b[max_ind] += 1 - max_sample | |
mean_s_7 += max_sample | |
print "strategy 1 (random/avg):", mean_s_1 / bags_p.shape[0] | |
print "strategy 2 (estimate (1 sample) and exploit) max:", mean_s_2_m / bags_p.shape[0] | |
print "strategy 2 (estimate (1 sample) and exploit) sample:", mean_s_2_s / bags_p.shape[0] | |
print "strategy 3 (estimate (3 samples) and exploit) max:", mean_s_3_m / bags_p.shape[0] | |
print "strategy 3 (estimate (3 samples) and exploit) sample:", mean_s_3_s / bags_p.shape[0] | |
print "strategy 4 (estimate (5 samples) and exploit) max:", mean_s_4_m / bags_p.shape[0] | |
print "strategy 4 (estimate (5 samples) and exploit) sample:", mean_s_4_s / bags_p.shape[0] | |
print "strategy 5 (bandit epsilon-greedy with epsilon", EPSILON_BANDIT, "):", mean_s_5 / bags_p.shape[0] | |
print "strategy 5 (bandit epsilon-greedy with epsilon 0 ):", mean_s_5_e0 / bags_p.shape[0] | |
print "strategy 6 (bandit UCB1):", mean_s_6 / bags_p.shape[0] | |
print "strategy 7 (Bayesian bandit):", mean_s_7 / bags_p.shape[0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment