Wednesday, April 6, 2011

Tutorial Part 2: Adding a JIT

This is the second part of a tutorial written by Andrew Brown. The first part described how to write an interpreter with PyPy.

Adding JIT

Translating RPython to C is pretty cool, but one of the best features of PyPy is its ability to generate just-in-time compilers for your interpreter. That's right, from just a couple hints on how your interpreter is structured, PyPy will generate and include a JIT compiler that will, at runtime, translate the interpreted code of our BF language to machine code!

So what do we need to tell PyPy to make this happen? First it needs to know where the start of your bytecode evaluation loop is. This lets it keep track of instructions being executed in the target language (BF).

We also need to let it know what defines a particular execution frame. Since our language doesn't really have stack frames, this boils down to what's constant for the execution of a particular instruction, and what's not. These are called "green" and "red" variables, respectively.

Refer back to example2.py for the following.

In our main loop, there are four variables used: pc, program, bracket_map, and tape. Of those, pc, program, and bracket_map are all green variables. They define the execution of a particular instruction. If the JIT routines see the same combination of green variables as before, it knows it's skipped back and must be executing a loop. The variable "tape" is our red variable, it's what's being manipulated by the execution.

So let's tell PyPy this info. Start by importing the JitDriver class and making an instance:

from pypy.rlib.jit import JitDriver
jitdriver = JitDriver(greens=['pc', 'program', 'bracket_map'],
        reds=['tape'])

And we add this line to the very top of the while loop in the mainloop function:

jitdriver.jit_merge_point(pc=pc, tape=tape, program=program,
        bracket_map=bracket_map)

We also need to define a JitPolicy. We're not doing anything fancy, so this is all we need somewhere in the file:

def jitpolicy(driver):
    from pypy.jit.codewriter.policy import JitPolicy
    return JitPolicy()

See this example at example3.py

Now try translating again, but with the flag --opt=jit:

$ python ./pypy/pypy/translator/goal/translate.py --opt=jit example3.py

It will take significantly longer to translate with JIT enabled, almost 8 minutes on my machine, and the resulting binary will be much larger. When it's done, try having it run the mandelbrot program again. A world of difference, from 12 seconds compared to 45 seconds before!

Interestingly enough, you can see when the JIT compiler switches from interpreted to machine code with the mandelbrot example. The first few lines of output come out pretty fast, and then the program gets a boost of speed and gets even faster.

A bit about Tracing JIT Compilers

It's worth it at this point to read up on how tracing JIT compilers work. Here's a brief explanation: The interpreter is usually running your interpreter code as written. When it detects a loop of code in the target language (BF) is executed often, that loop is considered "hot" and marked to be traced. The next time that loop is entered, the interpreter gets put in tracing mode where every executed instruction is logged.

When the loop is finished, tracing stops. The trace of the loop is sent to an optimizer, and then to an assembler which outputs machine code. That machine code is then used for subsequent loop iterations.

This machine code is often optimized for the most common case, and depends on several assumptions about the code. Therefore, the machine code will contain guards, to validate those assumptions. If a guard check fails, the runtime falls back to regular interpreted mode.

A good place to start for more information is http://en.wikipedia.org/wiki/Just-in-time_compilation

Debugging and Trace Logs

Can we do any better? How can we see what the JIT is doing? Let's do two things.

First, let's add a get_printable_location function, which is used during debug trace logging:

def get_location(pc, program, bracket_map):
    return "%s_%s_%s" % (
            program[:pc], program[pc], program[pc+1:]
            )
jitdriver = JitDriver(greens=['pc', 'program', 'bracket_map'], reds=['tape'],
        get_printable_location=get_location)

This function is passed in the green variables, and should return a string. Here, we're printing out the BF code, surrounding the currently executing instruction with underscores so we can see where it is.

Download this as example4.py and translate it the same as example3.py.

Now let's run a test program (test.b, which just prints the letter "A" 15 or so times in a loop) with trace logging:

$ PYPYLOG=jit-log-opt:logfile ./example4-c test.b

Now take a look at the file "logfile". This file is quite hard to read, so here's my best shot at explaining it.

The file contains a log of every trace that was performed, and is essentially a glimpse at what instructions it's compiling to machine code for you. It's useful to see if there are unnecessary instructions or room for optimization.

Each trace starts with a line that looks like this:

