This is the second part of a tutorial written by Andrew Brown. The first
part described how to write an interpreter with PyPy.
Adding JIT
Translating RPython to C is pretty cool, but one of the best features of PyPy
is its ability to generate just-in-time compilers for your interpreter.
That's right, from just a couple hints on how your interpreter is structured,
PyPy will generate and include a JIT compiler that will, at runtime, translate
the interpreted code of our BF language to machine code!
So what do we need to tell PyPy to make this happen? First it needs to know
where the start of your bytecode evaluation loop is. This lets it keep track of
instructions being executed in the target language (BF).
We also need to let it know what defines a particular execution frame. Since
our language doesn't really have stack frames, this boils down to what's
constant for the execution of a particular instruction, and what's not. These
are called "green" and "red" variables, respectively.
Refer back to example2.py for the following.
In our main loop, there are four variables used: pc, program, bracket_map, and
tape. Of those, pc, program, and bracket_map are all green variables. They
define the execution of a particular instruction. If the JIT routines see the
same combination of green variables as before, it knows it's skipped back and
must be executing a loop. The variable "tape" is our red variable, it's what's
being manipulated by the execution.
So let's tell PyPy this info. Start by importing the JitDriver class and making
an instance:
from pypy.rlib.jit import JitDriver
jitdriver = JitDriver(greens=['pc', 'program', 'bracket_map'],
reds=['tape'])
And we add this line to the very top of the while loop in the mainloop
function:
jitdriver.jit_merge_point(pc=pc, tape=tape, program=program,
bracket_map=bracket_map)
We also need to define a JitPolicy. We're not doing anything fancy, so this is
all we need somewhere in the file:
def jitpolicy(driver):
from pypy.jit.codewriter.policy import JitPolicy
return JitPolicy()
See this example at example3.py
Now try translating again, but with the flag --opt=jit:
$ python ./pypy/pypy/translator/goal/translate.py --opt=jit example3.py
It will take significantly longer to translate with JIT enabled, almost 8
minutes on my machine, and the resulting binary will be much larger. When it's
done, try having it run the mandelbrot program again. A world of difference,
from 12 seconds compared to 45 seconds before!
Interestingly enough, you can see when the JIT compiler switches from
interpreted to machine code with the mandelbrot example. The first few lines of
output come out pretty fast, and then the program gets a boost of speed and
gets even faster.
A bit about Tracing JIT Compilers
It's worth it at this point to read up on how tracing JIT compilers work.
Here's a brief explanation: The interpreter is usually running your interpreter
code as written. When it detects a loop of code in the target language (BF) is
executed often, that loop is considered "hot" and marked to be traced. The next
time that loop is entered, the interpreter gets put in tracing mode where every
executed instruction is logged.
When the loop is finished, tracing stops. The trace of the loop is sent to an
optimizer, and then to an assembler which outputs machine code. That machine
code is then used for subsequent loop iterations.
This machine code is often optimized for the most common case, and depends on
several assumptions about the code. Therefore, the machine code will contain
guards, to validate those assumptions. If a guard check fails, the runtime
falls back to regular interpreted mode.
A good place to start for more information is
http://en.wikipedia.org/wiki/Just-in-time_compilation
Debugging and Trace Logs
Can we do any better? How can we see what the JIT is doing? Let's do two
things.
First, let's add a get_printable_location function, which is used during debug
trace logging:
def get_location(pc, program, bracket_map):
return "%s_%s_%s" % (
program[:pc], program[pc], program[pc+1:]
)
jitdriver = JitDriver(greens=['pc', 'program', 'bracket_map'], reds=['tape'],
get_printable_location=get_location)
This function is passed in the green variables, and should return a string.
Here, we're printing out the BF code, surrounding the currently executing
instruction with underscores so we can see where it is.
Download this as example4.py and translate it the same as example3.py.
Now let's run a test program (test.b, which just prints the letter "A" 15 or so
times in a loop) with trace logging:
$ PYPYLOG=jit-log-opt:logfile ./example4-c test.b
Now take a look at the file "logfile". This file is quite hard to read, so
here's my best shot at explaining it.
The file contains a log of every trace that was performed, and is essentially a
glimpse at what instructions it's compiling to machine code for you. It's
useful to see if there are unnecessary instructions or room for optimization.
Each trace starts with a line that looks like this:
[3c091099e7a4a7] {jit-log-opt-loop
and ends with a line like this:
[3c091099eae17d jit-log-opt-loop}
The next line tells you which loop number it is, and how many ops are in it.
In my case, the first trace looks like this:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29 | [3c167c92b9118f] {jit-log-opt-loop
# Loop 0 : loop with 26 ops
[p0, p1, i2, i3]
debug_merge_point('+<[>[_>_+<-]>.[<+>-]<<-]++++++++++.', 0)
debug_merge_point('+<[>[>_+_<-]>.[<+>-]<<-]++++++++++.', 0)
i4 = getarrayitem_gc(p1, i2, descr=<SignedArrayDescr>)
i6 = int_add(i4, 1)
setarrayitem_gc(p1, i2, i6, descr=<SignedArrayDescr>)
debug_merge_point('+<[>[>+_<_-]>.[<+>-]<<-]++++++++++.', 0)
debug_merge_point('+<[>[>+<_-_]>.[<+>-]<<-]++++++++++.', 0)
i7 = getarrayitem_gc(p1, i3, descr=<SignedArrayDescr>)
i9 = int_sub(i7, 1)
setarrayitem_gc(p1, i3, i9, descr=<SignedArrayDescr>)
debug_merge_point('+<[>[>+<-_]_>.[<+>-]<<-]++++++++++.', 0)
i10 = int_is_true(i9)
guard_true(i10, descr=<Guard2>) [p0]
i14 = call(ConstClass(ll_dict_lookup__dicttablePtr_Signed_Signed), ConstPtr(ptr12), 90, 90, descr=<SignedCallDescr>)
guard_no_exception(, descr=<Guard3>) [i14, p0]
i16 = int_and(i14, -9223372036854775808)
i17 = int_is_true(i16)
guard_false(i17, descr=<Guard4>) [i14, p0]
i19 = call(ConstClass(ll_get_value__dicttablePtr_Signed), ConstPtr(ptr12), i14, descr=<SignedCallDescr>)
guard_no_exception(, descr=<Guard5>) [i19, p0]
i21 = int_add(i19, 1)
i23 = int_lt(i21, 114)
guard_true(i23, descr=<Guard6>) [i21, p0]
guard_value(i21, 86, descr=<Guard7>) [i21, p0]
debug_merge_point('+<[>[_>_+<-]>.[<+>-]<<-]++++++++++.', 0)
jump(p0, p1, i2, i3, descr=<Loop0>)
[3c167c92bc6a15] jit-log-opt-loop}
|
I've trimmed the debug_merge_point lines a bit, they were really long.
So let's see what this does. This trace takes 4 parameters: 2 object pointers
(p0 and p1) and 2 integers (i2 and i3). Looking at the debug lines, it seems to
be tracing one iteration of this loop: "[>+<-]"
It starts executing the first operation on line 4, a ">", but immediately
starts executing the next operation. The ">" had no instructions, and looks
like it was optimized out completely. This loop must always act on the same
part of the tape, the tape pointer is constant for this trace. An explicit
advance operation is unnecessary.
Lines 5 to 8 are the instructions for the "+" operation. First it gets the
array item from the array in pointer p1 at index i2 (line 6), adds 1 to it and
stores it in i6 (line 7), and stores it back in the array (line 8).
Line 9 starts the "<" instruction, but it is another no-op. It seems that i2
and i3 passed into this routine are the two tape pointers used in this loop
already calculated. Also deduced is that p1 is the tape array. It's not clear
what p0 is.
Lines 10 through 13 perform the "-" operation: get the array value (line 11),
subtract (line 12) and set the array value (line 13).
Next, on line 14, we come to the "]" operation. Lines 15 and 16 check whether
i9 is true (non-zero). Looking up, i9 is the array value that we just
decremented and stored, now being checked as the loop condition, as expected
(remember the definition of "]"). Line 16 is a guard, if the condition is not
met, execution jumps somewhere else, in this case to the routine called
<Guard2> and is passed one parameter: p0.
Assuming we pass the guard, lines 17 through 23 are doing the dictionary lookup
to bracket_map to find where the program counter should jump to. I'm not too
familiar with what the instructions are actually doing, but it looks like there
are two external calls and 3 guards. This seems quite expensive, especially
since we know bracket_map will never change (PyPy doesn't know that). We'll
see below how to optimize this.
Line 24 increments the newly acquired instruction pointer. Lines 25 and 26 make
sure it's less than the program's length.
Additionally, line 27 guards that i21, the incremented instruction pointer, is
exactly 86. This is because it's about to jump to the beginning (line 29) and
the instruction pointer being 86 is a precondition to this block.
Finally, the loop closes up at line 28 so the JIT can jump to loop body <Loop0>
to handle that case (line 29), which is the beginning of the loop again. It
passes in parameters (p0, p1, i2, i3).
Optimizing
As mentioned, every loop iteration does a dictionary lookup to find the
corresponding matching bracket for the final jump. This is terribly
inefficient, the jump target is not going to change from one loop to the next.
This information is constant and should be compiled in as such.
The problem is that the lookups are coming from a dictionary, and PyPy is
treating it as opaque. It doesn't know the dictionary isn't being modified or
isn't going to return something different on each query.
What we need to do is provide another hint to the translation to say that the
dictionary query is a pure function, that is, its output depends only on its
inputs and the same inputs should always return the same output.
To do this, we use a provided function decorator pypy.rlib.jit.purefunction,
and wrap the dictionary call in a decorated function:
@purefunction
def get_matching_bracket(bracket_map, pc):
return bracket_map[pc]
This version can be found at example5.py
Translate again with the JIT option and observe the speedup. Mandelbrot now
only takes 6 seconds! (from 12 seconds before this optimization)
Let's take a look at the trace from the same function:
[3c29fad7b792b0] {jit-log-opt-loop
# Loop 0 : loop with 15 ops
[p0, p1, i2, i3]
debug_merge_point('+<[>[_>_+<-]>.[<+>-]<<-]++++++++++.', 0)
debug_merge_point('+<[>[>_+_<-]>.[<+>-]<<-]++++++++++.', 0)
i4 = getarrayitem_gc(p1, i2, descr=<SignedArrayDescr>)
i6 = int_add(i4, 1)
setarrayitem_gc(p1, i2, i6, descr=<SignedArrayDescr>)
debug_merge_point('+<[>[>+_<_-]>.[<+>-]<<-]++++++++++.', 0)
debug_merge_point('+<[>[>+<_-_]>.[<+>-]<<-]++++++++++.', 0)
i7 = getarrayitem_gc(p1, i3, descr=<SignedArrayDescr>)
i9 = int_sub(i7, 1)
setarrayitem_gc(p1, i3, i9, descr=<SignedArrayDescr>)
debug_merge_point('+<[>[>+<-_]_>.[<+>-]<<-]++++++++++.', 0)
i10 = int_is_true(i9)
guard_true(i10, descr=<Guard2>) [p0]
debug_merge_point('+<[>[_>_+<-]>.[<+>-]<<-]++++++++++.', 0)
jump(p0, p1, i2, i3, descr=<Loop0>)
[3c29fad7ba32ec] jit-log-opt-loop}
Much better! Each loop iteration is an add, a subtract, two array loads, two
array stores, and a guard on the exit condition. That's it! This code doesn't
require any program counter manipulation.
I'm no expert on optimizations, this tip was suggested by Armin Rigo on the
pypy-dev list. Carl Friedrich has a series of posts on how to optimize your
interpreter that are also very useful: http://bit.ly/bundles/cfbolz/1
Final Words
I hope this has shown some of you what PyPy is all about other than a faster
implementation of Python.
For those that would like to know more about how the process works, there are
several academic papers explaining the process in detail that I recommend. In
particular: Tracing the Meta-Level: PyPy's Tracing JIT Compiler.
See http://readthedocs.org/docs/pypy/en/latest/extradoc.html