Navigation Menu

Skip to content

Commit

Permalink
ListOfListsBuilder: add support for omit_lists
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Aug 23, 2016
1 parent 5711c29 commit 0d57440
Showing 1 changed file with 34 additions and 3 deletions.
37 changes: 34 additions & 3 deletions pyopencl/algorithm.py
Expand Up @@ -645,8 +645,11 @@ def __call__(self, *args, **kwargs):
%else:
%for name, dtype in list_names_and_dtypes:
%if name not in count_sharing:
index_type plb_${name}_index =
plb_${name}_start_index[i];
index_type plb_${name}_index;
if (plb_${name}_start_index)
plb_${name}_index = plb_${name}_start_index[i];
else
plb_${name}_index = 0;
%endif
%endfor
%endif
Expand All @@ -656,7 +659,8 @@ def __call__(self, *args, **kwargs):
%if is_count_stage:
%for name, dtype in list_names_and_dtypes:
%if name not in count_sharing:
plb_${name}_count[i] = plb_loc_${name}_count;
if (plb_${name}_count)
plb_${name}_count[i] = plb_loc_${name}_count;
%endif
%endfor
%endif
Expand Down Expand Up @@ -725,6 +729,8 @@ class ListOfListsBuilder:
List entries are generated by calls to `APPEND_<list name>(value)`.
Multiple lists may be generated at once.
.. automethod:: __init__
.. automethod:: __call__
"""
def __init__(self, context, list_names_and_dtypes, generate_template,
arg_decls, count_sharing=None, devices=None,
Expand Down Expand Up @@ -956,6 +962,11 @@ def __call__(self, queue, n_objects, *args, **kwargs):
be passed as their :attr:`pyopencl.array.Array.data` attribute instead.
:arg allocator: optionally, the allocator to use to allocate new
arrays.
:arg omit_lists: An iterable of list names that should *not* be built
with this invocation. The kernel code may *not* call ``APPEND_name``
for these omitted lists. If it does, undefined behavior will result.
The returned *lists* dictionary will not contain an entry for names
in *omit_lists*.
:arg wait_for: |explain-waitfor|
:returns: a tuple ``(lists, event)``, where
*lists* a mapping from (built) list names to objects which
Expand All @@ -972,6 +983,10 @@ def __call__(self, queue, n_objects, *args, **kwargs):
This implies that all lists are contiguous.
*event* is a :class:`pyopencl.Event` for dependency management.
.. versionchanged:: 2016.2
Added omit_lists.
"""
if n_objects >= int(np.iinfo(np.int32).max):
index_dtype = np.int64
Expand All @@ -980,10 +995,15 @@ def __call__(self, queue, n_objects, *args, **kwargs):
index_dtype = np.dtype(index_dtype)

allocator = kwargs.pop("allocator", None)
omit_lists = kwargs.pop("omit_lists", [])
wait_for = kwargs.pop("wait_for", None)
if kwargs:
raise TypeError("invalid keyword arguments: '%s'" % ", ".join(kwargs))

for l in omit_lists:
if not any(l == name for name, _ in self.list_names_and_dtypes):
raise ValueError("invalid list name '%s' in omit_lists")

result = {}
count_list_args = []

Expand All @@ -999,6 +1019,9 @@ def __call__(self, queue, n_objects, *args, **kwargs):
for name, dtype in self.list_names_and_dtypes:
if name in self.count_sharing:
continue
if name in omit_lists:
count_list_args.append(None)
continue

counts = cl.array.empty(queue,
(n_objects + 1), index_dtype, allocator=allocator)
Expand Down Expand Up @@ -1033,6 +1056,8 @@ def __call__(self, queue, n_objects, *args, **kwargs):
for name, dtype in self.list_names_and_dtypes:
if name in self.count_sharing:
continue
if name in omit_lists:
continue

info_record = result[name]
starts_ary = info_record.starts
Expand All @@ -1051,6 +1076,12 @@ def __call__(self, queue, n_objects, *args, **kwargs):

write_list_args = []
for name, dtype in self.list_names_and_dtypes:
if name in omit_lists:
write_list_args.append(None)
if name not in self.count_sharing:
write_list_args.append(None)
continue

if name in self.count_sharing:
sharing_from = self.count_sharing[name]

Expand Down

0 comments on commit 0d57440

Please sign in to comment.