Skip to content

Commit

Permalink
Add support for unknown dimension in reshape (stolen from a PyCUDA pa…
Browse files Browse the repository at this point in the history
…tch by Thomas Unterthiner)
  • Loading branch information
inducer committed Jan 27, 2015
1 parent 23c213e commit fa3a0a2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
10 changes: 8 additions & 2 deletions pyopencl/array.py
Expand Up @@ -1296,8 +1296,14 @@ def reshape(self, *shape, **kwargs):
if isinstance(shape[0], tuple) or isinstance(shape[0], list):
shape = tuple(shape[0])

if any(s < 0 for s in shape):
raise NotImplementedError("negative/automatic shapes not supported")
if -1 in shape:
shape = list(shape)
idx = shape.index(-1)
size = -reduce(lambda x, y: x * y, shape, 1)
shape[idx] = self.size // size
if any(s < 0 for s in shape):
raise ValueError("can only specify one unknown dimension")
shape = tuple(shape)

if shape == self.shape:
return self._new_with_changes(
Expand Down
22 changes: 22 additions & 0 deletions test/test_array.py
Expand Up @@ -776,6 +776,28 @@ def test_event_management(ctx_factory):
assert len(x.events) < 100


def test_reshape(ctx_factory):
context = ctx_factory()
queue = cl.CommandQueue(context)

a = np.arange(128).reshape(8, 16).astype(np.float32)
a_dev = cl_array.to_device(queue, a)

# different ways to specify the shape
a_dev.reshape(4, 32)
a_dev.reshape((4, 32))
a_dev.reshape([4, 32])

# using -1 as unknown dimension
assert a_dev.reshape(-1, 32).shape == (4, 32)
assert a_dev.reshape((32, -1)).shape == (32, 4)
assert a_dev.reshape(((8, -1, 4))).shape == (8, 4, 4)

import pytest
with pytest.raises(ValueError):
a_dev.reshape(-1, -1, 4)


if __name__ == "__main__":
# make sure that import failures get reported, instead of skipping the
# tests.
Expand Down

0 comments on commit fa3a0a2

Please sign in to comment.