Skip to content

Commit

Permalink
Code generation for kernel enqueue and work around pocl complex arg s…
Browse files Browse the repository at this point in the history
…plit issue
  • Loading branch information
inducer committed Jul 5, 2015
1 parent 27381fe commit 602db4a
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 93 deletions.
342 changes: 283 additions & 59 deletions pyopencl/__init__.py
Expand Up @@ -42,6 +42,8 @@
"its source directory. This likely won't work.")
raise

_CPY2 = _cl._CPY2
_CPY26 = _cl._CPY2 and sys.version_info < (2, 7)

import numpy as np

Expand Down Expand Up @@ -167,6 +169,8 @@
and name[0].islower() and name not in ["zip", "map", "range"]]


# {{{ diagnostics

class CompilerWarning(UserWarning):
pass

Expand All @@ -185,6 +189,25 @@ def compiler_output(text):
class _ErrorRecord(_Record):
pass

# }}}


# {{{ arg packing helpers

_size_t_char = ({
8: 'Q',
4: 'L',
2: 'H',
1: 'B',
})[_cl._ffi.sizeof('size_t')]
_type_char_map = {
'n': _size_t_char.lower(),
'N': _size_t_char
}
del _size_t_char

# }}}


# {{{ find pyopencl shipped source code

Expand Down Expand Up @@ -609,72 +632,260 @@ def kernel_init(self, prg, name):
kernel_old_init(self, prg, name)
self._source = getattr(prg, "_source", None)

def kernel_call(self, queue, global_size, local_size, *args, **kwargs):
global_offset = kwargs.pop("global_offset", None)
g_times_l = kwargs.pop("g_times_l", False)
wait_for = kwargs.pop("wait_for", None)
self._generate_naive_call()

if kwargs:
raise TypeError(
"Kernel.__call__ recived unexpected keyword arguments: %s"
% ", ".join(list(kwargs.keys())))
# {{{ code generation for __call__, set_args

self.set_args(*args)
def kernel__set_set_args_body(self, body, num_passed_args):
from pytools.py_codegen import (
PythonFunctionGenerator,
PythonCodeGenerator,
Indentation)

return enqueue_nd_range_kernel(queue, self, global_size, local_size,
global_offset, wait_for, g_times_l=g_times_l)
arg_names = ["arg%d" % i for i in xrange(num_passed_args)]

def kernel_set_scalar_arg_dtypes(self, arg_dtypes):
assert len(arg_dtypes) == self.num_args, (
"length of argument type array (%d) and "
"CL-generated number of arguments (%d) do not agree"
% (len(arg_dtypes), self.num_args))
# {{{ wrap in error handler

arg_type_chars = []
err_gen = PythonCodeGenerator()

for arg_dtype in arg_dtypes:
if arg_dtype is None:
arg_type_chars.append(None)
err_gen("try:")
with Indentation(err_gen):
err_gen.extend(body)
err_gen("except TypeError as e:")
with Indentation(err_gen):
err_gen("""
if current_arg is not None:
args = [{args}]
advice = ""
from pyopencl.array import Array
if isinstance(args[current_arg], Array):
advice = " (perhaps you meant to pass 'array.data' " \
"instead of the array itself?)"
raise _cl.LogicError(
"when processing argument #%d (1-based): %s%s"
% (current_arg+1, str(e), advice))
else:
raise
"""
.format(args=", ".join(arg_names)))
err_gen("")

# }}}

def add_preamble(gen):
gen.add_to_preamble(
"import numpy as np")
gen.add_to_preamble(
"import pyopencl.cffi_cl as _cl")
gen.add_to_preamble(
"from pyopencl.cffi_cl import _lib, "
"_ffi, _handle_error, _CLKernelArg")
gen.add_to_preamble("from pyopencl import status_code")
gen.add_to_preamble("from struct import pack")
gen.add_to_preamble("")

# {{{ generate _enqueue

gen = PythonFunctionGenerator("enqueue_knl_%s" % self.function_name,
["self", "queue", "global_size", "local_size"]
+ arg_names
+ ["global_offset=None", "g_times_l=None", "wait_for=None"])

add_preamble(gen)
gen.extend(err_gen)

gen("""
return _cl.enqueue_nd_range_kernel(queue, self, global_size, local_size,
global_offset, wait_for, g_times_l=g_times_l)
""")

self._enqueue = gen.get_function()

# }}}

