Last active
May 4, 2021 13:21
-
-
Save quadrismegistus/cd51daacbe6f0edcadc46248538cfbb0 to your computer and use it in GitHub Desktop.
Easy parallel processing in python with progress bar
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Simple mofo'n parallelism with progress bar. Born of frustration with p_tqdm. | |
""" | |
def pmap_do(inp): | |
func,obj,args,kwargs = inp | |
return func(obj,*args,**kwargs) | |
def pmap_iter(func, objs, args=[], kwargs={}, num_proc=4, use_threads=False, progress=True, desc=None, **y): | |
""" | |
Yields results of func(obj) for each obj in objs | |
Uses multiprocessing.Pool(num_proc) for parallelism. | |
If use_threads, use ThreadPool instead of Pool. | |
Results in any order. | |
""" | |
# imports | |
from tqdm import tqdm | |
# if parallel | |
if not desc: desc=f'Mapping {func.__name__}()' | |
if desc: desc=f'{desc} [x{num_proc}]' | |
if num_proc>1 and len(objs)>1: | |
# real objects | |
objects = [(func,obj,args,kwargs) for obj in objs] | |
# create pool | |
import multiprocessing as mp | |
pool=mp.Pool(num_proc) if not use_threads else mp.pool.ThreadPool(num_proc) | |
# yield iter | |
iterr = pool.imap(pmap_do, objects) | |
for res in tqdm(iterr,total=len(objs),desc=desc) if progress else iterr: | |
yield res | |
# Close the pool? | |
pool.close() | |
pool.join() | |
else: | |
# yield | |
for obj in (tqdm(objs,desc=desc) if progress else objs): | |
yield func(obj,*args,**kwargs) | |
def pmap(*x,**y): | |
""" | |
Non iterator version of pmap_iter | |
""" | |
# return as list | |
return list(pmap_iter(*x,**y)) | |
def do_pmap_group(obj): | |
# unpack | |
func,group_df,group_key,group_name = obj | |
# load from cache? | |
if type(group_df)==str: | |
group_df=pd.read_pickle(group_df) | |
# run func | |
outdf=func(group_df) | |
# annotate with groupnames on way out | |
if type(group_name) not in {list,tuple}:group_name=[group_name] | |
for x,y in zip(group_key,group_name): | |
outdf[x]=y | |
# return | |
return outdf | |
def pmap_groups(func,df_grouped,use_cache=True,**attrs): | |
import os,tempfile,pandas as pd | |
from tqdm import tqdm | |
# get index/groupby col name(s) | |
group_key=df_grouped.grouper.names | |
# if not using cache | |
# if not use_cache or attrs.get('num_proc',1)<2: | |
if not use_cache: | |
objs=[ | |
(func,group_df,group_key,group_name) | |
for group_name,group_df in df_grouped | |
] | |
else: | |
objs=[] | |
tmpdir=tempfile.mkdtemp() | |
for i,(group_name,group_df) in enumerate(tqdm(list(df_grouped),desc='Preparing input')): | |
tmp_path = os.path.join(tmpdir, str(i)+'.pkl') | |
# print([i,group_name,tmp_path,group_df]) | |
group_df.to_pickle(tmp_path) | |
objs+=[(func,tmp_path,group_key,group_name)] | |
# desc? | |
if not attrs.get('desc'): attrs['desc']=f'Mapping {func.__name__}' | |
return pd.concat( | |
pmap( | |
do_pmap_group, | |
objs, | |
**attrs | |
) | |
).set_index(group_key) | |
def pmap_df(df, func, num_proc=1): | |
df_split = np.array_split(df, num_proc) | |
df = pd.concat(pmap(func, df_split, num_proc=num_proc)) | |
return df |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment