# ångstromCTF 2023 - snap circuits

## Overview

## Summary

This challenge used garbled circuits. They have a cryptographic application, which is however irrelevant for understanding this challenge and solution. Essentially we get the bits of the flag xored by random key we do not know, but we can also define some logical gates that take the flag bits as inputs, and get garbled/encrypted truth tables for these gates. Using that the encryption for these tables involves xoring with the same keystream if the input is the same, we can xor the output of two gates together to remove one layer of encryption, after which we can recover the key the flag gets xored with.

Writeup by: shalaamum

Solved by: shalaamum

This writeup was first published here.

## Circuits

The challenge involves a garbled circuit that already has wires corresponding
to the flag bits as input and where we can define `and`

and `xor`

gates on top
of that, with a gate being defined by its type (`and`

/ `xor`

) and the two
input wires. The circuit is then garbled, and we get some information from this
garbled circuit. So let us begin by having a look what wires and gates are and
how they are garbled.

### Wires

The types of wires that correspond to the input bits in the circuit are
represented by a tuple with only one component `id`

:

```
circuit = [('id',) for _ in flag_bits]
```

which are replaced in `garble`

by replacing them with a `Wire`

obtained with
`Wire.new()`

:

```
case ('id',):
wire = Wire.new()
wires.append(wire)
in_wires.append(wire)
```

So let us look at the `Wire`

class.

```
class Wire:
Label = namedtuple('Label', 'key ptr')
def __init__(self, zero, one):
self.zero = zero
self.one = one
@classmethod
def new(cls):
bit = rand_bit()
zero = Wire.Label(rand_buf(), bit)
one = Wire.Label(rand_buf(), bit ^ 1)
return cls(zero, one)
def get_label(self, value):
match value:
case 0:
return self.zero
case 1:
return self.one
case _:
raise ValueError('cannot translate value to label')
```

So a `Wire`

has two members: `zero`

and `one`

, each of which consist of a
random 16 byte `key`

and an encrypted bit; there is a randomly chosen bit `k`

so that to `zero`

we associate the encrypted bit `0 ^ k`

and to `one`

we
associate the encrypted bit `1 ^ k`

.

### Gates

Let `op`

denote `and`

or `xor`

, and let us assume we already have two `Wire`

s
`W_a`

and `W_b`

with associated bits `k_a`

and `k_b`

, as well as `key`

’s
`W_{a,0}`

, `W_{a,1}`

, `W_{b,0}`

, and `W_{b,1}`

(with e.g. `W_{a,1}`

being the
`key`

mathcomponent of the tuple given by `W_a.one`

. Then the gate with operation
`op`

and the two input wires `W_a`

and `W_b`

will be garbled as follows in
`garble`

:

```
case (op, a, b):
wire, table = garble_gate(op, wires[a], wires[b])
wires.append(wire)
tables.append(table)
```

Here is the `garble_gate`

function:

```
def garble_gate(op, wa, wb):
wc = Wire.new()
table = [[None, None], [None, None]]
for va, vb, vc in get_truth_table(op):
la = wa.get_label(va)
lb = wb.get_label(vb)
lc = wc.get_label(vc)
table[la.ptr][lb.ptr] = Cipher(la, lb).encrypt(lc)
return wc, table
```

`get_truth_table`

just returns the truth table for `op`

, i.e. `va op vb = vc`

for the 4 possibilities for `va, vb`

. Thus we obtain `table`

so that, with
`W_c`

the new `Wire`

denoted by `wc`

in the code, using notation as above.

```
table[v_a ^ k_a][v_b ^ k_b] = Cipher((W_{a,v_a}, v_a ^ k_a), (W_{b,v_b}, v_b ^ k_b)).encrypt((W_{c,v_a op v_b}, (v_a op v_b) ^ k_c))
```

To understand the right hand side, we need to look at `Cipher`

:

```
class Cipher:
def __init__(self, *labels):
key = b''.join([l.key for l in labels])
self.shake = SHAKE128.new(key)
def xor_buf(self, buf):
return strxor(buf, self.shake.read(BUF_LEN))
def xor_bit(self, bit):
return bit ^ self.shake.read(1)[0] & 1
def encrypt(self, label):
return self.xor_buf(label.key), self.xor_bit(label.ptr)
def decrypt(self, row):
return Wire.Label(self.xor_buf(row[0]), self.xor_bit(row[1]))
```

So for `Cipher((W_{a,v_a}, v_a ^ k_a), (W_{b,v_b}, v_b ^ k_b))`

we will have
that the key of the `Cipher`

is `concatenation(W_{a,v_a}, W_{b,v_b})`

. SHAKE128
then generates a keystream, and both the `key`

and the bit of the label we are
encrypting will be xored with part of the keystream. We won’t need the `key`

,
so let us ignore that and just note that there is a function `f`

that takes
`W_{a,v_a}`

, `W_{b,v_b}`

as input and produces a bit `f(W_{a,v_a}, W_{b,v_b})`

that is xored into the bit we are encrypting, which in our case was `(v_a op v_b) ^ k_c`

.

The upshot is that the table of the garbled gate has the following form, where
`?`

stands for the 16 bytes `key`

of the label, but we do not care what the
value is:

```
table[v_a ^ k_a][v_b ^ k_b] = (?, f(W_{a,v_a}, W_{b,v_b}) ^ (v_a op v_b) ^ k_c)
```

## The challenge

### Setup

`flag_bits`

contains the individual bits of the flag. Let us denote the
individual bits by `b_0`

to `b_{n-1}`

, with `n`

the total number of bits.

First the challenge adds a wires to the circuit for each bit:

```
circuit = [('id',) for _ in flag_bits]
```

Let us denote the garbled versions of these wires by `W_0`

to `W_{n-1}`

, with
notation for the various components of wires as above.

Then we get to define up to 1000 gates:

```
while True:
gate = input('gate: ')
tokens = gate.split(' ')
try:
assert len(tokens) == 3
assert tokens[0] in ('and', 'xor')
assert tokens[1].isdecimal()
assert tokens[2].isdecimal()
op = tokens[0]
a = int(tokens[1])
b = int(tokens[2])
assert 0 <= a < len(circuit)
assert 0 <= b < len(circuit)
assert len(circuit) < 1000
circuit.append((op, a, b))
except:
print('moving on...')
break
```

The circuit is then garbled.

```
in_wires, _, tables = garble(circuit)
```

### What we get

We then first get output corresponding to the flag bit wires:

```
for i, b in enumerate(flag_bits):
label = in_wires[i].get_label(b)
print(f'wire {i}: {label.key.hex()} {label.ptr}')
```

So the information we get is `W_{i,b_i}`

and `b_i ^ k_i`

, for every `0 <= i < n`

.

We then also get some output from the gates we provided:

```
print('table data:')
for table in tables:
for i in range(2):
for j in range(2):
row = table[i][j]
print(f'{row[0].hex()} {row[1]}')
```

If we defined a gate with operation `op`

that had as input wires `W_a`

and
`W_b`

we will get as output the values of `table[v_a][v_b]`

for each of the
four possible values of `(v_a, v_b)`

.
Above we arrived at `table`

being of the form

```
table[v_a ^ k_a][v_b ^ k_b] = (?, f(W_{a,v_a}, W_{b,v_b}) ^ (v_a op v_b) ^ k_c)
```

so that we obtain (by replacing `v_a`

with `v_a ^ k_a`

and `v_b`

with `v_b ^ k_b`

)

```
table[v_a][v_b] = (?, f(W_{a,v_a ^ k_a}, W_{b,v_b ^ k_b}) ^ ((v_a ^ k_a) op (v_b ^ k_b)) ^ k_c)
```

## Solution

For each gate we define we get four bits

```
f(W_{a,v_a ^ k_a}, W_{b,v_b ^ k_b}) ^ ((v_a ^ k_a) op (v_b ^ k_b)) ^ k_c
```

but unfortunately, as the arguments to `f`

depend on `v_a`

and `v_b`

, the value
of `f`

is essentially a random value that only occurs once for that gate, so we
gain no information on `((v_a ^ k_a) op (v_b ^ k_b)) ^ k_c`

.
However, if we have two gates with operations `op_1`

and `op_2`

, both with the
same inputs, and with the first gate indexed with `c`

, the second with `d`

,
then we get both

```
f(W_{a,v_a ^ k_a}, W_{b,v_b ^ k_b}) ^ ((v_a ^ k_a) op_1 (v_b ^ k_b)) ^ k_c
```

and

```
f(W_{a,v_a ^ k_a}, W_{b,v_b ^ k_b}) ^ ((v_a ^ k_a) op_1 (v_b ^ k_b)) ^ k_d
```

with the same value produced by `f`

. We can thus cancel this contribution by
xoring the two values, so that for each of the four possibilities for `(v_a, v_b)`

we obtain

```
((v_a ^ k_a) op_1 (v_b ^ k_b)) ^ k_c ^ ((v_a ^ k_a) op_2 (v_b ^ k_b)) ^ k_d
= ((v_a ^ k_a) op_1 (v_b ^ k_b)) ^ ((v_a ^ k_a) op_2 (v_b ^ k_b)) ^ (k_c ^ k_d)
```

The operations `op_1`

and `op_2`

can be chosen from `and`

and `xor`

. If we use
the same one for both, then the first two terms cancel, leaving us with `k_c ^ k_d`

. If we instead have one of the operations, say `op_1`

, be `^`

, and the
other `&`

, then we obtain

```
= ((v_a ^ k_a) ^ (v_b ^ k_b)) ^ ((v_a ^ k_a) & (v_b ^ k_b)) ^ (k_c ^ k_d)
```

To better evaluate this, let us rewrite this in terms of `v'_a = v_a ^ k_a`

,
`v'_b = v_b ^ k_b`

, and `r = k_c ^ k_d`

Then we get

```
= v'_a ^ v'_b ^ (v'_a & v'_b) ^ r
```

What kind of values can this take? We get:

```
v'_a = 0, v'_b = 0 --> r
v'_a = 1, v'_b = 0 --> 1 ^ r
v'_a = 0, v'_b = 1 --> 1 ^ r
v'_a = 1, v'_b = 1 --> 1 ^ r
```

So we see one value occurs three times, and then there is a unique value, which
occurs when `v'_a = 0`

and `v'_b = 0`

, corresponding to `v_a = k_a`

and `v_b = k_b`

.
Thus, by setting up an `and`

and a `xor`

gate with the same two input wires, we can
recover `k_a`

and `k_b`

.

Once we have used this to recover `k_i`

for all `0 <= i < n`

we can xor this
with the value `b_i ^ k_i`

we got from the wires to obtain the bits `b_i`

of
the flag.

## Attachments

Challenge and solution script.