Created
September 5, 2018 14:58
-
-
Save seibert/52a204395cdc84eeeaf0ce05464a636b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def cuda_detect(): | |
'''Attempt to detect the version of CUDA present in the operating system. | |
On Windows and Linux, the CUDA library is installed by the NVIDIA | |
driver package, and is typically found in the standard library path, | |
rather than with the CUDA SDK (which is optional for running CUDA apps). | |
On macOS, the CUDA library is only installed with the CUDA SDK, and | |
might not be in the library path. | |
Returns: version string (Ex: '9.2') or None if CUDA not found. | |
''' | |
# platform specific libcuda location | |
import platform | |
system = platform.system() | |
if system == 'Darwin': | |
lib_filenames = [ | |
'libcuda.dylib', # check library path first | |
'/usr/local/cuda/lib/libcuda.dylib' | |
] | |
elif system == 'Linux': | |
lib_filenames = [ | |
'libcuda.so', # check library path first | |
'/usr/lib64/nvidia/libcuda.so', # Redhat/CentOS/Fedora | |
'/usr/lib/x86_64-linux-gnu/libcuda.so', # Ubuntu | |
] | |
elif system == 'Windows': | |
lib_filenames = ['nvcuda.dll'] | |
else: | |
return None # CUDA not available for other operating systems | |
# open library | |
import ctypes | |
if system == 'Windows': | |
dll = ctypes.windll | |
else: | |
dll = ctypes.cdll | |
libcuda = None | |
for lib_filename in lib_filenames: | |
try: | |
libcuda = dll.LoadLibrary(lib_filename) | |
break | |
except: | |
pass | |
if libcuda is None: | |
return None | |
# Get CUDA version | |
try: | |
cuInit = libcuda.cuInit | |
flags = ctypes.c_uint(0) | |
ret = cuInit(flags) | |
if ret != 0: | |
return None | |
cuDriverGetVersion = libcuda.cuDriverGetVersion | |
version_int = ctypes.c_int(0) | |
ret = cuDriverGetVersion(ctypes.byref(version_int)) | |
if ret != 0: | |
return None | |
# Convert version integer to version string | |
value = version_int.value | |
return '%d.%d' % (value // 1000, (value % 1000) // 10) | |
except: | |
return None | |
if __name__ == '__main__': | |
print('CUDA version:', cuda_detect()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment