Monday, March 2, 2009

Applying a Tracing JIT to an Interpreter

After I had failed once more to explain to someone on IRC what the idea behind the current JIT generator work of PyPy, I decided to just write a blog post to explain it. Here it is :-). The post turned out to be a bit long, so please bear with me.

The goal of the post is to give an understanding of how PyPy's JIT generator is going to work. To do this, I will look at what happens when you write an interpreter in Java and apply a completely normal tracing JIT to it (for this reason all the code examples will be in some sort of pseudo-Java). The resulting generated machine code is bad, so I will explain a way to fix the occurring problem.

The techniques I describe here are conceptually similar to what we are doing in PyPy. The details (as usual) are different. The reasons why I am trying to explain things in this way is that I can start from tracing JITs, which are a known existing technique.

To understand the following, it is helpful to already know a bit how a normal tracing JIT works. I will give a reminder of how it is working, but there also exist a couple of more thorough introductions on the web already. I also will leave out a lot of details about the more detailed workings of tracing JITs and only explain the things that are relevant to what I am trying to get to here.

Tracing JITs

Tracing JITs are an idea explored by the Dynamo project in the context of dynamic optimization of machine code at runtime. The techniques were then successfully applied to Java VMs and are now being used by Mozilla's TraceMonkey JavaScript VM. They are built on some basic assumptions:

  • programs spend most of their runtime in loops
  • several iterations of the same loop are likely to take similar code paths
  • the best way to gain information about the behaviour of a program is to observe it

The basic approach of a tracing JIT is to only generate machine code for commonly executed loops and to interpret the rest of the program. The code for those common loops however should be highly optimized, including aggressive inlining.

The generation of loops works as follows: At first, everything is interpreted. The interpreter does a bit of lightweight profiling to figure out which loops are run often. When a common loop is identified, the interpreter enters a special mode (called tracing mode). When in tracing mode, the interpreter records a history (the trace) of all the operations it executes, in addition to actually performing the operations. During tracing, the trace is repeatedly checked whether the interpreter is at a position in the program that it had seen earlier in the trace. If this happens, the trace recorded corresponds to a loop in the program that the tracing interpreter is running. At this point, this loop is turned into machine code by taking the trace and making machine code versions of all the operations in it.

This process assumes that the path through the loop that was traced is a "typical" example of possible paths (which is statistically likely). Of course it is possible that later another path through the loop is taken, therefore the machine code will contain guards, which check that the path is still the same. If during execution of the machine code a guard fails, the machine code is left and execution falls back to using interpretation (there are more complex mechanisms in place to still produce more code for the cases of guard failures, but they are of no importance for this post).

It is important to understand when the tracer considers a loop in the trace to be closed. This happens when the position key is the same as at an earlier point. The position key describes the position of the execution of the program, e.g. usually contains things like the function currently being executed and the program counter position of the tracing interpreter.

Let's look at a small example. Take the following code:

int sum_1_to_n(int n) {
    int result = 0;
    while (n >= 0) {
        result += n;
        n -= 1;
    }
    return result;
}

The tracing JIT will at one point trace the execution of the while loop in sum_1_to_n. The trace might look as follows:

guard_true(n >= 0);
result += n;
n -= 1;
<loop_back>

This trace will then be turned into machine code. Note that the machine code loop is by itself infinite and can only be left via a guard failure.

A slightly more complex example:

int f(int a, int b) {
    if (b % 46 == 41)
        return a - b;
    else
        return a + b;
}

int strange_sum(int n) {
    int result = 0;
    while (n >= 0) {
        result = f(result, n);
        n -= 1;
    }
    return result;
}

The trace of the loop in strange_sum would maybe look like this:

guard_true(n >= 0);
a = result;
b = n;
guard_false(b % 46 == 41);
result = a + b;
n -= 1;
<loop_back>

This would then be turned into machine code. Note how f was inlined into the loop and how the common else case was turned into machine code, while the other one is implemented via a guard failure.

Applying a Tracing JIT to an Interpreter

In the rest of the post we will explore what happens when the program that is being executed/compiled by the tracing JIT is itself a (bytecode) interpreter for another language.

A stylized bytecode interpreter for a simple programming language could look as follows:

W_Object interpret(String bytecode, ...) {
    Stack<W_Object> stack = new Stack<W_Object>();
    int pc = 0;
    while (true) { // bytecode dispatch loop
        char instruction = bytecode.charAt(pc);
        pc += 1;
        switch (instruction) {
            case ADD:
                W_Object arg2 = stack.pop();
                W_Object arg1 = stack.pop();
                stack.push(do_addition(arg1, arg2));
                break;
            case SUB:
                W_Object arg2 = stack.pop();
                W_Object arg1 = stack.pop();
                stack.push(do_substraction(arg1, arg2));
                break;
            case RETURN:
                return stack.pop();
            case JUMP_BACKWARD:
                pc -= (int)bytecode.charAt(pc);
                break;
            case LOAD_INTEGER:
                int value = (int)bytecode.charAt(pc);
                pc += 1;
                stack.push(new W_Integer(value));
                break;
            case PRINT:
                do_print(stack.pop());
                break;
            case DUP:
                stack.push(stack.peek());
                break;
            case JUMP_IF_TRUE:
                ...
            ...
        }
    }

If we apply a tracing JIT to this function, it will trace and compile the execution of one bytecode, because after one bytecode the bytecode dispatch loop is closed. E.g. it might trace and produce machine code for the execution of a SUB. (Sidenote: this interpret function is an example where one of the assumptions of a tracing JIT break down: two iterations of the bytecode dispatch loop are rarely going to follow the same code path, because usually two consecutive bytecodes encode different instructions).

The important bit to remember here is that the tracing JIT will produce a machine code loop that corresponds to the bytecode dispatch loop in the interpret function. Let's see how we can change that.

Improving the Generated Code

If we want to make use of the fact that the program that is being jitted is itself an interpreter, we need to change the tracing JIT a bit. To be more precise we add a way for the user of the tracing JIT to add information to the position key that the tracing JIT uses to decide when a loop is closed. This is done by a call to a magic function add_to_position_key. This allows the program writer to influence the tracing JIT's behaviour.

The semantics of add_to_position_key is as follows: The method itself does not do anything. It has an effect only when it is seen during tracing. If it is seen during tracing, the tracer adds the argument of the call to the position key that the tracer is using to find out whether a loop was closed or not.

In the example of the interpret function above, we would add a call to this function into the while loop as follows:

W_Object interpret(String bytecode, ...) {
    Stack stack = new Stack();
    int pc = 0;
    while (true) { // bytecode dispatch loop
        add_to_position_key(pc);
        add_to_position_key(bytecode);
        char instruction = bytecode.charAt(pc);
        pc += 1;
        switch (instruction) {
            case ADD:
    ...

When the modified tracing JIT traces now the interpret function executing a SUB, something interesting happens. When the bytecode loop is closed, the modified tracing JIT does not consider the trace to be a loop, because the value of pc has been increased by one, so the position key differs. Instead it continues to trace, effectively unrolling the bytecode dispatch loop of interpret.

The only way for a loop to be considered closed is if the pc variable has the same value a second time. This can only happen after a JUMP_BACKWARD instruction has been executed. A JUMP_BACKWARD instruction will only be in the bytecode when the bytecode represents a loop. This means that the modified tracing JIT will trace the interpret function and will only consider that the trace represents a loop when the bytecode itself represents a loop! Thus, a machine code loop will eventually be created that corresponds to the loop in the bytecode.

Let's look at at example. If we have a bytecode that corresponds to the following instructions:

pc |   instruction
---+---------------------
0  |  LOAD_INTEGER 0
2  |  DUP
3  |  PRINT
4  |  LOAD_INTEGER 1
6  |  ADD
7  |  JUMP_BACKWARD 6

This loop will print integers starting from 0 and going on from there. The modified tracing JIT will unroll the bytecode dispatch until it sees the JUMP_BACKWARD bytecode. After that bytecode the pc will be 2 again. Thus the earlier position key is repeated, which means that the loop will be closed. The produced machine code will do the equivalent of the following Java code:

...
guard_true(pc == 2)
guard_true(bytecode == "... correct bytecode string ...")
while (true) {
    instruction = bytecode.charAt(pc);
    pc += 1;
    guard_true(instruction == DUP);
    stack.push(stack.peek());

    instruction = bytecode.charAt(pc);
    pc += 1;
    guard_true(instruction == PRINT);
    do_print(stack.pop());

    instruction = bytecode.charAt(pc);
    pc += 1;
    guard_true(instruction == LOAD_INTEGER)
    value = (int)bytecode.charAt(pc);
    pc += 1
    stack.push(W_Integer(value))

    instruction = bytecode.charAt(pc);
    pc += 1;
    guard_true(instruction == ADD)
    arg2 = stack.pop()
    arg1 = stack.pop()
    stack.push(do_addition(arg1, arg2))

    instruction = bytecode.charAt(pc);
    pc += 1;
    guard_true(instruction == JUMP_BACKWARD)
    pc -= (int)bytecode.charAt(pc);
}

This is machine code that essentially does what the bytecode above did. Of course the code still remains some remnants of the interpreter (like the program counter manipulations, the stack handling, etc), which would have to be removed by some clever enough optimization step. If this were done, result would look a lot more natural.

Summary

If a tracing JIT is enhanced by a way to influence its loop-closing behaviour we can significantly improve its performance when the jitted program is itself an interpreter. The result is that in such a case the produced machine code will correspond to the functions that are being interpreted, not to the code of the interpreter itself.

Now, what does all this have to do with PyPy? What we are working on since a while is a sort of tracing JIT for RPython which allows to be customized with a function very similar to the add_to_position_key described above. This will make it possible to make the tracing JIT generate code that corresponds to the code that the interpreter interprets. For example, we would add a call to add_to_position_key to SPy, PyPy's Smalltalk VM. Then the tracing JIT will produce machine code for Smalltalk-level loops, with all the usual benefits of a tracing JIT (like inlining of intermediate methods, constant-folding, ...). This JIT differs from normal tracing JITs in that it also supports very powerful constant-folding and allocation-removal optimizations. Those optimizations will (hopefully) be the content of a later blog post.

The basics of this process have been working fine since quite a while. What the work currently focuses on is to improve the optimizers to remove not only the bytecode manipulation code, but also the stack handling, and a large number of other inefficiencies.

4 comments:

Benjamin Peterson said...

Wow, that's very cool! PyPy is an amazing novel project, but how did you guys ever think of this?

Unknown said...

great explanaition.. thanks!

Anonymous said...

Very nice. You might want to have a look at Sullivan, et al (2003) Dynamic Native Optimization of Interpreters. They also identify the need to record not only the native PC but also the interpreter's virtual PC to identify useful trace heads. They provide three intrinsic functions (compared to your single add_to_position_key) to achieve this. Further, they provide three more intrinsics to support constant propagation.

Carl Friedrich Bolz-Tereick said...

Wow, that's extremely interesting! It's indeed very similar to what I describe in the blog post (apparently Armin knew of the paper, but I didn't). Of course they are severely hampered by the fact that the system is working on assembler level, so they don't really have enough information available to do really interesting optimizations.