Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Allow passing an allocator to ReductionKernel (patch by Simon Perkins)
  • Loading branch information
inducer committed Oct 16, 2014
1 parent e82424c commit c0fad50
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 17 deletions.
2 changes: 1 addition & 1 deletion doc/source/array.rst
Expand Up @@ -1079,7 +1079,7 @@ Custom Reductions

.. module:: pycuda.reduction

.. class:: ReductionKernel(dtype_out, neutral, reduce_expr, map_expr=None, arguments=None, name="reduce_kernel", keep=False, options=[], preamble="")
.. class:: ReductionKernel(dtype_out, neutral, reduce_expr, map_expr=None, arguments=None, name="reduce_kernel", keep=False, options=[], preamble="", allocator=None)

Generate a kernel that takes a number of scalar or vector *arguments*
(at least one vector argument), performs the *map_expr* on each entry of
Expand Down
16 changes: 8 additions & 8 deletions pycuda/gpuarray.py
Expand Up @@ -98,7 +98,7 @@ def _splay_backend(n, dev):
block_count = max_blocks
threads_per_block = max_threads

#print "n:%d bc:%d tpb:%d" % (n, block_count, threads_per_block)
# print "n:%d bc:%d tpb:%d" % (n, block_count, threads_per_block)
return (block_count, 1), (threads_per_block, 1, 1)


Expand Down Expand Up @@ -1278,30 +1278,30 @@ def f(a, b, out=None, stream=None):

# {{{ reductions

def sum(a, dtype=None, stream=None):
def sum(a, dtype=None, stream=None, allocator=None):
from pycuda.reduction import get_sum_kernel
krnl = get_sum_kernel(dtype, a.dtype)
return krnl(a, stream=stream)
return krnl(a, stream=stream, allocator=allocator)


def subset_sum(subset, a, dtype=None, stream=None):
def subset_sum(subset, a, dtype=None, stream=None, allocator=None):
from pycuda.reduction import get_subset_sum_kernel
krnl = get_subset_sum_kernel(dtype, subset.dtype, a.dtype)
return krnl(subset, a, stream=stream)


def dot(a, b, dtype=None, stream=None):
def dot(a, b, dtype=None, stream=None, allocator=None):
from pycuda.reduction import get_dot_kernel
if dtype is None:
dtype = _get_common_dtype(a, b)
krnl = get_dot_kernel(dtype, a.dtype, b.dtype)
return krnl(a, b, stream=stream)
return krnl(a, b, stream=stream, allocator=allocator)


def subset_dot(subset, a, b, dtype=None, stream=None):
def subset_dot(subset, a, b, dtype=None, stream=None, allocator=None):
from pycuda.reduction import get_subset_dot_kernel
krnl = get_subset_dot_kernel(dtype, subset.dtype, a.dtype, b.dtype)
return krnl(subset, a, b, stream=stream)
return krnl(subset, a, b, stream=stream, allocator=allocator)


def _make_minmax_kernel(what):
Expand Down
15 changes: 7 additions & 8 deletions pycuda/reduction.py
Expand Up @@ -58,16 +58,11 @@
source code with only those rights set forth herein.
"""




from pycuda.tools import context_dependent_memoize
from pycuda.tools import dtype_to_ctype
import numpy as np




def get_reduction_module(out_type, block_size,
neutral, reduce_expr, map_expr, arguments,
name="reduce_kernel", keep=False, options=None, preamble=""):
Expand Down Expand Up @@ -257,6 +252,10 @@ def __call__(self, *args, **kwargs):
repr_vec = vectors[0]
sz = repr_vec.size

allocator = kwargs.get("allocator", None)
if allocator is None:
allocator = repr_vec.allocator

if sz <= self.block_size*SMALL_SEQ_COUNT*MAX_BLOCK_COUNT:
total_block_size = SMALL_SEQ_COUNT*self.block_size
block_count = (sz + total_block_size - 1) // total_block_size
Expand All @@ -267,13 +266,13 @@ def __call__(self, *args, **kwargs):
seq_count = (sz + macroblock_size - 1) // macroblock_size

if block_count == 1:
result = empty((), self.dtype_out, repr_vec.allocator)
result = empty((), self.dtype_out, allocator=allocator)
else:
result = empty((block_count,), self.dtype_out, repr_vec.allocator)
result = empty((block_count,), self.dtype_out, allocator=allocator)

kwargs = dict(shared_size=self.block_size*self.dtype_out.itemsize)

#print block_count, seq_count, self.block_size, sz
# print block_count, seq_count, self.block_size, sz
f((block_count, 1), (self.block_size, 1, 1), stream,
*([result.gpudata]+invocation_args+[seq_count, sz]),
**kwargs)
Expand Down
46 changes: 46 additions & 0 deletions test/test_gpuarray.py
Expand Up @@ -840,6 +840,52 @@ def test_struct_reduce(self):
assert minmax["cur_min"] == np.min(a)
assert minmax["cur_max"] == np.max(a)

@mark_cuda_test
def test_sum_allocator(self):
import pycuda.tools
pool = pycuda.tools.DeviceMemoryPool()

rng = np.random.randint(low=512,high=1024)

a = gpuarray.arange(rng,dtype=np.int32)
b = gpuarray.sum(a)
c = gpuarray.sum(a, allocator=pool.allocate)

# Test that we get the correct results
assert b.get() == rng*(rng-1)//2
assert c.get() == rng*(rng-1)//2

# Test that result arrays were allocated with the appropriate allocator
assert b.allocator == a.allocator
assert c.allocator == pool.allocate

@mark_cuda_test
def test_dot_allocator(self):
import pycuda.tools
pool = pycuda.tools.DeviceMemoryPool()

a_cpu = np.random.randint(low=512,high=1024,size=1024)
b_cpu = np.random.randint(low=512,high=1024,size=1024)

# Compute the result on the CPU
dot_cpu_1 = np.dot(a_cpu, b_cpu)

a_gpu = gpuarray.to_gpu(a_cpu)
b_gpu = gpuarray.to_gpu(b_cpu)

# Compute the result on the GPU using different allocators
dot_gpu_1 = gpuarray.dot(a_gpu, b_gpu)
dot_gpu_2 = gpuarray.dot(a_gpu, b_gpu, allocator=pool.allocate)

# Test that we get the correct results
assert dot_cpu_1 == dot_gpu_1.get()
assert dot_cpu_1 == dot_gpu_2.get()

# Test that result arrays were allocated with the appropriate allocator
assert dot_gpu_1.allocator == a_gpu.allocator
assert dot_gpu_2.allocator == pool.allocate


@mark_cuda_test
def test_view_and_strides(self):
from pycuda.curandom import rand as curand
Expand Down

0 comments on commit c0fad50

Please sign in to comment.