[3c091099e7a4a7] {jit-log-opt-loop

and ends with a line like this:

[3c091099eae17d jit-log-opt-loop}

The next line tells you which loop number it is, and how many ops are in it. In my case, the first trace looks like this:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
  [3c167c92b9118f] {jit-log-opt-loop
  # Loop 0 : loop with 26 ops
  [p0, p1, i2, i3]
  debug_merge_point('+<[>[_>_+<-]>.[<+>-]<<-]++++++++++.', 0)
  debug_merge_point('+<[>[>_+_<-]>.[<+>-]<<-]++++++++++.', 0)
  i4 = getarrayitem_gc(p1, i2, descr=<SignedArrayDescr>)
  i6 = int_add(i4, 1)
  setarrayitem_gc(p1, i2, i6, descr=<SignedArrayDescr>)
  debug_merge_point('+<[>[>+_<_-]>.[<+>-]<<-]++++++++++.', 0)
  debug_merge_point('+<[>[>+<_-_]>.[<+>-]<<-]++++++++++.', 0)
  i7 = getarrayitem_gc(p1, i3, descr=<SignedArrayDescr>)
  i9 = int_sub(i7, 1)
  setarrayitem_gc(p1, i3, i9, descr=<SignedArrayDescr>)
  debug_merge_point('+<[>[>+<-_]_>.[<+>-]<<-]++++++++++.', 0)
  i10 = int_is_true(i9)
  guard_true(i10, descr=<Guard2>) [p0]
  i14 = call(ConstClass(ll_dict_lookup__dicttablePtr_Signed_Signed), ConstPtr(ptr12), 90, 90, descr=<SignedCallDescr>)
  guard_no_exception(, descr=<Guard3>) [i14, p0]
  i16 = int_and(i14, -9223372036854775808)
  i17 = int_is_true(i16)
  guard_false(i17, descr=<Guard4>) [i14, p0]
  i19 = call(ConstClass(ll_get_value__dicttablePtr_Signed), ConstPtr(ptr12), i14, descr=<SignedCallDescr>)
  guard_no_exception(, descr=<Guard5>) [i19, p0]
  i21 = int_add(i19, 1)
  i23 = int_lt(i21, 114)
  guard_true(i23, descr=<Guard6>) [i21, p0]
  guard_value(i21, 86, descr=<Guard7>) [i21, p0]
  debug_merge_point('+<[>[_>_+<-]>.[<+>-]<<-]++++++++++.', 0)
  jump(p0, p1, i2, i3, descr=<Loop0>)
  [3c167c92bc6a15] jit-log-opt-loop}

I've trimmed the debug_merge_point lines a bit, they were really long.

So let's see what this does. This trace takes 4 parameters: 2 object pointers (p0 and p1) and 2 integers (i2 and i3). Looking at the debug lines, it seems to be tracing one iteration of this loop: "[>+<-]"

It starts executing the first operation on line 4, a ">", but immediately starts executing the next operation. The ">" had no instructions, and looks like it was optimized out completely. This loop must always act on the same part of the tape, the tape pointer is constant for this trace. An explicit advance operation is unnecessary.

Lines 5 to 8 are the instructions for the "+" operation. First it gets the array item from the array in pointer p1 at index i2 (line 6), adds 1 to it and stores it in i6 (line 7), and stores it back in the array (line 8).

Line 9 starts the "<" instruction, but it is another no-op. It seems that i2 and i3 passed into this routine are the two tape pointers used in this loop already calculated. Also deduced is that p1 is the tape array. It's not clear what p0 is.

Lines 10 through 13 perform the "-" operation: get the array value (line 11), subtract (line 12) and set the array value (line 13).

Next, on line 14, we come to the "]" operation. Lines 15 and 16 check whether i9 is true (non-zero). Looking up, i9 is the array value that we just decremented and stored, now being checked as the loop condition, as expected (remember the definition of "]"). Line 16 is a guard, if the condition is not met, execution jumps somewhere else, in this case to the routine called <Guard2> and is passed one parameter: p0.

Assuming we pass the guard, lines 17 through 23 are doing the dictionary lookup to bracket_map to find where the program counter should jump to. I'm not too familiar with what the instructions are actually doing, but it looks like there are two external calls and 3 guards. This seems quite expensive, especially since we know bracket_map will never change (PyPy doesn't know that). We'll see below how to optimize this.

Line 24 increments the newly acquired instruction pointer. Lines 25 and 26 make sure it's less than the program's length.

Additionally, line 27 guards that i21, the incremented instruction pointer, is exactly 86. This is because it's about to jump to the beginning (line 29) and the instruction pointer being 86 is a precondition to this block.

Finally, the loop closes up at line 28 so the JIT can jump to loop body <Loop0> to handle that case (line 29), which is the beginning of the loop again. It passes in parameters (p0, p1, i2, i3).

Optimizing

As mentioned, every loop iteration does a dictionary lookup to find the corresponding matching bracket for the final jump. This is terribly inefficient, the jump target is not going to change from one loop to the next. This information is constant and should be compiled in as such.

The problem is that the lookups are coming from a dictionary, and PyPy is treating it as opaque. It doesn't know the dictionary isn't being modified or isn't going to return something different on each query.

