Skip to content

Commit

Permalink
Merge pull request #105 from zamorays/surfANDtexExtensions
Browse files Browse the repository at this point in the history
Surf an dtex extensions
  • Loading branch information
inducer committed Mar 11, 2016
2 parents 2b16d1c + 88c4bed commit 1a7ba64
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 8 deletions.
29 changes: 28 additions & 1 deletion pycuda/cuda/pycuda-helpers.hpp
Expand Up @@ -6,7 +6,7 @@
extern "C++" {
// "double-precision" textures ------------------------------------------------
/* Thanks to Nathan Bell <nbell@nvidia.com> for help in figuring this out. */

typedef float fp_tex_float;
typedef int2 fp_tex_double;
typedef uint2 fp_tex_cfloat;
Expand Down Expand Up @@ -56,6 +56,28 @@ extern "C++" {
fp_tex_cdouble v = tex2D(tex, i, j);
return pycuda::complex<double>(__hiloint2double(v.y, v.x), __hiloint2double(v.w, v.z));
}
// 2D Layered extension

template <enum cudaTextureReadMode read_mode>
__device__ double fp_tex2DLayered(texture<fp_tex_double, cudaTextureType2DLayered, read_mode> tex, float i, float j, int layer)
{
fp_tex_double v = tex2DLayered(tex, i, j, layer);
return __hiloint2double(v.y, v.x);
}

template <enum cudaTextureReadMode read_mode>
__device__ pycuda::complex<float> fp_tex2DLayered(texture<fp_tex_cfloat, cudaTextureType2DLayered, read_mode> tex, float i, float j, int layer)
{
fp_tex_cfloat v = tex2DLayered(tex, i, j, layer);
return pycuda::complex<float>(__int_as_float(v.x), __int_as_float(v.y));
}

template <enum cudaTextureReadMode read_mode>
__device__ pycuda::complex<double> fp_tex2DLayered(texture<fp_tex_cdouble, cudaTextureType2DLayered, read_mode> tex, float i, float j, int layer)
{
fp_tex_cdouble v = tex2DLayered(tex, i, j, layer);
return pycuda::complex<double>(__hiloint2double(v.y, v.x), __hiloint2double(v.w, v.z));
}

// 3D functionality

Expand Down Expand Up @@ -189,6 +211,11 @@ extern "C++" {
{ \
return tex2D(tex, i, j); \
} \
template <enum cudaTextureReadMode read_mode> \
__device__ TYPE fp_tex2DLayered(texture<TYPE, cudaTextureType2DLayered, read_mode> tex, int i, int j, int layer) \
{ \
return tex2DLayered(tex, i, j, layer); \
} \
template <enum cudaTextureReadMode read_mode> \
__device__ TYPE fp_tex3D(texture<TYPE, 3, read_mode> tex, int i, int j, int k) \
{ \
Expand Down
71 changes: 64 additions & 7 deletions test/test_driver.py
Expand Up @@ -354,6 +354,50 @@ def test_2d_fp_textures(self):
assert np.sum(np.abs(A_gpu.get()-np.transpose(A_cpu))) == np.array(0,dtype=prec)
A_gpu.gpudata.free()

@mark_cuda_test
def test_2d_fp_texturesLayered(self):
orden = "F"
npoints = 32

for prec in [np.int16,np.float32,np.float64,np.complex64,np.complex128]:
prec_str = dtype_to_ctype(prec)
if prec == np.complex64: fpName_str = 'fp_tex_cfloat'
elif prec == np.complex128: fpName_str = 'fp_tex_cdouble'
elif prec == np.float64: fpName_str = 'fp_tex_double'
else: fpName_str = prec_str
A_cpu = np.zeros([npoints,npoints],order=orden,dtype=prec)
A_cpu[:] = np.random.rand(npoints,npoints)[:]
A_gpu = gpuarray.zeros(A_cpu.shape,dtype=prec,order=orden)

myKern = '''
#include <pycuda-helpers.hpp>
texture<fpName, cudaTextureType2DLayered, cudaReadModeElementType> mtx_tex;
__global__ void copy_texture(cuPres *dest)
{
int row = blockIdx.x*blockDim.x + threadIdx.x;
int col = blockIdx.y*blockDim.y + threadIdx.y;
dest[row + col*blockDim.x*gridDim.x] = fp_tex2DLayered(mtx_tex, col, row, 1);
}
'''
myKern = myKern.replace('fpName',fpName_str)
myKern = myKern.replace('cuPres',prec_str)
mod = SourceModule(myKern)

copy_texture = mod.get_function("copy_texture")
mtx_tex = mod.get_texref("mtx_tex")
cuBlock = (16,16,1)
if cuBlock[0]>npoints:
cuBlock = (npoints,npoints,1)
cuGrid = (npoints//cuBlock[0]+1*(npoints % cuBlock[0] != 0 ),npoints//cuBlock[1]+1*(npoints % cuBlock[1] != 0 ),1)
copy_texture.prepare('P',texrefs=[mtx_tex])
cudaArray = drv.np_to_array(A_cpu,orden,allowSurfaceBind=True)
mtx_tex.set_array(cudaArray)
copy_texture.prepared_call(cuGrid,cuBlock,A_gpu.gpudata)
assert np.sum(np.abs(A_gpu.get()-np.transpose(A_cpu))) == np.array(0,dtype=prec)
A_gpu.gpudata.free()

@mark_cuda_test
def test_3d_fp_textures(self):
orden = "C"
Expand Down Expand Up @@ -400,17 +444,30 @@ def test_3d_fp_textures(self):

@mark_cuda_test
def test_3d_fp_surfaces(self):
orden = "F"
orden = "C"
npoints = 32

for prec in [np.int16,np.float32,np.float64,np.complex64,np.complex128]:
prec_str = dtype_to_ctype(prec)
if prec == np.complex64: fpName_str = 'fp_tex_cfloat'
elif prec == np.complex128: fpName_str = 'fp_tex_cdouble'
elif prec == np.float64: fpName_str = 'fp_tex_double'
else: fpName_str = prec_str
A_cpu = np.zeros([npoints,npoints,npoints],order=orden,dtype=prec)
A_cpu[:] = np.random.rand(npoints,npoints,npoints)[:]
if prec == np.complex64:
fpName_str = 'fp_tex_cfloat'
A_cpu = np.zeros([npoints,npoints,npoints],order=orden,dtype=prec)
A_cpu[:].real = np.random.rand(npoints,npoints,npoints)[:]
A_cpu[:].imag = np.random.rand(npoints,npoints,npoints)[:]
elif prec == np.complex128:
fpName_str = 'fp_tex_cdouble'
A_cpu = np.zeros([npoints,npoints,npoints],order=orden,dtype=prec)
A_cpu[:].real = np.random.rand(npoints,npoints,npoints)[:]
A_cpu[:].imag = np.random.rand(npoints,npoints,npoints)[:]
elif prec == np.float64:
fpName_str = 'fp_tex_double'
A_cpu = np.zeros([npoints,npoints,npoints],order=orden,dtype=prec)
A_cpu[:] = np.random.rand(npoints,npoints,npoints)[:]
else:
fpName_str = prec_str
A_cpu = np.zeros([npoints,npoints,npoints],order=orden,dtype=prec)
A_cpu[:] = np.random.rand(npoints,npoints,npoints)[:]*100.

A_gpu = gpuarray.to_gpu(A_cpu) # Array randomized

myKernRW = '''
Expand Down

0 comments on commit 1a7ba64

Please sign in to comment.