m0leCon Teaser 2023 - babyPQ

Share on:

Overview

Summary

This is a writeup for the “babyPQ” challenge by mr96 at m0leCon Teaser 2023. We were the first team to solve this challenge, with a total of 7 solves.

This challenge involved lots of probabilities; there was a Feistel cipher involving S-boxes whose output bits satisfy a linear relation with high (but not certain) probability, and one first had to figure out what the linear relation is, and then use it to leak the last round key, for which basic counting statistics were insufficient, so that one had to actually calculate probability updates based on the evidence provided by plaintext-ciphertext pairs.

Solution by: shalaamum

Writeup by: shalaamum

This writeup was first published here.

Key generation

The code run by the challenge is shown in the snippet below

public_key, private_key = gen_keypair(8, 48, 0.97)
flag = open("flag.txt", "rb").read()
enc = encrypt(flag, public_key)
print(f"{public_key = }")
print(f"{enc = }")

and we are given the output of this as a textfile. This means we have a public key, and the encrypted flag. The public key comes together with a private key it seems, so as a first step it might be good to recover that. So let us look at gen_keypair.

def gen_keypair(m, n, p):
    private_key = [randint(0,1) for _ in range(n-1)]
    private_key.append(1)
    SBoxes = []

    for _ in range(6):
        SBoxes.append(create_sbox(m, n, p, private_key))
    
    return SBoxes, private_key

From this we see that the private key consist of 48 bits, of which the first 47 are random, and the last one is 1. The public key that we obtain consists of 6 S-boxes, which are somehow constructed using this private key.

def create_sbox(m, n, p, priv):
    Bs = []

    for _ in range(n-1):
        B = random_boolean_function(m)
        Bs.append(B)
    
    last_bf = []

    for i in range(2^m):
        val = sum([priv[j] * int(Bs[j](i)) for j in range(n-1)]) % 2
        if random() < p:
            last_bf.append(val)
        else:
            last_bf.append(1-val)

    Bs.append(BooleanFunction(last_bf))

    Bs = matrix([[int(x) for x in B.truth_table()] for B in Bs]).transpose()
    S = []

    for B in Bs:
        S.append(sum([B[i]*2^i for i in range(n)]))
    
    return S

To get our bearings let us first note that each S-box here takes an 8-bit input, but produces a 48-bit output, which is slightly unusual. The way this is constructed is that the value of the first 47 bits is random, and then the last bit (which is the most significant one), is chosen to be such the scalar product of the bit vector with the private key1 is 0 with a probability of 97%.

Recovering the private key

To recover the 47 unknown bits of the private key, a system of 47 linearly independent linear equations suffices. We have 6 S-boxes à 256 values, which gives us 1536 linear equations, though only about 97% of them are correct. However, if we just take a random 47 of those, then with probability 0.97^47 ≈ 24% all of them will be correct. Note that we also have an easy way of checking whether a guess for the private key is correct; for a random one the scalar product with the output bits of the S-boxes will be 1 and 0 with a probability of roughly 50% each, so the correct private key is a very noticeable outlier. Thus we have an efficient and very simple way to recover the private key: Take a random 47 outputs of the S-boxes. If the linear system that corresponds is not linearly independent, throw it away and start over. Solve the system of linear equations to obtain a guess for the private key. Check whether the private key is correct by testing the distribution obtained by taking the scalar product with outputs of the S-boxes. If the distribution is more like 50-50 than 97-3, reject this guess and start over. After a few attempts we will have obtained the correct private key.

Encryption

Let us now have a look at encryption. We are provided with the return value of encrypt(flag, public_key), and this function is defined as follows:

def encrypt(pt, public_key):
    tmp_key = os.urandom(8)
    ad = []
    
    for _ in range(1024):
        tmp_pt = os.urandom(12)
        tmp_ct = feistel_encrypt(tmp_pt, tmp_key, public_key)
        ad.append((tmp_pt.hex(), tmp_ct.hex()))
    
    pt = pad(pt, 12)
    pt = [pt[i:i+12] for i in range(0, len(pt), 12)]

    c = b''.join([feistel_encrypt(p, tmp_key, public_key) for p in pt]).hex()

    return (c, ad)

