Skip to content

Commit

Permalink
Fix non-contiguous reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Nov 28, 2014
1 parent d1c9724 commit 8553186
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 12 deletions.
98 changes: 95 additions & 3 deletions pyopencl/array.py
Expand Up @@ -1288,20 +1288,112 @@ def reshape(self, *shape, **kwargs):
raise TypeError("unexpected keyword arguments: %s"
% kwargs.keys())

if order not in "CF":
raise ValueError("order must be either 'C' or 'F'")

# TODO: add more error-checking, perhaps

if isinstance(shape[0], tuple) or isinstance(shape[0], list):
shape = tuple(shape[0])

if any(s < 0 for s in shape):
raise NotImplementedError("negative/automatic shapes not supported")

if shape == self.shape:
return self
return self._new_with_changes(
data=self.base_data, offset=self.offset, shape=shape,
strides=self.strides)

size = reduce(lambda x, y: x * y, shape, 1)
import operator
size = reduce(operator.mul, shape, 1)
if size != self.size:
raise ValueError("total size of new array must be unchanged")

# {{{ determine reshaped strides

# copied and translated from
# https://github.com/numpy/numpy/blob/4083883228d61a3b571dec640185b5a5d983bf59/numpy/core/src/multiarray/shape.c # noqa

newdims = shape
newnd = len(newdims)

# Remove axes with dimension 1 from the old array. They have no effect
# but would need special cases since their strides do not matter.

olddims = []
oldstrides = []
for oi in range(len(self.shape)):
s = self.shape[oi]
if s != 1:
olddims.append(s)
oldstrides.append(self.strides[oi])

oldnd = len(olddims)

newstrides = [-1]*len(newdims)

# oi to oj and ni to nj give the axis ranges currently worked with
oi = 0
oj = 1
ni = 0
nj = 1
while ni < newnd and oi < oldnd:
np = newdims[ni]
op = olddims[oi]

while np != op:
if np < op:
# Misses trailing 1s, these are handled later
np *= newdims[nj]
nj += 1
else:
op *= olddims[oj]
oj += 1

# Check whether the original axes can be combined
for ok in range(oi, oj-1):
if order == "F":
if oldstrides[ok+1] != olddims[ok]*oldstrides[ok]:
raise ValueError("cannot reshape without copy")
else:
# C order
if (oldstrides[ok] != olddims[ok+1]*oldstrides[ok+1]):
raise ValueError("cannot reshape without copy")

# Calculate new strides for all axes currently worked with
if order == "F":
newstrides[ni] = oldstrides[oi]
for nk in xrange(ni+1, nj):
newstrides[nk] = newstrides[nk - 1]*newdims[nk - 1]
else:
# C order
newstrides[nj - 1] = oldstrides[oj - 1]
for nk in range(nj-1, ni, -1):
newstrides[nk - 1] = newstrides[nk]*newdims[nk]

ni = nj
nj += 1

oi = oj
oj += 1

# Set strides corresponding to trailing 1s of the new shape.
if ni >= 1:
last_stride = newstrides[ni - 1]
else:
last_stride = self.dtype.itemsize

if order == "F":
last_stride *= newdims[ni - 1]

for nk in range(ni, len(shape)):
newstrides[nk] = last_stride

# }}}

return self._new_with_changes(
data=self.base_data, offset=self.offset, shape=shape,
strides=_make_strides(self.dtype.itemsize, shape, order))
strides=tuple(newstrides))

def ravel(self):
"""Returns flattened array containing the same data."""
Expand Down
17 changes: 8 additions & 9 deletions pyopencl/ipython_ext.py
Expand Up @@ -3,6 +3,7 @@
from IPython.core.magic import (magics_class, Magics, cell_magic, line_magic)

import pyopencl as cl
import sys


def _try_to_utf8(text):
Expand All @@ -14,8 +15,10 @@ def _try_to_utf8(text):
@magics_class
class PyOpenCLMagics(Magics):
def _run_kernel(self, kernel, options):
kernel = _try_to_utf8(kernel)
options = _try_to_utf8(options).strip()
if sys.version_info < (3,):
kernel = _try_to_utf8(kernel)
options = _try_to_utf8(options).strip()

try:
ctx = self.shell.user_ns["cl_ctx"]
except KeyError:
Expand All @@ -34,37 +37,33 @@ def _run_kernel(self, kernel, options):
raise RuntimeError("unable to locate cl context, which must be "
"present in namespace as 'cl_ctx' or 'ctx'")

prg = cl.Program(ctx, kernel).build(options=options)
prg = cl.Program(ctx, kernel).build(options=options.split())

for knl in prg.all_kernels():
self.shell.user_ns[knl.function_name] = knl


@cell_magic
def cl_kernel(self, line, cell):
kernel = cell

opts, args = self.parse_options(line,'o:')
opts, args = self.parse_options(line, 'o:')
build_options = opts.get('o', '')

self._run_kernel(kernel, build_options)


def _load_kernel_and_options(self, line):
opts, args = self.parse_options(line,'o:f:')
opts, args = self.parse_options(line, 'o:f:')

build_options = opts.get('o')
kernel = self.shell.find_user_code(opts.get('f') or args)

return kernel, build_options


@line_magic
def cl_kernel_from_file(self, line):
kernel, build_options = self._load_kernel_and_options(line)
self._run_kernel(kernel, build_options)


@line_magic
def cl_load_edit_kernel(self, line):
kernel, build_options = self._load_kernel_and_options(line)
Expand Down
18 changes: 18 additions & 0 deletions test/test_array.py
Expand Up @@ -719,6 +719,24 @@ def test_view_and_strides(ctx_factory):
assert (y.get() == X.get()[:3, :5]).all()


def test_meshmode_view(ctx_factory):
context = ctx_factory()
queue = cl.CommandQueue(context)

n = 2
result = cl.array.empty(queue, (2, n*6), np.float64)

def view(z):
return z[..., n*3:n*6].reshape(z.shape[:-1] + (n, 3))

result = result.with_queue(queue)
result.fill(0)
view(result)[0].fill(1)
view(result)[1].fill(1)
x = result.get()
assert (view(x) == 1).all()


def test_event_management(ctx_factory):
context = ctx_factory()
queue = cl.CommandQueue(context)
Expand Down

0 comments on commit 8553186

Please sign in to comment.