Skip to content

Instantly share code, notes, and snippets.

@Staticity
Created December 4, 2022 17:42
Show Gist options
  • Save Staticity/1e879f315864e1d07e41980c602aaaea to your computer and use it in GitHub Desktop.
Save Staticity/1e879f315864e1d07e41980c602aaaea to your computer and use it in GitHub Desktop.
quick script to show likelihood stuff
import numpy as np
import random
import plotly.graph_objects as go
from plotly.subplots import make_subplots
class GaussianDistribution:
def __init__(self, mean, sigma):
self.u = mean
self.o = sigma
def sample(self):
xp = np.random.normal(self.u, self.o)
return GaussianDistribution(xp, self.o)
def likelihood(self, x):
s = 1 / (self.o * np.sqrt(2 * np.pi))
e = np.exp(-(1/2) * ((x - self.u) / self.o) ** 2)
return s * e
def negative_log_likelihood(self, x):
# we could just do this
# return -np.log(self.likelihood(x))
# but let's write out the whole thing instead
log_s = np.log(1 / (self.o * np.sqrt(2 * np.pi)))
log_e = -(1/2) * ((x - self.u) / self.o) ** 2
return -(log_s + log_e)
def likelihood(measurements, x):
l = 1.0
for m in measurements:
l *= m.likelihood(x)
return l
def negative_log_likelihood(measurements, x):
nll = 0
for m in measurements:
nll += m.negative_log_likelihood(x)
return nll
def main():
true_mean = 16.573 # cm
sigma = 1 # cm
true_distribution = GaussianDistribution(true_mean, sigma)
# let's simulate N measurements (we choose N)
num_measurements = 10
measurements = []
for i in range(num_measurements):
measurements.append(true_distribution.sample())
# let's plot the likelihood and log-likelihood for a number of x values
xs = np.linspace(0, 20, num=100)
fig = make_subplots(specs=[[{"secondary_y": True}]])
# Add Likelihood Plot
ls = [likelihood(measurements, x) for x in xs]
fig.add_trace(go.Scatter(x=xs, y=ls, name='likelihood'))
# Add Negative Log Likelihood plot
nlls = [negative_log_likelihood(measurements, x) for x in xs]
fig.add_trace(go.Scatter(x=xs, y=nlls, name='neg-log likelihood'), secondary_y=True)
# Let's visualize the predicted MLE between both plots
mle = max((l, x) for x, l in zip(xs, ls))
fig.add_vline(x=mle[1], annotation_text="MLE")
fig.show()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment