Skip to content

Commit

Permalink
Merge pull request #72 from rutsky/compile_ptx
Browse files Browse the repository at this point in the history
add option to compile device-independent PTX files with pycuda.compiler.compile()
  • Loading branch information
inducer committed Jun 16, 2015
2 parents f693197 + 364e1e9 commit f5958d4
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions pycuda/compiler.py
Expand Up @@ -66,9 +66,11 @@ def preprocess_source(source, options, nvcc):
return stdout.decode("utf-8", "replace")


def compile_plain(source, options, keep, nvcc, cache_dir):
def compile_plain(source, options, keep, nvcc, cache_dir, target="cubin"):
from os.path import join

assert target in ["cubin", "ptx"]

if cache_dir:
checksum = _new_md5()

Expand All @@ -84,7 +86,7 @@ def compile_plain(source, options, keep, nvcc, cache_dir):
checksum.update(str(platform_bits()).encode("utf-8"))

cache_file = checksum.hexdigest()
cache_path = join(cache_dir, cache_file + ".cubin")
cache_path = join(cache_dir, cache_file + "." + target)

try:
cache_file = open(cache_path, "rb")
Expand Down Expand Up @@ -113,12 +115,12 @@ def compile_plain(source, options, keep, nvcc, cache_dir):

print("*** compiler output in %s" % file_dir)

cmdline = [nvcc, "--cubin"] + options + [cu_file_name]
cmdline = [nvcc, "--" + target] + options + [cu_file_name]
result, stdout, stderr = call_capture_output(cmdline,
cwd=file_dir, error_on_nonzero=False)

try:
cubin_f = open(join(file_dir, file_root + ".cubin"), "rb")
result_f = open(join(file_dir, file_root + "." + target), "rb")
except IOError:
no_output = True
else:
Expand All @@ -144,12 +146,12 @@ def compile_plain(source, options, keep, nvcc, cache_dir):
warn("The CUDA compiler succeeded, but said the following:\n"
+ (stdout+stderr).decode("utf-8", "replace"), stacklevel=4)

cubin = cubin_f.read()
cubin_f.close()
result_data = result_f.read()
result_f.close()

if cache_dir:
outf = open(cache_path, "wb")
outf.write(cubin)
outf.write(result_data)
outf.close()

if not keep:
Expand All @@ -158,7 +160,7 @@ def compile_plain(source, options, keep, nvcc, cache_dir):
unlink(join(file_dir, name))
rmdir(file_dir)

return cubin
return result_data


def _get_per_user_string():
Expand Down Expand Up @@ -187,7 +189,9 @@ def _find_pycuda_include_path():

def compile(source, nvcc="nvcc", options=None, keep=False,
no_extern_c=False, arch=None, code=None, cache_dir=None,
include_dirs=[]):
include_dirs=[], target="cubin"):

assert target in ["cubin", "ptx"]

if not no_extern_c:
source = 'extern "C" {\n%s\n}\n' % source
Expand Down Expand Up @@ -241,7 +245,7 @@ def compile(source, nvcc="nvcc", options=None, keep=False,
for i in include_dirs:
options.append("-I"+i)

return compile_plain(source, options, keep, nvcc, cache_dir)
return compile_plain(source, options, keep, nvcc, cache_dir, target)


class SourceModule(object):
Expand Down

0 comments on commit f5958d4

Please sign in to comment.