# {{{ generate set_args

gen = PythonFunctionGenerator("_set_args", ["self"] + arg_names)

add_preamble(gen)
gen.extend(err_gen)

self._set_args = gen.get_function()

# }}}

def kernel__generate_buffer_arg_setter(self, gen, arg_idx, buf_var):
from pytools.py_codegen import Indentation

if _CPY2:
# https://github.com/numpy/numpy/issues/5381
gen("if isinstance({buf_var}, np.generic):".format(buf_var=buf_var))
with Indentation(gen):
gen("{buf_var} = np.getbuffer({buf_var})".format(buf_var=buf_var))

gen("""
c_buf, sz, _ = _cl._c_buffer_from_obj({buf_var})
status = _lib.kernel__set_arg_buf(self.ptr, {arg_idx}, c_buf, sz)
if status != _ffi.NULL:
_handle_error(status)
"""
.format(arg_idx=arg_idx, buf_var=buf_var))

def kernel__generate_generic_arg_handler(self, gen, arg_idx, arg_var):
from pytools.py_codegen import Indentation

gen("""
if {arg_var} is None:
status = _lib.kernel__set_arg_null(self.ptr, {arg_idx})
if status != _ffi.NULL:
_handle_error(status)
elif isinstance({arg_var}, _CLKernelArg):
self.set_arg({arg_idx}, {arg_var})
"""
.format(arg_idx=arg_idx, arg_var=arg_var))

gen("else:")
with Indentation(gen):
self._generate_buffer_arg_setter(gen, arg_idx, arg_var)

def kernel__generate_naive_call(self):
num_args = self.num_args

from pytools.py_codegen import PythonCodeGenerator
gen = PythonCodeGenerator()

for i in range(num_args):
gen("# process argument {arg_idx}".format(arg_idx=i))
gen("")
gen("current_arg = {arg_idx}".format(arg_idx=i))
self._generate_generic_arg_handler(gen, i, "arg%d" % i)
gen("")

self._set_set_args_body(gen, num_args)

def kernel_set_scalar_arg_dtypes(self, scalar_arg_dtypes):
# {{{ arg counting bug handling

# For example:
# https://github.com/pocl/pocl/issues/197
# (but Apple CPU has a similar bug)

work_around_arg_count_bug = False
warn_about_arg_count_bug = False

from pyopencl.characterize import has_struct_arg_count_bug

count_bug_per_dev = [
has_struct_arg_count_bug(dev)
for dev in self.context.devices]

if any(count_bug_per_dev):
if all(count_bug_per_dev):
work_around_arg_count_bug = True
else:
arg_type_chars.append(np.dtype(arg_dtype).char)
warn_about_arg_count_bug = True

# }}}

cl_arg_idx = 0

from pytools.py_codegen import PythonCodeGenerator
gen = PythonCodeGenerator()

for arg_idx, arg_dtype in enumerate(scalar_arg_dtypes):
gen("# process argument {arg_idx}".format(arg_idx=arg_idx))
gen("")
gen("current_arg = {arg_idx}".format(arg_idx=arg_idx))
arg_var = "arg%d" % arg_idx

if arg_dtype is None:
self._generate_generic_arg_handler(gen, cl_arg_idx, arg_var)
cl_arg_idx += 1
gen("")
continue

arg_dtype = np.dtype(arg_dtype)

self._arg_type_chars = arg_type_chars
if arg_dtype.char == "V":
self._generate_generic_arg_handler(gen, cl_arg_idx, arg_var)
cl_arg_idx += 1

