Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
davidweichiang committed Jul 11, 2015
2 parents 110c574 + 09e61e6 commit a6551ff
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 2 deletions.
69 changes: 67 additions & 2 deletions pycuda/gpuarray.py
Expand Up @@ -23,7 +23,7 @@ def _get_common_dtype(obj1, obj2):

# {{{ vector types

class vec:
class vec: # noqa
pass


Expand Down Expand Up @@ -192,7 +192,7 @@ def __init__(self, shape, dtype, allocator=drv.mem_alloc,

strides = tuple(strides)

self.shape = shape
self.shape = tuple(shape)
self.dtype = dtype
self.strides = strides
self.mem_size = self.size = s
Expand All @@ -213,6 +213,10 @@ def __init__(self, shape, dtype, allocator=drv.mem_alloc,

self._grid, self._block = splay(self.mem_size)

@property
def ndim(self):
return len(self.shape)

@property
@memoize_method
def flags(self):
Expand Down Expand Up @@ -707,7 +711,13 @@ def astype(self, dtype, stream=None):
return result

def reshape(self, *shape):
"""Gives a new shape to an array without changing its data."""

# TODO: add more error-checking, perhaps
if not self.flags.forc:
raise RuntimeError("only contiguous arrays may "
"be used as arguments to this operation")

if isinstance(shape[0], tuple) or isinstance(shape[0], list):
shape = tuple(shape[0])

Expand Down Expand Up @@ -769,6 +779,33 @@ def view(self, dtype=None):
base=self,
gpudata=int(self.gpudata))

def transpose(self, axes=None):
"""Permute the dimensions of an array.
:arg axes: list of ints, optional.
By default, reverse the dimensions, otherwise permute the axes
according to the values given.
:returns: :class:`GPUArray` A view of the array with its axes permuted.
"""

if axes is None:
axes = range(self.ndim-1, -1, -1)
if len(axes) != len(self.shape):
raise ValueError("axes don't match array")
new_shape = [self.shape[axes[i]] for i in xrange(len(axes))]
new_strides = [self.strides[axes[i]] for i in xrange(len(axes))]
return GPUArray(shape=tuple(new_shape),
dtype=self.dtype,
allocator=self.allocator,
base=self.base or self,
gpudata=self.gpudata,
strides=tuple(new_strides))

@property
def T(self): # noqa
return self.transpose()

# {{{ slicing

def __getitem__(self, index):
Expand Down Expand Up @@ -836,6 +873,11 @@ def __getitem__(self, index):
"more than one ellipsis not allowed in index")
seen_ellipsis = True

elif index_entry is np.newaxis:
new_shape.append(1)
new_strides.append(0)
index_axis += 1

else:
raise IndexError("invalid subindex in axis %d" % index_axis)

Expand Down Expand Up @@ -1354,6 +1396,29 @@ def make_func_for_chunk_size(chunk_size):
# }}}


# {{{ shape manipulation

def transpose(a, axes=None):
"""Permute the dimensions of an array.
:arg a: :class:`GPUArray`
:arg axes: list of ints, optional.
By default, reverse the dimensions, otherwise permute the axes
according to the values given.
:returns: :class:`GPUArray` A view of the array with its axes permuted.
"""
return a.transpose(axes)


def reshape(a, shape):
"""Gives a new shape to an array without changing its data."""

return a.reshape(shape)

# }}}


# {{{ conditionals

def if_positive(criterion, then_, else_, out=None, stream=None):
Expand Down
24 changes: 24 additions & 0 deletions test/test_gpuarray.py
Expand Up @@ -953,6 +953,30 @@ def test_minimum_maximum_scalar(self):
assert la.norm(max_a0_gpu.get() - np.maximum(a, 0)) == 0
assert la.norm(min_a0_gpu.get() - np.minimum(0, a)) == 0

@mark_cuda_test
def test_transpose(self):
import pycuda.gpuarray as gpuarray
from pycuda.curandom import rand as curand

a_gpu = curand((10,20,30))
a = a_gpu.get()

#assert np.allclose(a_gpu.transpose((1,2,0)).get(), a.transpose((1,2,0))) # not contiguous
assert np.allclose(a_gpu.T.get(), a.T)

@mark_cuda_test
def test_newaxis(self):
import pycuda.gpuarray as gpuarray
from pycuda.curandom import rand as curand

a_gpu = curand((10,20,30))
a = a_gpu.get()

b_gpu = a_gpu[:,np.newaxis]
b = a[:,np.newaxis]

assert b_gpu.shape == b.shape
assert b_gpu.strides == b.strides

if __name__ == "__main__":
# make sure that import failures get reported, instead of skipping the tests.
Expand Down

0 comments on commit a6551ff

Please sign in to comment.