Skip to content

Commit

Permalink
Finish interface tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Jul 15, 2015
1 parent 2b0ddac commit 83445bb
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
46 changes: 33 additions & 13 deletions pyopencl/bitonic_sort.py
Expand Up @@ -57,6 +57,10 @@ class BitonicSort(object):
that is a power of 2.
.. versionadded:: 2015.2
.. seealso:: :class:`pyopencl.algorithm.RadixSort`
.. autofunction:: __call__
"""

kernels_srcs = {
Expand All @@ -70,16 +74,21 @@ class BitonicSort(object):
'PML': _tmpl.ParallelMerge_Local
}

def __init__(self, context, key_dtype, idx_dtype=None):
self.dtype = dtype_to_ctype(key_dtype)
def __init__(self, context):
self.context = context
if idx_dtype is None:
self.idx_t = 'uint' # Dummy

else:
self.idx_t = dtype_to_ctype(idx_dtype)
def __call__(self, arr, idx=None, queue=None, wait_for=None, axis=0):
"""
:arg arr: the array to be sorted. Will be overwritten with the sorted array.
:arg idx: an array of indices to be tracked along with the sorting of *arr*
:arg queue: a :class:`pyopencl.CommandQueue`, defaults to the array's queue
if None
:arg wait_for: a list of :class:`pyopencl.Event` instances or None
:arg axis: the axis of the array by which to sort
:returns: a tuple (sorted_array, event)
"""

def __call__(self, arr, idx=None, mkcpy=True, queue=None, wait_for=None, axis=0):
if queue is None:
queue = arr.queue

Expand All @@ -95,14 +104,17 @@ def __call__(self, arr, idx=None, mkcpy=True, queue=None, wait_for=None, axis=0)
if not _is_power_of_2(arr.shape[axis]):
raise ValueError("sorted array axis length must be a power of 2")

arr = arr.copy() if mkcpy else arr

if idx is None:
argsort = 0
else:
argsort = 1

run_queue = self.sort_b_prepare_wl(argsort, arr.shape, axis)
run_queue = self.sort_b_prepare_wl(
argsort,
arr.dtype,
idx.dtype if idx is not None else None, arr.shape,
axis)

knl, nt, wg, aux = run_queue[0]

if idx is not None:
Expand Down Expand Up @@ -143,7 +155,15 @@ def get_program(self, letter, argsort, params):
return prg

@memoize_method
def sort_b_prepare_wl(self, argsort, shape, axis):
def sort_b_prepare_wl(self, argsort, key_dtype, idx_dtype, shape, axis):
key_ctype = dtype_to_ctype(key_dtype)

if idx_dtype is None:
idx_ctype = 'uint' # Dummy

else:
idx_ctype = dtype_to_ctype(idx_dtype)

run_queue = []
ds = int(shape[axis])
size = reduce(mul, shape)
Expand All @@ -159,7 +179,7 @@ def sort_b_prepare_wl(self, argsort, shape, axis):
wg = min(ds, self.context.devices[0].max_work_group_size)
length = wg >> 1
prg = self.get_program(
'BLO', argsort, (1, 1, self.dtype, self.idx_t, ds, ns))
'BLO', argsort, (1, 1, key_ctype, idx_ctype, ds, ns))
run_queue.append((prg.run, size, (wg,), True))

while length < ds:
Expand All @@ -183,7 +203,7 @@ def sort_b_prepare_wl(self, argsort, shape, axis):
nthreads = size >> ninc

prg = self.get_program(letter, argsort,
(inc, direction, self.dtype, self.idx_t, ds, ns))
(inc, direction, key_ctype, idx_ctype, ds, ns))
run_queue.append((prg.run, nthreads, None, False,))
inc >>= ninc

Expand Down
8 changes: 4 additions & 4 deletions test/test_algorithm.py
Expand Up @@ -857,8 +857,8 @@ def test_bitonic_sort(ctx_factory, size, dtype):
from pyopencl.bitonic_sort import BitonicSort

s = clrandom.rand(queue, (2, size, 3,), dtype, luxury=None, a=0, b=1.0)
sorter = BitonicSort(ctx, s.dtype)
sgs, evt = sorter(s, axis=1)
sorter = BitonicSort(ctx)
sgs, evt = sorter(s.copy(), axis=1)
assert np.array_equal(np.sort(s.get(), axis=1), sgs.get())


Expand All @@ -884,9 +884,9 @@ def test_bitonic_argsort(ctx_factory, size, dtype):
index = cl_array.arange(queue, 0, size, 1, dtype=np.int32)
m = clrandom.rand(queue, (size,), np.float32, luxury=None, a=0, b=1.0)

sorterm = BitonicSort(ctx, m.dtype, idx_dtype=index.dtype)
sorterm = BitonicSort(ctx)

ms, evt = sorterm(m, idx=index, axis=0)
ms, evt = sorterm(m.copy(), idx=index, axis=0)

assert np.array_equal(np.sort(m.get()), ms.get())

Expand Down

0 comments on commit 83445bb

Please sign in to comment.