So actual encryption happens with feistel_encrypt, which operates on 12-byte blocks and depends on an 8-byte encryption key we do not know (yet), and we obtain both the ciphertext of encryption of the (padded) flag, as well as 1024 random plaintext-ciphertext pairs. So we will presumably need to do a known-plaintext-attack on the cipher to recover the encryption key before we can decrypt the flag. public_key here consists of the S-boxes, which have a statistical anomaly in their output that we already know, so we will need to somehow use this to carry out our attack.

Feistel cipher

As the name of the function suggests, this is a Feistel cipher.

def feistel_encrypt(pt, key, sboxes):
    ks = expand_key(key, 8)
    l, r = bytes_to_long(pt[:6]), bytes_to_long(pt[6:])
    
    for i in range(8):
        l, r = r, l ^^ f(r ^^ ks[i], sboxes)
    
    return long_to_bytes(l) + long_to_bytes(r)

There are 8 rounds. In each round, the right half block gets xored with a round key, then a function f is applied, and finally the left half block is xored into it.

We can already make a couple of observations here. To start, let us look at the first two rounds.

l_1 = r_0
r_1 = l_0 ^ f(r_0 ^ ks_0, sboxes)
l_2 = r_1 = l_0 ^ f(r_0 ^ ks_0, sboxes)
r_2 = l_1 ^ f(r_1 ^ ks_1, sboxes) = r_0 ^ f(r_1 ^ ks_1, sboxes)

We can see that l_2 and r_2 are given by xoring l_0 and r_0, respectively, with the output of one evaluation of the function f. Repeating this for 8 rounds, we will then get that, in particular, the following holds.

l_8 = l_0 ^ f(r_0 ^ ks_0, sboxes) ^ f(r_2 ^ ks_2, sboxes) ^ f(r_4 ^ ks_4, sboxes) ^ f(r_6 ^ ks_6, sboxes)
r_8 = r_0 ^ f(r_1 ^ ks_1, sboxes) ^ f(r_3 ^ ks_3, sboxes) ^ f(r_5 ^ ks_5, sboxes) ^ f(r_7 ^ ks_7, sboxes)

So each of the two half-blocks of the ciphertext are obtained by xoring the corresponding half-block of the plaintext with four outputs of the function f. The arguments to these evaluations of f are other half-blocks xored with a round key, and in both of the two expressions there is one evaluation of f in which we know that half-block if we have a plaintext-ciphertext pair; for l_8 it is the first one, as we know r_0, and for r_8 this is the last one, as we know r_7=l_8. In the following sections we will describe how to use the second equality above to recover the last round key, it is possible to completely analogously recover the first roundkey.

Round key expansion

The round keys are derived from the key as follows:

def expand_key(key, rounds):
    ks = [key]
    
    for _ in range(rounds-1):
        cipher = ChaCha20.new(key = bytes([_])*32, nonce = b"\x00"*8)
        ks.append(cipher.encrypt(ks[-1]))
    
    return [bytes_to_long(k) for k in ks]

So the first round key is the original key, and the further round keys are obtained by encrypting them with ChaCha20 using key and nonce that we know. In particular, this means knowledge of one of the roundkeys will allow us to obtain the other roundkeys and the original key. Furthermore, we can also note that, as ChaCha20 produces a keystream the plaintext is xored with to produce the ciphertext, to know a particular byte of the original key we only need to know that specific byte in one of the roundkeys. This will be very helpful because, while the key is 8 bytes long, only the first 6 bytes of each roundkey are used. But this behavior of ChaCha20 encryption means that the extra 2 bytes don’t matter and we do not have to bruteforce over them.

The function f

This function is defined as follows

def f(p, sboxes):
    p = [(p>>(8*i)) & 0xFF for i in range(6)]
    res = 0
    for i in range(6):
        res ^^= sboxes[i][p[i]]
    return res

