ångstromCTF 2023 - snap circuits

Share on:


Writeup by: shalaamum

Solved by: shalaamum


This challenge used garbled circuits. They have a cryptographic application, which is however irrelevant for understanding this challenge and solution. Essentially we get the bits of the flag xored by random key we do not know, but we can also define some logical gates that take the flag bits as inputs, and get garbled/encrypted truth tables for these gates. Using that the encryption for these tables involves xoring with the same keystream if the input is the same, we can xor the output of two gates together to remove one layer of encryption, after which we can recover the key the flag gets xored with.


The challenge involves a garbled circuit that already has wires corresponding to the flag bits as input and where we can define and and xor gates on top of that, with a gate being defined by its type (and / xor) and the two input wires. The circuit is then garbled, and we get some information from this garbled circuit. So let us begin by having a look what wires and gates are and how they are garbled.


The types of wires that correspond to the input bits in the circuit are represented by a tuple with only one component id:

circuit = [('id',) for _ in flag_bits]

which are replaced in garble by replacing them with a Wire obtained with Wire.new():

case ('id',):
	wire = Wire.new()

So let us look at the Wire class.

class Wire:
	Label = namedtuple('Label', 'key ptr')

	def __init__(self, zero, one):
		self.zero = zero
		self.one = one

	def new(cls):
		bit = rand_bit()
		zero = Wire.Label(rand_buf(), bit)
		one = Wire.Label(rand_buf(), bit ^ 1)
		return cls(zero, one)

	def get_label(self, value):
		match value:
			case 0:
				return self.zero
			case 1:
				return self.one
			case _:
				raise ValueError('cannot translate value to label')

So a Wire has two members: zero and one, each of which consist of a random 16 byte key and an encrypted bit; there is a randomly chosen bit k so that to zero we associate the encrypted bit 0 ^ k and to one we associate the encrypted bit 1 ^ k.


