Skip to content

Commit

Permalink
More interface restructuring
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Jul 15, 2015
1 parent d60e3c1 commit 2b0ddac
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 46 deletions.
79 changes: 37 additions & 42 deletions pyopencl/bitonic_sort.py
Expand Up @@ -42,6 +42,8 @@
from pytools import memoize_method
from mako.template import Template

import pyopencl.bitonic_sort_templates as _tmpl


def _is_power_of_2(n):
from pyopencl.tools import bitlog2
Expand All @@ -56,36 +58,28 @@ class BitonicSort(object):
.. versionadded:: 2015.2
"""
def __init__(self, context, shape, key_dtype, idx_dtype=None, axis=0):
import pyopencl.bitonic_sort_templates as tmpl

self.cached_defs = {}
self.kernels_srcs = {
'B2': tmpl.ParallelBitonic_B2,
'B4': tmpl.ParallelBitonic_B4,
'B8': tmpl.ParallelBitonic_B8,
'B16': tmpl.ParallelBitonic_B16,
'C4': tmpl.ParallelBitonic_C4,
'BL': tmpl.ParallelBitonic_Local,
'BLO': tmpl.ParallelBitonic_Local_Optim,
'PML': tmpl.ParallelMerge_Local
}

kernels_srcs = {
'B2': _tmpl.ParallelBitonic_B2,
'B4': _tmpl.ParallelBitonic_B4,
'B8': _tmpl.ParallelBitonic_B8,
'B16': _tmpl.ParallelBitonic_B16,
'C4': _tmpl.ParallelBitonic_C4,
'BL': _tmpl.ParallelBitonic_Local,
'BLO': _tmpl.ParallelBitonic_Local_Optim,
'PML': _tmpl.ParallelMerge_Local
}

def __init__(self, context, key_dtype, idx_dtype=None):
self.dtype = dtype_to_ctype(key_dtype)
self.context = context
self.axis = axis
if idx_dtype is None:
self.argsort = 0
self.idx_t = 'uint' # Dummy

else:
self.argsort = 1
self.idx_t = dtype_to_ctype(idx_dtype)

self.defstpl = Template(tmpl.defines)
self.run_queue = self.sort_b_prepare_wl(shape, self.axis)

def __call__(self, arr, idx=None, mkcpy=True, queue=None, wait_for=None):
def __call__(self, arr, idx=None, mkcpy=True, queue=None, wait_for=None, axis=0):
if queue is None:
queue = arr.queue

Expand All @@ -95,18 +89,23 @@ def __call__(self, arr, idx=None, mkcpy=True, queue=None, wait_for=None):

last_evt = cl.enqueue_marker(queue, wait_for=wait_for)

if arr.shape[self.axis] == 0:
if arr.shape[axis] == 0:
return arr, last_evt

if not _is_power_of_2(arr.shape[self.axis]):
if not _is_power_of_2(arr.shape[axis]):
raise ValueError("sorted array axis length must be a power of 2")

arr = arr.copy() if mkcpy else arr

run_queue = self.run_queue
if idx is None:
argsort = 0
else:
argsort = 1

run_queue = self.sort_b_prepare_wl(argsort, arr.shape, axis)
knl, nt, wg, aux = run_queue[0]

if self.argsort and idx is not None:
if idx is not None:
if aux:
last_evt = knl(
queue, (nt,), wg, arr.data, idx.data,
Expand All @@ -118,7 +117,7 @@ def __call__(self, arr, idx=None, mkcpy=True, queue=None, wait_for=None):
queue, (nt,), wg, arr.data, idx.data,
wait_for=[last_evt])

elif not self.argsort:
else:
if aux:
last_evt = knl(
queue, (nt,), wg, arr.data,
Expand All @@ -127,29 +126,24 @@ def __call__(self, arr, idx=None, mkcpy=True, queue=None, wait_for=None):
for knl, nt, wg, _ in run_queue[1:]:
last_evt = knl(queue, (nt,), wg, arr.data, wait_for=[last_evt])

else:
raise ValueError("Array of indexes required for this sorter. If argsort is not needed,\
recreate sorter witout index datatype provided.")
return arr, last_evt

@memoize_method
def get_program(self, letter, params):
if params in self.cached_defs.keys():
defs = self.cached_defs[params]
else:
defs = self.defstpl.render(
NS="\\", argsort=self.argsort, inc=params[0], dir=params[1],
dtype=params[2], idxtype=params[3],
dsize=params[4], nsize=params[5])
def get_program(self, letter, argsort, params):
defstpl = Template(_tmpl.defines)

self.cached_defs[params] = defs
defs = defstpl.render(
NS="\\", argsort=argsort, inc=params[0], dir=params[1],
dtype=params[2], idxtype=params[3],
dsize=params[4], nsize=params[5])

kid = Template(self.kernels_srcs[letter]).render(argsort=self.argsort)
kid = Template(self.kernels_srcs[letter]).render(argsort=argsort)

prg = cl.Program(self.context, defs + kid).build()
return prg

def sort_b_prepare_wl(self, shape, axis):
@memoize_method
def sort_b_prepare_wl(self, argsort, shape, axis):
run_queue = []
ds = int(shape[axis])
size = reduce(mul, shape)
Expand All @@ -164,7 +158,8 @@ def sort_b_prepare_wl(self, shape, axis):

wg = min(ds, self.context.devices[0].max_work_group_size)
length = wg >> 1
prg = self.get_program('BLO', (1, 1, self.dtype, self.idx_t, ds, ns))
prg = self.get_program(
'BLO', argsort, (1, 1, self.dtype, self.idx_t, ds, ns))
run_queue.append((prg.run, size, (wg,), True))

while length < ds:
Expand All @@ -187,7 +182,7 @@ def sort_b_prepare_wl(self, shape, axis):

nthreads = size >> ninc

prg = self.get_program(letter,
prg = self.get_program(letter, argsort,
(inc, direction, self.dtype, self.idx_t, ds, ns))
run_queue.append((prg.run, nthreads, None, False,))
inc >>= ninc
Expand Down
10 changes: 6 additions & 4 deletions test/test_algorithm.py
Expand Up @@ -848,6 +848,7 @@ def test_key_value_sorter(ctx_factory):
np.float32,
# np.float64
])
@pytest.mark.bitonic
def test_bitonic_sort(ctx_factory, size, dtype):
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)
Expand All @@ -856,8 +857,8 @@ def test_bitonic_sort(ctx_factory, size, dtype):
from pyopencl.bitonic_sort import BitonicSort

s = clrandom.rand(queue, (2, size, 3,), dtype, luxury=None, a=0, b=1.0)
sorter = BitonicSort(ctx, s.shape, s.dtype, axis=1)
sgs, evt = sorter(s)
sorter = BitonicSort(ctx, s.dtype)
sgs, evt = sorter(s, axis=1)
assert np.array_equal(np.sort(s.get(), axis=1), sgs.get())


Expand All @@ -872,6 +873,7 @@ def test_bitonic_sort(ctx_factory, size, dtype):
np.float32,
# np.float64
])
@pytest.mark.bitonic
def test_bitonic_argsort(ctx_factory, size, dtype):
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)
Expand All @@ -882,9 +884,9 @@ def test_bitonic_argsort(ctx_factory, size, dtype):
index = cl_array.arange(queue, 0, size, 1, dtype=np.int32)
m = clrandom.rand(queue, (size,), np.float32, luxury=None, a=0, b=1.0)

sorterm = BitonicSort(ctx, m.shape, m.dtype, idx_dtype=index.dtype, axis=0)
sorterm = BitonicSort(ctx, m.dtype, idx_dtype=index.dtype)

ms, evt = sorterm(m, idx=index)
ms, evt = sorterm(m, idx=index, axis=0)

assert np.array_equal(np.sort(m.get()), ms.get())

Expand Down

0 comments on commit 2b0ddac

Please sign in to comment.