After I had failed once more to explain to someone on IRC what the idea behind
the current JIT generator work of PyPy, I decided to just write a blog post to
explain it. Here it is :-). The post turned out to be a bit long, so please bear
with me.
The goal of the post is to give an understanding of how PyPy's JIT generator is
going to work. To do this, I will look at what happens when you write an
interpreter in Java and apply a completely normal tracing JIT to it (for this
reason all the code examples will be in some sort of pseudo-Java). The
resulting generated machine code is bad, so I will explain a way to fix the
occurring problem.
The techniques I describe here are conceptually similar to what we are doing in
PyPy. The details (as usual) are different. The reasons why I am trying to
explain things in this way is that I can start from tracing JITs, which are a
known existing technique.
To understand the following, it is helpful to already know a bit how a normal
tracing JIT works. I will give a reminder of how it is working, but there also
exist a couple of more thorough introductions on the web already.
I also will leave out a lot of details about the more detailed workings of
tracing JITs and only explain the things that are relevant to what I am trying
to get to here.
Tracing JITs are an idea explored by the Dynamo project in the context of
dynamic optimization of machine code at runtime. The techniques were then
successfully applied to Java VMs and are now being used by Mozilla's
TraceMonkey JavaScript VM. They are built on some basic assumptions:
- programs spend most of their runtime in loops
- several iterations of the same loop are likely to take similar code paths
- the best way to gain information about the behaviour of a program is to
observe it
The basic approach of a tracing JIT is to only generate machine code for
commonly executed loops and to interpret the rest of the program. The code for
those common loops however should be highly optimized, including aggressive
inlining.
The generation of loops works as follows: At first, everything is interpreted.
The interpreter does a bit of lightweight profiling to figure out which loops
are run often. When a common loop is identified, the interpreter enters a
special mode (called tracing mode). When in tracing mode, the interpreter
records a history (the trace) of all the operations it executes, in addition
to actually performing the operations. During tracing, the trace is repeatedly
checked whether the interpreter is at a position in the program that it had seen
earlier in the trace. If this happens, the trace recorded corresponds to a loop
in the program that the tracing interpreter is running. At this point, this loop
is turned into machine code by taking the trace and making machine code versions
of all the operations in it.
This process assumes that the path through the loop that was traced is a
"typical" example of possible paths (which is statistically likely). Of course
it is possible that later another path through the loop is taken, therefore the
machine code will contain guards, which check that the path is still the same.
If during execution of the machine code a guard fails, the machine code is left
and execution falls back to using interpretation (there are more complex
mechanisms in place to still produce more code for the cases of guard failures,
but they are of no importance for this post).
It is important to understand when the tracer considers a loop in the trace to
be closed. This happens when the position key is the same as at an earlier
point. The position key describes the position of the execution of the program,
e.g. usually contains things like the function currently being executed and the
program counter position of the tracing interpreter.
Let's look at a small example. Take the following code:
int sum_1_to_n(int n) {
int result = 0;
while (n >= 0) {
result += n;
n -= 1;
}
return result;
}
The tracing JIT will at one point trace the execution of the while loop in
sum_1_to_n. The trace might look as follows:
guard_true(n >= 0);
result += n;
n -= 1;
<loop_back>
This trace will then be turned into machine code. Note that the machine code
loop is by itself infinite and can only be left via a guard failure.
A slightly more complex example:
int f(int a, int b) {
if (b % 46 == 41)
return a - b;
else
return a + b;
}
int strange_sum(int n) {
int result = 0;
while (n >= 0) {
result = f(result, n);
n -= 1;
}
return result;
}
The trace of the loop in strange_sum would maybe look like this:
guard_true(n >= 0);
a = result;
b = n;
guard_false(b % 46 == 41);
result = a + b;
n -= 1;
<loop_back>
This would then be turned into machine code. Note how f was inlined into the
loop and how the common else case was turned into machine code, while the
other one is implemented via a guard failure.
In the rest of the post we will explore what happens when the program that is
being executed/compiled by the tracing JIT is itself a (bytecode) interpreter
for another language.
A stylized bytecode interpreter for a simple programming language could look as
follows:
W_Object interpret(String bytecode, ...) {
Stack<W_Object> stack = new Stack<W_Object>();
int pc = 0;
while (true) { // bytecode dispatch loop
char instruction = bytecode.charAt(pc);
pc += 1;
switch (instruction) {
case ADD:
W_Object arg2 = stack.pop();
W_Object arg1 = stack.pop();
stack.push(do_addition(arg1, arg2));
break;
case SUB:
W_Object arg2 = stack.pop();
W_Object arg1 = stack.pop();
stack.push(do_substraction(arg1, arg2));
break;
case RETURN:
return stack.pop();
case JUMP_BACKWARD:
pc -= (int)bytecode.charAt(pc);
break;
case LOAD_INTEGER:
int value = (int)bytecode.charAt(pc);
pc += 1;
stack.push(new W_Integer(value));
break;
case PRINT:
do_print(stack.pop());
break;
case DUP:
stack.push(stack.peek());
break;
case JUMP_IF_TRUE:
...
...
}
}
If we apply a tracing JIT to this function, it will trace and compile the
execution of one bytecode, because after one bytecode the bytecode dispatch loop
is closed. E.g. it might trace and produce machine code for the execution of a
SUB. (Sidenote: this interpret function is an example where one of the
assumptions of a tracing JIT break down: two iterations of the bytecode dispatch
loop are rarely going to follow the same code path, because usually two
consecutive bytecodes encode different instructions).
The important bit to remember here is that the tracing JIT will produce a
machine code loop that corresponds to the bytecode dispatch loop in the
interpret function. Let's see how we can change that.
If we want to make use of the fact that the program that is being jitted is
itself an interpreter, we need to change the tracing JIT a bit. To be more
precise we add a way for the user of the tracing JIT to add information to the
position key that the tracing JIT uses to decide when a loop is closed. This is
done by a call to a magic function add_to_position_key. This allows the
program writer to influence the tracing JIT's behaviour.
The semantics of add_to_position_key is as follows: The method itself does
not do anything. It has an effect only when it is seen during tracing. If it is
seen during tracing, the tracer adds the argument of the call to the position
key that the tracer is using to find out whether a loop was closed or not.
In the example of the interpret function above, we would add a call to this
function into the while loop as follows:
W_Object interpret(String bytecode, ...) {
Stack stack = new Stack();
int pc = 0;
while (true) { // bytecode dispatch loop
add_to_position_key(pc);
add_to_position_key(bytecode);
char instruction = bytecode.charAt(pc);
pc += 1;
switch (instruction) {
case ADD:
...
When the modified tracing JIT traces now the interpret function executing a
SUB, something interesting happens. When the bytecode loop is closed, the
modified tracing JIT does not consider the trace to be a loop, because the value of
pc has been increased by one, so the position key differs. Instead it
continues to trace, effectively unrolling the bytecode dispatch loop of
interpret.
The only way for a loop to be considered closed is if the pc variable has
the same value a second time. This can only happen after a JUMP_BACKWARD
instruction has been executed. A JUMP_BACKWARD instruction will only be in
the bytecode when the bytecode represents a loop. This means that the modified
tracing JIT will trace the interpret function and will only consider that
the trace represents a loop when the bytecode itself represents a loop! Thus, a
machine code loop will eventually be created that corresponds to the loop in the
bytecode.
Let's look at at example. If we have a bytecode that corresponds to the
following instructions:
pc | instruction
---+---------------------
0 | LOAD_INTEGER 0
2 | DUP
3 | PRINT
4 | LOAD_INTEGER 1
6 | ADD
7 | JUMP_BACKWARD 6
This loop will print integers starting from 0 and going on from there. The
modified tracing JIT will unroll the bytecode dispatch until it sees the
JUMP_BACKWARD bytecode. After that bytecode the pc will be 2 again. Thus
the earlier position key is repeated, which means that the loop will be closed.
The produced machine code will do the equivalent of the following Java code:
...
guard_true(pc == 2)
guard_true(bytecode == "... correct bytecode string ...")
while (true) {
instruction = bytecode.charAt(pc);
pc += 1;
guard_true(instruction == DUP);
stack.push(stack.peek());
instruction = bytecode.charAt(pc);
pc += 1;
guard_true(instruction == PRINT);
do_print(stack.pop());
instruction = bytecode.charAt(pc);
pc += 1;
guard_true(instruction == LOAD_INTEGER)
value = (int)bytecode.charAt(pc);
pc += 1
stack.push(W_Integer(value))
instruction = bytecode.charAt(pc);
pc += 1;
guard_true(instruction == ADD)
arg2 = stack.pop()
arg1 = stack.pop()
stack.push(do_addition(arg1, arg2))
instruction = bytecode.charAt(pc);
pc += 1;
guard_true(instruction == JUMP_BACKWARD)
pc -= (int)bytecode.charAt(pc);
}
This is machine code that essentially does what the bytecode above did. Of
course the code still remains some remnants of the interpreter (like the program
counter manipulations, the stack handling, etc), which would have to be removed
by some clever enough optimization step. If this were done, result would look a
lot more natural.
If a tracing JIT is enhanced by a way to influence its loop-closing behaviour we
can significantly improve its performance when the jitted program is itself an
interpreter. The result is that in such a case the produced machine code
will correspond to the functions that are being interpreted, not to the code of
the interpreter itself.
Now, what does all this have to do with PyPy? What we are working on since a
while is a sort of tracing JIT for RPython which allows to be customized with a
function very similar to the add_to_position_key described above. This will
make it possible to make the tracing JIT generate code that corresponds to the
code that the interpreter interprets. For example, we would add a call to
add_to_position_key to SPy, PyPy's Smalltalk VM. Then the tracing JIT will
produce machine code for Smalltalk-level loops, with all the usual benefits of a
tracing JIT (like inlining of intermediate methods, constant-folding, ...).
This JIT differs from normal tracing JITs in that it also supports very powerful
constant-folding and allocation-removal optimizations. Those optimizations will
(hopefully) be the content of a later blog post.
The basics of this process have been working fine since quite a while. What the
work currently focuses on is to improve the optimizers to remove not only the
bytecode manipulation code, but also the stack handling, and a large number of
other inefficiencies.