So we split the input half-block into 6 individual bytes, then apply the S-box to each of them (yielding a number 6 bytes wide), then xor all of them together. We can write this as follows

f(x_0, ..., x_5) = S_0(x_0) ^ ... ^ S_5(x_5)

Where S_i is the i-th S-box.

Relation to the private key

Putting the results of the previous subsections together we can write

r_8
= r_0 ^ f(r_1 ^ ks_1, sboxes) ^ f(r_3 ^ ks_3, sboxes) ^ f(r_5 ^ ks_5, sboxes) ^ f(r_7 ^ ks_7, sboxes)
= r_0 ^ (XOR_{0 <= i < 6, j in {1, 3, 5}} S_i((r_j)_i ^ (ks_j)_i) )
      ^ (XOR_{0 <= i < 6} S_i((l_8)_i ^ (ks_7)_i) )

We can rewrite this as:

r_8 ^ r_0 = (XOR_{0 <= i < 6, j in {1, 3, 5}} S_i((r_j)_i ^ (ks_j)_i) )
            ^ (XOR_{0 <= i < 6} S_i((l_8)_i ^ (ks_7)_i) )

As mentioned before, if we have a cleartext-ciphertext pair, then we will know r_8, r_0, and l_8. What we would like is to obtain information about the last roundkey ks_7. For this we will have to use the statistical anomaly the S-boxes exhibit, in that the scalar product of their output with the private key (not to be confused with the encryption key or the round keys derived from the encryption key) is 0 with high likelihood, rather than the expected 50% if they were random. As the scalar product is bilinear, we obtain the following equation by taking the scalar product of the previous one with the private key p 2.

(r_8 ^ r_0) ⋅ p = (Sum_{0 <= i < 6, j in {1, 3, 5}} S_i((r_j)_i ^ (ks_j)_i) ⋅ p)
                + (Sum_{0 <= i < 6} S_i((l_8)_i ^ (ks_7)_i) ⋅ p)

To shorten notation, let us from now on denote the expression on the right by R = R_p + R_l (with p standing for “previous rounds” and l standing for “last round”). We know the left hand side of the above equation. Let us assume for the moment it is 0. Then we know that R must be 0 as well, which implies that an even number of the scalar products of outputs of S-boxes with p involved in the sum are 1. This is the information that we will use to obtain information about ks_7, by doing Bayesian updates on our probabilities on each observation of a plaintext-ciphertext pair.

Calculating the probabilities

To shorten notation, let us denote the last roundkey that so far was denoted by ks_7 by k instead, with k_j being the j-th byte of the last roundkey. Initially, before observing anything, each value for each byte of the last round key is equally likely. Thus we initialize our probabilities with P(k_j = v) = 1/256 for v in {0,...,255}.

Now we are observing a plaintext-ciphertext pair with (r_8 ^ r_0) ⋅ p = a for some value a in F_2 and with l_8 = l for some l. Then we replace the previous value of P(k_j = v) with the probability conditional on the observation we made:

P(k_j = v | l_8=l and R=a)

To calculate this we can use Bayes’ theorem:

P(k_j=v | l_8=l and R=a) = P(l_8=l and R=a | k_j=v) * P(k_j=v) / P(l_8=l and R=a)

Using some simple rewriting using the definition of conditional probability we obtain:

= ( P(R=a | l_8=l and k_j=v) * P(l_8=l | k_j=v) * P(k_j=v) ) / ( P(R=a | l_8=l) * P(l_8=l) )

Using that the probability of l_8=l is independent of the key, we get:

= P(R=a | l_8=l and k_j=v) * P(k_j=v) / P(R=a | l_8=l)

There are three probabilities here that we need. P(k_j=v) is just our prior. The other two we need to calculate.

Let us consider P(R=a | l_8=l) first. So this means of the 24 involved S-boxes, modulo 2 exactly a many had a scalar product with p that was 1. So we can divide up the probability P(R=a | l_8=l) into subcases, indexed by tuples (a_{1,0}, ..., a_{1,5}, a_{3,0}, ..., a_{3,5}, a_{5,0}, ..., a_{5,5}, a_{7,0}, ..., a_{7,5}) in F_2^24 such that the sum of the components (modulo 2) is a. Let us call the set of these tuples A. Then the following holds.

P(R=a | l_8=l)
= Sum_{(a_1,0, ..., a_7,5) in A}
        (Prod_{0<=j<6, i in {1,3,5}}  P(S_j((r_i)_j ^ (ks_i)_j) ⋅ p = a_{i,j} | l_8=l))
        * (Prod_{0<=j<6} P(S_j(l_8 ^ (ks_7)_j) ⋅ p = a_{7,j} | l_8=l) )

If we let W_i be the number of inputs x for which S_i(x) ⋅ p = 1, then we obtain, by using that the first three round’s S-boxes depend on neither l_8 nor on the last round key:

= Sum_{(a_1,0, ..., a_7,5) in A}
        (Prod_{0<=j<6, i in {1,3,5}} (W_j**a_{i,j}) * ((256-W_j)**{1 - a_{i,j}}) )
        * (Prod_{0<=j<6} P(S_j(l_8 ^ (ks_7)_j) ⋅ p = a_{7,j} | l_8=l) )

For the S-boxes of the last round, we will use our prior for the values of the last round key. We have:

P(S_j(l_8 ^ (ks_7)_j) ⋅ p = a_{i,j} | l_8=l)
= P(S_j(l ^ k_j) ⋅ p = a_{i,j})
= Sum_{0 <= v < 256} P(k_j=v) * P(S_j(l ^ v)  ⋅ p = a_{i,j})

Here S_j(l ^ v) can be calculated directly, because l and v are concrete values. Hence the second factor can be calculated and will be either 1 or 0. Calculating all these values, and plugging them in to the previous formula we thus have a recipe for calculating P(R=a | l_8=l).

The calculation of P(R=a | l_8=l and k_j=v) is very similar, with the difference being that in one of the terms we just considered we will not sum over all 256 possible v’s, weighted by their prior probability, but instead we only take the one summand v.

Putting everything together

So we begin with P(k_j = v) = 1/256 for v in {0,...,255}, and then update these probabilities as described above after observing each of the 1024 plaintext-ciphertext pairs we are given. At the end some bytes of the key will have probability concentrated on very few possible values, for example we arrive at byte 2 having value 108 with a probability of 99.999%. Some other bytes have probability spread out a bit more, so we can choose a cutoff like 95% and consider the possible values whose probabilities sum up to at least 0.95. If for each of the 6 bytes we have a probability of 95% of having the correct value as part of those we consider, then the probability of finding the right combination among them is 0.95**6=73.5%. In actuality the probability is higher in our case, because e.g. for byte 2 we have a probability way higher than only 95% by just considering a single value. With the cutoff at 95% we get a total of 172368 combinations that we have to search through to find the correct last round key. This keyspace is smaller by a factor of over 10**9 compared to just bruteforcing the 6 bytes directly, a 48 bit bruteforce became roughly a 18 bit bruteforce, with the plaintext-ciphertext pairs leaking roughly 30 bits. The bruteforce over the remaining keyspace for the last round key can easily be done by deducing the corresponding encryption key and testing whether with that the plaintexts encrypt to the ciphertexts we are given.

Attachments

Solve script and the original challenge as well as original output, but renamed to be importable in python.


  1. We consider both to be elements of (F_2)^48. So the private key p and an output s of an S-box are given by p = (p_1, ..., p_48) and s = (s_1, ..., s_48), with the components elements of F_2, so either 0 or 1. Their scalar product is p ⋅ s = p_1 * s_1 + p_2 * s_2 + ... + p_48 * s_48, with the sum of products being taken in F_2, i.e. modulo 2↩︎

  2. After taking the scalar product, we have values in F_2, where addition, being addition mod 2, is the same as xor, so we can rewrite the xors as sum. ↩︎