Musings on Tracing in PyPy
Last summer, Shriram Krishnamurthi asked on Twitter:
"I'm curious what the current state of tracing JITs is. They used to be all the rage for a while, then I though I heard they weren't so effective, then I haven't heard of them at all. Is the latter because they are ubiquitous, or because they proved to not work so well?"
I replied with my personal (pretty subjective) opinions about the question in a lengthy Twitter thread (which also spawned an even lengthier discussion). I wanted to turn what I wrote there into a blog post to make it more widely available (Twitter is no longer easily consumable without an account), and also because I'm mostly not using Twitter anymore. The blog post i still somewhat terse, I've written a small background section and tried to at least add links to further information. Please ask in the comments if something is particularly unclear.
Background
I'll explain a few of the central terms of the rest of the post. JIT compilers are compilers that do their work at runtime, interleaved (or concurrent with) the execution of the program. There are (at least) two common general styles of JIT compiler architectures. The most common one is that of a method-based JIT, which will compile one method or function at a time. Then there are tracing JIT compilers, which generate code by tracing the execution of the user's program. They often focus on loops as their main unit of compilation.
Then there is the distinction between a "regular" JIT compiler and that of a meta-JIT. A regular JIT is built to compile one specific source language to machine code. A meta-JIT is a framework for building JIT compilers for a variety of different languages, re-using as much machinery as possible between the different implementation.
Personal and Project Context
Some personal context: my perspective is informed by nearly two decades of work on PyPy. PyPy's implementation language, RPython, has support for a meta-JIT, which allows it to reuse its JIT infrastructure for the various Python versions that we support (currently we do releases of PyPy2.7 and PyPy3.10 together). Our meta-JIT infrastructure has been used for some experimental different languages like:
- PyPy's regular expression engine
- RPySom, a tiny Smalltalk
- Ruby
- PHP
- Prolog,
- Racket,
- a database (SQLite)
- Lox, the language of Crafting Interpreters
- an ARM and RISC-V emulator
- and many more
Those implementations had various degrees of maturity and many of them are research software and aren't maintained any more.
PyPy gives itself the goal to try to be extremely compatible with all the quirks of the Python language. We don't change the Python language to make things easier to compile and we support the introspection and debugging features of Python. We try very hard to have no opinions on language design. The CPython core developers come up with the semantics, we somehow deal with them.
Meta-tracing
PyPy started using a tracing JIT approach not because we thought method-based just-in-time compilers are bad. Historically we had tried to implement a method-based meta-JIT that was using partial evaluation (we wrote three or four method-based prototypes that all weren't as good as we hoped). After all those experiments failed we switched to the tracing approach, and only at this point did our meta-JIT start producing interesting performance.
In the meta-JIT context tracing has good properties, because tracing has relatively understandable behavior and its easy(ish) to tweak how things work with extra annotations in the interpreter source.
Another reason why meta-tracing often works well for PyPy is that it can often slice through the complicated layers of Python quite effectively and remove a lot of overhead. Python is often described as simple, but I think that's actually a misconception. On the implementation level it's a very big and complicated language, and it is also continuously getting new features every year (the language is quite a bit more complicated than Javascript, for example1).
Truffle
Later Truffle came along and made a method-based meta-JIT using partial evaluation work. However Truffle (and Graal) has had significantly more people working on it and much more money invested. In addition, it at first required a quite specific style of AST-based interpreters (in the last few years they have also added support for bytecode-based interpreters).
It's still my impression that getting similar results with Truffle is more work for language implementers than with RPython, and the warmup of Truffle can often pretty bad. But Truffle is definitely an existence proof that meta-JITs don't have to be based on tracing.
Tracing, the good
Let's now actually get to he heart of Shriram's question and discuss some of the advantages of tracing that go beyond the ease of using tracing for a meta-JIT.
Tracing allows for doing very aggressive partial inlining, Following just the hot path through lots of layers of abstraction is obviously often really useful for generating fast code.
It's definitely possible to achieve the same effect in a method-based context with path splitting. But it requires a lot more implementation work and is not trivial, because the path execution counts of inlined functions can often be very call-site dependent. Tracing, on the other hand, gives you call-site dependent path splitting "for free".
(The aggressive partial inlining and path splitting is even more important in the meta-tracing context of PyPy, where some of inlined layers are a part of the language runtime, and where rare corner cases that are never executed in practice are everywhere.)
Another advantage of tracing is that it makes a number of optimizations really easy to implement, because there are (to first approximation) no control flow merges. This makes all the optimizations that we do (super-)local optimizations, that operate on a single (very long) basic block. This allows the JIT to do the optimizations in exactly one forwards and one backwards pass. An example is our allocation removal/partial escape analysis pass, which is quite simple, whereas the version for general control flow has a lot more complexity, particularly in its handling of loops.
This ease of implementation of optimizations allowed us to implement some pretty decent optimizations. Our allocation removal, the way PyPy's JIT can reason about the heap, about dictionary accesses, about properties of functions of the runtime, about the range and known bits of integer variables is all quite solid.
Tracing, the bad
Tracing also comes with a significant number of downsides. Probably the biggest one is that it tends to have big performance cliffs (PyPy certainly has them, and other tracing JITs such as TraceMonkey had them too). In my experience the 'good' cases of tracing are really good, but if something goes wrong you are annoyed and performance can become a lot slower. With a simple method JIT the performance is often much more "even".
Another set of downsides is that tracing has a number of corner cases and "weird" behaviour in certain situations. Questions such as: - When do you stop inlining? - What happens when you trace recursion? - What happens if your traces are consistently too long, even without inlining? - and so on...
Some of these problems can be solved by adding heuristics to the tracing JIT, but doing so loses a lot of the simplicity of tracing again.
There are also some classes of programs that tend to generally perform quite poorly when they are executed by a tracing JIT, bytecode interpreters in particularly, and other extremely unpredictably branchy code. This is because the core assumption of the tracing jit "loops take similar control flow paths" is just really wrong in the case of interpreters.
Discussion
The Twitter thread spawned quite a bit of discussion, please look at the original thread for all of the comments. Here are three that I wanted to highlight:
"This is a really great summary. Meta-tracing is probably the one biggest success story. I think it has to do with how big and branchy the bytecode implementations are for typical dynamic languages; the trace captures latent type feedback naturally.
There is an upper limit, tho."
I agree with this completely. The complexity of Python bytecodes is a big factor for why meta tracing works well for us. But also in Python there are many builtin types (collection types, types that form the meta-object protocol of Python, standard library modules implemented in C/RPython) and tracing operations on them is very important too, for good performance.
"I think Mozilla had a blog post talking more about the difficulty with TraceMonkey, could only find this one: https://blog.mozilla.org/nnethercote/category/jagermonkey/"
"imo doing tracing for JS is really hard mode, because the browser is so incredibly warmup-sensitive. IIRC tracemonkey used a really low loop trip count (single-digit?) to decide when to start tracing (pypy uses >1000). the JS interpreters of the time were also quite slow."
In the meantime there were some more reminiscences about tracing in Javascript by Shu-Yu Guo in a panel discussion and by Jason Orendorff on Mastodon.
"There are a number of corner cases you have to deal with in a tracing JIT. It's unfortunately not as simple and easy as the initial papers would have you believe. One example is how would you deal with a loop inside a loop? Is your tracing now recursive?
There's been some research work on trace stitching to deal with trace explosion but it does add complexity. With the increase in complexity, I think most industrial VM developers would rather pick tried-and-true method-based JITs that are well understood."
Conclusion
Given access to enough developers and in the context of "normal" jitting (ie not meta-jitting) it's very unclear to me that you should use tracing. It makes more sense to rather spend effort on a solid control-flow-graph-based baseline and then try to get some of the good properties of tracing on top (path splitting, partial inlining, partial escape analysis, etc).
For PyPy with its meta-JIT (and the fact that we don't have particularly much funding nor people) I still think tracing was/is a relatively pragmatic choice. When I talked with Sam Tobin-Hochstadt about this topic recently he characterized it like this: "tracing is a labor-saving device for compiler authors".
Performance-wise PyPy is still quite hard to beat in the cases where it works well (i.e. pure Python code that doesn't use too many C modules, which are supported but slow in PyPy). In general, there are very few JITs for Python (particularly with the constraint of not being "allowed" to change the language), the most competitive other ones are GraalPy, also based on a meta-JIT approach. Instagram is running on Cinder and also CPython has grown a JIT recently which was part of the recent 3.13 release, but only as an off-by-default build option, so I'm very excited about how Python's performance will develop in the next years!
-
(A side point: people who haven't worked on Python tend to underestimate its complexity and pace of development. A pet peeve of mine is C++ compiler devs/static analysis/Javascript people/other well-meaning communities coming with statements like "why don't you just..." 🤷♀️) ↩
Towards PyPy3.11 - an update
We1 are steadily working towards a Python 3.11 interpreter, which will be part of the upcoming PyPy 7.3.18 release. Along with that, we also recently updated speed.pypy.org to compare PyPy's performance to CPython 3.11 (it used to be CPython 3.7).
Guest Post: Final Encoding in RPython Interpreters
Introduction
This post started as a quick note summarizing a recent experiment I carried out upon a small RPython interpreter by rewriting it in an uncommon style. It is written for folks who have already written some RPython and want to take a deeper look at interpreter architecture.
Some experiments are about finding solutions to problems. This experiment is about taking a solution which is already well-understood and applying it in the context of RPython to find a new approach. As we will see, there is no real change in functionality or the number of clauses in the interpreter; it's more like a comparison between endo- and exoskeletons, a different arrangement of equivalent bones and plates.
Overview
An RPython interpreter for a programming language generally does three or four things, in order:
- Read and parse input programs
- Encode concrete syntax as abstract syntax
- Optionally, optimize or reduce the abstract syntax
- Evaluate the abstract syntax: read input data, compute, print output data, etc.
Today we'll look at abstract syntax. Most programming languages admit a concrete parse tree which is readily abstracted to provide an abstract syntax tree (AST). The AST is usually encoded with the initial style of encoding. An initial encoding can be transformed into any other encoding for the same AST, looks like a hierarchy of classes, and is implemented as a static structure on the heap.
In contrast, there is also a final encoding. A final encoding can be
transformed into by any other encoding, looks like an interface for the
actions of the interpreter, and is implemented as an unwinding structure on
the stack. From the RPython perspective, Python builtin modules like os
or
sys
are final encodings for features of the operating system; the underlying
implementation is different when translated or untranslated, but the interface
used to access those features does not change.
In RPython, an initial encoding is built from a hierarchy of classes. Each
class represents a type of tree nodes, corresponding to a parser production in
the concrete parse tree. Each class instance therefore represents an
individual tree node. The fields of a class, particularly those filled during
.__init__()
, store pre-computed properties of each node; methods can be used
to compute node properties on demand. This seems like an obvious and simple
approach; what other approaches could there be? We need an example.
Final Encoding of Brainfuck
We will consider Brainfuck, a simple Turing-complete programming language. An example Brainfuck program might be:
[-]
This program is built from a loop and a decrement, and sets a cell to zero. In an initial encoding which follows the algebraic semantics of Brainfuck, the program could be expressed by applying class constructors to build a structure on the heap:
Loop(Plus(-1))
A final encoding is similar, except that class constructors are replaced by methods, the structure is built on the stack, and we are parameterized over the choice of class:
lambda cls: cls.loop(cls.plus(-1))
In ordinary Python, transforming between these would be trivial, and mostly is a matter of passing around the appropriate class. Indeed, initial and final encodings are equivalent; we'll return to that fact later. However, in RPython, all of the types must line up, and classes must be determined before translation. We'll need to monomorphize our final encodings, using some RPython tricks later on. Before that, let's see what an actual Brainfuck interface looks like, so that we can cover all of the difficulties with final encoding.
Before we embark, please keep in mind that local code doesn't know what cls
is. There's no type-safe way to inspect an arbitrary semantic domain. In the
initial-encoded version, we can ask isinstance(bf, Loop)
to see whether an
AST node is a loop, but there simply isn't an equivalent for final-encoded
ASTs. So, there is an implicit challenge to think about: how do we evaluate a
program in an arbitrary semantic domain? For bonus points, how do we optimize
a program without inspecting the types of its AST nodes?
What follows is a dissection of this module at the given revision. Readers may find it satisfying to read the entire interpreter top to bottom first; it is less than 300 lines.
Core Functionality
Final encoding is given as methods on an interface. These five methods correspond precisely to the summands of the algebra of Brainfuck.
class BF(object): # Other methods elided def plus(self, i): pass def right(self, i): pass def input(self): pass def output(self): pass def loop(self, bfs): pass
Note that the .loop()
method takes another program as an argument.
Initial-encoded ASTs have other initial-encoded ASTs as fields on class
instances; final-encoded ASTs have other final-encoded ASTs as parameters
to interface methods. RPython infers all of the types, so the reader has to
know that i
is usually an integer while bfs
is a sequence of Brainfuck
operations.
We're using a class to implement this functionality. Later, we'll treat it as a mixin, rather than a superclass, to avoid typing problems.
Monoid
In order to optimize input programs, we'll need to represent the underlying monoid of Brainfuck programs. To do this, we add the signature for a monoid:
class BF(object): # Other methods elided def unit(self): pass def join(self, l, r): pass
This is technically a unital magma, since RPython doesn't support algebraic laws, but we will enforce the algebraic laws later on during optimization. We also want to make use of the folklore that free monoids are lists, allowing callers to pass a list of actions which we'll reduce with recursion:
class BF(object): # Other methods elided def joinList(self, bfs): if not bfs: return self.unit() elif len(bfs) == 1: return bfs[0] elif len(bfs) == 2: return self.join(bfs[0], bfs[1]) else: i = len(bfs) >> 1 return self.join(self.joinList(bfs[:i]), self.joinList(bfs[i:]))
.joinList()
is a little bulky to implement, but Wirth's principle applies:
the interpreter is shorter with it than without it.
Idioms
Finally, our interface includes a few high-level idioms, like the zero program
shown earlier, which are defined in terms of low-level behaviors. In an
initial encoding, these could be defined as module-level functions; here, we
define them on the mixin class BF
.
class BF(object): # Other methods elided def zero(self): return self.loop(self.plus(-1)) def move(self, i): return self.scalemove(i, 1) def move2(self, i, j): return self.scalemove2(i, 1, j, 1) def scalemove(self, i, s): return self.loop(self.joinList([ self.plus(-1), self.right(i), self.plus(s), self.right(-i)])) def scalemove2(self, i, s, j, t): return self.loop(self.joinList([ self.plus(-1), self.right(i), self.plus(s), self.right(j - i), self.plus(t), self.right(-j)]))
Interface-oriented Architecture
Applying Interfaces
Now, we hack at RPython's object model until everything translates. First, consider the task of pretty-printing. For Brainfuck, we'll simply regurgitate the input program as a Python string:
class AsStr(object): import_from_mixin(BF) def unit(self): return "" def join(self, l, r): return l + r def plus(self, i): return '+' * i if i > 0 else '-' * -i def right(self, i): return '>' * i if i > 0 else '<' * -i def loop(self, bfs): return '[' + bfs + ']' def input(self): return ',' def output(self): return '.'
Via rlib.objectmodel.import_from_mixin
, no stressing with covariance of
return types is required. Instead, we shift from a Java-esque view of classes
and objects, to an OCaml-ish view of prebuilt classes and constructors.
AsStr
is monomorphic, and any caller of it will have to create their own
covariance somehow. For example, here are the first few lines of the parsing
function:
@specialize.argtype(1) def parse(s, domain): ops = [domain.unit()] # Parser elided to preserve the reader's attention
By invoking rlib.objectmodel.specialize.argtype
, we make copies of the
parsing function, up to one per call site, based on our choice of semantic
domain. Oleg calls these "symantics"
but I prefer "domain" in code. Also, note how the parsing stack starts with
the unit of the monoid, which corresponds to the empty input string; the
parser will repeatedly use the monoidal join to build up a parsed expression
without inspecting it. Here's a small taste of that:
while i < len(s): char = s[i] if char == '+': ops[-1] = domain.join(ops[-1], domain.plus(1)) elif char == '-': ops[-1] = domain.join(ops[-1], domain.plus(-1)) # and so on
The reader may feel justifiably mystified; what breaks if we don't add these
magic annotations? Well, the translator will throw UnionError
because the
low-level types don't match. RPython only wants to make one copy of functions
like parse()
in its low-level representation, and each copy of parse()
will be compiled to monomorphic machine code. In this interpreter, in order to
support parsing to an optimized string and also parsing to an evaluator, we
need two copies of parse()
. It is okay to not fully understand this at
first.
Composing Interfaces
Earlier, we noted that an interpreter can optionally optimize input programs after parsing. To support this, we'll precompose a peephole optimizer onto an arbitrary domain. We could also postcompose with a parser instead, but that sounds more difficult. Here are the relevant parts:
def makePeephole(cls): domain = cls() def stripDomain(bfs): return domain.joinList([t[0] for t in bfs]) class Peephole(object): import_from_mixin(BF) def unit(self): return [] def join(self, l, r): return l + r # Actual definition elided... for now... return Peephole, stripDomain
Don't worry about the actual optimization yet. What's important here is the
pattern of initialization of semantic domains. makePeephole
is an
SML-style functor on semantic
domains: given a final encoding of Brainfuck, it produces another final
encoding of Brainfuck which incorporates optimizations. The helper
stripDomain
is a finalizer which performs the extraction from the
optimizer's domain to the underlying cls
that was passed in at translation
time. For example, let's optimize pretty-printing:
AsStr, finishStr = makePeephole(AsStr)
Now, it only takes one line to parse and print an optimized AST without ever building it on the heap. To be pedantic, fragments of the output string will be heap-allocated, but the AST's node structure will only ever be stack-allocated. Further, to be shallow, the parser is written to prevent malicious input from causing a stack overflow, and this forces it to maintain a heap-allocated RPython list of intermediate operations inside loops.
print finishStr(parse(text, AsStr()))
Performance
But is it fast? Yes. It's faster than the prior version, which was initial-encoded, and also faster than Andrew Brown's classic version (part 1, part 2). Since Brown's interpreter does not perform much optimization, we will focus on how final encoding can outperform initial encoding.
JIT
First, why is it faster than the same interpreter with initial encoding? Well,
it still has initial encoding from the JIT's perspective! There is an Op
class with a hierarchy of subclasses implementing individual behaviors. A
sincere tagless-final student, or those who remember Stop Writing Classes
(2012, Pycon
US), will
recognize that the following classes could be plain functions, and should
think of the classes as a concession to RPython's lack of support for lambdas
with closures rather than an initial encoding. We aren't ever going to
directly typecheck any Op
, but the JIT will generate typechecking guards
anyway, so we effectively get a fully-promoted AST inlined into each JIT
trace. First, some simple behaviors:
class Op(object): _immutable_ = True class _Input(Op): _immutable_ = True def runOn(self, tape, position): tape[position] = ord(os.read(0, 1)[0]) return position Input = _Input() class _Output(Op): _immutable_ = True def runOn(self, tape, position): os.write(1, chr(tape[position])) return position Output = _Output() class Add(Op): _immutable_ = True _immutable_fields_ = "imm", def __init__(self, imm): self.imm = imm def runOn(self, tape, position): tape[position] += self.imm return position
The JIT does technically have less information than before; it no longer knows
that a sequence of immutable operations is immutable enough to be worth
unrolling, but a bit of rlib.jit.unroll_safe
fixes that:
class Seq(Op): _immutable_ = True _immutable_fields_ = "ops[*]", def __init__(self, ops): self.ops = ops @unroll_safe def runOn(self, tape, position): for op in self.ops: position = op.runOn(tape, position) return position
Finally, the JIT entry point is at the head of each loop, just like with prior interpreters. Since Brainfuck doesn't support mid-loop jumps, there's no penalty for only allowing merge points at the head of the loop.
class Loop(Op): _immutable_ = True _immutable_fields_ = "op", def __init__(self, op): self.op = op def runOn(self, tape, position): op = self.op while tape[position]: jitdriver.jit_merge_point(op=op, position=position, tape=tape) position = op.runOn(tape, position) return position
That's the end of the implicit challenge. There's no secret to it; just
evaluate the AST. Here's part of the semantic domain for evaluation, as well
as the "functor" to optimize it. In AsOps.join()
are the only
isinstance()
calls in the entire interpreter! This is acceptable because
Seq
is effectively a type wrapper for an RPython list, so that a list of
operations is also an operation; its list is initial-encoded and available for
inspection.
class AsOps(object): import_from_mixin(BF) def unit(self): return Shift(0) def join(self, l, r): if isinstance(l, Seq) and isinstance(r, Seq): return Seq(l.ops + r.ops) elif isinstance(l, Seq): return Seq(l.ops + [r]) elif isinstance(r, Seq): return Seq([l] + r.ops) return Seq([l, r]) # Other methods elided! AsOps, finishOps = makePeephole(AsOps)
And finally here is the actual top-level code to evaluate the input program. As before, once everything is composed, the actual invocation only takes one line.
tape = bytearray("\x00" * cells) finishOps(parse(text, AsOps())).runOn(tape, 0)
Peephole Optimization
Our peephole optimizer is an abstract interpreter with one instruction of lookahead/rewrite buffer. It implements the aforementioned algebraic laws of the Brainfuck monoid. It also implements idiom recognition for loops. First, the abstract interpreter. The abstract domain has six elements:
class AbstractDomain(object): pass meh, aLoop, aZero, theIdentity, anAdd, aRight = [AbstractDomain() for _ in range(6)]
We'll also tag everything with an integer, so that anAdd
or aRight
can be
exact annotations. This is the actual Peephole.join()
method:
def join(self, l, r): if not l: return r rv = l[:] bfHead, adHead, immHead = rv.pop() for bf, ad, imm in r: if ad is theIdentity: continue elif adHead is aLoop and ad is aLoop: continue elif adHead is theIdentity: bfHead, adHead, immHead = bf, ad, imm elif adHead is anAdd and ad is aZero: bfHead, adHead, immHead = bf, ad, imm elif adHead is anAdd and ad is anAdd: immHead += imm if immHead: bfHead = domain.plus(immHead) elif rv: bfHead, adHead, immHead = rv.pop() else: bfHead = domain.unit() adHead = theIdentity elif adHead is aRight and ad is aRight: immHead += imm if immHead: bfHead = domain.right(immHead) elif rv: bfHead, adHead, immHead = rv.pop() else: bfHead = domain.unit() adHead = theIdentity else: rv.append((bfHead, adHead, immHead)) bfHead, adHead, immHead = bf, ad, imm rv.append((bfHead, adHead, immHead)) return rv
If this were to get much longer, then implementing a
DSL would be worth it,
but this is a short-enough method to inline. The abstract interpretation is
assumed by induction for the left-hand side of the join, save for the final
instruction, which is loaded into a rewrite register. Each instruction on the
right-hand side is inspected exactly once. The logic for anAdd
followed by
anAdd
is exactly the same as for aRight
followed by aRight
because they
both have underlying Abelian
groups given by the integers.
The rewrite register is carefully pushed onto and popped off from the
left-hand side in order to cancel out theIdentity
, which itself is merely a
unifier for anAdd
or aRight
of 0.
Note that we generate a lot of garbage. For example, parsing a string of n
'+' characters will cause the peephole optimizer to allocate n instances of
the underlying domain.plus()
action, from domain.plus(1)
up to
domain.plus(n)
. An older initial-encoded version of this interpreter used
hash consing to avoid ever
building an op more than once, even loops. It appears more efficient to
generate lots of immutable garbage than to repeatedly hash inputs and search
mutable hash tables, at least for optimizing Brainfuck incrementally during
parsing.
Finally, let's look at idiom recognition. RPython lists are initial-coded, so we can dispatch based on the length of the list, and then inspect the abstract domains of each action.
def isConstAdd(bf, i): return bf[1] is anAdd and bf[2] == i def oppositeShifts(bf1, bf2): return bf1[1] is bf2[1] is aRight and bf1[2] == -bf2[2] def oppositeShifts2(bf1, bf2, bf3): return (bf1[1] is bf2[1] is bf3[1] is aRight and bf1[2] + bf2[2] + bf3[2] == 0) def loop(self, bfs): if len(bfs) == 1: bf, ad, imm = bfs[0] if ad is anAdd and imm in (1, -1): return [(domain.zero(), aZero, 0)] elif len(bfs) == 4: if (isConstAdd(bfs[0], -1) and bfs[2][1] is anAdd and oppositeShifts(bfs[1], bfs[3])): return [(domain.scalemove(bfs[1][2], bfs[2][2]), aLoop, 0)] if (isConstAdd(bfs[3], -1) and bfs[1][1] is anAdd and oppositeShifts(bfs[0], bfs[2])): return [(domain.scalemove(bfs[0][2], bfs[1][2]), aLoop, 0)] elif len(bfs) == 6: if (isConstAdd(bfs[0], -1) and bfs[2][1] is bfs[4][1] is anAdd and oppositeShifts2(bfs[1], bfs[3], bfs[5])): return [(domain.scalemove2(bfs[1][2], bfs[2][2], bfs[1][2] + bfs[3][2], bfs[4][2]), aLoop, 0)] if (isConstAdd(bfs[5], -1) and bfs[1][1] is bfs[3][1] is anAdd and oppositeShifts2(bfs[0], bfs[2], bfs[4])): return [(domain.scalemove2(bfs[0][2], bfs[1][2], bfs[0][2] + bfs[2][2], bfs[3][2]), aLoop, 0)] return [(domain.loop(stripDomain(bfs)), aLoop, 0)]
This ends the bonus question. How do we optimize an unknown semantic domain? We must maintain an abstract context which describes elements of the domain. In initial encoding, we ask an AST about itself. In final encoding, we already know everything relevant about the AST.
The careful reader will see that I didn't really answer that opening question in the JIT section. Because the JIT still ranges over the same operations as before, it can't really be slower; but why is it now faster? Because the optimizer is now slightly better in a few edge cases. It performs the same optimizations as before, but the rigor of abstract interpretation causes it to emit slightly better operations to the JIT backend.
Concretely, improving the optimizer can shorten pretty-printed programs. The Busy Beaver Gauge measures the length of programs which search for solutions to mathematical problems. After implementing and debugging the final-encoded interpreter, I found that two of my entries on the Busy Beaver Gauge for Brainfuck had become shorter by about 2%. (Most other entries are already hand-optimized according to the standard algebra and have no optimization opportunities.)
Discussion
Given that initial and final encodings are equivalent, and noting that RPython's toolchain is written to prefer initial encodings, what did we actually gain? Did we gain anything?
One obvious downside to final encoding in RPython is interpreter size. The example interpreter shown here is a rewrite of an initial-encoded interpreter which can be seen here for comparison. Final encoding adds about 20% more code in this case.
Final encoding is not necessarily more code than initial encoding, though. All AST encodings in interpreters are subject to the Expression Problem, which states that there is generally a quadratic amount of code required to implement multiple behaviors for an AST with multiple types of nodes; specifically, n behaviors for m types of nodes require n × m methods. Initial encodings improve the cost of adding new types of nodes; final encodings improve the cost of adding new behaviors. Final encoding may tend to win in large codebases for mature languages, where the language does not change often but new behaviors are added frequently and maintained for long periods.
Optimizations in final encoding require a bit of planning. The abstract-interpretation approach is solid but relies upon the monoid and its algebraic laws. In the worst case, an entire class hierarchy could be required to encode the abstraction.
It is remarkable to find a 2% improvement in residual program size merely by reimplementing an optimizer as an abstract interpreter respecting the algebraic laws. This could be the most important lesson for compiler engineers, if it happens to generalize.
Final encoding was popularized via the tagless-final movement in OCaml and
Scala, including famously in a series of tutorials by Kiselyov et
al. A "tag", in this jargon, is a
runtime identifier for an object's type or class; a tagless encoding
effectively doesn't allow isinstance()
at all. In the above presentation,
tags could be hacked in, but were not materially relevant to most steps. Tags
were required for the final evaluation step, though, and the tagless-final
insight is that certain type systems can express type-safe evaluation without
those tags. We won't go further in this direction because tags also
communicate valuable information to the JIT.
Summarizing Table
Initial Encoding | Final Encoding |
---|---|
hierarchy of classes | signature of interfaces |
class constructors | method calls |
built on the heap | built on the stack |
traversals allocate stack | traversals allocate heap |
tags are available with isinstance()
|
tags are only available through hacks |
cost of adding a new AST node: one class | cost of adding a new AST node: one method on every other class |
cost of adding a new behavior: one method on every other class | cost of adding a new behavior: one class |
Credits
Thanks to folks in #pypy
on Libera Chat: arigato for the idea, larstiq for
pushing me to write it up, and cfbolz and mattip for reviewing and finding
mistakes. The original IRC discussion leading to this blog post is available
here.
This interpreter is part of the rpypkgs suite, a Nix flake for RPython interpreters. Readers with Nix installed can run this interpreter directly from the flake:
$ nix-prefetch-url https://github.com/MG-K/pypy-tutorial-ko/raw/refs/heads/master/mandel.b $ nix run github:rpypkgs/rpypkgs#bf -- /nix/store/ngnphbap9ncvz41d0fkvdh61n7j2bg21-mandel.b
A DSL for Peephole Transformation Rules of Integer Operations in the PyPy JIT
As is probably apparent from the sequence of blog posts about the topic in the last year, I have been thinking about and working on integer optimizations in the JIT compiler a lot. This work was mainly motivated by Pydrofoil, where integer operations matter a lot more than for your typical Python program.
In this post I'll describe my most recent change, which is a new small domain specific language that I implemented to specify peephole optimizations on integer operations in the JIT. It uses pattern matching to specify how (sequences of) integer operations should be simplified and optimized. The rules are then compiled to RPython code that then becomes part of the JIT's optimization passes.
To make it less likely to introduce incorrect optimizations into the JIT, the rules are automatically proven correct with Z3 as part of the build process (for a more hands-on intro to how that works you can look at the knownbits post). In this blog post I want to motivate why I introduced the DSL and give an introduction to how it works.
Motivation
This summer, after I wrote my scripts to mine JIT traces for missed optimization opportunities, I started implementing a few of the integer peephole rewrite that the script identified. Unfortunately, doing so led to the problem that the way we express these rewrites up to now is very imperative and verbose. Here's a snippet of RPython code that shows some rewrites for integer multiplication (look at the comments to see what the different parts actually do). You don't need to understand the code in detail, but basically it's in very imperative style and there's quite a lot of boilerplate.
def optimize_INT_MUL(self, op): arg0 = get_box_replacement(op.getarg(0)) b0 = self.getintbound(arg0) arg1 = get_box_replacement(op.getarg(1)) b1 = self.getintbound(arg1) if b0.known_eq_const(1): # 1 * x == x self.make_equal_to(op, arg1) elif b1.known_eq_const(1): # x * 1 == x self.make_equal_to(op, arg0) elif b0.known_eq_const(0) or b1.known_eq_const(0): # 0 * x == x * 0 == 0 self.make_constant_int(op, 0) else: for lhs, rhs in [(arg0, arg1), (arg1, arg0)]: lh_info = self.getintbound(lhs) if lh_info.is_constant(): x = lh_info.get_constant_int() if x & (x - 1) == 0: # x * (2 ** c) == x << c new_rhs = ConstInt(highest_bit(lh_info.get_constant_int())) op = self.replace_op_with(op, rop.INT_LSHIFT, args=[rhs, new_rhs]) self.optimizer.send_extra_operation(op) return elif x == -1: # x * -1 == -x op = self.replace_op_with(op, rop.INT_NEG, args=[rhs]) self.optimizer.send_extra_operation(op) return else: # x * (1 << y) == x << y shiftop = self.optimizer.as_operation(get_box_replacement(lhs), rop.INT_LSHIFT) if shiftop is None: continue if not shiftop.getarg(0).is_constant() or shiftop.getarg(0).getint() != 1: continue shiftvar = get_box_replacement(shiftop.getarg(1)) shiftbound = self.getintbound(shiftvar) if shiftbound.known_nonnegative() and shiftbound.known_lt_const(LONG_BIT): op = self.replace_op_with( op, rop.INT_LSHIFT, args=[rhs, shiftvar]) self.optimizer.send_extra_operation(op) return return self.emit(op)
Adding more rules to these functions is very tedious and gets super confusing when the functions get bigger. In addition I am always worried about making mistakes when writing this kind of code, and there is no feedback at all about which of these rules are actually applied a lot in real programs.
Therefore I decided to write a small domain specific language with the goal of expressing these rules in a more declarative way. In the rest of the post I'll describe the DSL (most of that description is adapted from the documentation about it that I wrote).
The Peephole Rule DSL
Simple transformation rules
The rules in the DSL specify how integer operation can be transformed into cheaper other integer operations. A rule always consists of a name, a pattern, and a target. Here's a simple rule:
add_zero: int_add(x, 0) => x
The name of the rule is add_zero
. It matches operations in the trace of the
form int_add(x, 0)
, where x
will match anything and 0
will match only the
constant zero. After the =>
arrow is the target of the rewrite, i.e. what the
operation is rewritten to, in this case x
.
The rule language has a list of which of the operations are commutative, so add_zero
will also optimize int_add(0, x)
to x
.
Variables in the pattern can repeat:
sub_x_x: int_sub(x, x) => 0
This rule matches against int_sub
operations where the two arguments are the
same (either the same box, or the same constant).
Here's a rule with a more complicated pattern:
sub_add: int_sub(int_add(x, y), y) => x
This pattern matches int_sub
operations, where the first argument was
produced by an int_add
operation. In addition, one of the arguments of the
addition has to be the same as the second argument of the subtraction.
The constants MININT
, MAXINT
and LONG_BIT
(which is either 32 or 64,
depending on which platform the JIT is built for) can be used in rules, they
behave like writing numbers but allow bit-width-independent formulations:
is_true_and_minint: int_is_true(int_and(x, MININT)) => int_lt(x, 0)
It is also possible to have a pattern where some arguments needs to be a constant, without specifying which constant. Those patterns look like this:
sub_add_consts: int_sub(int_add(x, C1), C2) # incomplete # more goes here => int_sub(x, C)
Variables in the pattern that start with a C
match against constants only.
However, in this current form the rule is incomplete, because the variable C
that is being used in the target operation is not defined anywhere. We will see
how to compute it in the next section.
Computing constants and other intermediate results
Sometimes it is necessary to compute intermediate results that are used in the target operation. To do that, there can be extra assignments between the rule head and the rule target.:
sub_add_consts: int_sub(int_add(x, C1), C2) # incomplete C = C1 + C2 => int_sub(x, C)
The right hand side of such an assignment is a subset of Python syntax,
supporting arithmetic using +
, -
, *
, and certain helper functions.
However, the syntax allows you to be explicit about unsignedness for some
operations. E.g. >>u
exists for unsigned right shifts (and I plan to add
>u
, >=u
, <u
, <=u
for comparisons).
Here's an example of a rule that uses >>u
:
urshift_lshift_x_c_c: uint_rshift(int_lshift(x, C), C) mask = (-1 << C) >>u C => int_and(x, mask)
Checks
Some rewrites are only true under certain conditions. For example,
int_eq(x, 1)
can be rewritten to x
, if x
is known to store a boolean value. This can
be expressed with checks:
eq_one: int_eq(x, 1) check x.is_bool() => x
A check is followed by a boolean expression. The variables from the pattern can
be used as IntBound
instances in checks (and also in assignments) to find out
what the abstract interpretation of the JIT knows about the value of a trace variable
(IntBound
is the name of the abstract domain that the JIT uses for integers,
despite the fact that it also stores knownbits information nowadays).
Here's another example:
mul_lshift: int_mul(x, int_lshift(1, y)) check y.known_ge_const(0) and y.known_le_const(LONG_BIT) => int_lshift(x, y)
It expresses that x * (1 << y)
can be rewritten to x << y
but checks that
y
is known to be between 0
and LONG_BIT
.
Checks and assignments can be repeated and combined with each other:
mul_pow2_const: int_mul(x, C) check C > 0 and C & (C - 1) == 0 shift = highest_bit(C) => int_lshift(x, shift)
In addition to calling methods on IntBound
instances, it's also possible to
access their attributes, like in this rule:
and_x_c_in_range: int_and(x, C) check x.lower >= 0 and x.upper <= C & ~(C + 1) => x
Rule Ordering and Liveness
The generated optimizer code will give preference to applying rules that
produce a constant or a variable as a rewrite result. Only if none of those
match do rules that produce new result operations get applied. For example, the
rules sub_x_x
and sub_add
are tried before trying sub_add_consts
,
because the former two rules optimize to a constant and a variable
respectively, while the latter produces a new operation as the result.
The rule sub_add_consts
has a possible problem, which is that if the
intermediate result of the int_add
operation in the rule head is used by
some other operations, then the sub_add_consts
rule does not actually
reduce the number of operations (and might actually make things slightly worse
due to increased register pressure). However, currently it would be extremely
hard to take that kind of information into account in the optimization pass of
the JIT, so we optimistically apply the rules anyway.
Checking rule coverage
Every rewrite rule should have at least one unit test where it triggers. To ensure this, the unit test file that mainly checks integer optimizations in the JIT has an assert at the end of a test run, that every rule fired at least once.
Printing rule statistics
The JIT can print statistics about which rule fired how often in the
jit-intbounds-stats
logging category, using the PYPYLOG mechanism. For
example, to print the category to stdout at the end of program execution, run
PyPy like this:
PYPYLOG=jit-intbounds-stats:- pypy ...
The output of that will look something like this:
int_add add_reassoc_consts 2514 add_zero 107008 int_sub sub_zero 31519 sub_from_zero 523 sub_x_x 3153 sub_add_consts 159 sub_add 55 sub_sub_x_c_c 1752 sub_sub_c_x_c 0 sub_xor_x_y_y 0 sub_or_x_y_y 0 int_mul mul_zero 0 mul_one 110 mul_minus_one 0 mul_pow2_const 1456 mul_lshift 0 ...
Termination and Confluence
Right now there are unfortunately no checks that the rules actually rewrite operations towards "simpler" forms. There is no cost model, and also nothing that prevents you from writing a rule like this:
neg_complication: int_neg(x) # leads to infinite rewrites => int_mul(-1, x)
Doing this would lead to endless rewrites if there is also another rule that turns multiplication with -1 into negation.
There is also no checking for confluence (yet?), i.e. the property that all rewrites starting from the same input trace always lead to the same output trace, no matter in which order the rules are applied.
Proofs
It is very easy to write a peephole rule that is not correct in all corner cases. Therefore all the rules are proven correct with Z3 before compiled into actual JIT code, by default. When the proof fails, a (hopefully minimal) counterexample is printed. The counterexample consists of values for all the inputs that fulfil the checks, values for the intermediate expressions, and then two different values for the source and the target operations.
E.g. if we try to add the incorrect rule:
mul_is_add: int_mul(a, b) => int_add(a, b)
We get the following counterexample as output:
Could not prove correctness of rule 'mul_is_add' in line 1 counterexample given by Z3: counterexample values: a: 0 b: 1 operation int_mul(a, b) with Z3 formula a*b has counterexample result vale: 0 BUT target expression: int_add(a, b) with Z3 formula a + b has counterexample value: 1
If we add conditions, they are taken into account and the counterexample will fulfil the conditions:
mul_is_add: int_mul(a, b) check a.known_gt_const(1) and b.known_gt_const(2) => int_add(a, b)
This leads to the following counterexample:
Could not prove correctness of rule 'mul_is_add' in line 46 counterexample given by Z3: counterexample values: a: 2 b: 3 operation int_mul(a, b) with Z3 formula a*b has counterexample result vale: 6 BUT target expression: int_add(a, b) with Z3 formula a + b has counterexample value: 5
Some IntBound
methods cannot be used in Z3 proofs because their control
flow is too complex. If that is the case, they can have Z3-equivalent
formulations defined (in every case this is done, it's a potential proof hole if
the Z3 friendly reformulation and the real implementation differ from each
other, therefore extra care is required to make very sure they are equivalent).
It's possible to skip the proof of individual rules entirely by adding
SORRY_Z3
to its body (but we should try not to do that too often):
eq_different_knownbits: int_eq(x, y) SORRY_Z3 check x.known_ne(y) => 0
Checking for satisfiability
In addition to checking whether the rule yields a correct optimization, we also check whether the rule can ever apply. This ensures that there are some runtime values that would fulfil all the checks in a rule. Here's an example of a rule violating this:
never_applies: int_is_true(x) check x.known_lt_const(0) and x.known_gt_const(0) # impossible condition, always False => x
Right now the error messages if this goes wrong are not completely easy to understand. I hope to be able to improve this later:
Rule 'never_applies' cannot ever apply in line 1 Z3 did not manage to find values for variables x such that the following condition becomes True: And(x <= x_upper, x_lower <= x, If(x_upper < 0, x_lower > 0, x_upper < 0))
Implementation Notes
The implementation of the DSL is done in a relatively ad-hoc manner. It is parsed using rply, there's a small type checker that tries to find common problems in how the rules are written. Z3 is used via the Python API, like in the previous blog posts that are using it. The pattern matching RPython code is generated using an approach inspired by Luc Maranget's paper Compiling Pattern Matching to Good Decision Trees. See this blog post for an approachable introduction.
Conclusion
Now that I've described the DSL, here are the rules that are equivalent to the imperative code in the motivation section:
mul_zero: int_mul(x, 0) => 0 mul_one: int_mul(x, 1) => x mul_minus_one: int_mul(x, -1) => int_neg(x) mul_pow2_const: int_mul(x, C) check C > 0 and C & (C - 1) == 0 shift = highest_bit(C) => int_lshift(x, shift) mul_lshift: int_mul(x, int_lshift(1, y)) check y.known_ge_const(0) and y.known_le_const(LONG_BIT) => int_lshift(x, y)
The current status of the DSL is that it got merged to PyPy's main branch. I rewrote a part of the integer rewrites into the DSL, but some are still in the old imperative style (mostly for complicated reasons, the easily ported ones are all done). Since I've only been porting optimizations that had existed prior to the existence of the DSL, performance numbers of benchmarks didn't change.
There are a number of features that are still missing and some possible extensions that I plan to work on in the future:
All the integer operations that the DSL handles so far are the variants that do not check for overflow (or where overflow was proven to be impossible to happen). In regular Python code the overflow-checking variants int_add_ovf etc are much more common, but the DSL doesn't support them yet. I plan to fix this, but don't completely understand how the correctness proofs for them should be done correctly.
A related problem is that I don't understand what it means for a rewrite to be correct if some of the operations are only defined for a subset of the input values. E.g. division isn't defined if the divisor is zero. In theory, a division operation in the trace should always be preceded by a check that the divisor isn't zero. But sometimes other optimization move the check around and the connection to the division gets lost or muddled. What optimizations can we still safely perform on the division? There's lots of prior work on this question, but I still don't understand what the correct approach in our context would be.
Ordering comparisons like
int_lt
,int_le
and their unsigned variants are not ported to the DSL yet. Comparisons are an area where the JIT is not super good yet at optimizing away operations. This is a pretty big topic and I've started a project with Nico Rittinghaus to try to improve the situation a bit more generally.A more advanced direction of work would be to implement a simplified form of e-graphs (or ae-graphs). The JIT has like half of an e-graph data structure already, and we probably can't afford a full one in terms of compile time costs, but maybe we can have two thirds or something?
Acknowledgements
Thank you to Max Bernstein and Martin Berger for super helpful feedback on drafts of the post!
Guest Post: How PortaOne uses PyPy for high-performance processing, connecting over 1B of phone calls every month
The PyPy project is always happy to hear about industrial use and deployments of PyPy. For the GC bug finding task earlier this year, we collaborated with PortaOne and we're super happy that Serhii Titov, head of the QA department at PortaOne, was up to writing this guest post to describe their use and experience with the project.
What does PortaOne do?
We at PortaOne Inc. allow telecom operators to launch new services (or provide existing services more efficiently) using our VoIP platform (PortaSIP) and our real-time charging system (PortaBilling), which provides additional features for cloud PBX, such as call transfer, queues, interactive voice response (IVR) and more. At this moment our support team manages several thousand servers with our software installed in 100 countries, through which over 500 telecommunication service providers connect millions of end users every day. The unique thing about PortaOne is that we supply the source code of our product to our customers - something unheard of in the telecom world! Thus we attract "telco innovators", who use our APIs to build around the system and the source code to create unique tweaks of functionality, which produces amazing products.
At the core of PortaSIP is the middle-ware component (the proper name for it is "B2BUA", but that probably does not say much to anyone outside of experts in VoIP), which implements the actual handling of SIP calls, messages, etc. and all added features (for instance, trying to send a call via telco operators through which the cost per minute is lower). It has to be fast (since even a small delay in establishing a call is noticed by a customer), reliable (everyone hates when a call drops or cannot be completed) and yet easily expandable with new functionality. This is why we decided to use Python as opposed to C/C++ or similar programming languages, which are often used in telecom equipment.
The B2BUA component is a batch of similar Python processes that are looped
inside a
asyncore.dispatcher
wrapper. The load balancing between these Python processes is done by our
stateless SIP proxy server written in C++. All our sockets are served by this
B2BUA. We have our custom client-wrappers around pymysql
, redis
,
cassandra-driver
and requests
to communicate with external services. Some
of the Python processes use cffi
wrappers around C-code to improve their performance (examples: an Oracle DB
driver, a client to a radius server, a custom C logger).
The I/O operations that block the main thread of the Python processes are
processed in sub-threads. We have custom wrappers around threading.Thread
and also asyncore.dispatcher
. The results of such operations are returned to
the main thread.
Improving our performance with PyPy
We started with CPython and then in 2014 switched to PyPy because it was faster. Here's an exact quote from our first testing notes: "PyPy gives significant performance boost, ~50%". Nowadays, after years of changes in all the software involved, PyPy still gives us +50% boost compared to CPython.
Taking care of real time traffic for so many people around the globe is something we're really proud of. I hope the PyPy team can be proud of it as well, as the PyPy product is a part of this solution.
Finding a garbage collector bug: stage 1, the GC hooks
However our path with PyPy wasn't perfectly smooth. There were very rare cases of crashes on PyPy that we weren't able to catch. That's because to make coredump useful we needed to switch to PyPy with debug, but we cannot let it run in that mode on a production system for an extended period of time, and we did not have any STR (steps-to-reproduce) to make PyPy crash again in our lab. That's why we kept (and still keep) both interpreters installed just in case, and we would switch to CPython if we noticed it happening.
At the time of updating PyPy from 3.5 to 3.6 our QA started noticing those crashes more often, but we still had no luck with STR or collecting proper coredumps with debug symbols. Then it became even worse after our development played with the Garbage Collector's options to increase performance of our middleware component. The crashes started to affect our regular performance testing (controlled by QA manager Yevhenii Bovda). At that point it was decided that we can no longer live like that and so we started an intense investigation.
During the first stage of our investigation (following the best practice of troubleshooting) we narrowed down the issue as much as we could. So, it was not our code, it was definitely somewhere in PyPy. Eventually our SIP software engineer Yevhenii Yatchenko found out that this bug is connected with the use of our custom hooks in the GC. Yevhenii created ticket #4899 and within 2-3 days we got a fix from a member of the PyPy team, in true open-source fashion.
Finding a garbage collector bug: stage 2, the real bug
Then came stage 2. In parallel with the previous ticket, Yevhenii created
#4900 that we still see failing
with coredumps quite often, and they are not connected to GC custom hooks. In a
nutshell, it took us dozens of back and forward emails, three Zoom sessions and
four versions of a patch to solve the issue. During the last iteration we got a
new set of options to try and a new version of the patch. Surprisingly, that
helped! What a relief! So, the next logical step was to remove all debug
options and run PyPy only with the patch. Unfortunately, it started to fail
again and we came to the obvious conclusion that what will help us is not a
patch, but one of options we were testing out. At that point we found out that
PYPY_GC_MAX_PINNED=0
is a necessary and sufficient condition to solve our issue. This points to
another bug in the garbage collector, somehow related to object pinning.
Here's our current state: we have to add PYPY_GC_MAX_PINNED=0
, but we do not
face the crashes anymore.
Conclusion and next steps
Gratitude is extended to Carl for his invaluable assistance in resolving the nasty bugss, because it seems we're the only ones who suffered from the last one and we really did not want to fall back to CPython due to its performance disadvantage.
Serhii Titov, head of the QA department at PortaOne Inc.
P.S. If you are a perfectionist and at this point you have mixed feelings and you are still bothered by the question "But there might still be a bug in the GC, what about that?" - Carl has some ideas about it and he will sort it out (we will help with the testing/verification part).
PyPy v7.3.17 release
PyPy v7.3.17: release of python 2.7 and 3.10
The PyPy team is proud to release version 7.3.17 of PyPy.
This release includes a new RISC-V JIT backend, an improved REPL based on work by the CPython team, and better JIT optimizations of integer operations. Special shout-outs to Logan Chien for the RISC-V backend work, to Nico Rittinghaus for better integer optimization in the JIT, and the CPython team that has worked on the repl.
The release includes two different interpreters:
PyPy2.7, which is an interpreter supporting the syntax and the features of Python 2.7 including the stdlib for CPython 2.7.18+ (the
+
is for backported security updates)PyPy3.10, which is an interpreter supporting the syntax and the features of Python 3.10, including the stdlib for CPython 3.10.14.
The interpreters are based on much the same codebase, thus the dual release. This is a micro release, all APIs are compatible with the other 7.3 releases. It follows after 7.3.16 release on April 23, 2024.
We recommend updating. You can find links to download the releases here:
We would like to thank our donors for the continued support of the PyPy project. If PyPy is not quite good enough for your needs, we are available for direct consulting work. If PyPy is helping you out, we would love to hear about it and encourage submissions to our blog via a pull request to https://github.com/pypy/pypy.org
We would also like to thank our contributors and encourage new people to join the project. PyPy has many layers and we need help with all of them: bug fixes, PyPy and RPython documentation improvements, or general help with making RPython's JIT even better.
If you are a python library maintainer and use C-extensions, please consider making a HPy / CFFI / cppyy version of your library that would be performant on PyPy. In any case, both cibuildwheel and the multibuild system support building wheels for PyPy.
RISC-V backend for the JIT
PyPy's JIT has added support for generating 64-bit RISC-V machine code at runtime (RV64-IMAD, specifically). So far we are not releasing binaries for any RISC-V platforms, but there are instructions on how to cross-compile binaries.
REPL Improvements
The biggest user-visible change of the release is new features in the repl of PyPy3.10. CPython 3.13 has adopted and extended PyPy's pure-Python repl, adding a number of features and fixing a number or bugs in the process. We have backported and added the following features:
Prompts and tracebacks use terminal colors, as well as terminal hyperlinks for file names.
Bracketed paste enable pasting several lines of input into the terminal without auto-indentation getting in the way.
A special interactive help browser (F1), history browser (F2), explicit paste mode (F3).
Support for Ctrl-<left/right> to jump over whole words at a time.
See the CPython documentation for further details. Thanks to Łukasz Langa, Pablo Galindo Salgado and the other CPython devs involved in this work.
Better JIT optimizations of integer operations
The optimizers of PyPy's JIT have become much better at reasoning about and optimizing integer operations. This is done with a new "knownbits" abstract domain. In many programs that do bit-manipulation of integers, some of the bits of the integer variables of the program can be statically known. Here's a simple example:
With the new abstract domain, the JIT can optimize the if
-condition to
True
, because it already knows that the lowest bit of x
must be set.
This optimization applies to all Python-integers that fit into a machine word
(PyPy optimistically picks between two different representations for int
,
depending on the size of the value). Unfortunately there is very little impact
of this change on almost all Python code, because intensive bit-manipulation is
rare in Python. However, the change leads to significant performance
improvements in Pydrofoil (the RPython-based RISC-V/ARM emulators that are
automatically generated from high-level Sail specifications of the respective
ISAs, and that use the RPython JIT to improve performance).
PyPy versions and speed.pypy.org
The keen-eyed will have noticed no mention of Python version 3.9 in the releases above. Typically we will maintain only one version of Python3, but due to PyPy3.9 support on conda-forge we maintained multiple versions from the first release of PyPy3.10 in PyPy v7.3.12 (Dec 2022). Conda-forge is sunsetting its PyPy support, which means we can drop PyPy3.9. Since that was the major driver of benchmarks at https://speed.pypy.org, we revamped the site to showcase PyPy3.9, PyPy3.10, and various versions of cpython on the home page. For historical reasons, the "baseline" for comparison is still cpython 3.7.19.
We will keep the buildbots building PyPY3.9 until the end of August, these builds will still be available on the nightly builds tab of the buildbot.
What is PyPy?
PyPy is a Python interpreter, a drop-in replacement for CPython It's fast (PyPy and CPython performance comparison) due to its integrated tracing JIT compiler.
We also welcome developers of other dynamic languages to see what RPython can do for them.
We provide binary builds for:
x86 machines on most common operating systems (Linux 32/64 bits, Mac OS 64 bits, Windows 64 bits)
64-bit ARM machines running Linux (
aarch64
) and macos (macos_arm64
).
PyPy supports Windows 32-bit, Linux PPC64 big- and little-endian, Linux ARM 32 bit, RISC-V RV64IMAFD Linux, and s390x Linux but does not release binaries. Please reach out to us if you wish to sponsor binary releases for those platforms. Downstream packagers provide binary builds for debian, Fedora, conda, OpenBSD, FreeBSD, Gentoo, and more.
What else is new?
For more information about the 7.3.17 release, see the full changelog.
Please update, and continue to help us make pypy better.
Cheers, The PyPy Team
Conda-forge proposes sunsetting support for PyPy
Conda-forge has kindly been providing support for PyPy since 2019. The conda-forge team has been very patient and generous with resources, but it seems the uptake of PyPy has not justified the effort. Major packages still are not available on PyPy, others find it hard to update versions. We don't get much feedback at all about people using PyPy, and even less about PyPy on conda-forge. The conda-forge team has proposed sunsetting PyPy going forward, which means current packages would remain but no new packages would be built. If you have an opinion, you can comment on that PR, or on this blog post.
Since conda-forge supports PyPy3.9 but not PyPy3.10, we have continued releasing PyPy3.9 even though we typically support only one version of PyPy3. With the sunsetting proposal, we will not release any more updates to PyPy3.9. I opened a poll about the intention to drop PyPy3.9. If you have an opinion, please chime in.
A Knownbits Abstract Domain for the Toy Optimizer, Correctly
After Max' introduction to abstract interpretation for the toy optimizer in the last post, I want to present a more complicated abstract domain in this post. This abstract domain reasons about the individual bits of a variable in a trace. Every bit can be either "known zero", "known one" or "unknown". The abstract domain is useful for optimizing integer operations, particularly the bitwise operations. The abstract domain follows quite closely the tristate abstract domain of the eBPF verifier in the Linux Kernel, as described by the paper Sound, Precise, and Fast Abstract Interpretation with Tristate Numbers by Harishankar Vishwanathan, Matan Shachnai, Srinivas Narayana, and Santosh Nagarakatte.
The presentation in this post will still be in the context of the toy optimizer. We'll spend a significant part of the post convincing ourselves that the abstract domain transfer functions that we're writing are really correct, using both property-based testing and automated proofs (again using Z3).
PyPy has implemented and merged a more complicated version of the same abstract domain for the "real" PyPy JIT. A more thorough explanation of that real world implementation will follow.
I'd like to thank Max Bernstein and Armin Rigo for lots of great feedback on drafts of this post. The PyPy implementation was mainly done by Nico Rittinghaus and me.
Contents:
- Motivation
- The Knownbits Abstract Domain
- Transfer Functions
- Property-based Tests with Hypothesis
- When are Transfer Functions Correct? How do we test them?
- Implementing Binary Transfer Functions
- Addition and Subtraction
- Proving correctness of the transfer functions with Z3
- Cases where this style of Z3 proof doesn't work
- Making Statements about Precision
- Using the Abstract Domain in the Toy Optimizer for Generalized Constant Folding
- Using the KnownBits Domain for Conditional Peephole Rewrites
- Conclusion
Motivation
In many programs that do bit-manipulation of integers, some of the bits of the integer variables of the program can be statically known. Here's a simple example:
x = a | 1 ... if x & 1: ... else: ...
After the assignment x = a | 1
, we know that the lowest bit of x
must be 1
(the other bits are unknown) and an optimizer could remove the condition x & 1
by
constant-folding it to 1
.
Another (more complicated) example is:
assert i & 0b111 == 0 # check that i is a multiple of 8 j = i + 16 assert j & 0b111 == 0
This kind of code could e.g. happen in a CPU
emulator, where i
and j
are
integers that represent emulated pointers, and the assert
s are alignment
checks. The first assert implies that the lowest three bits of i must be 0
.
Adding 16 to such a number produces a result where the lowest three bits are
again all 0
, therefore the second assert is always true. So we would like a
compiler to remove the second assert.
Both of these will optimizations are doable with the help of the knownbits abstract domain that we'll discuss in the rest of the post.
The Knownbits Abstract Domain
An abstract value of the knownbits domain needs to be able to store, for every
bit of an integer variable in a program, whether it is known 0, known 1, or
unknown. To represent
three different states, we need 2 bits, which we will call one
and unknown
.
Here's the encoding:
one | unknown | knownbit |
---|---|---|
0 | 0 | 0 |
1 | 0 | 1 |
0 | 1 | ? |
1 | 1 | illegal |
The unknown
bit is set if we don't know the value of the bit ("?"), the one
bit is set if the bit is known to be a 1
. Since two bits are enough to encode
four different states, but we only need three, the combination of a set one
bit and a set unknown
is not allowed.
We don't just want to encode a single bit, however. Instead, we want to do this
for all the bits of an integer variable. Therefore the instances of the abstract
domain get two integer fields ones
and unknowns
, where each pair of
corresponding bits encodes the knowledge about the corresponding bit of the
integer variable in the program.
We can start implementing a Python class that works like this:
from dataclasses import dataclass @dataclass(eq=False) class KnownBits: ones : int unknowns : int def __post_init__(self): if isinstance(self.ones, int): assert self.is_well_formed() def is_well_formed(self): # a bit cannot be both 1 and unknown return self.ones & self.unknowns == 0 @staticmethod def from_constant(const : int): """ Construct a KnownBits corresponding to a constant, where all bits are known.""" return KnownBits(const, 0) def is_constant(self): """ Check if the KnownBits instance represents a constant. """ # it's a constant if there are no unknowns return self.unknowns == 0
We can also add some convenience properties. Sometimes it is easier to work with an integer where all the known bits are set, or one where the positions of all the known zeros have a set bit:
class KnownBits: ... @property def knowns(self): """ return an integer where the known bits are set. """ # the knowns are just the unknowns, inverted return ~self.unknowns @property def zeros(self): """ return an integer where the places that are known zeros have a bit set. """ # it's a 0 if it is known, but not 1 return self.knowns & ~self.ones
Also, for debugging and for writing tests we want a way to print the known bits
in a human-readable form, and also to have a way to construct a KnownBits
instance from a string. It's not important to understand the details of
__str__
or from_str
for the rest of the post, so I'm putting them into a fold:
KnownBits
from and to string conversions
class KnownBits: ... def __repr__(self): if self.is_constant(): return f"KnownBits.from_constant({self.ones})" return f"KnownBits({self.ones}, {self.unknowns})" def __str__(self): res = [] ones, unknowns = self.ones, self.unknowns # construct the string representation right to left while 1: if not ones and not unknowns: break # we leave off the leading known 0s if ones == -1 and not unknowns: # -1 has all bits set in two's complement, so the leading # bits are all 1 res.append('1') res.append("...") break if unknowns == -1: # -1 has all bits set in two's complement, so the leading bits # are all ? assert not ones res.append("?") res.append("...") break if unknowns & 1: res.append('?') elif ones & 1: res.append('1') else: res.append('0') ones >>= 1 unknowns >>= 1 if not res: res.append('0') res.reverse() return "".join(res) @staticmethod def from_str(s): """ Construct a KnownBits instance that from a string. String can start with ...1 to mean that all higher bits are 1, or ...? to mean that all higher bits are unknown. Otherwise it is assumed that the higher bits are all 0. """ ones, unknowns = 0, 0 startindex = 0 if s.startswith("...?"): unknowns = -1 startindex = 4 elif s.startswith("...1"): ones = -1 startindex = 4 for index in range(startindex, len(s)): ones <<= 1 unknowns <<= 1 c = s[index] if c == '1': ones |= 1 elif c == '?': unknowns |= 1 return KnownBits(ones, unknowns) @staticmethod def all_unknown(): """ convenience constructor for the "all bits unknown" abstract value """ return KnownBits.from_str("...?")
And here's a pytest-style unit test for str
:
def test_str(): assert str(KnownBits.from_constant(0)) == '0' assert str(KnownBits.from_constant(5)) == '101' assert str(KnownBits(5, 0b10)) == '1?1' assert str(KnownBits(~0b1111, 0b10)) == '...100?0' assert str(KnownBits(1, ~0b1)) == '...?1'
An instance of KnownBits
represents a set of integers, namely those that match
the known bits stored in the instance. We can write a method contains
that
takes a concrete int
value and returns True
if the value matches the
pattern of the known bits:
class KnownBits: ... def contains(self, value : int): """ Check whether the KnownBits instance contains the concrete integer `value`. """ # check whether value matches the bit pattern. in the places where we # know the bits, the value must agree with ones. return value & self.knowns == self.ones
and a test:
def test_contains(): k1 = KnownBits.from_str('1?1') assert k1.contains(0b111) assert k1.contains(0b101) assert not k1.contains(0b110) assert not k1.contains(0b011) k2 = KnownBits.from_str('...?1') # all odd numbers for i in range(-101, 100): assert k2.contains(i) == (i & 1)
Transfer Functions
Now that we have implemented the basics of the KnownBits
class, we need to
start implementing the transfer functions. They are for computing what we know
about the results of an operation, given the knowledge we have about the bits
of the arguments.
We'll start with a simple unary operation, invert(x)
(which is ~x
in Python
and C syntax), which flips all the bits of at integer. If we know some bits of
the arguments, we can compute the corresponding bits of the result. The unknown
bits remain unknown.
Here's the code:
class KnownBits: ... def abstract_invert(self): # self.zeros has bits set where the known 0s are in self return KnownBits(self.zeros, self.unknowns)
And a unit-test:
def test_invert(): k1 = KnownBits.from_str('01?01?01?') k2 = k1.abstract_invert() assert str(k2) == '...10?10?10?' k1 = KnownBits.from_str('...?') k2 = k1.abstract_invert() assert str(k2) == '...?'
Before we continue with further transfer functions, we'll think about
correctness of the transfer functions and build up some test infrastructure. To
test transfer functions, it's quite important to move being simple example-style
unit tests. The state-space for more complicated binary transfer functions is
extremely large and it's too easy to do something wrong in a corner case.
Therefore we'll look at property-based-test for KnownBits
next.
Property-based Tests with Hypothesis
We want to do property-based tests of KnownBits
, to try
make it less likely that we'll get a corner-case in the implementation wrong.
We'll use Hypothesis for that.
I can't give a decent introduction to Hypothesis here, but want to give a few hints about the API. Hypothesis is a way to run unit tests with randomly generated input. It provides strategies to describe the data that the test functions expects. Hypothesis provides primitive strategies (for things like integers, strings, floats, etc) and ways to build composite strategies out of the primitive ones.
To be able to write the tests, we need to generate random KnownBits
instances,
and we also want an int
instance that is a member of the KnownBits
instance.
We generate tuples of (KnownBits, int)
together, to ensure this property.
We'll ask Hypothesis to generate us a random concrete int
as the concrete
value, and then we'll also generate a second random int
to use as the
unknown
masks (i.e. which bits of the concrete int we don't know in the
KnownBits
instance). Here's a function that takes two such ints and builds the
tuple:
def build_knownbits_and_contained_number(concrete_value : int, unknowns : int): # to construct a valid KnownBits instance, we need to mask off the unknown # bits ones = concrete_value & ~unknowns return KnownBits(ones, unknowns), concrete_value
We can turn this function into a hypothesis strategy to generate input data
using the strategies.builds
function:
from hypothesis import strategies, given, settings ints = strategies.integers() random_knownbits_and_contained_number = strategies.builds( build_knownbits_and_contained_number, ints, ints )
One important special case of KnownBits
are the constants, which contain only
a single concrete value. We'll also generate some of those specifically, and
then combine the random_knownbits_and_contained_number
strategy with it:
constant_knownbits = strategies.builds( lambda value: (KnownBits.from_constant(value), value), ints ) knownbits_and_contained_number = constant_knownbits | random_knownbits_and_contained_number
Now we can write the first property-based tests, for the KnownBits.contains
method:
@given(knownbits_and_contained_number) def test_contains(t): k, n = t assert k.contains(t)
The @given
decorator is used to tell Hypothesis which strategy to use to
generate random data for the test function. Hypothesis will run the test with a
number of random examples (100 by default). If it finds an error, it will try to
minimize the example needed that demonstrates the problem, to try to make it
easier to understand what is going wrong. It also saves all failing cases into
an example database and tries them again on subsequent runs.
This test is as much a check for whether we got the strategies right as it is
for the logic in KnownBits.contains
. Here's an example output of random
concrete and abstract values that we are getting here:
110000011001101 ...?0???1 ...1011011 ...1011011 ...1001101110101000010010011111011 ...1001101110101000010010011111011 ...1001101110101000010010011111011 ...100110111010100001?010?1??1??11 1000001101111101001011010011111101000011000111011001011111101 1000001101111101001011010011111101000011000111011001011111101 1000001101111101001011010011111101000011000111011001011111101 1000001101111101001011010011111101000011000111????01?11?????1 1111100000010 1111100000010 1111100000010 ...?11111?00000?? 110110 110110 110110 ...?00?00????11??10 110110 ??0??0 ...100010111011111 ...?100?10111??111? ...1000100000110001 ...?000?00000??000? 110000001110 ...?0?0??000?00?0?0000000?00???0000?????00???000?0?00?01?000?0??1?? 110000001110 ??000000???0 1011011010000001110101001111000010001001011101010010010001000000010101010010001101110101111111010101010010101100110000011110000 1011011010000001110101001111000010001001011101010010010001000000010101010010001101110101111111010101010010101100110000011110000 ...1011010010010100 ...1011010010010100 ...1011111110110011 ...1011111110110011 101000011110110 101000011?10?1? 100101 ?00?0?
That looks suitably random, but we might want to bias our random numbers a little bit towards common error values like small constants, powers of two, etc. Like this:
INTEGER_WIDTH = 64 # some small integers ints_special = set(range(100)) # powers of two ints_special = ints_special.union(1 << i for i in range(INTEGER_WIDTH - 2)) # powers of two - 1 ints_special = ints_special.union((1 << i) - 1 for i in range(INTEGER_WIDTH - 2)) # negative versions of what we have so far ints_special = ints_special.union(-x for x in ints_special) # bit-flipped versions of what we have so far ints_special = ints_special.union(~x for x in ints_special) ints_special = list(ints_special) # sort them (because hypothesis simplifies towards earlier elements in the list) ints_special.sort(key=lambda element: (abs(element), element < 0)) ints = strategies.sampled_from(ints_special) | strategies.integers()
Now we get data like this:
1110 1110 ...10000000000000000001 ...10000??0??0000??00?1 1 ??0??0000??00?1 1 ? ...10101100 ...10101100 110000000011001010111011111111111111011110010001001100110001011 ...?0?101? 110000000011001010111011111111111111011110010001001100110001011 ??00000000??00?0?0???0??????????????0????00?000?00??00??000?0?? ...1011111111111111111111111111 ...?11?11?? ...1011111111111111111111111111 ...?0?????????????????????????? 0 ...?0?????????????????????????? 101101 101101 111111111111111111111111111111111111111111111 111111111111111111111111111111111111111111111 10111 10111 ...101100 ...1?111011?0 101000 ?001010?0 101000 ?0?000 110010 110010 ...100111 ...100111 1111011010010 1111011010010 ...1000000000000000000000000000000000000 ...1000000000000000000000000000000000000
We can also write a test that checks that the somewhat tricky logic in
__str__
and from_str
is correct, by making sure that the two functions
round-trip (ie converting a KnownBits
to a string and then back to a
KnownBits
instance produces the same abstract value).
@given(knownbits_and_contained_number) def test_hypothesis_str_roundtrips(t1): k1, n1 = t1 s = str(k1) k2 = KnownBits.from_str(s) assert k1.ones == k2.ones assert k1.unknowns == k2.unknowns
Now let's actually apply this infrastructure to test abstract_invert
.
When are Transfer Functions Correct? How do we test them?
Abstract values, i.e. instances of KnownBits
represent sets of concrete
values. We want the transfer functions to compute overapproximations of the
concrete values. So if we have an arbitrary abstract value k
, with a concrete
number n
that is a member of the abstract values (i.e.
k.contains(n) == True
) then the result of the concrete operation op(n)
must be a member of the result of the abstract operation k.abstract_op()
(i.e. k.abstract_op().contains(op(n)) == True
).
Checking the correctness/overapproximation property is a good match for
hypothesis. Here's what the test for abstract_invert
looks like:
@given(knownbits_and_contained_number) def test_hypothesis_invert(t): k1, n1 = t1 n2 = ~n1 # compute the real result k2 = k1.abstract_invert() # compute the abstract result assert k2.contains(n2) # the abstract result must contain the real result
This is the only condition needed for abstract_invert
to be correct. If
abstract_invert
fulfils this property for every combination of abstract and
concrete value then abstract_invert
is correct. Note however, that this test
does not actually check whether abstract_invert
gives us precise results. A
correct (but imprecise) implementation of abstract_invert
would simply return
a completely unknown result, regardless of what is known about the input
KnownBits
.
The "proper" CS term for this notion of correctness is called soundness. The correctness condition on the transfer functions is called a Galois connection. I won't go into any mathematical/technical details here, but wanted to at least mention the terms. I found Martin Kellogg's slides to be quite an approachable introduction to the Galois connection and how to show soundness.
Implementing Binary Transfer Functions
Now we have infrastructure in place for testing transfer functions with random
inputs. With that we can start thinking about the more complicated case, that of
binary operations. Let's start with the simpler ones, and
and or
. For and
,
we can know a 0
bit in the result if either of the input bits are known 0
;
or we can know a 1
bit in the result if both input bits are known 1
.
Otherwise the resulting bit is unknown. Let's look at all the combinations:
and input1: 000111??? input2: 01?01?01? result: 00001?0??
class KnownBits: ... def abstract_and(self, other): ones = self.ones & other.ones # known ones knowns = self.zeros | other.zeros | ones return KnownBits(ones, ~knowns)
Here's an example unit-test and a property-based test for and
:
def test_and(): # test all combinations of 0, 1, ? in one example k1 = KnownBits.from_str('01?01?01?') k2 = KnownBits.from_str('000111???') res = k1.abstract_and(k2) # should be: 0...00001?0?? assert str(res) == "1?0??" @given(knownbits_and_contained_number, knownbits_and_contained_number) def test_hypothesis_and(t1, t2): k1, n1 = t1 k2, n2 = t2 k3 = k1.abstract_and(k2) n3 = n1 & n2 assert k3.contains(n3)
To implement or
is pretty similar. The result is known 1
where either of the
inputs is 1
. The result is known 0
where both inputs are known 0
, and ?
otherwise.
or input1: 000111??? input2: 01?01?01? result: 01?111?1?
class KnownBits: ... def abstract_or(self, other): ones = self.ones | other.ones zeros = self.zeros & other.zeros knowns = ones | zeros return KnownBits(ones, ~knowns)
Here's an example unit-test and a property-based test for or
:
def test_or(): k1 = KnownBits.from_str('01?01?01?') k2 = KnownBits.from_str('000111???') res = k1.abstract_or(k2) # should be: 0...01?111?1? assert str(res) == "1?111?1?" @given(knownbits_and_contained_number, knownbits_and_contained_number) def test_hypothesis_or(t1, t2): k1, n1 = t1 k2, n2 = t2 k3 = k1.abstract_or(k2) n3 = n1 | n2 assert k3.contains(n3)
Implementing support for abstract_xor
is relatively simple, and left as an
exercise :-).
Addition and Subtraction
invert
, and
, and or
are relatively simple transfer functions to write,
because they compose over the individual bits of the integers. The arithmetic
functions add
and sub
are significantly harder, because of carries and
borrows. Coming up with the formulas for them and gaining an intuitive
understanding is quite tricky and involves carefully going through a few
examples with pen and paper. When implementing this in PyPy, Nico and I didn't
come up with the implementation ourselves, but instead took them from the
Tristate Numbers paper. Here's the code,
with example tests and hypothesis tests:
class KnownBits: ... def abstract_add(self, other): sum_ones = self.ones + other.ones sum_unknowns = self.unknowns + other.unknowns all_carries = sum_ones + sum_unknowns ones_carries = all_carries ^ sum_ones unknowns = self.unknowns | other.unknowns | ones_carries ones = sum_ones & ~unknowns return KnownBits(ones, unknowns) def abstract_sub(self, other): diff_ones = self.ones - other.ones val_borrows = (diff_ones + self.unknowns) ^ (diff_ones - other.unknowns) unknowns = self.unknowns | other.unknowns | val_borrows ones = diff_ones & ~unknowns return KnownBits(ones, unknowns) def test_add(): k1 = KnownBits.from_str('0?10?10?10') k2 = KnownBits.from_str('0???111000') res = k1.abstract_add(k2) assert str(res) == "?????01?10" def test_sub(): k1 = KnownBits.from_str('0?10?10?10') k2 = KnownBits.from_str('0???111000') res = k1.abstract_sub(k2) assert str(res) == "...?11?10" k1 = KnownBits.from_str( '...1?10?10?10') k2 = KnownBits.from_str('...10000???111000') res = k1.abstract_sub(k2) assert str(res) == "111?????11?10" @given(knownbits_and_contained_number, knownbits_and_contained_number) def test_hypothesis_add(t1, t2): k1, n1 = t1 k2, n2 = t2 k3 = k1.abstract_add(k2) n3 = n1 + n2 assert k3.contains(n3) @given(knownbits_and_contained_number, knownbits_and_contained_number) def test_hypothesis_sub(t1, t2): k1, n1 = t1 k2, n2 = t2 k3 = k1.abstract_sub(k2) n3 = n1 - n2 assert k3.contains(n3)
Now we are in a pretty good situation, and have implemented abstract versions
for a bunch of important arithmetic and binary functions. What's also surprising
is that the implementation of all of the transfer functions is quite efficient.
We didn't have to write loops over the individual bits at all, instead we found
closed form expressions using primitive operations on the underlying integers
ones
and unknowns
. This means that computing the results of abstract
operations is quite efficient, which is important when using the abstract domain
in the context of a JIT compiler.
Proving correctness of the transfer functions with Z3
As one can probably tell from my recent posts, I've been thinking about
compiler correctness a lot. Getting the transfer functions absolutely
correct is really crucial, because a bug in them would lead to miscompilation of
Python code when the abstract domain is added to the JIT. While the randomized
tests are great, it's still entirely possible for them to miss bugs. The state
space for the arguments of a binary transfer function is 3**64 * 3**64
, and if
only a small part of that contains wrong behaviour it would be really unlikely
for us to find it with random tests by chance. Therefore I was reluctant to
merge the PyPy branch that contained the new abstract domain for a long time.
To increase our confidence in the correctness of the transfer functions further, we can use Z3 to prove their correctness, which gives us much stronger guarantees (not 100%, obviously). In this subsection I will show how to do that.
Here's an attempt to do this manually in the Python repl:
>>>> import z3 >>>> solver = z3.Solver() >>>> # like last blog post, proof by failing to find counterexamples >>>> def prove(cond): assert solver.check(z3.Not(cond)) == z3.unsat >>>> >>>> # let's set up a z3 bitvector variable for an arbitrary concrete value >>>> n1 = z3.BitVec('concrete_value', 64) >>>> n1 concrete_value >>>> # due to operator overloading we can manipulate z3 formulas >>>> n2 = ~n1 >>>> n2 ~concrete_value >>>> >>>> # now z3 bitvector variables for the ones and zeros fields >>>> ones = z3.BitVec('abstract_ones', 64) >>>> unknowns = z3.BitVec('abstract_unknowns', 64) >>>> # we construct a KnownBits instance with the z3 variables >>>> k1 = KnownBits(ones, unknowns) >>>> # due to operator overloading we can call the methods on k1: >>>> k2 = k1.abstract_invert() >>>> k2.ones ~abstract_unknowns & ~abstract_ones >>>> k2.unknowns abstract_unknowns >>>> # here's the correctness condition that we want to prove: >>>> k2.contains(n2) ~concrete_value & ~abstract_unknowns == ~abstract_unknowns & ~abstract_ones >>>> # let's try >>>> prove(k2.contains(n2)) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "<stdin>", line 1, in prove AssertionError >>>> # it doesn't work! let's look at the counterexample to see why: >>>> solver.model() [abstract_unknowns = 0, abstract_ones = 0, concrete_value = 1] >>>> # we can build a KnownBits instance with the values in the >>>> # counterexample: >>>> ~1 # concrete result -2 >>>> counter_example_k1 = KnownBits(0, 0) >>>> counter_example_k1 KnownBits.from_constant(0) >>>> counter_example_k2 = counter_example_k1.abstract_invert() >>>> counter_example_k2 KnownBits.from_constant(-1) >>>> # let's check the failing condition >>>> counter_example_k2.contains(~1) False
What is the problem here? We didn't tell Z3 that n1
was supposed to be a
member of k1
. We can add this as a precondition to the solver, and then the
prove works:
>>>> solver.add(k1.contains(n1)) >>>> prove(k2.contains(n2)) # works!
This is super cool! It's really a proof about the actual implementation, because we call the implementation methods directly, and due to the operator overloading that Z3 does we can be sure that we are actually checking a formula that corresponds to the Python code. This eliminates one source of errors in formal methods.
Doing the proof manually on the Python REPL is kind of annoying though, and we also would like to make sure that the proofs are re-done when we change the code. What we would really like to do is writing the proofs as a unit-test that we can run while developing and in CI. Doing this is possible, and the unit tests that really perform proofs look pleasingly similar to the Hypothesis-based ones.
First we need to set up a bit of infrastructure:
INTEGER_WIDTH = 64 def BitVec(name): return z3.BitVec(name, INTEGER_WIDTH) def BitVecVal(val): return z3.BitVecVal(val, INTEGER_WIDTH) def z3_setup_variables(): # instantiate a solver solver = z3.Solver() # a Z3 variable for the first concrete value n1 = BitVec("n1") # a KnownBits instances that uses Z3 variables as its ones and unknowns, # representing the first abstract value k1 = KnownBits(BitVec("n1_ones"), BitVec("n1_unkowns")) # add the precondition to the solver that the concrete value n1 must be a # member of the abstract value k1 solver.add(k1.contains(n1)) # a Z3 variable for the second concrete value n2 = BitVec("n2") # a KnownBits instances for the second abstract value k2 = KnownBits(BitVec("n2_ones"), BitVec("n2_unkowns")) # add the precondition linking n2 and k2 to the solver solver.add(k2.contains(n2)) return solver, k1, n1, k2, n2 def prove(cond, solver): z3res = solver.check(z3.Not(cond)) if z3res != z3.unsat: assert z3res == z3.sat # can't be timeout, we set no timeout # make the model with the counterexample global, to make inspecting the # bug easier when running pytest --pdb global model model = solver.model() print(f"n1={model.eval(n1)}, n2={model.eval(n2)}") counter_example_k1 = KnownBits(model.eval(k1.ones).as_signed_long(), model.eval(k1.unknowns).as_signed_long()) counter_example_k2 = KnownBits(model.eval(k2.ones).as_signed_long(), model.eval(k2.unknowns).as_signed_long()) print(f"k1={counter_example_k1}, k2={counter_example_k2}") print(f"but {cond=} evaluates to {model.eval(cond)}") raise ValueError(solver.model())
And then we can write proof-unit-tests like this:
def test_z3_abstract_invert(): solver, k1, n1, _, _ = z3_setup_variables() k2 = k1.abstract_invert() n2 = ~n1 prove(k2.contains(n2), solver) def test_z3_abstract_and(): solver, k1, n1, k2, n2 = z3_setup_variables() k3 = k1.abstract_and(k2) n3 = n1 & n2 prove(k3.contains(n3), solver) def test_z3_abstract_or(): solver, k1, n1, k2, n2 = z3_setup_variables() k3 = k1.abstract_or(k2) n3 = n1 | n2 prove(k3.contains(n3), solver) def test_z3_abstract_add(): solver, k1, n1, k2, n2 = z3_setup_variables() k3 = k1.abstract_add(k2) n3 = n1 + n2 prove(k3.contains(n3), solver) def test_z3_abstract_sub(): solver, k1, n1, k2, n2 = z3_setup_variables() k3 = k1.abstract_sub(k2) n3 = n1 - n2 prove(k3.contains(n3), solver)
It's possible to write a bit more Python-metaprogramming-magic and unify the Hypothesis and Z3 tests into the same test definition.1
Cases where this style of Z3 proof doesn't work
Unfortunately the approach described in the previous section only works for a
very small number of cases. It breaks down as soon as the KnownBits
methods
that we're calling contain any if
conditions (including hidden ones like
the short-circuiting and
and or
in Python). Let's look at an example and
implement abstract_eq
. eq
is supposed to be an operation that compares two
integers and returns 0
or 1
if they are different or equal, respectively.
Implementing this in knownbits looks like this (with example and hypothesis
tests):
class KnownBits: ... def abstract_eq(self, other): # the result is a 0, 1, or ? # if they are both the same constant, they must be equal if self.is_constant() and other.is_constant() and self.ones == other.ones: return KnownBits.from_constant(1) # check whether we have known disagreeing bits, then we know the result # is 0 if self._disagrees(other): return KnownBits.from_constant(0) return KnownBits(0, 1) # an unknown boolean def _disagrees(self, other): # check whether the bits disagree in any place where both are known both_known = self.knowns & other.knowns return self.ones & both_known != other.ones & both_known def test_eq(): k1 = KnownBits.from_str('...?') k2 = KnownBits.from_str('...?') assert str(k1.abstract_eq(k2)) == '?' k1 = KnownBits.from_constant(10) assert str(k1.abstract_eq(k1)) == '1' k1 = KnownBits.from_constant(10) k2 = KnownBits.from_constant(20) assert str(k1.abstract_eq(k2)) == '0' @given(knownbits_and_contained_number, knownbits_and_contained_number) def test_hypothesis_eq(t1, t2): k1, n1 = t1 k2, n2 = t2 k3 = k1.abstract_eq(k2) assert k3.contains(int(n1 == n2))
Trying to do the proof in the same style as before breaks:
>>>> k3 = k1.abstract_eq(k2) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "knownbits.py", line 246, in abstract_eq if self._disagrees(other): File "venv/site-packages/z3/z3.py", line 381, in __bool__ raise Z3Exception("Symbolic expressions cannot be cast to concrete Boolean values.") z3.z3types.Z3Exception: Symbolic expressions cannot be cast to concrete Boolean values.
We cannot call abstract_eq
on a KnownBits
with Z3 variables as fields,
because once we hit an if
statement, the whole approach of relying on the
operator overloading breaks down. Z3 doesn't actually parse the Python code or
anything advanced like that, we rather build an expression only by running the
code and letting the Z3 formulas build up.
To still prove the correctness of abstract_eq
we need to manually transform
the control flow logic of the function into a Z3 formula that uses the z3.If
expression, using a small helper function:
def z3_cond(b, trueval=1, falseval=0): return z3.If(b, BitVecVal(trueval), BitVecVal(falseval)) def z3_abstract_eq(k1, k2): # follow the *logic* of abstract_eq, we can't call it due to the ifs in it case1cond = z3.And(k1.is_constant(), k2.is_constant(), k1.ones == k2.ones) case2cond = k1._disagrees(k2) # ones is 1 in the first case, 0 otherwise ones = z3_cond(case1cond, 1, 0) # in the first two cases, unknowns is 0, 1 otherwise unknowns = z3_cond(z3.Or(case1cond, case2cond), 0, 1) return KnownBits(ones, unknowns) def test_z3_abstract_eq_logic(): solver, k1, n1, k2, n2 = z3_setup_variables() n3 = z3_cond(n1 == n2) # concrete result k3 = z3_abstract_eq(k1, k2) prove(k3.contains(n3), solver)
This proof works. It is a lot less satisfying than the previous ones though,
because we could have done an error in the manual transcription from Python code
to Z3 formulas (there are possibly more heavy-handed approaches where we do
this transformation more automatically using e.g. the ast
module to analyze
the source code, but that's a much more complicated researchy project). To
lessen this problem somewhat we can factor out the parts of the logic that don't
have any conditions into small helper methods (like _disagrees
in this
example) and use them in the manual conversion of the code to Z3 formulas.2
The final condition that Z3 checks, btw, is this one:
If(n1 == n2, 1, 0) & ~If(Or(And(n1_unkowns == 0, n2_unkowns == 0, n1_ones == n2_ones), n1_ones & ~n1_unkowns & ~n2_unkowns != n2_ones & ~n1_unkowns & ~n2_unkowns), 0, 1) == If(And(n1_unkowns == 0, n2_unkowns == 0, n1_ones == n2_ones), 1, 0)
Making Statements about Precision
So far we have only used Z3 to prove statements about correctness, i.e. that
our abstract operations overapproximate what can happen with concrete values.
While proving this property is essential if we want to avoid miscompilation,
correctness alone is not a very strong constraint on the implementation of our
abstract transfer functions. We could simply return Knownbits.unknowns()
for
every abstract_*
method and the resulting overapproximation would be correct,
but useless in practice.
It's much harder to make statements about whether the transfer functions are maximally precise. There are two aspects of precision I want to discuss in this section, however.
The first aspect is that we would really like it if the transfer functions compute the maximally precise results for singleton sets. If all abstract arguments of an operations are constants, i.e. contain only a single concrete element, then we know that the resulting set also has only a single element. We can prove that all our transfer functions have this property:
def test_z3_prove_constant_folding(): solver, k1, n1, k2, n2 = z3_setup_variables() k3 = k1.abstract_invert() prove(z3.Implies(k1.is_constant(), k3.is_constant()), solver) k3 = k1.abstract_and(k2) prove(z3.Implies(z3.And(k1.is_constant(), k2.is_constant()), k3.is_constant()), solver) k3 = k1.abstract_or(k2) prove(z3.Implies(z3.And(k1.is_constant(), k2.is_constant()), k3.is_constant()), solver) k3 = k1.abstract_sub(k2) prove(z3.Implies(z3.And(k1.is_constant(), k2.is_constant()), k3.is_constant()), solver) k3 = z3_abstract_eq(k1, k2) prove(z3.Implies(z3.And(k1.is_constant(), k2.is_constant()), k3.is_constant()), solver)
Proving with Z3 that the transfer functions are maximally precise for non-constant arguments seems to be relatively hard. I tried a few completely rigorous approaches and failed. The paper Sound, Precise, and Fast Abstract Interpretation with Tristate Numbers contains an optimality proof for the transfer functions of addition and subtraction, so we can be certain that they are as precise as is possible.
I still want to show an approach for trying to find concrete examples of abstract values that are less precise than they could be, using a combination of Hypothesis and Z3. The idea is to use hypothesis to pick random abstract values. Then we compute the abstract result using our transfer function. Afterwards we can ask Z3 to find us an abstract result that is better than the one our transfer function produced. If Z3 finds a better abstract result, we have a concrete example of imprecision for our transfer function. Those tests aren't strict proofs, because they rely on generating random abstract values, but they can still be valuable (not for the transfer functions in this blog post, which are all optimal).
Here is what the code looks like (this is a little bit bonus content, I'll not explain the details and can only hope that the comments are somewhat helpful):
@given(random_knownbits_and_contained_number, random_knownbits_and_contained_number) @settings(deadline=None) def test_check_precision(t1, t2): k1, n1 = t1 k2, n2 = t2 # apply transfer function k3 = k1.abstract_add(k2) example_res = n1 + n2 # try to find a better version of k3 with Z3 solver = z3.Solver() solver.set("timeout", 8000) var1 = BitVec('v1') var2 = BitVec('v2') ones = BitVec('ones') unknowns = BitVec('unknowns') better_k3 = KnownBits(ones, unknowns) print(k1, k2, k3) # we're trying to find an example for a better k3, so we use check, without # negation: res = solver.check(z3.And( # better_k3 should be a valid knownbits instance better_k3.is_well_formed(), # it should be better than k3, ie there are known bits in better_k3 # that we don't have in k3 better_k3.knowns & ~k3.knowns != 0, # now encode the correctness condition for better_k3 with a ForAll: # for all concrete values var1 and var2, it must hold that if # var1 is in k1 and var2 is in k2 it follows that var1 + var2 is in # better_k3 z3.ForAll( [var1, var2], z3.Implies( z3.And(k1.contains(var1), k2.contains(var2)), better_k3.contains(var1 + var2))))) # if this query is satisfiable, we have found a better result for the # abstract_add if res == z3.sat: model = solver.model() rk3 = KnownBits(model.eval(ones).as_signed_long(), model.eval(unknowns).as_signed_long()) print("better", rk3) assert 0 if res == z3.unknown: print("timeout")
It does not actually fail for abstract_add
(nor the other abstract
functions). To see the test failing we can add some imprecision to the
implementation of abstract_add
to see Hypothesis and Z3 find examples of
values that are not optimally precise (for example by setting some bits
of unknowns
in the implementation of abstract_add
unconditionally).
Using the Abstract Domain in the Toy Optimizer for Generalized Constant Folding
Now after all this work we can finally actually use the knownbits abstract domain in the toy optimizer. The code for this follows Max' intro post about abstract interpretation quite closely.
For completeness sake, in the fold there's the basic infrastructure classes that make up the IR again (they are identical or at least extremely close to the previous toy posts).
toy infrastructure
class Value: def find(self): raise NotImplementedError("abstract") @dataclass(eq=False) class Operation(Value): name : str args : list[Value] forwarded : Optional[Value] = None def find(self) -> Value: op = self while isinstance(op, Operation): next = op.forwarded if next is None: return op op = next return op def arg(self, index): return self.args[index].find() def make_equal_to(self, value : Value): self.find().forwarded = value @dataclass(eq=False) class Constant(Value): value : object def find(self): return self class Block(list): def __getattr__(self, opname): def wraparg(arg): if not isinstance(arg, Value): arg = Constant(arg) return arg def make_op(*args): op = Operation(opname, [wraparg(arg) for arg in args]) self.append(op) return op return make_op def bb_to_str(l : Block, varprefix : str = "var"): def arg_to_str(arg : Value): if isinstance(arg, Constant): return str(arg.value) else: return varnames[arg] varnames = {} res = [] for index, op in enumerate(l): # give the operation a name used while # printing: var = f"{varprefix}{index}" varnames[op] = var arguments = ", ".join( arg_to_str(op.arg(i)) for i in range(len(op.args)) ) strop = f"{var} = {op.name}({arguments})" res.append(strop) return "\n".join(res)
Now we can write some first tests, the first one simply checking constant folding:
def test_constfold_two_ops(): bb = Block() var0 = bb.getarg(0) var1 = bb.int_add(5, 4) var2 = bb.int_add(var1, 10) var3 = bb.int_add(var2, var0) opt_bb = simplify(bb) assert bb_to_str(opt_bb, "optvar") == """\ optvar0 = getarg(0) optvar1 = int_add(19, optvar0)"""
Calling the transfer functions on constant KnownBits
produces a constant
results, as we have seen. Therefore "regular" constant folding should hopefully
be achieved by optimizing with the KnownBits
abstract domain too.
The next two tests are slightly more complicated and can't be optimized by regular constant-folding. They follow the motivating examples from the start of this blog post, a hundred years ago:
def test_constfold_via_knownbits(): bb = Block() var0 = bb.getarg(0) var1 = bb.int_or(var0, 1) var2 = bb.int_and(var1, 1) var3 = bb.dummy(var2) opt_bb = simplify(bb) assert bb_to_str(opt_bb, "optvar") == """\ optvar0 = getarg(0) optvar1 = int_or(optvar0, 1) optvar2 = dummy(1)""" def test_constfold_alignment_check(): bb = Block() var0 = bb.getarg(0) var1 = bb.int_invert(0b111) # mask off the lowest three bits, thus var2 is aligned var2 = bb.int_and(var0, var1) # add 16 to aligned quantity var3 = bb.int_add(var2, 16) # check alignment of result var4 = bb.int_and(var3, 0b111) var5 = bb.int_eq(var4, 0) # var5 should be const-folded to 1 var6 = bb.dummy(var5) opt_bb = simplify(bb) assert bb_to_str(opt_bb, "optvar") == """\ optvar0 = getarg(0) optvar1 = int_and(optvar0, -8) optvar2 = int_add(optvar1, 16) optvar3 = dummy(1)"""
Here is simplify
to make these tests pass:
def unknown_transfer_functions(*abstract_args): return KnownBits.all_unknown() def simplify(bb: Block) -> Block: abstract_values = {} # dict mapping Operation to KnownBits def knownbits_of(val : Value): if isinstance(val, Constant): return KnownBits.from_constant(val.value) return abstract_values[val] opt_bb = Block() for op in bb: # apply the transfer function on the abstract arguments name_without_prefix = op.name.removeprefix("int_") method_name = f"abstract_{name_without_prefix}" transfer_function = getattr(KnownBits, method_name, unknown_transfer_functions) abstract_args = [knownbits_of(arg.find()) for arg in op.args] abstract_res = abstract_values[op] = transfer_function(*abstract_args) # if the result is a constant, we optimize the operation away and make # it equal to the constant result if abstract_res.is_constant(): op.make_equal_to(Constant(abstract_res.ones)) continue # otherwise emit the op opt_bb.append(op) return opt_bb
The code follows the approach from the previous blog post very closely. The only difference is that we apply the transfer function first, to be able to detect whether the abstract domain can tell us that the result has to always be a constant. This code makes all three tests pass.
Using the KnownBits
Domain for Conditional Peephole Rewrites
So far we are only using the KnownBits
domain to find out that certain
operations have to produce a constant. We can also use the KnownBits
domain
to check whether certain operation rewrites are correct. Let's use one of the
examples from the Mining JIT traces for missing optimizations with
Z3
post, where Z3 found the inefficiency (x << 4) & -0xf == x << 4
in PyPy JIT
traces. We don't have shift operations, but we want to generalize this optimization
anyway. The general form of this rewrite is that under some circumstances x &
y == x
, and we can use the KnownBits
domain to detect situations where this
must be true.
To understand when x & y == x
is true, we can think about individual pairs of
bits a
and b
. If a == 0
, then a & b == 0 & b == 0 == a
. If b == 1
then a & b == a & 1 == a
. So if either a == 0
or b == 1
is true,
a & b == a
follows. And if either of these conditions is true for all the
bits of x
and y
, we can know that x & y == x
.
We can write a method on KnownBits
to check for this condition:
class KnownBits: ... def is_and_identity(self, other): """ Return True if n1 & n2 == n1 for any n1 in self and n2 in other. (or, equivalently, return True if n1 | n2 == n2)""" return self.zeros | other.ones == -1
Since my reasoning about this feels ripe for errors, let's check that our understanding is correct with Z3:
def test_prove_is_and_identity(): solver, k1, n1, k2, n2 = z3_setup_variables() prove(z3.Implies(k1.is_and_identity(k2), n1 & n2 == n1), solver)
Now let's use this in the toy optimizer. Here are two tests for this rewrite:
def test_remove_redundant_and(): bb = Block() var0 = bb.getarg(0) var1 = bb.int_invert(0b1111) # mask off the lowest four bits var2 = bb.int_and(var0, var1) # applying the same mask is not redundant var3 = bb.int_and(var2, var1) var4 = bb.dummy(var3) opt_bb = simplify(bb) assert bb_to_str(opt_bb, "optvar") == """\ optvar0 = getarg(0) optvar1 = int_and(optvar0, -16) optvar2 = dummy(optvar1)""" def test_remove_redundant_and_more_complex(): bb = Block() var0 = bb.getarg(0) var1 = bb.getarg(1) # var2 has bit pattern ???? var2 = bb.int_and(var0, 0b1111) # var3 has bit pattern ...?1111 var3 = bb.int_or(var1, 0b1111) # var4 is just var2 var4 = bb.int_and(var2, var3) var5 = bb.dummy(var4) opt_bb = simplify(bb) assert bb_to_str(opt_bb, "optvar") == """\ optvar0 = getarg(0) optvar1 = getarg(1) optvar2 = int_and(optvar0, 15) optvar3 = int_or(optvar1, 15) optvar4 = dummy(optvar2)"""
The first test could also be made to pass by implementing a reassociation
optimization that turns (x & c1) & c2
into x & (c1 & c2)
and then constant-folds the second and
. But here we want to
use KnownBits
and conditionally rewrite int_and
to its first argument. So to make the tests pass,
we can change simplify
like this:
def simplify(bb: Block) -> Block: abstract_values = {} # dict mapping Operation to KnownBits def knownbits_of(val : Value): ... opt_bb = Block() for op in bb: # apply the transfer function on the abstract arguments name_without_prefix = op.name.removeprefix("int_") method_name = f"abstract_{name_without_prefix}" transfer_function = getattr(KnownBits, method_name, unknown_transfer_functions) abstract_args = [knownbits_of(arg.find()) for arg in op.args] abstract_res = abstract_values[op] = transfer_function(*abstract_args) # if the result is a constant, we optimize the operation away and make # it equal to the constant result if abstract_res.is_constant(): op.make_equal_to(Constant(abstract_res.ones)) continue # <<<< new code # conditionally rewrite int_and(x, y) to x if op.name == "int_and": k1, k2 = abstract_args if k1.is_and_identity(k2): op.make_equal_to(op.arg(0)) continue # >>>> end changes opt_bb.append(op) return opt_bb
And with that, the new tests pass as well. A real implementation would also check the other argument order, but we leave that out for the sake of brevity.
This rewrite also generalizes the rewrites int_and(0, x) -> 0
and
int_and(-1, x) -> x
, let's add a test for those:
def test_remove_and_simple(): bb = Block() var0 = bb.getarg(0) var1 = bb.getarg(1) var2 = bb.int_and(0, var0) # == 0 var3 = bb.int_invert(var2) # == -1 var4 = bb.int_and(var1, var3) # == var1 var5 = bb.dummy(var4) opt_bb = simplify(bb) assert bb_to_str(opt_bb, "optvar") == """\ optvar0 = getarg(0) optvar1 = getarg(1) optvar2 = dummy(optvar1)"""
This test just passes. And that's it for this post!
Conclusion
In this post we've seen the implementation, testing and proofs about a 'known bits' abstract domain, as well as its use in the toy optimizer to generalize constant folding, and to implement conditional peephole rewrites.
In the next posts I'll write about the real implementation of a knownbits domain in PyPy's JIT, its combination with the existing interval abstract domain, how to deal with gaining information from conditions in the program, and some lose ends.
Sources:
- Known bits in LLVM
- Tristate numbers for known bits in Linux eBPF
- Sound, Precise, and Fast Abstract Interpretation with Tristate Numbers
- Verifying the Verifier: eBPF Range Analysis Verification
-
Bit-Twiddling: Addition with Unknown
Bits
is a super readable blog post by Dougall J. I've taken the
ones
andunknowns
naming from this post, which I find significantly clearer thanvalue
andmask
, which the Linux kernel uses. - Bits, Math and Performance(?), a fantastic blog by Harold Aptroot. There are a lot of relevant posts about known bits, range analysis etc. Harold is also the author of Haroldbot, a website that can be used for bitvector calculations, and also checks bitvector identities.
- Sharpening Constraint Programming approaches for Bit-Vector Theory
- Deriving Abstract Transfer Functions for Analyzing Embedded Software
- Synthesizing Abstract Transformers
-
There's a subtletly about the Z3 proofs that I'm sort of glossing over here. Python integers are of arbitrary width, and the
KnownBits
code is actually carefully written to work for integers of any size. This property is tested by the Hypothesis tests, which don't limit the sizes of the generated random integers. However, the Z3 proofs only check bitvectors of a fixed bitwidth of 64. There are various ways to deal with this situation. For most "real" compilers, the bitwidth of integers would be fixed anyway. Then the componentsones
andunknowns
of theKnownBits
class would use the number of bits the corresponding integer variable has, and the Z3 proofs would use the same width. This is what we do in the PyPy JIT. ↩ -
The less close connection between implementation and proof for
abstract_eq
is one of the reasons why it makes sense to do unit-testing in addition to proofs. For a more detailed explanation of why both tests and proofs are good to have, see Jeremy Siek's blog post, as well as the Knuth quote. ↩
Abstract interpretation in the Toy Optimizer
This is a cross-post from Max Bernstein from his excellent blog where he writes about programming languages, compilers, optimizations, virtual machines. He's looking for a (dynamic language runtime or compiler related) job too.
CF Bolz-Tereick wrote some excellent posts in which they introduce a small IR and optimizer and extend it with allocation removal. We also did a live stream together in which we did some more heap optimizations.
In this blog post, I'm going to write a small abstract interpreter for the Toy IR and then show how we can use it to do some simple optimizations. It assumes that you are familiar with the little IR, which I have reproduced unchanged in a GitHub Gist.
Abstract interpretation is a general framework for efficiently computing properties that must be true for all possible executions of a program. It's a widely used approach both in compiler optimizations as well as offline static analysis for finding bugs. I'm writing this post to pave the way for CF's next post on proving abstract interpreters correct for range analysis and known bits analysis inside PyPy.
Before we begin, I want to note a couple of things:
- The Toy IR is in SSA form, which means that every variable is defined exactly once. This means that abstract properties of each variable are easy to track.
- The Toy IR represents a linear trace without control flow, meaning we won't talk about meet/join or fixpoints. They only make sense if the IR has a notion of conditional branches or back edges (loops).
Alright, let's get started.
Welcome to abstract interpretation
Abstract interpretation means a couple different things to different people. There's rigorous mathematical formalism thanks to Patrick and Radhia Cousot, our favorite power couple, and there's also sketchy hand-wavy stuff like what will follow in this post. In the end, all people are trying to do is reason about program behavior without running it.
In particular, abstract interpretation is an over-approximation of the behavior of a program. Correctly implemented abstract interpreters never lie, but they might be a little bit pessimistic. This is because instead of using real values and running the program---which would produce a concrete result and some real-world behavior---we "run" the program with a parallel universe of abstract values. This abstract run gives us information about all possible runs of the program.1
Abstract values always represent sets of concrete values. Instead of literally storing a set (in the world of integers, for example, it could get pretty big...there are a lot of integers), we group them into a finite number of named subsets.2
Let's learn a little about abstract interpretation with an example program and example abstract domain. Here's the example program:
v0 = 1 v1 = 2 v2 = add(v0, v1)
And our abstract domain is "is the number positive" (where "positive" means nonnegative, but I wanted to keep the words distinct):
top / \ positive negative \ / bottom
The special top value means "I don't know" and the special bottom value means "empty set" or "unreachable". The positive and negative values represent the sets of all positive and negative numbers, respectively.
We initialize all the variables v0
, v1
, and v2
to bottom and then walk
our IR, updating our knowledge as we go.
# here v0:bottom = 1 v1:bottom = 2 v2:bottom = add(v0, v1)
In order to do that, we have to have transfer functions for each operation. For constants, the transfer function is easy: determine if the constant is positive or negative. For other operations, we have to define a function that takes the abstract values of the operands and returns the abstract value of the result.
In order to be correct, transfer functions for operations have to be compatible with the behavior of their corresponding concrete implementations. You can think of them having an implicit universal quantifier forall in front of them.
Let's step through the constants at least:
v0:positive = 1 v1:positive = 2 # here v2:bottom = add(v0, v1)
Now we need to figure out the transfer function for add
. It's kind of tricky
right now because we haven't specified our abstract domain very well. I keep
saying "numbers", but what kinds of numbers? Integers? Real numbers? Floating
point? Some kind of fixed-width bit vector (int8
, uint32
, ...) like an
actual machine "integer"?
For this post, I am going to use the mathematical definition of integer, which
means that the values are not bounded in size and therefore do not overflow.
Actual hardware memory constraints aside, this is kind of like a Python int
.
So let's look at what happens when we add two abstract numbers:
top | positive | negative | bottom | |
---|---|---|---|---|
top | top | top | top | bottom |
positive | top | positive | top | bottom |
negative | top | top | negative | bottom |
bottom | bottom | bottom | bottom | bottom |
As an example, let's try to add two numbers a
and b
, where a
is positive
and b
is negative. We don't know anything about their values other than their
signs. They could be 5
and -3
, where the result is 2
, or they could be
1
and -100
, where the result is -99
. This is why we can't say anything
about the result of this operation and have to return top.
The short of this table is that we only really know the result of an addition
if both operands are positive or both operands are negative. Thankfully, in
this example, both operands are known positive. So we can learn something about
v2
:
v0:positive = 1 v1:positive = 2 v2:positive = add(v0, v1) # here
This may not seem useful in isolation, but analyzing more complex programs even
with this simple domain may be able to remove checks such as if (v2 < 0) { ... }
.
Let's take a look at another example using an sample absval
(absolute value)
IR operation:
v0 = getarg(0) v1 = getarg(1) v2 = absval(v0) v3 = absval(v1) v4 = add(v2, v3) v5 = absval(v4)
Even though we have no constant/concrete values, we can still learn something
about the states of values throughout the program. Since we know that absval
always returns a positive number, we learn that v2
, v3
, and v4
are all
positive. This means that we can optimize out the absval
operation on v5
:
v0:top = getarg(0) v1:top = getarg(1) v2:positive = absval(v0) v3:positive = absval(v1) v4:positive = add(v2, v3) v5:positive = v4
Other interesting lattices include:
- Constants (where the middle row is pretty wide)
- Range analysis (bounds on min and max of a number)
- Known bits (using a bitvector representation of a number, which bits are always 0 or 1)
For the rest of this blog post, we are going to do a very limited version of "known bits", called parity. This analysis only tracks the least significant bit of a number, which indicates if it is even or odd.
Parity
The lattice is pretty similar to the positive/negative lattice:
top / \ even odd \ / bottom
Let's define a data structure to represent this in Python code:
class Parity: def __init__(self, name): self.name = name def __repr__(self): return self.name
And instantiate the members of the lattice:
TOP = Parity("top") EVEN = Parity("even") ODD = Parity("odd") BOTTOM = Parity("bottom")
Now let's write a forward flow analysis of a basic block using this lattice.
We'll do that by assuming that a method on Parity
is defined for each IR
operation. For example, Parity.add
, Parity.lshift
, etc.
def analyze(block: Block) -> None: parity = {v: BOTTOM for v in block} def parity_of(value): if isinstance(value, Constant): return Parity.const(value) return parity[value] for op in block: transfer = getattr(Parity, op.name) args = [parity_of(arg.find()) for arg in op.args] parity[op] = transfer(*args)
For every operation, we compute the abstract value---the parity---of the
arguments and then call the corresponding method on Parity
to get the
abstract result.
We need to special case Constant
s due to a quirk of how the Toy IR is
constructed: the constants don't appear in the instruction stream and instead
are free-floating.
Let's start by looking at the abstraction function for concrete values---constants:
class Parity: # ... @staticmethod def const(value): if value.value % 2 == 0: return EVEN else: return ODD
Seems reasonable enough. Let's pause on operations for a moment and consider an example program:
v0 = getarg(0) v1 = getarg(1) v2 = lshift(v0, 1) v3 = lshift(v1, 1) v4 = add(v2, v3) v5 = dummy(v4)
This function (which is admittedly a little contrived) takes two inputs, shifts
them left by one bit, adds the result, and then checks the least significant
bit of the addition result. It then passes that result into a dummy
function,
which you can think of as "return" or "escape".
To do some abstract interpretation on this program, we'll need to implement the
transfer functions for lshift
and add
(dummy
will just always return
TOP
). We'll start with add
. Remember that adding two even numbers returns
an even number, adding two odd numbers returns an even number, and mixing even
and odd returns an odd number.
class Parity: # ... def add(self, other): if self is BOTTOM or other is BOTTOM: return BOTTOM if self is TOP or other is TOP: return TOP if self is EVEN and other is EVEN: return EVEN if self is ODD and other is ODD: return EVEN return ODD
We also need to fill in the other cases where the operands are top or bottom. In this case, they are both "contagious"; if either operand is bottom, the result is as well. If neither is bottom but either operand is top, the result is as well.
Now let's look at lshift
. Shifting any number left by a non-zero number of
bits will always result in an even number, but we need to be careful about the
zero case! Shifting by zero doesn't change the number at all. Unfortunately,
since our lattice has no notion of zero, we have to over-approximate here:
class Parity: # ... def lshift(self, other): # self << other if other is ODD: return EVEN return TOP
This means that we will miss some opportunities to optimize, but it's a tradeoff that's just part of the game. (We could also add more elements to our lattice, but that's a topic for another day.)
Now, if we run our abstract interpretation, we'll collect some interesting
properties about the program. If we temporarily hack on the internals of
bb_to_str
, we can print out parity information alongside the IR operations:
v0:top = getarg(0) v1:top = getarg(1) v2:even = lshift(v0, 1) v3:even = lshift(v1, 1) v4:even = add(v2, v3) v5:top = dummy(v4)
This is pretty awesome, because we can see that v4
, the result of the
addition, is always even. Maybe we can do something with that information.
Optimization
One way that a program might check if a number is odd is by checking the least
significant bit. This is a common pattern in C code, where you might see code
like y = x & 1
. Let's introduce a bitand
IR operation that acts like the
&
operator in C/Python. Here is an example of use of it in our program:
v0 = getarg(0) v1 = getarg(1) v2 = lshift(v0, 1) v3 = lshift(v1, 1) v4 = add(v2, v3) v5 = bitand(v4, 1) # new! v6 = dummy(v5)
We'll hold off on implementing the transfer function for it---that's left as an exercise for the reader---and instead do something different.
Instead, we'll see if we can optimize operations of the form bitand(X, 1)
. If
we statically know the parity as a result of abstract interpretation, we can
replace the bitand
with a constant 0
or 1
.
We'll first modify the analyze
function (and rename it) to return a new
Block
containing optimized instructions:
def simplify(block: Block) -> Block: parity = {v: BOTTOM for v in block} def parity_of(value): if isinstance(value, Constant): return Parity.const(value) return parity[value] result = Block() for op in block: # TODO: Optimize op # Emit result.append(op) # Analyze transfer = getattr(Parity, op.name) args = [parity_of(arg.find()) for arg in op.args] parity[op] = transfer(*args) return result
We're approaching this the way that PyPy does things under the hood, which is all in roughly a single pass. It tries to optimize an instruction away, and if it can't, it copies it into the new block.
Now let's add in the bitand
optimization. It's mostly some gross-looking
pattern matching that checks if the right hand side of a bitwise and
operation is 1
(TODO: the left hand side, too). CF had some neat ideas on how
to make this more ergonomic, which I might save for later.3
Then, if we know the parity, optimize the bitand
into a constant.
def simplify(block: Block) -> Block: parity = {v: BOTTOM for v in block} def parity_of(value): if isinstance(value, Constant): return Parity.const(value) return parity[value] result = Block() for op in block: # Try to simplify if isinstance(op, Operation) and op.name == "bitand": arg = op.arg(0) mask = op.arg(1) if isinstance(mask, Constant) and mask.value == 1: if parity_of(arg) is EVEN: op.make_equal_to(Constant(0)) continue elif parity_of(arg) is ODD: op.make_equal_to(Constant(1)) continue # Emit result.append(op) # Analyze transfer = getattr(Parity, op.name) args = [parity_of(arg.find()) for arg in op.args] parity[op] = transfer(*args) return result
Remember: because we use union-find to rewrite instructions in the optimizer
(make_equal_to
), later uses of the same instruction get the new
optimized version "for free" (find
).
Let's see how it works on our IR:
v0 = getarg(0) v1 = getarg(1) v2 = lshift(v0, 1) v3 = lshift(v1, 1) v4 = add(v2, v3) v6 = dummy(0)
Hey, neat! bitand
disappeared and the argument to dummy
is now the constant
0
because we know the lowest bit.
Wrapping up
Hopefully you have gained a little bit of an intuitive understanding of abstract interpretation. Last year, being able to write some code made me more comfortable with the math. Now being more comfortable with the math is helping me write the code. It's nice upward spiral.
The two abstract domains we used in this post are simple and not very useful in practice but it's possible to get very far using slightly more complicated abstract domains. Common domains include: constant propagation, type inference, range analysis, effect inference, liveness, etc. For example, here is a a sample lattice for constant propagation:
It has multiple levels to indicate more and less precision. For example, you
might learn that a variable is either 1
or 2
and be able to encode that as
nonnegative
instead of just going straight to top
.
Check out some real-world abstract interpretation in open source projects:
- Known bits in LLVM
- Constant range in LLVM
- But I am told that the ranges don't form a lattice (see Interval Analysis and Machine Arithmetic: Why Signedness Ignorance Is Bliss)
- Tristate numbers for known bits in Linux eBPF
- Range analysis in Linux eBPF
- GDB prologue analysis of assembly to understand the stack and find frame pointers without using DWARF (some docs)
If you have some readable examples, please share them so I can add.
Acknowledgements
Thank you to CF Bolz-Tereick for the toy optimizer and helping edit this post!
-
In the words of abstract interpretation researchers Vincent Laviron and Francesco Logozzo in their paper Refining Abstract Interpretation-based Static Analyses with Hints (APLAS 2009):
The three main elements of an abstract interpretation are: (i) the abstract elements ("which properties am I interested in?"); (ii) the abstract transfer functions ("which is the abstract semantics of basic statements?"); and (iii) the abstract operations ("how do I combine the abstract elements?").
We don't have any of these "abstract operations" in this post because there's no control flow but you can read about them elsewhere! ↩
-
These abstract values are arranged in a lattice, which is a mathematical structure with some properties but the most important ones are that it has a top, a bottom, a partial order, a meet operation, and values can only move in one direction on the lattice.
Using abstract values from a lattice promises two things:
- The analysis will terminate
- The analysis will be correct for any run of the program, not just one sample run
-
Something about
__match_args__
and@property
... ↩
Mining JIT traces for missing optimizations with Z3
In my last post I've described how to use Z3 to find simple local peephole
optimization patterns
for the integer operations in PyPy's JIT. An example is int_and(x, 0) ->
0
. In this post I want to scale up the problem of identifying possible
optimizations to much bigger instruction sequences, also using Z3. For that, I
am starting with the JIT traces of real benchmarks, after they have been
optimized by the optimizer of PyPy's JIT. Then we can ask Z3 to find
inefficient integer operations in those traces.
Starting from the optimized traces of real programs has some big advantages over the "classical" superoptimization approach of generating and then trying all possible sequences of instructions. It avoids the combinatorial explosion that happens with the latter approach. Also, starting from the traces of benchmarks or (even better) actual programs makes sure that we actually care about the missing optimizations that are found in this way. And because the traces are analyzed after they have been optimized by PyPy's optimizer, we only get reports for missing optimizations, that the JIT isn't able to do (yet).
The techniques and experiments I describe in this post are again the result of a bunch of discussions with John Regehr at a conference a few weeks ago, as well as reading his blog posts and papers. Thanks John! Also thanks to Max Bernstein for super helpful feedback on the drafts of this blog post (and for poking me to write things in general).
High-Level Approach
The approach that I took works as follows:
- Run benchmarks or other interesting programs and then dump the IR of the JIT traces into a file. The traces have at that point been already optimized by the PyPy JIT's optimizer.
- For every trace, ignore all the operations on non-integer variables.
- Translate every integer operation into a Z3 formula.
- For every operation, use Z3 to find out whether the operation is redundant (how that is done is described below).
- If the operation is redundant, the trace is less efficient than it could have been, because the optimizer could also have removed the operation. Report the inefficiency.
- Minimize the inefficient programs by removing as many operations as possible to make the problem easier to understand.
In the post I will describe the details and show some pseudocode of the approach. I'll also make the proper code public eventually (but it needs a healthy dose of cleanups first).
Dumping PyPy Traces
PyPy will write its JIT traces into the file out
if the environment variable
PYPYLOG
is set as follows:
PYPYLOG=jit-log-opt:out pypy <program.py>
This environment variable works for PyPy, but also for other virtual machines built with RPython.
(This is really a side point for the rest of the blog post, but since the
question came up I wanted to clarify it: Operations on integers in the Python
program that the JIT is running don't all correspond 1-to-1 with the int_...
operations in the traces. The int_...
trace operations always operate on
machine words. The Python int
type supports arbitrarily large integers. PyPy
will optimistically try to lower the operations on Python integers into machine
word operations, but adds the necessary guards into the trace to make sure that
overflow outside of the range of machine words is caught. In case one of these
guards fails the interpreter switches to a big integer heap-allocated
representation.)
Encoding Traces as Z3 formulas
The last blog post already contained the code to encode the results of
individual trace operations into Z3 formulas, so we don't need to repeat that
here. To encode traces of operations we introduce a Z3 variable for every
operation in the trace and then call the z3_expression
function for every
single one of the operations in the trace.
For example, for the following trace:
[i1] i2 = uint_rshift(i1, 32) i3 = int_and(i2, 65535) i4 = uint_rshift(i1, 48) i5 = int_lshift(i4, 16) i6 = int_or(i5, i3) jump(i6, i2) # equal
We would get the Z3 formula:
z3.And(i2 == LShR(i1, 32), i3 == i2 & 65535, i4 == LShR(i1, 48), i5 == i4 << 16)
Usually we won't ask for the formula of the whole trace at once. Instead we go through the trace operation by operation and try to find inefficiencies in the current one we are looking at. Roughly like this (pseudo-)code:
def newvar(name): return z3.BitVec(name, INTEGER_WIDTH) def find_inefficiencies(trace): solver = z3.Solver() var_to_z3var = {} for input_argument in trace.inputargs: var_to_z3var[input_argument] = newz3var(input_argument) for op in trace: var_to_z3var[op] = z3resultvar = newz3var(op.resultvarname) arg0 = op.args[0] z3arg0 = var_to_z3var[arg0] if len(op.args) == 2: arg1 = op.args[1] z3arg1 = var_to_z3var[arg1] else: z3arg1 = None res, valid_if = z3_expression(op.name, z3arg0, z3arg1) # checking for inefficiencies, see the next sections ... if ...: return "inefficient", op # not inefficient, assert op into the solver and continue with the next op solver.add(z3resultvar == res) return None # no inefficiency found
Identifying constant booleans with Z3
To get started finding inefficiencies in a trace, we can first focus on boolean variables. For every operation in the trace that returns a bool we can ask Z3 to prove that this variable must be always True or always False. Most of the time, neither of these proofs will succeed. But if Z3 manages to prove one of them, we know have found an ineffiency: instead of computing the boolean result (eg by executing a comparison) the JIT's optimizer could have replaced the operation with the corresponding boolean constant.
Here's an example of an inefficiency found that way: if x < y
and y < z
are
both true, PyPy's JIT could conclude that x < z
must also
be true. However, currently the JIT cannot make that conclusion because it
only reasons about the concrete ranges (lower and upper bounds) for every
integer variable, but it has no way to remember anything about relationships
between different variables. This kind of reasoning would quite often be useful
to remove list/string bounds checks. Here's a talk about how LLVM does
this (but it might be
too heavyweight for a JIT setting).
Here are some more examples found that way:
-
x - 1 == x
is always False -
x - (x == -1) == -1
is always False. The patternx - (x == -1)
happens a lot in PyPy's hash computations: To be compatible with the CPython hashes we need to make sure that no object's hash is -1 (CPython uses -1 as an error value on the C level).
Here's pseudo-code for how to implement checking boolean operations for inefficiencies:
def find_inefficiencies(trace): ... for op in trace: ... res, valid_if = z3_expression(op.name, z3arg0, z3arg1) # check for boolean constant result if op.has_boolean_result(): if prove(solver, res == 0): return "inefficient", op, 0 if prove(solver, res == 1): return "inefficient", op, 1 # checking for other inefficiencies, see the next sections ... # not inefficient, add op to the solver and continue with the next op solver.add(z3resultvar == res) return None # no inefficiency found
Identifying redundant operations
A more interesting class of redundancy is to try to find two operations in a trace that compute the same result. We can do that by asking Z3 to prove for each pair of different operations in the trace to prove that the result is always the same. If a previous operation returns the same result, the JIT could have re-used that result instead of re-computing it, saving time. Doing this search for equivalent operations with Z3 is quadratic in the number of operations, but since traces have a maximum length it is not too bad in practice.
This is the real workhorse of my script so far, it's what finds most of the inefficiencies. Here's a few examples:
- The very first and super useful example the script found is
int_eq(b, 1) == b
ifb
is known to be a boolean (ie and integer 0 or 1). I have already implemented this optimization in the JIT. - Similarly,
int_and(b, 1) == b
for booleans. (x << 4) & -0xf == x << 4
-
((x >> 63) << 1) << 2) >> 3 == x >> 63
. In general the JIT is quite bad at optimizing repeated shifts (the infrastructure for doing better with that is already in place, so this will be a relatively easy fix). -
(x & 0xffffffff) | ((x >> 32) << 32) == x
. Having the JIT optimize this would maybe require first recognizing that(x >> 32) << 32
can be expressed as a mask:(x & 0xffffffff00000000)
, and then using(x & c1) | (x & c2) == x & (c1 | c2)
- A commonly occurring pattern is variations of this one:
((x & 1345) ^ 2048) - 2048 == x & 1345
(with different constants, of course). xor is add without carry, andx & 1345
does not have the bit2048
set. Therefore the^ 2048
is equivalent to+ 2048
, which the- 2048
cancels. More generally, ifa & b == 0
, thena + b == a | b == a ^ b
. I don't understand at all why this appears so often in the traces, but I see variations of it a lot. LLVM can optimize this, but GCC can't, thanks to Andrew Pinski for filing the bug!
And here's some implementation pseudo-code again:
def find_inefficiencies(trace): ... for op in trace: ... res, valid_if = z3_expression(op.name, z3arg0, z3arg1) # check for boolean constant result ... # searching for redundant operations for previous_op in trace: if previous_op is op: break # done, reached the current op previous_op_z3var = var_to_z3var[previous_op] if prove(solver, previous_op_z3var == res): return "inefficient", op, previous_op ... # more code here later ... # not inefficient, add op to the solver and continue with the next op solver.add(z3resultvar == res) return None # no inefficiency found
Synthesizing more complicated constants with exists-forall
To find out whether some integer operations always return a constant result, we
can't simply use the same trick as for those operations that return boolean
results, because enumerating 2⁶⁴ possible constants and checking them all
would take too long. Like in the last post, we can use z3.ForAll
to find out
whether Z3 can synthesize a constant for the result of an operation for us.
If such a constant exists, the JIT could have removed the operation,
and replaced it with the constant that Z3 provides.
Here a few examples of inefficiencies found this way:
-
(x ^ 1) ^ x == 1
(or, more generally:(x ^ y) ^ x == y
) - if
x | y == 0
, it follows thatx == 0
andy == 0
- if
x != MAXINT
, thenx + 1 > x
Implementing this is actually slightly annoying. The solver.add
calls for
non-inefficient ops add assertions to the solver, which are now confusing the
z3.ForAll
query. We could remove all assertion from the solver, then do the
ForAll
query, then add the assertions back. What I ended doing instead was
instantiating a second solver object that I'm using for the ForAll
queries,
that remains empty the whole time.
def find_inefficiencies(trace): solver = z3.Solver() empty_solver = z3.Solver() var_to_z3var = {} ... for op in trace: ... res, valid_if = z3_expression(op.name, z3arg0, z3arg1) # check for boolean constant result ... # searching for redundant operations ... # checking for constant results constvar = z3.BitVec('find_const', INTEGER_WIDTH) condition = z3.ForAll( var_to_z3var.values(), z3.Implies( *solver.assertions(), expr == constvar ) ) if empty_solver.check(condition) == z3.sat: model = empty_solver.model() const = model[constvar].as_signed_long() return "inefficient", op, const # not inefficient, add op to the solver and continue with the next op solver.add(z3resultvar == res) return None # no inefficiency found
Minimization
Analyzing an inefficiency by hand in the context of a larger trace is quite tedious. Therefore I've implemented a (super inefficient) script to try to make the examples smaller. Here's how that works:
- First throw out all the operations that occur after the inefficient operation in the trace.
- Then we remove all "dead" operations, ie operations that don't have their results used (all the operations that we can analyze with Z3 are without side effects).
- Now we try to remove every guard in the trace one by one and check afterwards, whether the resulting trace still has an inefficiency.
- We also try to replace every single operation with a new argument to the trace, to see whether the inefficiency is still present.
The minimization process is sort of inefficient and I should probably be using shrinkray or C-Reduce instead. However, it seems to work well in practice and the runtime isn't too bad.
Results
So far I am using the JIT traces of three programs: 1) Booting Linux on the Pydrofoil RISC-V emulator, 2) booting Linux on the Pydrofoil ARM emulator, and 3) running the PyPy bootstrap process on top of PyPy.
I picked these programs because most Python programs don't contain interesting amounts of integer operations, and the traces of the emulators contain a lot of them. I also used the bootstrap process because I still wanted to try a big Python program and personally care about the runtime of this program a lot.
The script identifies 94 inefficiencies in the traces, a lot of them come from repeating patterns. My next steps will be to manually inspect them all, categorize them, and implement easy optimizations identified that way. I also want a way to sort the examples by execution count in the benchmarks, to get a feeling for which of them are most important.
I didn't investigate the full set of Python benchmarks that PyPy uses yet, because I don't expect them to contain interesting amounts of integer operations, but maybe I am wrong about that? Will have to try eventually.
Conclusion
This was again much easier to do than I would have expected! Given that I had the translation of trace ops to Z3 already in place, it was a matter of about a day's of programming to use this infrastructure to find the first problems and minimizing them.
Reusing the results of existing operations or replacing operations by constants can be seen as "zero-instruction superoptimization". I'll probably be rather busy for a while to add the missing optimizations identified by my simple script. But later extensions to actually synthesize one or several operations in the attempt to optimize the traces more and find more opportunities should be possible.
Finding inefficiencies in traces with Z3 is significantly less annoying and also less error-prone than just manually inspecting traces and trying to spot optimization opportunities.
Random Notes and Sources
Again, John's blog posts:
- Let’s Work on an LLVM Superoptimizer
- Early Superoptimizer Results
- A Few Synthesizing Superoptimizer Results
- Synthesizing Constants
and papers:
I remembered recently that I had seen the approach of optimizing the traces of a tracing JIT with Z3 a long time ago, as part of the (now long dead, I think) SPUR project. There's a workshop paper from 2010 about this. SPUR was trying to use Z3 built into the actual JIT (as opposed to using Z3 only to find places where the regular optimizers could be improved). In addition to bitvectors, SPUR also used the Z3 support for arrays to model the C# heap and remove redundant stores. This is still another future extension for all the Z3 work I've been doing in the context of the PyPy JIT.