ångstromCTF 2023 - snap circuits
Overview
Writeup by: shalaamum
Solved by: shalaamum
Summary
This challenge used garbled circuits. They have a cryptographic application, which is however irrelevant for understanding this challenge and solution. Essentially we get the bits of the flag xored by random key we do not know, but we can also define some logical gates that take the flag bits as inputs, and get garbled/encrypted truth tables for these gates. Using that the encryption for these tables involves xoring with the same keystream if the input is the same, we can xor the output of two gates together to remove one layer of encryption, after which we can recover the key the flag gets xored with.
Circuits
The challenge involves a garbled circuit that already has wires corresponding
to the flag bits as input and where we can define and
and xor
gates on top
of that, with a gate being defined by its type (and
/ xor
) and the two
input wires. The circuit is then garbled, and we get some information from this
garbled circuit. So let us begin by having a look what wires and gates are and
how they are garbled.
Wires
The types of wires that correspond to the input bits in the circuit are
represented by a tuple with only one component id
:
circuit = [('id',) for _ in flag_bits]
which are replaced in garble
by replacing them with a Wire
obtained with
Wire.new()
:
case ('id',):
wire = Wire.new()
wires.append(wire)
in_wires.append(wire)
So let us look at the Wire
class.
class Wire:
Label = namedtuple('Label', 'key ptr')
def __init__(self, zero, one):
self.zero = zero
self.one = one
@classmethod
def new(cls):
bit = rand_bit()
zero = Wire.Label(rand_buf(), bit)
one = Wire.Label(rand_buf(), bit ^ 1)
return cls(zero, one)
def get_label(self, value):
match value:
case 0:
return self.zero
case 1:
return self.one
case _:
raise ValueError('cannot translate value to label')
So a Wire
has two members: zero
and one
, each of which consist of a
random 16 byte key
and an encrypted bit; there is a randomly chosen bit k
so that to zero
we associate the encrypted bit 0 ^ k
and to one
we
associate the encrypted bit 1 ^ k
.
Gates
Let op
denote and
or xor
, and let us assume we already have two Wire
s
W_a
and W_b
with associated bits k_a
and k_b
, as well as key
’s
W_{a,0}
, W_{a,1}
, W_{b,0}
, and W_{b,1}
(with e.g. W_{a,1}
being the
key
mathcomponent of the tuple given by W_a.one
. Then the gate with operation
op
and the two input wires W_a
and W_b
will be garbled as follows in
garble
:
case (op, a, b):
wire, table = garble_gate(op, wires[a], wires[b])
wires.append(wire)
tables.append(table)
Here is the garble_gate
function:
def garble_gate(op, wa, wb):
wc = Wire.new()
table = [[None, None], [None, None]]
for va, vb, vc in get_truth_table(op):
la = wa.get_label(va)
lb = wb.get_label(vb)
lc = wc.get_label(vc)
table[la.ptr][lb.ptr] = Cipher(la, lb).encrypt(lc)
return wc, table
get_truth_table
just returns the truth table for op
, i.e. va op vb = vc
for the 4 possibilities for va, vb
. Thus we obtain table
so that, with
W_c
the new Wire
denoted by wc
in the code, using notation as above.
table[v_a ^ k_a][v_b ^ k_b] = Cipher((W_{a,v_a}, v_a ^ k_a), (W_{b,v_b}, v_b ^ k_b)).encrypt((W_{c,v_a op v_b}, (v_a op v_b) ^ k_c))
To understand the right hand side, we need to look at Cipher
:
class Cipher:
def __init__(self, *labels):
key = b''.join([l.key for l in labels])
self.shake = SHAKE128.new(key)
def xor_buf(self, buf):
return strxor(buf, self.shake.read(BUF_LEN))
def xor_bit(self, bit):
return bit ^ self.shake.read(1)[0] & 1
def encrypt(self, label):
return self.xor_buf(label.key), self.xor_bit(label.ptr)
def decrypt(self, row):
return Wire.Label(self.xor_buf(row[0]), self.xor_bit(row[1]))
So for Cipher((W_{a,v_a}, v_a ^ k_a), (W_{b,v_b}, v_b ^ k_b))
we will have
that the key of the Cipher
is concatenation(W_{a,v_a}, W_{b,v_b})
. SHAKE128
then generates a keystream, and both the key
and the bit of the label we are
encrypting will be xored with part of the keystream. We won’t need the key
,
so let us ignore that and just note that there is a function f
that takes
W_{a,v_a}
, W_{b,v_b}
as input and produces a bit f(W_{a,v_a}, W_{b,v_b})
that is xored into the bit we are encrypting, which in our case was (v_a op v_b) ^ k_c
.
The upshot is that the table of the garbled gate has the following form, where
?
stands for the 16 bytes key
of the label, but we do not care what the
value is:
table[v_a ^ k_a][v_b ^ k_b] = (?, f(W_{a,v_a}, W_{b,v_b}) ^ (v_a op v_b) ^ k_c)
The challenge
Setup
flag_bits
contains the individual bits of the flag. Let us denote the
individual bits by b_0
to b_{n-1}
, with n
the total number of bits.
First the challenge adds a wires to the circuit for each bit:
circuit = [('id',) for _ in flag_bits]
Let us denote the garbled versions of these wires by W_0
to W_{n-1}
, with
notation for the various components of wires as above.
Then we get to define up to 1000 gates:
while True:
gate = input('gate: ')
tokens = gate.split(' ')
try:
assert len(tokens) == 3
assert tokens[0] in ('and', 'xor')
assert tokens[1].isdecimal()
assert tokens[2].isdecimal()
op = tokens[0]
a = int(tokens[1])
b = int(tokens[2])
assert 0 <= a < len(circuit)
assert 0 <= b < len(circuit)
assert len(circuit) < 1000
circuit.append((op, a, b))
except:
print('moving on...')
break
The circuit is then garbled.
in_wires, _, tables = garble(circuit)
What we get
We then first get output corresponding to the flag bit wires:
for i, b in enumerate(flag_bits):
label = in_wires[i].get_label(b)
print(f'wire {i}: {label.key.hex()} {label.ptr}')
So the information we get is W_{i,b_i}
and b_i ^ k_i
, for every 0 <= i < n
.
We then also get some output from the gates we provided:
print('table data:')
for table in tables:
for i in range(2):
for j in range(2):
row = table[i][j]
print(f'{row[0].hex()} {row[1]}')
If we defined a gate with operation op
that had as input wires W_a
and
W_b
we will get as output the values of table[v_a][v_b]
for each of the
four possible values of (v_a, v_b)
.
Above we arrived at table
being of the form
table[v_a ^ k_a][v_b ^ k_b] = (?, f(W_{a,v_a}, W_{b,v_b}) ^ (v_a op v_b) ^ k_c)
so that we obtain (by replacing v_a
with v_a ^ k_a
and v_b
with v_b ^ k_b
)
table[v_a][v_b] = (?, f(W_{a,v_a ^ k_a}, W_{b,v_b ^ k_b}) ^ ((v_a ^ k_a) op (v_b ^ k_b)) ^ k_c)
Solution
For each gate we define we get four bits
f(W_{a,v_a ^ k_a}, W_{b,v_b ^ k_b}) ^ ((v_a ^ k_a) op (v_b ^ k_b)) ^ k_c
but unfortunately, as the arguments to f
depend on v_a
and v_b
, the value
of f
is essentially a random value that only occurs once for that gate, so we
gain no information on ((v_a ^ k_a) op (v_b ^ k_b)) ^ k_c
.
However, if we have two gates with operations op_1
and op_2
, both with the
same inputs, and with the first gate indexed with c
, the second with d
,
then we get both
f(W_{a,v_a ^ k_a}, W_{b,v_b ^ k_b}) ^ ((v_a ^ k_a) op_1 (v_b ^ k_b)) ^ k_c
and
f(W_{a,v_a ^ k_a}, W_{b,v_b ^ k_b}) ^ ((v_a ^ k_a) op_1 (v_b ^ k_b)) ^ k_d
with the same value produced by f
. We can thus cancel this contribution by
xoring the two values, so that for each of the four possibilities for (v_a, v_b)
we obtain
((v_a ^ k_a) op_1 (v_b ^ k_b)) ^ k_c ^ ((v_a ^ k_a) op_2 (v_b ^ k_b)) ^ k_d
= ((v_a ^ k_a) op_1 (v_b ^ k_b)) ^ ((v_a ^ k_a) op_2 (v_b ^ k_b)) ^ (k_c ^ k_d)
The operations op_1
and op_2
can be chosen from and
and xor
. If we use
the same one for both, then the first two terms cancel, leaving us with k_c ^ k_d
. If we instead have one of the operations, say op_1
, be ^
, and the
other &
, then we obtain
= ((v_a ^ k_a) ^ (v_b ^ k_b)) ^ ((v_a ^ k_a) & (v_b ^ k_b)) ^ (k_c ^ k_d)
To better evaluate this, let us rewrite this in terms of v'_a = v_a ^ k_a
,
v'_b = v_b ^ k_b
, and r = k_c ^ k_d
Then we get
= v'_a ^ v'_b ^ (v'_a & v'_b) ^ r
What kind of values can this take? We get:
v'_a = 0, v'_b = 0 --> r
v'_a = 1, v'_b = 0 --> 1 ^ r
v'_a = 0, v'_b = 1 --> 1 ^ r
v'_a = 1, v'_b = 1 --> 1 ^ r
So we see one value occurs three times, and then there is a unique value, which
occurs when v'_a = 0
and v'_b = 0
, corresponding to v_a = k_a
and v_b = k_b
.
Thus, by setting up an and
and a xor
gate with the same two input wires, we can
recover k_a
and k_b
.
Once we have used this to recover k_i
for all 0 <= i < n
we can xor this
with the value b_i ^ k_i
we got from the wires to obtain the bits b_i
of
the flag.
Attachments
Challenge and solution script.