Some test fixes for Beignet
[pyopencl.git] / examples / demo_mandelbrot.py
1 # I found this example for PyCuda here:
2 # http://wiki.tiker.net/PyCuda/Examples/Mandelbrot
3 #
4 # An improved sequential/pure Python code was contributed
5 # by CRVSADER//KY <crusaderky@gmail.com>.
6 #
7 # I adapted it for PyOpenCL. Hopefully it is useful to someone.
8 # July 2010, HolgerRapp@gmx.net
9 #
10 # Original readme below these lines.
11
12 # Mandelbrot calculate using GPU, Serial numpy and faster numpy
13 # Use to show the speed difference between CPU and GPU calculations
14 # ian@ianozsvald.com March 2010
15
16 # Based on vegaseat's TKinter/numpy example code from 2006
17 # http://www.daniweb.com/code/snippet216851.html#
18 # with minor changes to move to numpy from the obsolete Numeric
19
20 import time
21
22 import numpy as np
23
24 import pyopencl as cl
25
26 # You can choose a calculation routine below (calc_fractal), uncomment
27 # one of the three lines to test the three variations
28 # Speed notes are listed in the same place
29
30 # set width and height of window, more pixels take longer to calculate
31 w = 2048
32 h = 2048
33
34
35 def calc_fractal_opencl(q, maxiter):
36     ctx = cl.create_some_context()
37     queue = cl.CommandQueue(ctx)
38
39     output = np.empty(q.shape, dtype=np.uint16)
40
41     mf = cl.mem_flags
42     q_opencl = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=q)
43     output_opencl = cl.Buffer(ctx, mf.WRITE_ONLY, output.nbytes)
44
45     prg = cl.Program(ctx, """
46     #pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable
47     __kernel void mandelbrot(__global float2 *q,
48                      __global ushort *output, ushort const maxiter)
49     {
50         int gid = get_global_id(0);
51         float nreal, real = 0;
52         float imag = 0;
53
54         output[gid] = 0;
55
56         for(int curiter = 0; curiter < maxiter; curiter++) {
57             nreal = real*real - imag*imag + q[gid].x;
58             imag = 2* real*imag + q[gid].y;
59             real = nreal;
60
61             if (real*real + imag*imag > 4.0f)
62                  output[gid] = curiter;
63         }
64     }
65     """).build()
66
67     prg.mandelbrot(queue, output.shape, None, q_opencl,
68                    output_opencl, np.uint16(maxiter))
69
70     cl.enqueue_copy(queue, output, output_opencl).wait()
71
72     return output
73
74
75 def calc_fractal_serial(q, maxiter):
76     # calculate z using pure python on a numpy array
77     # note that, unlike the other two implementations,
78     # the number of iterations per point is NOT constant
79     z = np.zeros(q.shape, complex)
80     output = np.resize(np.array(0,), q.shape)
81     for i in range(len(q)):
82         for iter in range(maxiter):
83             z[i] = z[i]*z[i] + q[i]
84             if abs(z[i]) > 2.0:
85                 output[i] = iter
86                 break
87     return output
88
89
90 def calc_fractal_numpy(q, maxiter):
91     # calculate z using numpy, this is the original
92     # routine from vegaseat's URL
93     output = np.resize(np.array(0,), q.shape)
94     z = np.zeros(q.shape, np.complex64)
95
96     for it in range(maxiter):
97         z = z*z + q
98         done = np.greater(abs(z), 2.0)
99         q = np.where(done, 0+0j, q)
100         z = np.where(done, 0+0j, z)
101         output = np.where(done, it, output)
102     return output
103
104 # choose your calculation routine here by uncommenting one of the options
105 calc_fractal = calc_fractal_opencl
106 # calc_fractal = calc_fractal_serial
107 # calc_fractal = calc_fractal_numpy
108
109 if __name__ == '__main__':
110     try:
111         import Tkinter as tk
112     except ImportError:
113         # Python 3
114         import tkinter as tk
115     from PIL import Image, ImageTk
116
117     class Mandelbrot(object):
118         def __init__(self):
119             # create window
120             self.root = tk.Tk()
121             self.root.title("Mandelbrot Set")
122             self.create_image()
123             self.create_label()
124             # start event loop
125             self.root.mainloop()
126
127         def draw(self, x1, x2, y1, y2, maxiter=30):
128             # draw the Mandelbrot set, from numpy example
129             xx = np.arange(x1, x2, (x2-x1)/w)
130             yy = np.arange(y2, y1, (y1-y2)/h) * 1j
131             q = np.ravel(xx+yy[:, np.newaxis]).astype(np.complex64)
132
133             start_main = time.time()
134             output = calc_fractal(q, maxiter)
135             end_main = time.time()
136
137             secs = end_main - start_main
138             print("Main took", secs)
139
140             self.mandel = (output.reshape((h, w)) /
141                            float(output.max()) * 255.).astype(np.uint8)
142
143         def create_image(self):
144             """"
145             create the image from the draw() string
146             """
147             # you can experiment with these x and y ranges
148             self.draw(-2.13, 0.77, -1.3, 1.3)
149             self.im = Image.fromarray(self.mandel)
150             self.im.putpalette([i for rgb in ((j, 0, 0) for j in range(255))
151                                 for i in rgb])
152
153         def create_label(self):
154             # put the image on a label widget
155             self.image = ImageTk.PhotoImage(self.im)
156             self.label = tk.Label(self.root, image=self.image)
157             self.label.pack()
158
159     # test the class
160     test = Mandelbrot()