Created
October 22, 2020 15:27
-
-
Save maheshambule/20050c305c5841a3cde3e11d31d09f2e to your computer and use it in GitHub Desktop.
Testing GIL load for Pytorch, numpy and BERT
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor | |
import threading | |
import gil_load | |
from pynvml import * | |
import concurrent.futures as futures | |
import click | |
import traceback | |
N_THREADS = 4 | |
NPTS = 4096 | |
image_link = "https://upload.wikimedia.org/wikipedia/commons/f/ff/Pizigani_1367_Chart_10MB.jpg" | |
def numpy_preprocess(): | |
print("running numpy preprocess") | |
for i in range(2): | |
x = np.random.randn(NPTS, NPTS) | |
x[:] = np.fft.fft2(x).real | |
def torch_preprocess(image_name="Pizigani_1367_Chart_10MB.jpg", batch_size=10): | |
try: | |
print(f"running torch preprocess with image_name={image_name} and batch_size={batch_size}") | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
images =[] | |
for i in range(batch_size): | |
image = Image.open(image_name) | |
image_processing = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
im = image_processing(image) | |
images.append(im) | |
torch.stack(images) | |
except Exception as e: | |
traceback.print_exc(file=sys.stdout) | |
print(e) | |
print("done torch prepocess") | |
def predict(compute_unit=0, model_name='bert-base-uncased'): | |
try: | |
i = 0 | |
from transformers import BertModel | |
import torch | |
import traceback, os, sys | |
from transformers import BertTokenizer | |
# Model Load | |
model = BertModel.from_pretrained(model_name) | |
if torch.cuda.is_available(): | |
device = torch.device(f"cuda:0") | |
model.to(device) | |
print(f"Loaded model {model_name} in GPU#{compute_unit} - {os.getpid()}:{threading.current_thread().ident}") | |
# Inference | |
while (i < 10): | |
tokenizer = BertTokenizer.from_pretrained(model_name) | |
text = "Replace me by any text you'd like." | |
encoded_input = tokenizer(text, return_tensors='pt') | |
device = torch.device(f"cuda:0") | |
encoded_input.to(device) | |
e = model(**encoded_input) | |
# dump_device_info(0) | |
# cpu_intensive_method() | |
print(F"Inference number {i}") | |
i = i + 1 | |
except Exception as e: | |
traceback.print_exc(file=sys.stdout) | |
print(e) | |
def preprocess_predict(image_name="Pizigani_1367_Chart_10MB.jpg", batch_size=10, compute_unit=0): | |
try: | |
torch_preprocess(image_name=image_name, batch_size=batch_size) | |
predict(compute_unit=compute_unit) | |
except Exception as e: | |
traceback.print_exc(file=sys.stdout) | |
print(e) | |
@click.command() | |
@click.option('--server', default="thread", type=str) | |
@click.option('--instances', default=5, type=int) | |
@click.option('--image_name', default="Pizigani_1367_Chart_10MB.jpg", type=str) | |
@click.option('--batch_size', default=10, type=int) | |
@click.option('--target_method', default="torch_preprocess", type=str) | |
def run_benchmark(server, instances, image_name, batch_size, target_method): | |
gil_load.init() | |
gil_load.start() | |
print(f"=======server={server},instances={instances},target_method={target_method}, image_name={image_name}, batch_size={batch_size}===================") | |
wait_for = [] | |
executor = ThreadPoolExecutor(max_workers=instances) if server == "thread" \ | |
else ProcessPoolExecutor(max_workers=instances) | |
with executor as e: | |
for compute_unit in range(0, instances): | |
if target_method == 'numpy_preprocess': | |
e.submit(numpy_preprocess) | |
elif target_method == 'torch_preprocess': | |
e.submit(torch_preprocess, image_name, batch_size) | |
elif target_method == 'predict': | |
e.submit(predict, compute_unit) | |
elif target_method == 'preprocess_predict': | |
e.submit(preprocess_predict, image_name, batch_size, compute_unit) | |
for f in futures.as_completed(wait_for): | |
print("printing") | |
print('main: result: {}'.format(f.result())) | |
gil_load.stop() | |
stats = gil_load.get() | |
print(gil_load.format(stats)) | |
# gil_load.start() | |
# | |
# threads = [] | |
# for i in range(N_THREADS): | |
# thread = threading.Thread(target=predict, daemon=True) | |
# threads.append(thread) | |
# thread.start() | |
# | |
# | |
# for thread in threads: | |
# thread.join() | |
# | |
# gil_load.stop() | |
# | |
# stats = gil_load.get() | |
# print(gil_load.format(stats)) | |
if __name__ == "__main__": | |
run_benchmark() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment