# Proving that 1 + 1 = 10 in Rust

*"There are 10 types of people: those who understand binary, and those who don't."*

I recently read this writeup about using Rust's type system to prove that 1 + 1 = 2, and was inspired to make a version of it with a more efficient representation.
To recap (but really, read that post first if you haven't already), they used the Peano representation of the natural numbers, which is based on the *successor* function $S$:

$\begin{aligned} 1 & = S(0) \\ 2 & = S(1) = S(S(0)) \\ 3 & = S(2) = S(S(S(0))) \\ \vdots \end{aligned}$

In this system, every natural number is defined as some finite recursive applications of $S$, terminating at zero. This is simple, but also tremendously inefficient. We need a representation of length $O(n)$ just to represent $n$.

Positional systems, like the familiar decimal (base 10) system, have representations of length just $O(\log n)$.
This is *exponentially* better; we can write down numbers like 1,000,000 without pages and pages of $S(S(S(\ldots$
This extra efficiency should let us do some bigger calculations with our numbers-as-types.

## Single-bit operations

We're going to use binary rather than decimal, to simplify the case work. If we were going to represent binary numbers at runtime, we could do something like this:

```
#[derive(Clone, Copy, Debug)]
enum Bit {
Zero,
One,
}
use Bit::*;
```

The simplest piece of arithmetic we can implement is a half adder:

But we want to do all this at compile-time, not runtime, so we have to lift a few things up to the type level. Values become types:

```
#[derive(Debug, Default)]
struct Zero;
#[derive(Debug, Default)]
struct One;
```

The type-level version of functions like `half_adder`

are traits with associated types:

```
trait HalfAdder {
type Sum;
type Carry;
}
```

The "parameters" to the function are passed through the `Self`

type.
The "return values" are the associated types `Sum`

and `Carry`

.
We can "evaluate" the function by writing something like `<(One, One) as HalfAdder>::Sum`

.
Note that `(One, One)`

is a type here, not a value.

Finally, the type-level version of a `match`

block is just different `impl`

s:

```
impl HalfAdder for (Zero, Zero) {
type Sum = Zero;
type Carry = Zero;
}
impl HalfAdder for (Zero, One) {
type Sum = One;
type Carry = Zero;
}
impl HalfAdder for (One, Zero) {
type Sum = One;
type Carry = Zero;
}
impl HalfAdder for (One, One) {
type Sum = Zero;
type Carry = One;
}
```

Let's test this implementation with some type annotations:

We can implement a full adder on top of the half adder:

Or, at the type level:

## Addition

To represent natural numbers at runtime as sequences of bits, we could do this:

```
type Natural = Vec<Bit>;
```

But there's not really a nice equivalent of `Vec`

at the type level.
Instead, we can think a little more functionally and use a recursive linked list:

`#![allow(unused)] fn main() { #[derive(Clone, Copy, Debug)] enum Bit { Zero, One, } use Bit::*; #[derive(Debug)] enum Natural { Nil, Cons(Bit, Box<Natural>), } use Natural::*; fn cons(bit: Bit, tail: Natural) -> Natural { Cons(bit, Box::new(tail)) } let zero = Nil; let one = cons(One, Nil); let two = cons(Zero, cons(One, Nil)); let three = cons(One, cons(One, Nil)); let four = cons(Zero, cons(Zero, cons(One, Nil))); println!("{:?}", four); }`

The bits are ordered from least to most significant. This is backwards from our usual way of writing numbers, but it makes the implementation a bit easier. Here's a ripple-carry adder:

Cons cells are just pairs, so we can use tuple types to represent type-level lists. This is the only part that's easier to read at the type level :)

```
#[derive(Debug, Default)]
struct Nil;
type NatZero = Nil;
type NatOne = (One, Nil);
type Two = (Zero, (One, Nil));
type Three = (One, (One, Nil));
type Four = (Zero, (Zero, (One, Nil)));
```

Like before, let's make a trait for our ripple adder function:

```
/// Given natural numbers A, B, and a carry bit C, computes A + B + C.
trait RippleAdder {
type Sum;
}
type RippleSum<A, B, C> = <(A, B, C) as RippleAdder>::Sum;
```

And translate the `match`

arms into `impl`

s:

## Multiplication

Multiplication has a nice recursive definition too:

I hoped this would translate directly to the type level too:

But it doesn't work (try it).
The original post also ran into a similar error.
I don't exactly understand why, but it seems like Rust's trait solver likes it when recursive `where`

clauses involve obviously "smaller" structures than `Self`

, so the recursion terminates more easily.
It's the `((Zero, Product<A, (B1, B2)>), Nil, Zero): RippleAdder`

bound that's giving us trouble, since `(Zero, Product<A, (B1, B2)>)`

can be bigger than the inputs.

So I had to think up a different multiplication algorithm with a simpler recursive structure that would pass the type checker.
The trick I used is to do all the shifting (`B -> (Zero, B)`

) ahead of time so I don't have to put it in a `where`

clause.
It's kind of a map-reduce style algorithm.

```
/// Given (A, B), produces (B, (B, (B, (..., Nil)))) with the same length as A.
trait MulRepeat {
type Result;
}
type MulRepeated<A, B> = <(A, B) as MulRepeat>::Result;
/// Base case.
impl<B> MulRepeat for (Nil, B) {
type Result = Nil;
}
/// Recursive case.
impl<A1, A2, B> MulRepeat for ((A1, A2), B)
where
(A2, B): MulRepeat,
{
type Result = (B, MulRepeated<A2, B>);
}
/// Bit-shifts every number in a list:
///
/// (B, (B, (B, (..., Nil))))
/// -> ((Zero, B), ((Zero, B), ((Zero, B), (..., Nil))))
///
/// Used as a subroutine for MulShift.
trait ShiftAll {
type Result;
}
type AllShifted<A> = <A as ShiftAll>::Result;
/// Base case.
impl ShiftAll for Nil {
type Result = Nil;
}
/// Recursive case.
impl<A, B> ShiftAll for (A, B)
where
B: ShiftAll,
{
type Result = ((Zero, A), AllShifted<B>);
}
/// For each i, shifts the ith number i times.
///
/// (B, (B, (B, (..., Nil))))
/// -> (B, ((Zero, B), ((Zero, (Zero, B)), (..., Nil))))
trait MulShift {
type Result;
}
type MulShifted<A> = <A as MulShift>::Result;
/// Base case.
impl MulShift for Nil {
type Result = Nil;
}
/// Recursive case.
impl<A, B> MulShift for (A, B)
where
B: MulShift,
MulShifted<B>: ShiftAll,
{
type Result = (A, AllShifted<MulShifted<B>>);
}
/// Given a number A and a list of numbers B with the same length, replaces each
/// B with Nil when A has a zero bit in that position.
///
/// A: (One, (Zero, (One, (..., Nil))))
/// B: (B1, (B2, (B3, (..., Nil))))
/// -> (B1, (Nil, (B3, (..., Nil))))
trait MulMask {
type Result;
}
type MulMasked<A, B> = <(A, B) as MulMask>::Result;
/// Base case.
impl MulMask for (Nil, Nil) {
type Result = Nil;
}
/// Recursive case.
impl<A, B1, B2> MulMask for ((Zero, A), (B1, B2))
where
(A, B2): MulMask,
{
type Result = (Nil, MulMasked<A, B2>);
}
/// Recursive case.
impl<A, B1, B2> MulMask for ((One, A), (B1, B2))
where
(A, B2): MulMask,
{
type Result = (B1, MulMasked<A, B2>);
}
/// Sums up all the numbers in a list.
trait MulReduce {
type Result;
}
type MulReduced<A> = <A as MulReduce>::Result;
/// Base case.
impl MulReduce for Nil {
type Result = Nil;
}
/// Recursive case.
impl<A, B> MulReduce for (A, B)
where
B: MulReduce,
(A, MulReduced<B>, Zero): RippleAdder,
{
type Result = RippleSum<A, MulReduced<B>, Zero>;
}
/// Calculates A times B.
///
/// For example:
///
/// A: 101, B: 111
/// MulRepeated<A, B>: [111, 111, 111]
/// MulShifted<...>: [111, 1110, 11100]
/// MulMasked<A, ...>: [111, 0, 11100]
/// MulReduced<...>: 111 + 0 + 11100
/// == 100011
type Product<A, B> = MulReduced<MulMasked<A, MulShifted<MulRepeated<A, B>>>>;
```

This time, it works!

We can compute some pretty big numbers:

Who needs const generics anyway?