What we need to do is provide another hint to the translation to say that the dictionary query is a pure function, that is, its output depends only on its inputs and the same inputs should always return the same output.

To do this, we use a provided function decorator pypy.rlib.jit.purefunction, and wrap the dictionary call in a decorated function:

@purefunction
def get_matching_bracket(bracket_map, pc):
    return bracket_map[pc]

This version can be found at example5.py

Translate again with the JIT option and observe the speedup. Mandelbrot now only takes 6 seconds! (from 12 seconds before this optimization)

Let's take a look at the trace from the same function:

[3c29fad7b792b0] {jit-log-opt-loop
# Loop 0 : loop with 15 ops
[p0, p1, i2, i3]
debug_merge_point('+<[>[_>_+<-]>.[<+>-]<<-]++++++++++.', 0)
debug_merge_point('+<[>[>_+_<-]>.[<+>-]<<-]++++++++++.', 0)
i4 = getarrayitem_gc(p1, i2, descr=<SignedArrayDescr>)
i6 = int_add(i4, 1)
setarrayitem_gc(p1, i2, i6, descr=<SignedArrayDescr>)
debug_merge_point('+<[>[>+_<_-]>.[<+>-]<<-]++++++++++.', 0)
debug_merge_point('+<[>[>+<_-_]>.[<+>-]<<-]++++++++++.', 0)
i7 = getarrayitem_gc(p1, i3, descr=<SignedArrayDescr>)
i9 = int_sub(i7, 1)
setarrayitem_gc(p1, i3, i9, descr=<SignedArrayDescr>)
debug_merge_point('+<[>[>+<-_]_>.[<+>-]<<-]++++++++++.', 0)
i10 = int_is_true(i9)
guard_true(i10, descr=<Guard2>) [p0]
debug_merge_point('+<[>[_>_+<-]>.[<+>-]<<-]++++++++++.', 0)
jump(p0, p1, i2, i3, descr=<Loop0>)
[3c29fad7ba32ec] jit-log-opt-loop}

Much better! Each loop iteration is an add, a subtract, two array loads, two array stores, and a guard on the exit condition. That's it! This code doesn't require any program counter manipulation.

I'm no expert on optimizations, this tip was suggested by Armin Rigo on the pypy-dev list. Carl Friedrich has a series of posts on how to optimize your interpreter that are also very useful: http://bit.ly/bundles/cfbolz/1

Final Words

I hope this has shown some of you what PyPy is all about other than a faster implementation of Python.

For those that would like to know more about how the process works, there are several academic papers explaining the process in detail that I recommend. In particular: Tracing the Meta-Level: PyPy's Tracing JIT Compiler.

See http://readthedocs.org/docs/pypy/en/latest/extradoc.html

9 comments:

Winston Ewert said...

Some interpreters are written to evaluate directly from the AST. i.e. they never generate bytecode, instead each node in the ast simply has the code to execute it as a "virtual" function. Could PyPy JIT such an interpreter? Or does it essentially assume a bytecode based interpreter?

Anonymous said...

In theory it should be able to, if it's written in RPython. Perhaps it would be harder to place the hints for the jit engine?

As far as I understand it, it still traces some kind of bytecode (generated from the RPython code), but uses the can_enter_jit hints to determine what to trace and the length of a trace.

If it'll be fast is another question though. Why not give it a try? (E.g. one could implement the LLVM kaleidoscope language in RPython.)

Maciej Fijalkowski said...

@Winston in theory nothing prevents JIT from working on AST-based interpreters. In practice however, it would require a bit of engineering to convince the JIT that the green (constant) argument is a complex object structure. That's however just engineering

Carl Friedrich Bolz said...

It's actually not a problem at all to have an AST-based interpreter. In fact, the Prolog uses "ASTs" (Prolog is homoiconic, so the ASTs are just Prologs normal data structures).

Maciej: that's not a problem if your ASTs are actually immutable. If they aren't you have a problem which indeed requires some engineering.

Quiz said...

The effect of the loop "[>+<-]" is

tape[position+1] += tape[position]
tape[position] = 0

We saw that PyPy can optimize the program counter away in this loop--but this loop could be executed in constant time. Will PyPy ever be able to optimize it to that degree?

Winston Ewert said...

Well, you finally motivated me to give it a try. I optimized the BF example and managed to get some pretty nice speed boosts all without dipping into the low level (aside from reading the log)

Anonymous said...

Great article, man! Many thanks and keep on rocking!

Anonymous said...

Great tutorial, but where can I find the 'test.b' file (mentioned for the tracing JIT) for a try?

Anonymous said...

hi guys. can jit merge points not be put inside methods? Going off example3.py, if I take the body of the while loop and move it into a method of the Tape class (along with the jitdriver), all the speed gains go away. can anyone explain why this happens? Thanks!