Let op denote and or xor, and let us assume we already have two Wires W_a and W_b with associated bits k_a and k_b, as well as key’s W_{a,0}, W_{a,1}, W_{b,0}, and W_{b,1} (with e.g. W_{a,1} being the key mathcomponent of the tuple given by W_a.one. Then the gate with operation op and the two input wires W_a and W_b will be garbled as follows in garble:

case (op, a, b):
	wire, table = garble_gate(op, wires[a], wires[b])

Here is the garble_gate function:

def garble_gate(op, wa, wb):
	wc = Wire.new()
	table = [[None, None], [None, None]]
	for va, vb, vc in get_truth_table(op):
		la = wa.get_label(va)
		lb = wb.get_label(vb)
		lc = wc.get_label(vc)
		table[la.ptr][lb.ptr] = Cipher(la, lb).encrypt(lc)
	return wc, table

get_truth_table just returns the truth table for op, i.e. va op vb = vc for the 4 possibilities for va, vb. Thus we obtain table so that, with W_c the new Wire denoted by wc in the code, using notation as above.

table[v_a ^ k_a][v_b ^ k_b] = Cipher((W_{a,v_a}, v_a ^ k_a), (W_{b,v_b}, v_b ^ k_b)).encrypt((W_{c,v_a op v_b}, (v_a op v_b) ^ k_c))

To understand the right hand side, we need to look at Cipher:

class Cipher:
	def __init__(self, *labels):
		key = b''.join([l.key for l in labels])
		self.shake = SHAKE128.new(key)

	def xor_buf(self, buf):
		return strxor(buf, self.shake.read(BUF_LEN))	

	def xor_bit(self, bit):
		return bit ^ self.shake.read(1)[0] & 1

	def encrypt(self, label):
		return self.xor_buf(label.key), self.xor_bit(label.ptr)

	def decrypt(self, row):
		return Wire.Label(self.xor_buf(row[0]), self.xor_bit(row[1]))

So for Cipher((W_{a,v_a}, v_a ^ k_a), (W_{b,v_b}, v_b ^ k_b)) we will have that the key of the Cipher is concatenation(W_{a,v_a}, W_{b,v_b}). SHAKE128 then generates a keystream, and both the key and the bit of the label we are encrypting will be xored with part of the keystream. We won’t need the key, so let us ignore that and just note that there is a function f that takes W_{a,v_a}, W_{b,v_b} as input and produces a bit f(W_{a,v_a}, W_{b,v_b}) that is xored into the bit we are encrypting, which in our case was (v_a op v_b) ^ k_c.

The upshot is that the table of the garbled gate has the following form, where ? stands for the 16 bytes key of the label, but we do not care what the value is:

table[v_a ^ k_a][v_b ^ k_b] = (?, f(W_{a,v_a}, W_{b,v_b}) ^ (v_a op v_b) ^ k_c)

The challenge


flag_bits contains the individual bits of the flag. Let us denote the individual bits by b_0 to b_{n-1}, with n the total number of bits.

First the challenge adds a wires to the circuit for each bit:

circuit = [('id',) for _ in flag_bits]

Let us denote the garbled versions of these wires by W_0 to W_{n-1}, with notation for the various components of wires as above.

Then we get to define up to 1000 gates:

while True:
	gate = input('gate: ')
	tokens = gate.split(' ')
		assert len(tokens) == 3
		assert tokens[0] in ('and', 'xor')
		assert tokens[1].isdecimal()
		assert tokens[2].isdecimal()
		op = tokens[0]
		a = int(tokens[1])
		b = int(tokens[2])
		assert 0 <= a < len(circuit)
		assert 0 <= b < len(circuit)
		assert len(circuit) < 1000
		circuit.append((op, a, b))
		print('moving on...')

The circuit is then garbled.

in_wires, _, tables = garble(circuit)

What we get

We then first get output corresponding to the flag bit wires:

for i, b in enumerate(flag_bits):
	label = in_wires[i].get_label(b)
	print(f'wire {i}: {label.key.hex()} {label.ptr}')

So the information we get is W_{i,b_i} and b_i ^ k_i, for every 0 <= i < n.

We then also get some output from the gates we provided:

print('table data:')
for table in tables:
	for i in range(2):
		for j in range(2):
			row = table[i][j]
			print(f'{row[0].hex()} {row[1]}')

If we defined a gate with operation op that had as input wires W_a and W_b we will get as output the values of table[v_a][v_b] for each of the four possible values of (v_a, v_b). Above we arrived at table being of the form

table[v_a ^ k_a][v_b ^ k_b] = (?, f(W_{a,v_a}, W_{b,v_b}) ^ (v_a op v_b) ^ k_c)

so that we obtain (by replacing v_a with v_a ^ k_a and v_b with v_b ^ k_b)

table[v_a][v_b] = (?, f(W_{a,v_a ^ k_a}, W_{b,v_b ^ k_b}) ^ ((v_a ^ k_a) op (v_b ^ k_b)) ^ k_c)


For each gate we define we get four bits

f(W_{a,v_a ^ k_a}, W_{b,v_b ^ k_b}) ^ ((v_a ^ k_a) op (v_b ^ k_b)) ^ k_c

but unfortunately, as the arguments to f depend on v_a and v_b, the value of f is essentially a random value that only occurs once for that gate, so we gain no information on ((v_a ^ k_a) op (v_b ^ k_b)) ^ k_c. However, if we have two gates with operations op_1 and op_2, both with the same inputs, and with the first gate indexed with c, the second with d, then we get both

f(W_{a,v_a ^ k_a}, W_{b,v_b ^ k_b}) ^ ((v_a ^ k_a) op_1 (v_b ^ k_b)) ^ k_c


f(W_{a,v_a ^ k_a}, W_{b,v_b ^ k_b}) ^ ((v_a ^ k_a) op_1 (v_b ^ k_b)) ^ k_d

with the same value produced by f. We can thus cancel this contribution by xoring the two values, so that for each of the four possibilities for (v_a, v_b) we obtain

((v_a ^ k_a) op_1 (v_b ^ k_b)) ^ k_c ^ ((v_a ^ k_a) op_2 (v_b ^ k_b)) ^ k_d
= ((v_a ^ k_a) op_1 (v_b ^ k_b)) ^ ((v_a ^ k_a) op_2 (v_b ^ k_b)) ^ (k_c ^ k_d)

The operations op_1 and op_2 can be chosen from and and xor. If we use the same one for both, then the first two terms cancel, leaving us with k_c ^ k_d. If we instead have one of the operations, say op_1, be ^, and the other &, then we obtain

= ((v_a ^ k_a) ^ (v_b ^ k_b)) ^ ((v_a ^ k_a) & (v_b ^ k_b)) ^ (k_c ^ k_d)

To better evaluate this, let us rewrite this in terms of v'_a = v_a ^ k_a, v'_b = v_b ^ k_b, and r = k_c ^ k_d Then we get

= v'_a ^ v'_b ^ (v'_a & v'_b) ^ r

What kind of values can this take? We get:

v'_a = 0, v'_b = 0  -->  r
v'_a = 1, v'_b = 0  -->  1 ^ r
v'_a = 0, v'_b = 1  -->  1 ^ r
v'_a = 1, v'_b = 1  -->  1 ^ r

So we see one value occurs three times, and then there is a unique value, which occurs when v'_a = 0 and v'_b = 0, corresponding to v_a = k_a and v_b = k_b. Thus, by setting up an and and a xor gate with the same two input wires, we can recover k_a and k_b.

Once we have used this to recover k_i for all 0 <= i < n we can xor this with the value b_i ^ k_i we got from the wires to obtain the bits b_i of the flag.


Challenge and solution script.