Skip to content

Commit

Permalink
Fix creation of sub-buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Nov 19, 2015
1 parent 23ba40d commit 00f8fda
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 35 deletions.
19 changes: 12 additions & 7 deletions pyopencl/cffi_cl.py
Expand Up @@ -945,16 +945,21 @@ def get_sub_region(self, origin, size, flags=0):
size, flags))
sub_buf = self._create(_sub_buf[0])
MemoryObject.__init__(sub_buf, None)
return sub_buf

def __getitem__(self, idx):
if not (idx.step == 1 or idx.stop is None):
raise RuntimeError("Buffer slice must have stride 1",
if not isinstance(idx, slice):
raise TypeError("buffer subscript must be a slice object")

start, stop, stride = idx.indices(self.size)
if stride != 1:
raise ValueError("Buffer slice must have stride 1",
status_code.INVALID_VALUE, "Buffer.__getitem__")
_ret = _ffi.new('clobj_t*')
_handle_error(_lib.buffer__getitem(
_ret, self.ptr, idx.start or 0, idx.stop or 0))
ret = self._create(_ret[0])
MemoryObject.__init__(ret, None)

assert start <= stop

size = stop - start
return self.get_sub_region(start, size)

# }}}

Expand Down
27 changes: 0 additions & 27 deletions src/c_wrapper/buffer.cpp
Expand Up @@ -23,33 +23,6 @@ buffer::get_sub_region(size_t orig, size_t size, cl_mem_flags flags) const
});
return new_buffer(mem);
}

PYOPENCL_USE_RESULT buffer*
buffer::getitem(ssize_t start, ssize_t end) const
{
ssize_t length;
pyopencl_call_guarded(clGetMemObjectInfo, this, CL_MEM_SIZE,
size_arg(length), nullptr);
if (PYOPENCL_UNLIKELY(length <= 0))
throw clerror("Buffer.__getitem__", CL_INVALID_VALUE,
"Cannot get the length of the buffer.");
if (end == 0 || end > length) {
end = length;
} else if (end < 0) {
end += length;
}
if (start < 0) {
start += length;
}
if (end <= start || start < 0)
throw clerror("Buffer.__getitem__", CL_INVALID_VALUE,
"Buffer slice should have end > start >= 0");
cl_mem_flags flags;
pyopencl_call_guarded(clGetMemObjectInfo, this, CL_MEM_FLAGS,
size_arg(flags), nullptr);
flags &= ~CL_MEM_COPY_HOST_PTR;
return get_sub_region((size_t)start, (size_t)(end - start), flags);
}
#endif

// c wrapper
Expand Down
1 change: 0 additions & 1 deletion src/c_wrapper/buffer.h
Expand Up @@ -17,7 +17,6 @@ class buffer : public memory_object {
#if PYOPENCL_CL_VERSION >= 0x1010
PYOPENCL_USE_RESULT buffer *get_sub_region(size_t orig, size_t size,
cl_mem_flags flags) const;
PYOPENCL_USE_RESULT buffer *getitem(ssize_t start, ssize_t end) const;
#endif
};

Expand Down
28 changes: 28 additions & 0 deletions test/test_wrapper.py
Expand Up @@ -863,6 +863,34 @@ def test_global_offset(ctx_factory):
assert (a_2 == 2*a).all()


def test_sub_buffers(ctx_factory):
ctx = ctx_factory()
if (ctx._get_cl_version() < (1, 1) and
cl.get_cl_header_version() < (1, 1)):
from pytest import skip
skip("sub-buffers are only available in OpenCL 1.1")

alignment = ctx.devices[0].mem_base_addr_align

queue = cl.CommandQueue(ctx)

n = 30000
a = (np.random.rand(n) * 100).astype(np.uint8)

mf = cl.mem_flags
a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=a)

start = (5000 // alignment) * alignment
stop = start + 20 * alignment

a_sub_ref = a[start:stop]

a_sub = np.empty_like(a_sub_ref)
cl.enqueue_copy(queue, a_sub, a_buf[start:stop])

assert np.array_equal(a_sub, a_sub_ref)


if __name__ == "__main__":
# make sure that import failures get reported, instead of skipping the tests.
import pyopencl # noqa
Expand Down

0 comments on commit 00f8fda

Please sign in to comment.