Skip to content

Commit

Permalink
Merge branch 'listoflistsbuilder-add-back-memoryobject-support' into …
Browse files Browse the repository at this point in the history
…'master'

ListOfListsBuilder: Add back support for MemoryObject arguments

See merge request inducer/pyopencl!63
  • Loading branch information
inducer committed Nov 26, 2018
2 parents f8005bb + 9f00c62 commit 7c35c3a
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 20 deletions.
37 changes: 31 additions & 6 deletions pyopencl/algorithm.py
Expand Up @@ -1042,8 +1042,11 @@ def get_write_kernel(self, index_dtype):
def __call__(self, queue, n_objects, *args, **kwargs):
"""
:arg args: arguments corresponding to arg_decls in the constructor.
:class:`pyopencl.array.Array` are not allowed directly and should
be passed as their :attr:`pyopencl.array.Array.data` attribute instead.
Array-like arguments must be either
1D :class:`pyopencl.array.Array` objects or
:class:`pyopencl.MemoryObject` objects, of which the latter
can be obtained from a :class:`pyopencl.array.Array` using the
:attr:`pyopencl.array.Array.data` attribute.
:arg allocator: optionally, the allocator to use to allocate new
arrays.
:arg omit_lists: An iterable of list names that should *not* be built
Expand Down Expand Up @@ -1111,8 +1114,30 @@ def __call__(self, queue, n_objects, *args, **kwargs):
if self.eliminate_empty_output_lists:
compress_kernel = self.get_compress_kernel(index_dtype)

from pyopencl.tools import expand_runtime_arg_list
args = expand_runtime_arg_list(self.arg_decls, args)
data_args = []
for i, (arg_descr, arg_val) in enumerate(zip(self.arg_decls, args)):
from pyopencl.tools import VectorArg
if isinstance(arg_descr, VectorArg):
from pyopencl import MemoryObject
if isinstance(arg_val, MemoryObject):
data_args.append(arg_val)
if arg_descr.with_offset:
raise ValueError(
"with_offset=True specified for argument %d "
"but the argument is not an array" % i)
continue

if arg_val.ndim != 1:
raise ValueError("argument %d is a multidimensional array" % i)

data_args.append(arg_val.base_data)
if arg_descr.with_offset:
data_args.append(arg_val.offset)
else:
data_args.append(arg_val)

del args
data_args = tuple(data_args)

# {{{ allocate memory for counts

Expand Down Expand Up @@ -1151,7 +1176,7 @@ def __call__(self, queue, n_objects, *args, **kwargs):
gsize, lsize = splay(queue, n_objects)

count_event = count_kernel(queue, gsize, lsize,
*(tuple(count_list_args) + args + (n_objects,)),
*(tuple(count_list_args) + data_args + (n_objects,)),
**dict(wait_for=wait_for))

compress_events = {}
Expand Down Expand Up @@ -1257,7 +1282,7 @@ def __call__(self, queue, n_objects, *args, **kwargs):
# }}}

evt = write_kernel(queue, gsize, lsize,
*(tuple(write_list_args) + args + (n_objects,)),
*(tuple(write_list_args) + data_args + (n_objects,)),
**dict(wait_for=scan_events))

return result, evt
Expand Down
11 changes: 9 additions & 2 deletions pyopencl/scan.py
Expand Up @@ -1480,8 +1480,15 @@ def __call__(self, *args, **kwargs):
# We're done here. (But pretend to return an event.)
return cl.enqueue_marker(queue, wait_for=wait_for)

from pyopencl.tools import expand_runtime_arg_list
data_args = list(expand_runtime_arg_list(self.parsed_args, args))
data_args = []
for arg_descr, arg_val in zip(self.parsed_args, args):
from pyopencl.tools import VectorArg
if isinstance(arg_descr, VectorArg):
data_args.append(arg_val.base_data)
if arg_descr.with_offset:
data_args.append(arg_val.offset)
else:
data_args.append(arg_val)

# }}}

Expand Down
12 changes: 0 additions & 12 deletions pyopencl/tools.py
Expand Up @@ -400,18 +400,6 @@ def get_arg_offset_adjuster_code(arg_types):

return "\n".join(result)


def expand_runtime_arg_list(args, user_args):
data_args = []
for arg_descr, arg_val in zip(args, user_args):
if isinstance(arg_descr, VectorArg):
data_args.append(arg_val.base_data)
if arg_descr.with_offset:
data_args.append(arg_val.offset)
else:
data_args.append(arg_val)
return tuple(data_args)

# }}}


Expand Down
25 changes: 25 additions & 0 deletions test/test_algorithm.py
Expand Up @@ -880,6 +880,31 @@ def test_list_builder(ctx_factory):
assert (inf.lists.get()[-6:] == [1, 2, 2, 3, 3, 3]).all()


def test_list_builder_with_memoryobject(ctx_factory):
from pytest import importorskip
importorskip("mako")

context = ctx_factory()
queue = cl.CommandQueue(context)

from pyopencl.algorithm import ListOfListsBuilder
from pyopencl.tools import VectorArg
builder = ListOfListsBuilder(context, [("mylist", np.int32)], """//CL//
void generate(LIST_ARG_DECL USER_ARG_DECL index_type i)
{
APPEND_mylist(input_list[i]);
}
""", arg_decls=[VectorArg(float, "input_list")])

n = 10000
input_list = cl.array.zeros(queue, (n,), float)
result, evt = builder(queue, n, input_list.data)

inf = result["mylist"]
assert inf.count == n
assert (inf.lists.get() == 0).all()


def test_list_builder_with_offset(ctx_factory):
from pytest import importorskip
importorskip("mako")
Expand Down

0 comments on commit 7c35c3a

Please sign in to comment.