Skip to content

Commit

Permalink
Make radix sort work with debug scan kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed May 4, 2016
1 parent a1ff700 commit cdf69f6
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 12 deletions.
33 changes: 26 additions & 7 deletions pyopencl/scan.py
Expand Up @@ -764,7 +764,7 @@ def _round_down_to_power_of_2(val):
group_base seg_end my_val DEBUG ARGS
ints_to_store ints_per_wg scan_types_per_int linear_index
linear_scan_data_idx dest src store_base wrapped_scan_type
dummy
dummy scan_tmp
LID_2 LID_1 LID_0
LDIM_0 LDIM_1 LDIM_2
Expand Down Expand Up @@ -1419,11 +1419,12 @@ def __call__(self, *args, **kwargs):
KERNEL
REQD_WG_SIZE(1, 1, 1)
void ${name_prefix}_debug_scan(
__global scan_type *scan_tmp,
${argument_signature},
const index_type N)
{
scan_type item = ${neutral};
scan_type last_item;
scan_type current = ${neutral};
scan_type prev;
for (index_type i = 0; i < N; ++i)
{
Expand All @@ -1439,18 +1440,31 @@ def __call__(self, *args, **kwargs):
scan_type my_val = INPUT_EXPR(i);
last_item = item;
prev = current;
%if is_segmented:
bool is_seg_start = IS_SEG_START(i, my_val);
%endif
item = SCAN_EXPR(last_item, my_val,
current = SCAN_EXPR(prev, my_val,
%if is_segmented:
is_seg_start
%else:
false
%endif
);
scan_tmp[i] = current;
}
scan_type last_item = scan_tmp[N-1];
for (index_type i = 0; i < N; ++i)
{
scan_type item = scan_tmp[i];
scan_type prev_item;
if (i)
prev_item = scan_tmp[i-1];
else
prev_item = ${neutral};
{
${output_statement};
Expand All @@ -1477,7 +1491,8 @@ def finish_setup(self):
self.kernel = getattr(
scan_prg, self.name_prefix+"_debug_scan")
scalar_arg_dtypes = (
get_arg_list_scalar_arg_dtypes(self.parsed_args)
[None]
+ get_arg_list_scalar_arg_dtypes(self.parsed_args)
+ [self.index_dtype])
self.kernel.set_scalar_arg_dtypes(scalar_arg_dtypes)

Expand All @@ -1500,7 +1515,11 @@ def __call__(self, *args, **kwargs):
if n is None:
n, = first_array.shape

data_args = []
scan_tmp = cl.array.empty(queue,
n, dtype=self.dtype,
allocator=allocator)

data_args = [scan_tmp.data]
from pyopencl.tools import VectorArg
for arg_descr, arg_val in zip(self.parsed_args, args):
if isinstance(arg_descr, VectorArg):
Expand Down
14 changes: 9 additions & 5 deletions test/test_algorithm.py
Expand Up @@ -38,7 +38,8 @@
from pyopencl.tools import ( # noqa
pytest_generate_tests_for_pyopencl as pytest_generate_tests)
from pyopencl.characterize import has_double_support, has_struct_arg_count_bug
from pyopencl.scan import InclusiveScanKernel, ExclusiveScanKernel
from pyopencl.scan import (InclusiveScanKernel, ExclusiveScanKernel,
GenericScanKernel, GenericDebugScanKernel)


# {{{ elementwise
Expand Down Expand Up @@ -668,7 +669,6 @@ def test_segmented_scan(ctx_factory):
else:
output_statement = "out[i] = item"

from pyopencl.scan import GenericScanKernel
knl = GenericScanKernel(context, dtype,
arguments="__global %s *ary, __global char *segflags, "
"__global %s *out" % (ctype, ctype),
Expand Down Expand Up @@ -748,7 +748,8 @@ def test_segmented_scan(ctx_factory):
print("%d excl:%s done" % (n, is_exclusive))


def test_sort(ctx_factory):
@pytest.mark.parametrize("scan_kernel", [GenericScanKernel, GenericDebugScanKernel])
def test_sort(ctx_factory, scan_kernel):
from pytest import importorskip
importorskip("mako")

Expand All @@ -759,7 +760,7 @@ def test_sort(ctx_factory):

from pyopencl.algorithm import RadixSort
sort = RadixSort(context, "int *ary", key_expr="ary[i]",
sort_arg_names=["ary"])
sort_arg_names=["ary"], scan_kernel=scan_kernel)

from pyopencl.clrandom import RanluxGenerator
rng = RanluxGenerator(queue, seed=15)
Expand All @@ -768,6 +769,9 @@ def test_sort(ctx_factory):

# intermediate arrays for largest size cause out-of-memory on low-end GPUs
for n in scan_test_counts[:-1]:
if n >= 2000 and isinstance(scan_kernel, GenericDebugScanKernel):
continue

print(n)

print(" rng")
Expand All @@ -785,7 +789,7 @@ def test_sort(ctx_factory):

numpy_elapsed = numpy_end-dev_end
dev_elapsed = dev_end-dev_start
print (" dev: %.2f MKeys/s numpy: %.2f MKeys/s ratio: %.2fx" % (
print(" dev: %.2f MKeys/s numpy: %.2f MKeys/s ratio: %.2fx" % (
1e-6*n/dev_elapsed, 1e-6*n/numpy_elapsed, numpy_elapsed/dev_elapsed))
assert (a_dev_sorted.get() == a_sorted).all()

Expand Down

0 comments on commit cdf69f6

Please sign in to comment.