def kernel_set_args(self, *args):
assert len(args) == self.num_args, (
elif arg_dtype.kind == "c":
if warn_about_arg_count_bug:
warn("{knl_name}: arguments include complex numbers, and "
"some (but not all) of the target devices mishandle "
"struct kernel arguments (hence the workaround is "
"disabled".format(
knl_name=self.function_name, stacklevel=2))

if arg_dtype == np.complex64:
arg_char = "f"
elif arg_dtype == np.complex128:
arg_char = "d"
else:
raise TypeError("unexpected complex type: %s" % arg_dtype)

if work_around_arg_count_bug and arg_dtype == np.complex128:
gen(
"buf = pack('{arg_char}', {arg_var}.real)"
.format(arg_char=arg_char, arg_var=arg_var))
self._generate_buffer_arg_setter(gen, cl_arg_idx, "buf")
cl_arg_idx += 1
gen(
"buf = pack('{arg_char}', {arg_var}.imag)"
.format(arg_char=arg_char, arg_var=arg_var))
self._generate_buffer_arg_setter(gen, cl_arg_idx, "buf")
cl_arg_idx += 1
else:
gen(
"buf = pack('{arg_char}{arg_char}', "
"{arg_var}.real, {arg_var}.imag)"
.format(arg_char=arg_char, arg_var=arg_var))
self._generate_buffer_arg_setter(gen, cl_arg_idx, "buf")
cl_arg_idx += 1

elif arg_dtype.char in "IL" and _CPY26:
# Prevent SystemError: ../Objects/longobject.c:336: bad
# argument to internal function

gen(
"buf = pack('{arg_char}', long({arg_var})"
.format(arg_char=arg_dtype.char, arg_var=arg_var))
self._generate_buffer_arg_setter(gen, cl_arg_idx, "buf")
cl_arg_idx += 1

else:
arg_char = arg_dtype.char
arg_char = _type_char_map.get(arg_char, arg_char)
gen(
"buf = pack('{arg_char}', {arg_var})"
.format(
arg_char=arg_char,
arg_var=arg_var))
self._generate_buffer_arg_setter(gen, cl_arg_idx, "buf")
cl_arg_idx += 1

gen("")

if cl_arg_idx != self.num_args:
raise TypeError(
"length of argument list (%d) and "
"CL-generated number of arguments (%d) do not agree"
% (len(args), self.num_args))
% (cl_arg_idx, self.num_args))

i = None
try:
try:
arg_type_chars = self.__dict__["_arg_type_chars"]
except KeyError:
for i, arg in enumerate(args):
self.set_arg(i, arg)
else:
from pyopencl._pvt_struct import pack
self._set_set_args_body(gen, len(scalar_arg_dtypes))

for i, (arg, arg_type_char) in enumerate(
zip(args, arg_type_chars)):
if arg_type_char and arg_type_char != "V":
self.set_arg(i, pack(arg_type_char, arg))
else:
self.set_arg(i, arg)
except TypeError as e:
if i is not None:
advice = ""
from pyopencl.array import Array
if isinstance(args[i], Array):
advice = " (perhaps you meant to pass 'array.data' " \
"instead of the array itself?)"

raise LogicError(
"when processing argument #%d (1-based): %s%s"
% (i+1, str(e), advice))
else:
raise
# }}}

def kernel_set_args(self, *args, **kwargs):
# Need to dupicate the 'self' argument for dynamically generated method
return self._set_args(self, *args, **kwargs)

def kernel_call(self, queue, global_size, local_size, *args, **kwargs):
# __call__ can't be overridden directly, so we need this
# trampoline hack.
return self._enqueue(self, queue, global_size, local_size, *args, **kwargs)

def kernel_capture_call(self, filename, queue, global_size, local_size,
*args, **kwargs):
Expand All @@ -683,9 +894,13 @@ def kernel_capture_call(self, filename, queue, global_size, local_size,
*args, **kwargs)

Kernel.__init__ = kernel_init
Kernel.__call__ = kernel_call
Kernel._set_set_args_body = kernel__set_set_args_body
Kernel._generate_buffer_arg_setter = kernel__generate_buffer_arg_setter
Kernel._generate_generic_arg_handler = kernel__generate_generic_arg_handler
Kernel._generate_naive_call = kernel__generate_naive_call
Kernel.set_scalar_arg_dtypes = kernel_set_scalar_arg_dtypes
Kernel.set_args = kernel_set_args
Kernel.__call__ = kernel_call
Kernel.capture_call = kernel_capture_call

# }}}
Expand Down Expand Up @@ -842,11 +1057,20 @@ def error_str(self):
except AttributeError:
return str(val)
else:
result = "%s failed: %s" % (val.routine(),
status_code.to_string(val.code(), "<unknown error %d>")
.lower().replace("_", " "))
if val.what():
result += " - " + val.what()
result = ""
if val.code() != status_code.SUCCESS:
result = status_code.to_string(
val.code(), "<unknown error %d>")
routine = val.routine()
if routine:
result = "%s failed: %s" % (
routine.lower().replace("_", " "),
result)
what = val.what()
if what:
if result:
result += " - "
result += what
return result

def error_code(self):
Expand Down

0 comments on commit 602db4a

Please sign in to comment.