Skip to content

Commit

Permalink
Merge pull request #84 from sjperkins/gpudata-squeeze
Browse files Browse the repository at this point in the history
Add a GPUArray.squeeze() method.
  • Loading branch information
inducer committed Aug 18, 2015
2 parents 84904de + d8fd1fa commit 54a9075
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
17 changes: 17 additions & 0 deletions pycuda/gpuarray.py
Expand Up @@ -742,6 +742,23 @@ def view(self, dtype=None):
base=self,
gpudata=int(self.gpudata))

def squeeze(self):
"""
Returns a view of the array with dimensions of
length 1 removed
"""
new_shape = tuple([dim for dim in self.shape if dim > 1])
new_strides = tuple([self.strides[i]
for i, dim in enumerate(self.shape) if dim > 1])

return GPUArray(
shape=new_shape,
dtype=self.dtype,
allocator=self.allocator,
strides=new_strides,
base=self,
gpudata=int(self.gpudata))

def transpose(self, axes=None):
"""Permute the dimensions of an array.
Expand Down
32 changes: 32 additions & 0 deletions test/test_gpuarray.py
Expand Up @@ -790,6 +790,38 @@ def test_view(self):
view = a_gpu.view(np.int16)
assert view.shape == (8, 32) and view.dtype == np.int16

@mark_cuda_test
def test_squeeze(self):
shape = (40, 2, 5, 100)
a_cpu = np.random.random(size=shape)
a_gpu = gpuarray.to_gpu(a_cpu)

# Slice with length 1 on dimensions 0 and 1
a_gpu_slice = a_gpu[0:1,1:2,:,:]
assert a_gpu_slice.shape == (1,1,shape[2],shape[3])
assert a_gpu_slice.flags.c_contiguous is False

# Squeeze it and obtain contiguity
a_gpu_squeezed_slice = a_gpu[0:1,1:2,:,:].squeeze()
assert a_gpu_squeezed_slice.shape == (shape[2],shape[3])
assert a_gpu_squeezed_slice.flags.c_contiguous is True

# Check that we get the original values out
assert np.all(a_gpu_slice.get().ravel() == a_gpu_squeezed_slice.get().ravel())

# Slice with length 1 on dimensions 2
a_gpu_slice = a_gpu[:,:,2:3,:]
assert a_gpu_slice.shape == (shape[0],shape[1],1,shape[3])
assert a_gpu_slice.flags.c_contiguous is False

# Squeeze it, but no contiguity here
a_gpu_squeezed_slice = a_gpu[:,:,2:3,:].squeeze()
assert a_gpu_squeezed_slice.shape == (shape[0],shape[1],shape[3])
assert a_gpu_squeezed_slice.flags.c_contiguous is False

# Check that we get the original values out
assert np.all(a_gpu_slice.get().ravel() == a_gpu_squeezed_slice.get().ravel())

@mark_cuda_test
def test_struct_reduce(self):
preamble = """
Expand Down

0 comments on commit 54a9075

Please sign in to comment.