Re-fix constant type to int64_t
[pyopencl.git] / examples / demo_mandelbrot.py
1 from __future__ import absolute_import
2 from __future__ import print_function
3 # I found this example for PyCuda here:
4 # http://wiki.tiker.net/PyCuda/Examples/Mandelbrot
5 #
6 # An improved sequential/pure Python code was contributed
7 # by CRVSADER//KY <crusaderky@gmail.com>.
8 #
9 # I adapted it for PyOpenCL. Hopefully it is useful to someone.
10 # July 2010, HolgerRapp@gmx.net
11 #
12 # Original readme below these lines.
13
14 # Mandelbrot calculate using GPU, Serial numpy and faster numpy
15 # Use to show the speed difference between CPU and GPU calculations
16 # ian@ianozsvald.com March 2010
17
18 # Based on vegaseat's TKinter/numpy example code from 2006
19 # http://www.daniweb.com/code/snippet216851.html#
20 # with minor changes to move to numpy from the obsolete Numeric
21
22 import time
23
24 import numpy as np
25
26 import pyopencl as cl
27 from six.moves import range
28
29 # You can choose a calculation routine below (calc_fractal), uncomment
30 # one of the three lines to test the three variations
31 # Speed notes are listed in the same place
32
33 # set width and height of window, more pixels take longer to calculate
34 w = 2048
35 h = 2048
36
37
38 def calc_fractal_opencl(q, maxiter):
39     ctx = cl.create_some_context()
40     queue = cl.CommandQueue(ctx)
41
42     output = np.empty(q.shape, dtype=np.uint16)
43
44     mf = cl.mem_flags
45     q_opencl = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=q)
46     output_opencl = cl.Buffer(ctx, mf.WRITE_ONLY, output.nbytes)
47
48     prg = cl.Program(ctx, """
49     #pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable
50     __kernel void mandelbrot(__global float2 *q,
51                      __global ushort *output, ushort const maxiter)
52     {
53         int gid = get_global_id(0);
54         float nreal, real = 0;
55         float imag = 0;
56
57         output[gid] = 0;
58
59         for(int curiter = 0; curiter < maxiter; curiter++) {
60             nreal = real*real - imag*imag + q[gid].x;
61             imag = 2* real*imag + q[gid].y;
62             real = nreal;
63
64             if (real*real + imag*imag > 4.0f)
65                  output[gid] = curiter;
66         }
67     }
68     """).build()
69
70     prg.mandelbrot(queue, output.shape, None, q_opencl,
71                    output_opencl, np.uint16(maxiter))
72
73     cl.enqueue_copy(queue, output, output_opencl).wait()
74
75     return output
76
77
78 def calc_fractal_serial(q, maxiter):
79     # calculate z using pure python on a numpy array
80     # note that, unlike the other two implementations,
81     # the number of iterations per point is NOT constant
82     z = np.zeros(q.shape, complex)
83     output = np.resize(np.array(0,), q.shape)
84     for i in range(len(q)):
85         for iter in range(maxiter):
86             z[i] = z[i]*z[i] + q[i]
87             if abs(z[i]) > 2.0:
88                 output[i] = iter
89                 break
90     return output
91
92
93 def calc_fractal_numpy(q, maxiter):
94     # calculate z using numpy, this is the original
95     # routine from vegaseat's URL
96     output = np.resize(np.array(0,), q.shape)
97     z = np.zeros(q.shape, np.complex64)
98
99     for it in range(maxiter):
100         z = z*z + q
101         done = np.greater(abs(z), 2.0)
102         q = np.where(done, 0+0j, q)
103         z = np.where(done, 0+0j, z)
104         output = np.where(done, it, output)
105     return output
106
107 # choose your calculation routine here by uncommenting one of the options
108 calc_fractal = calc_fractal_opencl
109 # calc_fractal = calc_fractal_serial
110 # calc_fractal = calc_fractal_numpy
111
112 if __name__ == '__main__':
113     try:
114         import six.moves.tkinter as tk
115     except ImportError:
116         # Python 3
117         import tkinter as tk
118     from PIL import Image, ImageTk
119
120     class Mandelbrot(object):
121         def __init__(self):
122             # create window
123             self.root = tk.Tk()
124             self.root.title("Mandelbrot Set")
125             self.create_image()
126             self.create_label()
127             # start event loop
128             self.root.mainloop()
129
130         def draw(self, x1, x2, y1, y2, maxiter=30):
131             # draw the Mandelbrot set, from numpy example
132             xx = np.arange(x1, x2, (x2-x1)/w)
133             yy = np.arange(y2, y1, (y1-y2)/h) * 1j
134             q = np.ravel(xx+yy[:, np.newaxis]).astype(np.complex64)
135
136             start_main = time.time()
137             output = calc_fractal(q, maxiter)
138             end_main = time.time()
139
140             secs = end_main - start_main
141             print("Main took", secs)
142
143             self.mandel = (output.reshape((h, w)) /
144                            float(output.max()) * 255.).astype(np.uint8)
145
146         def create_image(self):
147             """"
148             create the image from the draw() string
149             """
150             # you can experiment with these x and y ranges
151             self.draw(-2.13, 0.77, -1.3, 1.3)
152             self.im = Image.fromarray(self.mandel)
153             self.im.putpalette([i for rgb in ((j, 0, 0) for j in range(255))
154                                 for i in rgb])
155
156         def create_label(self):
157             # put the image on a label widget
158             self.image = ImageTk.PhotoImage(self.im)
159             self.label = tk.Label(self.root, image=self.image)
160             self.label.pack()
161
162     # test the class
163     test = Mandelbrot()