Skip to content

Commit

Permalink
Fix capture_call for calls after set_scalar_arg_dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Aug 18, 2015
1 parent 5a09475 commit 2f284bd
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
2 changes: 2 additions & 0 deletions pyopencl/__init__.py
Expand Up @@ -819,6 +819,8 @@ def kernel__generate_naive_call(self):
self._set_set_args_body(gen, num_args)

def kernel_set_scalar_arg_dtypes(self, scalar_arg_dtypes):
self._scalar_arg_dtypes = scalar_arg_dtypes

# {{{ arg counting bug handling

# For example:
Expand Down
18 changes: 16 additions & 2 deletions pyopencl/capture_call.py
Expand Up @@ -117,8 +117,22 @@ def capture_kernel_call(kernel, filename, queue, g_size, l_size, *args, **kwargs

cg("prg = cl.Program(ctx, CODE).build()")
cg("knl = prg.%s" % kernel.function_name)
if hasattr(kernel, "_arg_type_chars"):
cg("knl._arg_type_chars = %s" % repr(kernel._arg_type_chars))
if hasattr(kernel, "_scalar_arg_dtypes"):
def strify_dtype(d):
if d is None:
return "None"

d = np.dtype(d)
s = repr(d)
if s.startswith("dtype"):
s = "np."+s

return s

cg("knl.set_scalar_arg_dtypes((%s,))"
% ", ".join(
strify_dtype(dt) for dt in kernel._scalar_arg_dtypes))

cg("knl(queue, %s, %s," % (repr(g_size), repr(l_size)))
cg(" %s)" % ", ".join(kernel_args))
cg("")
Expand Down

0 comments on commit 2f284bd

Please sign in to comment.