Created
January 28, 2020 11:46
-
-
Save simeneide/56755565c8b70bc13ea7b83f4242c5d2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% IMPORTS | |
import torch | |
import pytorch_lightning as pl | |
import matplotlib.pyplot as plt | |
from pytorch_lightning import Trainer | |
from torch.nn import functional as F | |
import pyro | |
import pyro.distributions as dist | |
# %% | |
class CoolSystem(pl.LightningModule): | |
def __init__(self): | |
super(CoolSystem, self).__init__() | |
# not the best model... | |
self.l1 = torch.nn.Linear(1, 1) | |
def forward(self, x): | |
return self.l1(x) | |
def training_step(self, batch, batch_idx): | |
x,y = batch | |
yhat = self.forward(x) | |
loss = (yhat-y).abs().mean() | |
tensorboard_logs = {'train_loss': loss} | |
return {'loss': loss, 'log': tensorboard_logs} | |
def validation_step(self, batch, batch_idx): | |
x,y = batch | |
yhat = self.forward(x) | |
loss = (yhat-y).abs().mean() | |
return {'val_loss': loss} | |
def validation_end(self, outputs): | |
# OPTIONAL | |
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() | |
tensorboard_logs = {'loss': avg_loss} | |
return {'val_loss': avg_loss, 'log': tensorboard_logs} | |
def configure_optimizers(self): | |
# REQUIRED | |
# can return multiple optimizers and learning_rate schedulers | |
# (LBFGS it is automatically supported, no need for closure function) | |
return torch.optim.Adam(self.parameters(), lr=0.02) | |
@pl.data_loader | |
def train_dataloader(self): | |
x = torch.arange(100).float().view(-1,1) | |
y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2 | |
ds = torch.utils.data.TensorDataset(x,y) | |
dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 2) | |
return dataloader | |
@pl.data_loader | |
def val_dataloader(self): | |
x = torch.arange(10).float().view(-1,1) | |
y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2 | |
ds = torch.utils.data.TensorDataset(x,y) | |
dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 2) | |
return dataloader | |
# %% | |
system = CoolSystem() | |
# most basic trainer, uses good defaults | |
trainer = Trainer(min_epochs=1) | |
trainer.fit(system) | |
# RESULTS | |
list(system.parameters()) | |
# %% PYRO LIGHTNING!! | |
#%% | |
import torch | |
import pytorch_lightning as pl | |
import matplotlib.pyplot as plt | |
from pytorch_lightning import Trainer | |
from torch.nn import functional as F | |
import pyro | |
import pyro.distributions as dist | |
class PyroOptWrap(pyro.infer.SVI): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def state_dict(self,): | |
return {} | |
class PyroCoolSystem(pl.LightningModule): | |
def __init__(self, num_data = 100, lr = 1e-3): | |
super(PyroCoolSystem, self).__init__() | |
self.lr = lr | |
self.num_data =num_data | |
def model(self, batch): | |
x, y = batch | |
yhat = self.forward(x) | |
obsdistr = dist.Normal(yhat, 0.2)#.to_event(1) | |
pyro.sample("obs", obsdistr, obs = y) | |
return yhat | |
def guide(self, batch): | |
b_m = pyro.param("b-mean", torch.tensor(0.1)) | |
a_m = pyro.param("a-mean", torch.tensor(0.1)) | |
b = pyro.sample("beta", dist.Normal(b_m , 0.1)) | |
a = pyro.sample("alpha", dist.Normal(a_m,0.1)) | |
def forward(self, x): | |
b = pyro.sample("beta", dist.Normal(0,1)) | |
a = pyro.sample("alpha", dist.Normal(0,1)) | |
yhat = a + x*b | |
return yhat | |
def training_step(self, batch, batch_idx): | |
#x,y = batch | |
#yhat = self.forward(x) | |
loss = self.svi.step(batch) | |
loss = torch.tensor(loss).requires_grad_(True) | |
tensorboard_logs = {'running/loss': loss, 'param/a-mean': pyro.param("a-mean"), 'param/b-mean': pyro.param("b-mean") } | |
return {'loss': loss, 'log': tensorboard_logs} | |
def validation_step(self, batch, batch_idx): | |
loss = self.svi.evaluate_loss(batch) | |
loss = torch.tensor(loss).requires_grad_(True) | |
return {'val_loss': loss} | |
def validation_end(self, outputs): | |
# OPTIONAL | |
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() | |
tensorboard_logs = {'val_loss': avg_loss} | |
#print(pyro.param("a-mean"), pyro.param('b-mean')) | |
return {'val_loss': avg_loss, 'log': tensorboard_logs} | |
def configure_optimizers(self): | |
# REQUIRED | |
# can return multiple optimizers and learning_rate schedulers | |
# (LBFGS it is automatically supported, no need for closure function) | |
self.svi = PyroOptWrap(model=self.model, | |
guide=self.guide, | |
optim=pyro.optim.SGD({"lr": self.lr, "momentum":0.0}), | |
loss=pyro.infer.Trace_ELBO()) | |
return [self.svi] | |
@pl.data_loader | |
def train_dataloader(self): | |
x = torch.rand((self.num_data,)).float().view(-1,1) | |
y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2 | |
ds = torch.utils.data.TensorDataset(x,y) | |
dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 2) | |
return dataloader | |
@pl.data_loader | |
def val_dataloader(self): | |
x = torch.rand((100,)).float().view(-1,1) | |
y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2 | |
ds = torch.utils.data.TensorDataset(x,y) | |
dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 10) | |
return dataloader | |
def optimizer_step(self, *args, **kwargs): | |
pass | |
def backward(self, *args, **kwargs): | |
pass | |
# %% | |
pyro.clear_param_store() | |
system = PyroCoolSystem(num_data=2) | |
# most basic trainer, uses good defaults | |
trainer = Trainer(min_epochs=1, max_epochs=100) | |
trainer.fit(system) | |
# %% | |
# %% |
Ok, don't know about that. Both pyro and pytorch lightning have probably done massive updates since i created this gist unfortunately
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Copy-pasted in a notebook with
pytorch-lightning
andpyro
installed, and got the following error: