# m0leCon Teaser 2023 - babyPQ

## Overview

## Summary

This is a writeup for the “babyPQ” challenge by mr96 at m0leCon Teaser 2023. We were the first team to solve this challenge, with a total of 7 solves.

This challenge involved lots of probabilities; there was a Feistel cipher involving S-boxes whose output bits satisfy a linear relation with high (but not certain) probability, and one first had to figure out what the linear relation is, and then use it to leak the last round key, for which basic counting statistics were insufficient, so that one had to actually calculate probability updates based on the evidence provided by plaintext-ciphertext pairs.

Solution by: shalaamum

Writeup by: shalaamum

This writeup was first published here.

## Key generation

The code run by the challenge is shown in the snippet below

```
public_key, private_key = gen_keypair(8, 48, 0.97)
flag = open("flag.txt", "rb").read()
enc = encrypt(flag, public_key)
print(f"{public_key = }")
print(f"{enc = }")
```

and we are given the output of this as a textfile.
This means we have a public key, and the encrypted flag. The public key comes together with a private key it seems, so as a first step it might be good to recover that. So let us look at `gen_keypair`

.

```
def gen_keypair(m, n, p):
private_key = [randint(0,1) for _ in range(n-1)]
private_key.append(1)
SBoxes = []
for _ in range(6):
SBoxes.append(create_sbox(m, n, p, private_key))
return SBoxes, private_key
```

From this we see that the private key consist of 48 bits, of which the first 47 are random, and the last one is 1. The public key that we obtain consists of 6 S-boxes, which are somehow constructed using this private key.

```
def create_sbox(m, n, p, priv):
Bs = []
for _ in range(n-1):
B = random_boolean_function(m)
Bs.append(B)
last_bf = []
for i in range(2^m):
val = sum([priv[j] * int(Bs[j](i)) for j in range(n-1)]) % 2
if random() < p:
last_bf.append(val)
else:
last_bf.append(1-val)
Bs.append(BooleanFunction(last_bf))
Bs = matrix([[int(x) for x in B.truth_table()] for B in Bs]).transpose()
S = []
for B in Bs:
S.append(sum([B[i]*2^i for i in range(n)]))
return S
```

To get our bearings let us first note that each S-box here takes an 8-bit input, but produces a 48-bit output, which is slightly unusual. The way this is constructed is that the value of the first 47 bits is random, and then the last bit (which is the most significant one), is chosen to be such the scalar product of the bit vector with the private key^{1} is 0 with a probability of 97%.

### Recovering the private key

To recover the 47 unknown bits of the private key, a system of 47 linearly independent linear equations suffices. We have 6 S-boxes à 256 values, which gives us 1536 linear equations, though only about 97% of them are correct. However, if we just take a random 47 of those, then with probability `0.97^47 ≈ 24%`

*all* of them will be correct. Note that we also have an easy way of checking whether a guess for the private key is correct; for a random one the scalar product with the output bits of the S-boxes will be 1 and 0 with a probability of roughly 50% each, so the correct private key is a very noticeable outlier. Thus we have an efficient and very simple way to recover the private key: Take a random 47 outputs of the S-boxes. If the linear system that corresponds is not linearly independent, throw it away and start over. Solve the system of linear equations to obtain a guess for the private key. Check whether the private key is correct by testing the distribution obtained by taking the scalar product with outputs of the S-boxes. If the distribution is more like 50-50 than 97-3, reject this guess and start over. After a few attempts we will have obtained the correct private key.

## Encryption

Let us now have a look at encryption. We are provided with the return value of `encrypt(flag, public_key)`

, and this function is defined as follows:

```
def encrypt(pt, public_key):
tmp_key = os.urandom(8)
ad = []
for _ in range(1024):
tmp_pt = os.urandom(12)
tmp_ct = feistel_encrypt(tmp_pt, tmp_key, public_key)
ad.append((tmp_pt.hex(), tmp_ct.hex()))
pt = pad(pt, 12)
pt = [pt[i:i+12] for i in range(0, len(pt), 12)]
c = b''.join([feistel_encrypt(p, tmp_key, public_key) for p in pt]).hex()
return (c, ad)
```

So actual encryption happens with `feistel_encrypt`

, which operates on 12-byte blocks and depends on an 8-byte encryption key we do not know (yet), and we obtain both the ciphertext of encryption of the (padded) flag, as well as 1024 random plaintext-ciphertext pairs. So we will presumably need to do a known-plaintext-attack on the cipher to recover the encryption key before we can decrypt the flag. `public_key`

here consists of the S-boxes, which have a statistical anomaly in their output that we already know, so we will need to somehow use this to carry out our attack.

### Feistel cipher

As the name of the function suggests, this is a Feistel cipher.

```
def feistel_encrypt(pt, key, sboxes):
ks = expand_key(key, 8)
l, r = bytes_to_long(pt[:6]), bytes_to_long(pt[6:])
for i in range(8):
l, r = r, l ^^ f(r ^^ ks[i], sboxes)
return long_to_bytes(l) + long_to_bytes(r)
```

There are 8 rounds. In each round, the right half block gets xored with a round key, then a function `f`

is applied, and finally the left half block is xored into it.

We can already make a couple of observations here. To start, let us look at the first two rounds.

```
l_1 = r_0
r_1 = l_0 ^ f(r_0 ^ ks_0, sboxes)
l_2 = r_1 = l_0 ^ f(r_0 ^ ks_0, sboxes)
r_2 = l_1 ^ f(r_1 ^ ks_1, sboxes) = r_0 ^ f(r_1 ^ ks_1, sboxes)
```

We can see that `l_2`

and `r_2`

are given by xoring `l_0`

and `r_0`

, respectively, with the output of one evaluation of the function `f`

. Repeating this for 8 rounds, we will then get that, in particular, the following holds.

```
l_8 = l_0 ^ f(r_0 ^ ks_0, sboxes) ^ f(r_2 ^ ks_2, sboxes) ^ f(r_4 ^ ks_4, sboxes) ^ f(r_6 ^ ks_6, sboxes)
r_8 = r_0 ^ f(r_1 ^ ks_1, sboxes) ^ f(r_3 ^ ks_3, sboxes) ^ f(r_5 ^ ks_5, sboxes) ^ f(r_7 ^ ks_7, sboxes)
```

So each of the two half-blocks of the ciphertext are obtained by xoring the corresponding half-block of the plaintext with four outputs of the function `f`

. The arguments to these evaluations of `f`

are other half-blocks xored with a round key, and in both of the two expressions there is one evaluation of `f`

in which we know that half-block if we have a plaintext-ciphertext pair; for `l_8`

it is the first one, as we know `r_0`

, and for `r_8`

this is the last one, as we know `r_7=l_8`

. In the following sections we will describe how to use the second equality above to recover the last round key, it is possible to completely analogously recover the first roundkey.

### Round key expansion

The round keys are derived from the key as follows:

```
def expand_key(key, rounds):
ks = [key]
for _ in range(rounds-1):
cipher = ChaCha20.new(key = bytes([_])*32, nonce = b"\x00"*8)
ks.append(cipher.encrypt(ks[-1]))
return [bytes_to_long(k) for k in ks]
```

So the first round key is the original key, and the further round keys are obtained by encrypting them with ChaCha20 using key and nonce that we know. In particular, this means knowledge of one of the roundkeys will allow us to obtain the other roundkeys and the original key. Furthermore, we can also note that, as ChaCha20 produces a keystream the plaintext is xored with to produce the ciphertext, to know a particular byte of the original key we only need to know that specific byte in one of the roundkeys. This will be very helpful because, while the key is 8 bytes long, only the first 6 bytes of each roundkey are used. But this behavior of ChaCha20 encryption means that the extra 2 bytes don’t matter and we do not have to bruteforce over them.

### The function `f`

This function is defined as follows

```
def f(p, sboxes):
p = [(p>>(8*i)) & 0xFF for i in range(6)]
res = 0
for i in range(6):
res ^^= sboxes[i][p[i]]
return res
```

So we split the input half-block into 6 individual bytes, then apply the S-box to each of them (yielding a number 6 bytes wide), then xor all of them together. We can write this as follows

```
f(x_0, ..., x_5) = S_0(x_0) ^ ... ^ S_5(x_5)
```

Where `S_i`

is the `i`

-th S-box.

### Relation to the private key

Putting the results of the previous subsections together we can write

```
r_8
= r_0 ^ f(r_1 ^ ks_1, sboxes) ^ f(r_3 ^ ks_3, sboxes) ^ f(r_5 ^ ks_5, sboxes) ^ f(r_7 ^ ks_7, sboxes)
= r_0 ^ (XOR_{0 <= i < 6, j in {1, 3, 5}} S_i((r_j)_i ^ (ks_j)_i) )
^ (XOR_{0 <= i < 6} S_i((l_8)_i ^ (ks_7)_i) )
```

We can rewrite this as:

```
r_8 ^ r_0 = (XOR_{0 <= i < 6, j in {1, 3, 5}} S_i((r_j)_i ^ (ks_j)_i) )
^ (XOR_{0 <= i < 6} S_i((l_8)_i ^ (ks_7)_i) )
```

As mentioned before, if we have a cleartext-ciphertext pair, then we will know `r_8`

, `r_0`

, and `l_8`

. What we would like is to obtain information about the last roundkey `ks_7`

. For this we will have to use the statistical anomaly the S-boxes exhibit, in that the scalar product of their output with the private key (not to be confused with the encryption key or the round keys derived from the encryption key) is `0`

with high likelihood, rather than the expected `50%`

if they were random. As the scalar product is bilinear, we obtain the following equation by taking the scalar product of the previous one with the private key `p`

^{2}.

```
(r_8 ^ r_0) ⋅ p = (Sum_{0 <= i < 6, j in {1, 3, 5}} S_i((r_j)_i ^ (ks_j)_i) ⋅ p)
+ (Sum_{0 <= i < 6} S_i((l_8)_i ^ (ks_7)_i) ⋅ p)
```

To shorten notation, let us from now on denote the expression on the right by `R = R_p + R_l`

(with `p`

standing for “previous rounds” and `l`

standing for “last round”).
We know the left hand side of the above equation. Let us assume for the moment it is `0`

. Then we know that `R`

must be `0`

as well, which implies that an even number of the scalar products of outputs of S-boxes with `p`

involved in the sum are `1`

. This is the information that we will use to obtain information about `ks_7`

, by doing Bayesian updates on our probabilities on each observation of a plaintext-ciphertext pair.

### Calculating the probabilities

To shorten notation, let us denote the last roundkey that so far was denoted by `ks_7`

by `k`

instead, with `k_j`

being the `j`

-th byte of the last roundkey.
Initially, before observing anything, each value for each byte of the last round key is equally likely. Thus we initialize our probabilities with `P(k_j = v) = 1/256`

for v in `{0,...,255}`

.

Now we are observing a plaintext-ciphertext pair with `(r_8 ^ r_0) ⋅ p = a`

for some value `a`

in `F_2`

and with `l_8 = l`

for some `l`

. Then we *replace* the previous value of `P(k_j = v)`

with the probability conditional on the observation we made:

```
P(k_j = v | l_8=l and R=a)
```

To calculate this we can use Bayes’ theorem:

```
P(k_j=v | l_8=l and R=a) = P(l_8=l and R=a | k_j=v) * P(k_j=v) / P(l_8=l and R=a)
```

Using some simple rewriting using the definition of conditional probability we obtain:

```
= ( P(R=a | l_8=l and k_j=v) * P(l_8=l | k_j=v) * P(k_j=v) ) / ( P(R=a | l_8=l) * P(l_8=l) )
```

Using that the probability of `l_8=l`

is independent of the key, we get:

```
= P(R=a | l_8=l and k_j=v) * P(k_j=v) / P(R=a | l_8=l)
```

There are three probabilities here that we need. `P(k_j=v)`

is just our prior. The other two we need to calculate.

Let us consider `P(R=a | l_8=l)`

first. So this means of the 24 involved S-boxes, modulo `2`

exactly `a`

many had a scalar product with `p`

that was `1`

. So we can divide up the probability `P(R=a | l_8=l)`

