Skip to content

Commit

Permalink
Merge pull request #60 from grlee77/struct_demo_fix
Browse files Browse the repository at this point in the history
BUG: fix to demo_struct.py for recent CUDA 4.0+
  • Loading branch information
inducer committed Dec 19, 2014
2 parents 0758692 + c489bb5 commit 9c03e9e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 26 deletions.
36 changes: 19 additions & 17 deletions doc/source/tutorial.rst
Expand Up @@ -20,7 +20,7 @@ Transferring Data
The next step in most programs is to transfer data onto the device.
In PyCuda, you will mostly transfer data from :mod:`numpy` arrays
on the host. (But indeed, everything that satisfies the Python buffer
interface will work, even a :class:`str`.) Let's make a 4x4 array
interface will work, even a :class:`str`.) Let's make a 4x4 array
of random numbers::

import numpy
Expand All @@ -31,7 +31,7 @@ devices only support single precision::

a = a.astype(numpy.float32)

Finally, we need somewhere to transfer data to, so we need to
Finally, we need somewhere to transfer data to, so we need to
allocate memory on the device::

a_gpu = cuda.mem_alloc(a.nbytes)
Expand All @@ -56,8 +56,8 @@ code, and feed it into the constructor of a
}
""")

If there aren't any errors, the code is now compiled and loaded onto the
device. We find a reference to our :class:`pycuda.driver.Function` and call
If there aren't any errors, the code is now compiled and loaded onto the
device. We find a reference to our :class:`pycuda.driver.Function` and call
it, specifying *a_gpu* as the argument, and a block size of 4x4::

func = mod.get_function("doublify")
Expand All @@ -81,9 +81,9 @@ This will print something like this::
[-0.37920788 -0.59378809 1.36134958 1.56078029]
[ 0.14413041 -1.46224082 0.60812396 1.43176913]
[ 0.78825873 0.31750482 1.10785341 -0.22268796]]
It worked! That completes our walkthrough. Thankfully, PyCuda takes
over from here and does all the cleanup for you, so you're done.

It worked! That completes our walkthrough. Thankfully, PyCuda takes
over from here and does all the cleanup for you, so you're done.
Stick around for some bonus material in the next section, though.

(You can find the code for this demo as :file:`examples/demo.py` in the PyCuda
Expand All @@ -109,13 +109,15 @@ to argument types (as designated by Python's standard library :mod:`struct`
module), and then called. This also avoids having to assign explicit argument
sizes using the `numpy.number` classes::

func.prepare("P", block=(4,4,1))
func.prepared_call((1, 1), a_gpu)
grid = (1, 1)
block = (4, 4, 1)
func.prepare("P")
func.prepared_call(grid, block, a_gpu)

Bonus: Abstracting Away the Complications
-----------------------------------------
Using a :class:`pycuda.gpuarray.GPUArray`, the same effect can be

Using a :class:`pycuda.gpuarray.GPUArray`, the same effect can be
achieved with much less writing::

import pycuda.gpuarray as gpuarray
Expand Down Expand Up @@ -144,7 +146,7 @@ length arrays::
int datalen, __padding; // so 64-bit ptrs can be aligned
float *ptr;
};

__global__ void double_array(DoubleOperation *a) {
a = &a[blockIdx.x];
for (int idx = threadIdx.x; idx < a->datalen; idx += blockDim.x) {
Expand All @@ -164,14 +166,14 @@ two arrays are instantiated::
def __init__(self, array, struct_arr_ptr):
self.data = cuda.to_device(array)
self.shape, self.dtype = array.shape, array.dtype
cuda.memcpy_htod(int(struct_arr_ptr), numpy.int32(array.size))
cuda.memcpy_htod(int(struct_arr_ptr) + 8, numpy.intp(int(self.data)))
cuda.memcpy_htod(int(struct_arr_ptr), numpy.getbuffer(numpy.int32(array.size)))
cuda.memcpy_htod(int(struct_arr_ptr) + 8, numpy.getbuffer(numpy.intp(int(self.data))))
def __str__(self):
return str(cuda.from_device(self.data, self.shape, self.dtype))

struct_arr = cuda.mem_alloc(2 * DoubleOpStruct.mem_size)
do2_ptr = int(struct_arr) + DoubleOpStruct.mem_size

array1 = DoubleOpStruct(numpy.array([1, 2, 3], dtype=numpy.float32), struct_arr)
array2 = DoubleOpStruct(numpy.array([0, 4], dtype=numpy.float32), do2_ptr)
print("original arrays", array1, array2)
Expand All @@ -185,7 +187,7 @@ only the second::
func = mod.get_function("double_array")
func(struct_arr, block = (32, 1, 1), grid=(2, 1))
print("doubled arrays", array1, array2)

func(numpy.intp(do2_ptr), block = (32, 1, 1), grid=(1, 1))
print("doubled second only", array1, array2, "\n")

Expand Down
36 changes: 27 additions & 9 deletions examples/demo_struct.py
Expand Up @@ -10,8 +10,16 @@ class DoubleOpStruct:
def __init__(self, array, struct_arr_ptr):
self.data = cuda.to_device(array)
self.shape, self.dtype = array.shape, array.dtype
cuda.memcpy_htod(int(struct_arr_ptr), numpy.int32(array.size))
cuda.memcpy_htod(int(struct_arr_ptr) + 8, numpy.intp(int(self.data)))
"""
numpy.getbuffer() needed due to lack of new-style buffer interface for
scalar numpy arrays as of numpy version 1.9.1
see: https://github.com/inducer/pycuda/pull/60
"""
cuda.memcpy_htod(int(struct_arr_ptr),
numpy.getbuffer(numpy.int32(array.size)))
cuda.memcpy_htod(int(struct_arr_ptr) + 8,
numpy.getbuffer(numpy.intp(int(self.data))))

def __str__(self):
return str(cuda.from_device(self.data, self.shape, self.dtype))
Expand All @@ -33,36 +41,46 @@ def __str__(self):
};
__global__ void double_array(DoubleOperation *a)
__global__ void double_array(DoubleOperation *a)
{
a = a + blockIdx.x;
for (int idx = threadIdx.x; idx < a->datalen; idx += blockDim.x)
for (int idx = threadIdx.x; idx < a->datalen; idx += blockDim.x)
{
float *a_ptr = a->ptr;
a_ptr[idx] *= 2;
}
}
""")
func = mod.get_function("double_array")
func(struct_arr, block = (32, 1, 1), grid=(2, 1))
func(struct_arr, block=(32, 1, 1), grid=(2, 1))

print "doubled arrays"
print array1
print array2

func(numpy.intp(do2_ptr), block = (32, 1, 1), grid=(1, 1))
func(numpy.intp(do2_ptr), block=(32, 1, 1), grid=(1, 1))
print "doubled second only"
print array1
print array2

func.prepare("P", block=(32, 1, 1))
func.prepared_call((2, 1), struct_arr)
if cuda.get_version() < (4, ):
func.prepare("P", block=(32, 1, 1))
func.prepared_call((2, 1), struct_arr)
else:
func.prepare("P")
block = (32, 1, 1)
func.prepared_call((2, 1), block, struct_arr)


print "doubled again"
print array1
print array2

func.prepared_call((1, 1), do2_ptr)
if cuda.get_version() < (4, ):
func.prepared_call((1, 1), do2_ptr)
else:
func.prepared_call((1, 1), block, do2_ptr)


print "doubled second only again"
print array1
Expand Down

0 comments on commit 9c03e9e

Please sign in to comment.