Skip to content

Instantly share code, notes, and snippets.

@Ankowa
Last active April 26, 2024 12:15
Show Gist options
  • Save Ankowa/4a0cfa8ce48d7fb5ca15241b294930a6 to your computer and use it in GitHub Desktop.
Save Ankowa/4a0cfa8ce48d7fb5ca15241b294930a6 to your computer and use it in GitHub Desktop.
MIA_example
import torch
from torch.utils.data import Dataset
from typing import Tuple
import numpy as np
import requests
import pandas as pd
#### LOADING THE MODEL
from torchvision.models import resnet18
model = resnet18(pretrained=False)
model.fc = torch.nn.Linear(512, 44)
ckpt = torch.load("out/models/01_MIA_67.pt", map_location="cpu")
model.load_state_dict(ckpt)
#### DATASETS
class TaskDataset(Dataset):
def __init__(self, transform=None):
self.ids = []
self.imgs = []
self.labels = []
self.transform = transform
def __getitem__(self, index) -> Tuple[int, torch.Tensor, int]:
id_ = self.ids[index]
img = self.imgs[index]
if not self.transform is None:
img = self.transform(img)
label = self.labels[index]
return id_, img, label
def __len__(self):
return len(self.ids)
class MembershipDataset(TaskDataset):
def __init__(self, transform=None):
super().__init__(transform)
self.membership = []
def __getitem__(self, index) -> Tuple[int, torch.Tensor, int, int]:
id_, img, label = super().__getitem__(index)
return id_, img, label, self.membership[index]
data: MembershipDataset = torch.load("out/data/01/priv_out.pt")
#### EXAMPLE SUBMISSION
df = pd.DataFrame(
{
"ids": data.ids,
"score": np.random.randn(len(data.ids)),
}
)
df.to_csv("test.csv", index=None)
response = requests.post("http://35.184.239.3:9090/mia", files={"file": open("test.csv", "rb")}, headers={"token": "TOKEN"})
print(response.json())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment