Skip to content

Commit

Permalink
Merge pull request #77 from davidweichiang/shape
Browse files Browse the repository at this point in the history
shape-related changes: transpose, newaxis
  • Loading branch information
inducer committed Jul 11, 2015
2 parents e55c48f + 53bcda2 commit 0dfc4a7
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 1 deletion.
72 changes: 71 additions & 1 deletion pycuda/gpuarray.py
Expand Up @@ -192,7 +192,7 @@ def __init__(self, shape, dtype, allocator=drv.mem_alloc,

strides = tuple(strides)

self.shape = shape
self.shape = tuple(shape)
self.dtype = dtype
self.strides = strides
self.mem_size = self.size = s
Expand All @@ -213,6 +213,10 @@ def __init__(self, shape, dtype, allocator=drv.mem_alloc,

self._grid, self._block = splay(self.mem_size)

@property
def ndim(self):
return len(self.shape)

@property
@memoize_method
def flags(self):
Expand Down Expand Up @@ -695,7 +699,13 @@ def astype(self, dtype, stream=None):
return result

def reshape(self, *shape):
"""Gives a new shape to an array without changing its data."""

# TODO: add more error-checking, perhaps
if not self.flags.forc:
raise RuntimeError("only contiguous arrays may "
"be used as arguments to this operation")

if isinstance(shape[0], tuple) or isinstance(shape[0], list):
shape = tuple(shape[0])

Expand Down Expand Up @@ -757,6 +767,36 @@ def view(self, dtype=None):
base=self,
gpudata=int(self.gpudata))

def transpose(self, axes=None):
"""Permute the dimensions of an array.
Parameters
----------
axes : list of ints, optional
By default, reverse the dimensions, otherwise permute the axes
according to the values given.
Returns
-------
p : GPUArray
A view of the array with its axes permuted.
"""

if axes is None:
axes = range(self.ndim-1, -1, -1)
if len(axes) != len(self.shape):
raise ValueError("axes don't match array")
new_shape = [self.shape[axes[i]] for i in xrange(len(axes))]
new_strides = [self.strides[axes[i]] for i in xrange(len(axes))]
return GPUArray(shape=tuple(new_shape),
dtype=self.dtype,
allocator=self.allocator,
base=self.base or self,
gpudata=self.gpudata,
strides=tuple(new_strides))
@property
def T(self): return self.transpose()

# {{{ slicing

def __getitem__(self, index):
Expand Down Expand Up @@ -824,6 +864,11 @@ def __getitem__(self, index):
"more than one ellipsis not allowed in index")
seen_ellipsis = True

elif index_entry is np.newaxis:
new_shape.append(1)
new_strides.append(0)
index_axis += 1

else:
raise IndexError("invalid subindex in axis %d" % index_axis)

Expand Down Expand Up @@ -1232,6 +1277,31 @@ def make_func_for_chunk_size(chunk_size):

# }}}

# {{{ shape manipulation

def transpose(a, axes=None):
"""Permute the dimensions of an array.
Parameters
----------
a : GPUArray
axes : list of ints, optional
By default, reverse the dimensions, otherwise permute the axes
according to the values given.
Returns
-------
p : GPUArray
A view of the array with its axes permuted.
"""
return a.transpose(axes)

def reshape(a, shape):
"""Gives a new shape to an array without changing its data."""

return a.reshape(shape)

# }}}

# {{{ conditionals

Expand Down
24 changes: 24 additions & 0 deletions test/test_gpuarray.py
Expand Up @@ -953,6 +953,30 @@ def test_minimum_maximum_scalar(self):
assert la.norm(max_a0_gpu.get() - np.maximum(a, 0)) == 0
assert la.norm(min_a0_gpu.get() - np.minimum(0, a)) == 0

@mark_cuda_test
def test_transpose(self):
import pycuda.gpuarray as gpuarray
from pycuda.curandom import rand as curand

a_gpu = curand((10,20,30))
a = a_gpu.get()

#assert np.allclose(a_gpu.transpose((1,2,0)).get(), a.transpose((1,2,0))) # not contiguous
assert np.allclose(a_gpu.T.get(), a.T)

@mark_cuda_test
def test_newaxis(self):
import pycuda.gpuarray as gpuarray
from pycuda.curandom import rand as curand

a_gpu = curand((10,20,30))
a = a_gpu.get()

b_gpu = a_gpu[:,np.newaxis]
b = a[:,np.newaxis]

assert b_gpu.shape == b.shape
assert b_gpu.strides == b.strides

if __name__ == "__main__":
# make sure that import failures get reported, instead of skipping the tests.
Expand Down

0 comments on commit 0dfc4a7

Please sign in to comment.