into subcases, indexed by tuples
`(a_{1,0}, ..., a_{1,5}, a_{3,0}, ..., a_{3,5}, a_{5,0}, ..., a_{5,5}, a_{7,0}, ..., a_{7,5}) in F_2^24`

such that the sum of the components (modulo 2) is `a`

. Let us call the set of these tuples `A`

. Then the following holds.

```
P(R=a | l_8=l)
= Sum_{(a_1,0, ..., a_7,5) in A}
(Prod_{0<=j<6, i in {1,3,5}} P(S_j((r_i)_j ^ (ks_i)_j) ⋅ p = a_{i,j} | l_8=l))
* (Prod_{0<=j<6} P(S_j(l_8 ^ (ks_7)_j) ⋅ p = a_{7,j} | l_8=l) )
```

If we let `W_i`

be the number of inputs `x`

for which `S_i(x) ⋅ p = 1`

, then we obtain, by using that the first three round’s S-boxes depend on neither `l_8`

nor on the last round key:

```
= Sum_{(a_1,0, ..., a_7,5) in A}
(Prod_{0<=j<6, i in {1,3,5}} (W_j**a_{i,j}) * ((256-W_j)**{1 - a_{i,j}}) )
* (Prod_{0<=j<6} P(S_j(l_8 ^ (ks_7)_j) ⋅ p = a_{7,j} | l_8=l) )
```

For the S-boxes of the last round, we will use our prior for the values of the last round key. We have:

```
P(S_j(l_8 ^ (ks_7)_j) ⋅ p = a_{i,j} | l_8=l)
= P(S_j(l ^ k_j) ⋅ p = a_{i,j})
= Sum_{0 <= v < 256} P(k_j=v) * P(S_j(l ^ v) ⋅ p = a_{i,j})
```

Here `S_j(l ^ v)`

can be calculated directly, because `l`

and `v`

are concrete values. Hence the second factor can be calculated and will be either `1`

or `0`

. Calculating all these values, and plugging them in to the previous formula we thus have a recipe for calculating `P(R=a | l_8=l)`

.

The calculation of `P(R=a | l_8=l and k_j=v)`

is very similar, with the difference being that in one of the terms we just considered we will not sum over all 256 possible `v`

’s, weighted by their prior probability, but instead we only take the one summand `v`

.

### Putting everything together

So we begin with `P(k_j = v) = 1/256`

for v in `{0,...,255}`

, and then update these probabilities as described above after observing each of the 1024 plaintext-ciphertext pairs we are given. At the end some bytes of the key will have probability concentrated on very few possible values, for example we arrive at byte 2 having value `108`

with a probability of `99.999%`

. Some other bytes have probability spread out a bit more, so we can choose a cutoff like 95% and consider the possible values whose probabilities sum up to at least 0.95. If for each of the 6 bytes we have a probability of `95%`

of having the correct value as part of those we consider, then the probability of finding the right combination among them is `0.95**6=73.5%`

. In actuality the probability is higher in our case, because e.g. for byte 2 we have a probability way higher than only `95%`

by just considering a single value. With the cutoff at `95%`

we get a total of 172368 combinations that we have to search through to find the correct last round key. This keyspace is smaller by a factor of over `10**9`

compared to just bruteforcing the 6 bytes directly, a 48 bit bruteforce became roughly a 18 bit bruteforce, with the plaintext-ciphertext pairs leaking roughly 30 bits. The bruteforce over the remaining keyspace for the last round key can easily be done by deducing the corresponding encryption key and testing whether with that the plaintexts encrypt to the ciphertexts we are given.

## Attachments

Solve script and the original challenge as well as original output, but renamed to be importable in python.

We consider both to be elements of

`(F_2)^48`

. So the private key`p`

and an output`s`

of an S-box are given by`p = (p_1, ..., p_48)`

and`s = (s_1, ..., s_48)`

, with the components elements of`F_2`

, so either`0`

or`1`

. Their scalar product is`p ⋅ s = p_1 * s_1 + p_2 * s_2 + ... + p_48 * s_48`

, with the sum of products being taken in`F_2`

, i.e. modulo`2`

. ↩︎After taking the scalar product, we have values in

`F_2`

, where addition, being addition mod`2`

, is the same as xor, so we can rewrite the xors as sum. ↩︎