Skip to content

Commit

Permalink
add option to compile device-independent PTX files with pycuda.compil…
Browse files Browse the repository at this point in the history
…er.compile()
  • Loading branch information
rutsky committed Jun 16, 2015
1 parent 6010df6 commit 364e1e9
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions pycuda/compiler.py
Expand Up @@ -63,9 +63,11 @@ def preprocess_source(source, options, nvcc):
return stdout.decode("utf-8", "replace")


def compile_plain(source, options, keep, nvcc, cache_dir):
def compile_plain(source, options, keep, nvcc, cache_dir, target="cubin"):
from os.path import join

assert target in ["cubin", "ptx"]

if cache_dir:
checksum = _new_md5()

Expand All @@ -81,7 +83,7 @@ def compile_plain(source, options, keep, nvcc, cache_dir):
checksum.update(str(platform_bits()).encode("utf-8"))

cache_file = checksum.hexdigest()
cache_path = join(cache_dir, cache_file + ".cubin")
cache_path = join(cache_dir, cache_file + "." + target)

try:
cache_file = open(cache_path, "rb")
Expand Down Expand Up @@ -110,12 +112,12 @@ def compile_plain(source, options, keep, nvcc, cache_dir):

print "*** compiler output in %s" % file_dir

cmdline = [nvcc, "--cubin"] + options + [cu_file_name]
cmdline = [nvcc, "--" + target] + options + [cu_file_name]
result, stdout, stderr = call_capture_output(cmdline,
cwd=file_dir, error_on_nonzero=False)

try:
cubin_f = open(join(file_dir, file_root + ".cubin"), "rb")
result_f = open(join(file_dir, file_root + "." + target), "rb")
except IOError:
no_output = True
else:
Expand All @@ -141,12 +143,12 @@ def compile_plain(source, options, keep, nvcc, cache_dir):
warn("The CUDA compiler succeeded, but said the following:\n"
+ (stdout+stderr).decode("utf-8", "replace"), stacklevel=4)

cubin = cubin_f.read()
cubin_f.close()
result_data = result_f.read()
result_f.close()

if cache_dir:
outf = open(cache_path, "wb")
outf.write(cubin)
outf.write(result_data)
outf.close()

if not keep:
Expand All @@ -155,7 +157,7 @@ def compile_plain(source, options, keep, nvcc, cache_dir):
unlink(join(file_dir, name))
rmdir(file_dir)

return cubin
return result_data


def _get_per_user_string():
Expand Down Expand Up @@ -184,7 +186,9 @@ def _find_pycuda_include_path():

def compile(source, nvcc="nvcc", options=None, keep=False,
no_extern_c=False, arch=None, code=None, cache_dir=None,
include_dirs=[]):
include_dirs=[], target="cubin"):

assert target in ["cubin", "ptx"]

if not no_extern_c:
source = 'extern "C" {\n%s\n}\n' % source
Expand Down Expand Up @@ -238,7 +242,7 @@ def compile(source, nvcc="nvcc", options=None, keep=False,
for i in include_dirs:
options.append("-I"+i)

return compile_plain(source, options, keep, nvcc, cache_dir)
return compile_plain(source, options, keep, nvcc, cache_dir, target)


class SourceModule(object):
Expand Down

0 comments on commit 364e1e9

Please sign in to comment.