Abstract interpretation in the Toy Optimizer
This is a cross-post from Max Bernstein from his excellent blog where he writes about programming languages, compilers, optimizations, virtual machines. He's looking for a (dynamic language runtime or compiler related) job too.
CF Bolz-Tereick wrote some excellent posts in which they introduce a small IR and optimizer and extend it with allocation removal. We also did a live stream together in which we did some more heap optimizations.
In this blog post, I'm going to write a small abstract interpreter for the Toy IR and then show how we can use it to do some simple optimizations. It assumes that you are familiar with the little IR, which I have reproduced unchanged in a GitHub Gist.
Abstract interpretation is a general framework for efficiently computing properties that must be true for all possible executions of a program. It's a widely used approach both in compiler optimizations as well as offline static analysis for finding bugs. I'm writing this post to pave the way for CF's next post on proving abstract interpreters correct for range analysis and known bits analysis inside PyPy.
Before we begin, I want to note a couple of things:
- The Toy IR is in SSA form, which means that every variable is defined exactly once. This means that abstract properties of each variable are easy to track.
- The Toy IR represents a linear trace without control flow, meaning we won't talk about meet/join or fixpoints. They only make sense if the IR has a notion of conditional branches or back edges (loops).
Alright, let's get started.
Welcome to abstract interpretation¶
Abstract interpretation means a couple different things to different people. There's rigorous mathematical formalism thanks to Patrick and Radhia Cousot, our favorite power couple, and there's also sketchy hand-wavy stuff like what will follow in this post. In the end, all people are trying to do is reason about program behavior without running it.
In particular, abstract interpretation is an over-approximation of the behavior of a program. Correctly implemented abstract interpreters never lie, but they might be a little bit pessimistic. This is because instead of using real values and running the program---which would produce a concrete result and some real-world behavior---we "run" the program with a parallel universe of abstract values. This abstract run gives us information about all possible runs of the program.1
Abstract values always represent sets of concrete values. Instead of literally storing a set (in the world of integers, for example, it could get pretty big...there are a lot of integers), we group them into a finite number of named subsets.2
Let's learn a little about abstract interpretation with an example program and example abstract domain. Here's the example program:
v0 = 1 v1 = 2 v2 = add(v0, v1)
And our abstract domain is "is the number positive" (where "positive" means nonnegative, but I wanted to keep the words distinct):
top / \ positive negative \ / bottom
The special top value means "I don't know" and the special bottom value means "empty set" or "unreachable". The positive and negative values represent the sets of all positive and negative numbers, respectively.
We initialize all the variables v0
, v1
, and v2
to bottom and then walk
our IR, updating our knowledge as we go.
# here v0:bottom = 1 v1:bottom = 2 v2:bottom = add(v0, v1)
In order to do that, we have to have transfer functions for each operation. For constants, the transfer function is easy: determine if the constant is positive or negative. For other operations, we have to define a function that takes the abstract values of the operands and returns the abstract value of the result.
In order to be correct, transfer functions for operations have to be compatible with the behavior of their corresponding concrete implementations. You can think of them having an implicit universal quantifier forall in front of them.
Let's step through the constants at least:
v0:positive = 1 v1:positive = 2 # here v2:bottom = add(v0, v1)
Now we need to figure out the transfer function for add
. It's kind of tricky
right now because we haven't specified our abstract domain very well. I keep
saying "numbers", but what kinds of numbers? Integers? Real numbers? Floating
point? Some kind of fixed-width bit vector (int8
, uint32
, ...) like an
actual machine "integer"?
For this post, I am going to use the mathematical definition of integer, which
means that the values are not bounded in size and therefore do not overflow.
Actual hardware memory constraints aside, this is kind of like a Python int
.
So let's look at what happens when we add two abstract numbers:
top | positive | negative | bottom | |
---|---|---|---|---|
top | top | top | top | bottom |
positive | top | positive | top | bottom |
negative | top | top | negative | bottom |
bottom | bottom | bottom | bottom | bottom |
As an example, let's try to add two numbers a
and b
, where a
is positive
and b
is negative. We don't know anything about their values other than their
signs. They could be 5
and -3
, where the result is 2
, or they could be
1
and -100
, where the result is -99
. This is why we can't say anything
about the result of this operation and have to return top.
The short of this table is that we only really know the result of an addition
if both operands are positive or both operands are negative. Thankfully, in
this example, both operands are known positive. So we can learn something about
v2
:
v0:positive = 1 v1:positive = 2 v2:positive = add(v0, v1) # here
This may not seem useful in isolation, but analyzing more complex programs even
with this simple domain may be able to remove checks such as if (v2 < 0) { ... }
.
Let's take a look at another example using an sample absval
(absolute value)
IR operation:
v0 = getarg(0) v1 = getarg(1) v2 = absval(v0) v3 = absval(v1) v4 = add(v2, v3) v5 = absval(v4)
Even though we have no constant/concrete values, we can still learn something
about the states of values throughout the program. Since we know that absval
always returns a positive number, we learn that v2
, v3
, and v4
are all
positive. This means that we can optimize out the absval
operation on v5
:
v0:top = getarg(0) v1:top = getarg(1) v2:positive = absval(v0) v3:positive = absval(v1) v4:positive = add(v2, v3) v5:positive = v4
Other interesting lattices include:
- Constants (where the middle row is pretty wide)
- Range analysis (bounds on min and max of a number)
- Known bits (using a bitvector representation of a number, which bits are always 0 or 1)
For the rest of this blog post, we are going to do a very limited version of "known bits", called parity. This analysis only tracks the least significant bit of a number, which indicates if it is even or odd.
Parity¶
The lattice is pretty similar to the positive/negative lattice:
top / \ even odd \ / bottom
Let's define a data structure to represent this in Python code:
class Parity: def __init__(self, name): self.name = name def __repr__(self): return self.name
And instantiate the members of the lattice:
TOP = Parity("top") EVEN = Parity("even") ODD = Parity("odd") BOTTOM = Parity("bottom")
Now let's write a forward flow analysis of a basic block using this lattice.
We'll do that by assuming that a method on Parity
is defined for each IR
operation. For example, Parity.add
, Parity.lshift
, etc.
def analyze(block: Block) -> None: parity = {v: BOTTOM for v in block} def parity_of(value): if isinstance(value, Constant): return Parity.const(value) return parity[value] for op in block: transfer = getattr(Parity, op.name) args = [parity_of(arg.find()) for arg in op.args] parity[op] = transfer(*args)
For every operation, we compute the abstract value---the parity---of the
arguments and then call the corresponding method on Parity
to get the
abstract result.
We need to special case Constant
s due to a quirk of how the Toy IR is
constructed: the constants don't appear in the instruction stream and instead
are free-floating.
Let's start by looking at the abstraction function for concrete values---constants:
class Parity: # ... @staticmethod def const(value): if value.value % 2 == 0: return EVEN else: return ODD
Seems reasonable enough. Let's pause on operations for a moment and consider an example program:
v0 = getarg(0) v1 = getarg(1) v2 = lshift(v0, 1) v3 = lshift(v1, 1) v4 = add(v2, v3) v5 = dummy(v4)
This function (which is admittedly a little contrived) takes two inputs, shifts
them left by one bit, adds the result, and then checks the least significant
bit of the addition result. It then passes that result into a dummy
function,
which you can think of as "return" or "escape".
To do some abstract interpretation on this program, we'll need to implement the
transfer functions for lshift
and add
(dummy
will just always return
TOP
). We'll start with add
. Remember that adding two even numbers returns
an even number, adding two odd numbers returns an even number, and mixing even
and odd returns an odd number.
class Parity: # ... def add(self, other): if self is BOTTOM or other is BOTTOM: return BOTTOM if self is TOP or other is TOP: return TOP if self is EVEN and other is EVEN: return EVEN if self is ODD and other is ODD: return EVEN return ODD
We also need to fill in the other cases where the operands are top or bottom. In this case, they are both "contagious"; if either operand is bottom, the result is as well. If neither is bottom but either operand is top, the result is as well.
Now let's look at lshift
. Shifting any number left by a non-zero number of
bits will always result in an even number, but we need to be careful about the
zero case! Shifting by zero doesn't change the number at all. Unfortunately,
since our lattice has no notion of zero, we have to over-approximate here:
class Parity: # ... def lshift(self, other): # self << other if other is ODD: return EVEN return TOP
This means that we will miss some opportunities to optimize, but it's a tradeoff that's just part of the game. (We could also add more elements to our lattice, but that's a topic for another day.)
Now, if we run our abstract interpretation, we'll collect some interesting
properties about the program. If we temporarily hack on the internals of
bb_to_str
, we can print out parity information alongside the IR operations:
v0:top = getarg(0) v1:top = getarg(1) v2:even = lshift(v0, 1) v3:even = lshift(v1, 1) v4:even = add(v2, v3) v5:top = dummy(v4)
This is pretty awesome, because we can see that v4
, the result of the
addition, is always even. Maybe we can do something with that information.
Optimization¶
One way that a program might check if a number is odd is by checking the least
significant bit. This is a common pattern in C code, where you might see code
like y = x & 1
. Let's introduce a bitand
IR operation that acts like the
&
operator in C/Python. Here is an example of use of it in our program:
v0 = getarg(0) v1 = getarg(1) v2 = lshift(v0, 1) v3 = lshift(v1, 1) v4 = add(v2, v3) v5 = bitand(v4, 1) # new! v6 = dummy(v5)
We'll hold off on implementing the transfer function for it---that's left as an exercise for the reader---and instead do something different.
Instead, we'll see if we can optimize operations of the form bitand(X, 1)
. If
we statically know the parity as a result of abstract interpretation, we can
replace the bitand
with a constant 0
or 1
.
We'll first modify the analyze
function (and rename it) to return a new
Block
containing optimized instructions:
def simplify(block: Block) -> Block: parity = {v: BOTTOM for v in block} def parity_of(value): if isinstance(value, Constant): return Parity.const(value) return parity[value] result = Block() for op in block: # TODO: Optimize op # Emit result.append(op) # Analyze transfer = getattr(Parity, op.name) args = [parity_of(arg.find()) for arg in op.args] parity[op] = transfer(*args) return result
We're approaching this the way that PyPy does things under the hood, which is all in roughly a single pass. It tries to optimize an instruction away, and if it can't, it copies it into the new block.
Now let's add in the bitand
optimization. It's mostly some gross-looking
pattern matching that checks if the right hand side of a bitwise and
operation is 1
(TODO: the left hand side, too). CF had some neat ideas on how
to make this more ergonomic, which I might save for later.3
Then, if we know the parity, optimize the bitand
into a constant.
def simplify(block: Block) -> Block: parity = {v: BOTTOM for v in block} def parity_of(value): if isinstance(value, Constant): return Parity.const(value) return parity[value] result = Block() for op in block: # Try to simplify if isinstance(op, Operation) and op.name == "bitand": arg = op.arg(0) mask = op.arg(1) if isinstance(mask, Constant) and mask.value == 1: if parity_of(arg) is EVEN: op.make_equal_to(Constant(0)) continue elif parity_of(arg) is ODD: op.make_equal_to(Constant(1)) continue # Emit result.append(op) # Analyze transfer = getattr(Parity, op.name) args = [parity_of(arg.find()) for arg in op.args] parity[op] = transfer(*args) return result
Remember: because we use union-find to rewrite instructions in the optimizer
(make_equal_to
), later uses of the same instruction get the new
optimized version "for free" (find
).
Let's see how it works on our IR:
v0 = getarg(0) v1 = getarg(1) v2 = lshift(v0, 1) v3 = lshift(v1, 1) v4 = add(v2, v3) v6 = dummy(0)
Hey, neat! bitand
disappeared and the argument to dummy
is now the constant
0
because we know the lowest bit.
Wrapping up¶
Hopefully you have gained a little bit of an intuitive understanding of abstract interpretation. Last year, being able to write some code made me more comfortable with the math. Now being more comfortable with the math is helping me write the code. It's nice upward spiral.
The two abstract domains we used in this post are simple and not very useful in practice but it's possible to get very far using slightly more complicated abstract domains. Common domains include: constant propagation, type inference, range analysis, effect inference, liveness, etc. For example, here is a a sample lattice for constant propagation:
It has multiple levels to indicate more and less precision. For example, you
might learn that a variable is either 1
or 2
and be able to encode that as
nonnegative
instead of just going straight to top
.
Check out some real-world abstract interpretation in open source projects:
- Known bits in LLVM
- Constant range in LLVM
- But I am told that the ranges don't form a lattice (see Interval Analysis and Machine Arithmetic: Why Signedness Ignorance Is Bliss)
- Tristate numbers for known bits in Linux eBPF
- Range analysis in Linux eBPF
- GDB prologue analysis of assembly to understand the stack and find frame pointers without using DWARF (some docs)
If you have some readable examples, please share them so I can add.
Acknowledgements¶
Thank you to CF Bolz-Tereick for the toy optimizer and helping edit this post!
-
In the words of abstract interpretation researchers Vincent Laviron and Francesco Logozzo in their paper Refining Abstract Interpretation-based Static Analyses with Hints (APLAS 2009):
The three main elements of an abstract interpretation are: (i) the abstract elements ("which properties am I interested in?"); (ii) the abstract transfer functions ("which is the abstract semantics of basic statements?"); and (iii) the abstract operations ("how do I combine the abstract elements?").
We don't have any of these "abstract operations" in this post because there's no control flow but you can read about them elsewhere! ↩
-
These abstract values are arranged in a lattice, which is a mathematical structure with some properties but the most important ones are that it has a top, a bottom, a partial order, a meet operation, and values can only move in one direction on the lattice.
Using abstract values from a lattice promises two things:
- The analysis will terminate
- The analysis will be correct for any run of the program, not just one sample run
-
Something about
__match_args__
and@property
... ↩
Comments