m0leCon Teaser 2023 - Collisions

Share on:

Overview

Summary

This is a writeup for the “Collisions” challenge by mr96 at m0leCon Teaser 2023. We were the first team to solve this challenge, which had a total of 4 solves.

In this challenge we had to find a second preimage for a custom hash function.

Solution by: shalaamum

Writeup by: shalaamum

Writeup first published here.

Challenge setup

def chall():
    for _ in range(10):
        mymsg = os.urandom(64)
        myhash = hash(mymsg).hex()
        print((mymsg.hex(), myhash))
        yourmsg = bytes.fromhex(input())
        yourhash = hash(yourmsg).hex()
        assert myhash == yourhash and mymsg != yourmsg
    
    print(flag)

chall()

The remote generates random 64 bytes and hashes them with a custom hash function. We are given both the input as well as the hash and are challenged to provide a different input that has the same hash. If we manage to do this 10 times in a row, we receive the flag.

The hash function

blen = 16

def hash(msg):
    assert len(msg) < 256*blen
    
    m = pad(msg)
    
    assert len(m) % blen == 0
    
    s = len(m)//blen
    t = 2**ceil(log(s, 2))
    m = [m[blen*i:blen*(i+1)] for i in range(s)]

    for i in range(len(m)):
        m[i] = preprocess(m[i], 8*blen)
    
    while len(m) < 2*t-1:
        m.append(b"")
    
    for i in range(t, 2*t-1):
        l = 2*(i-t)
        r = l+1
        if m[l] == b"" and m[r] == b"":
            m[i] = b""
        elif m[r] == b"":
            m[i] = m[l]
        else:
            m[i] = xor(block(m[l], m[r]), m[r])
    
    return xor(block(m[-1], pad(bytes([s]))), pad(bytes([s])))

The message is first padded (this replaces msg with msg + bytes([blen-(len(msg)%blen)])*(blen-(len(msg)%blen))). The message is then split into blocks that are 16 bytes in length, these blocks are preprocessed, and then extended by enough empty blocks to make the length a power of 2. That power of 2 is denoted by t. Additionally, another t-1 empty blocks are appended. We should think of t-1 as (t/2) + (t/4) + ... + 1, and the list m as actually containing the nodes of a binary tree, like in the following picture for t=4:

      6
    /   \
   /     \
  4       5
 / \     / \
0   1   2   3

The original message fills part of the lowest row. The hash function than iteratively calculates the other rows from the row directly below. If the right child node is empty, then the node’s value will just be the left child’s value. But if both child nodes are non-empty, then the value of the node will be block(lv, rv) ^ rv, where lv and rv stand for the values of the left and right child node, respectively. At the very end of this process, the hash is obtained by carrying out this operation one last time with the root node of the tree as the “left child node” and the padded block length of the original input as the “right child node”.

At first I thought about whether it would be possible to obtain a collision by exploiting the behavior of hashing along the tree with respect to empty blocks. But it does not seem possible to inject empty blocks in the middle, and the padding and usage of the number of input blocks at the very end mean a solution in which one for example just uses a single-block input with the hash one wishes to obtain as value does not work. So perhaps we will have to find two blocks that preprocess to the same block, or find two pairs (lv, rv) such that block(lv, rv) ^ rv is the same.

The preprocessing

This is the preprocessing function.

def preprocess(b, cnt):
    state = [int(x) for x in bin(bytes_to_long(b))[2:].rjust(cnt, '0')]
    
    for i in range(cnt):
        feedback = state[0] ^ state[47] ^ (1 - (state[70] & state[85])) ^ state[91]
        state = state[1:] + [feedback]
    
    b = long_to_bytes(int(''.join(str(x) for x in state), 2)).rjust(cnt//8, b"\x00")
    return b

We can see that this function operates on bits (the first line splits the input into a list of bits, the line before the return converts back). There is a loop where the bitvector state is updated by removing the first component, and adding an extra bit at the end, which is given by the previous first bit xored with a Boolean expression of other bits. As we still have those other bits available in the new state, it is thus clear the previous first bit can be recovered from the new state. Hence this operation is a bijection.

Thus preprocess is not actually relevant for finding another preimage; collisions can’t happen at this stage, so we can just consider collisions post-preprocessing and then invert preprocess on our collision to get back the original input.

The block function

def block(m1, m2):
    keys = ks(m1)
    l, r = m2[:blen//2], m2[blen//2:]
    
    for i in range(nrounds):
        l, r = r, xor(l, f(r, keys[i], i))
    
    return l+r

The block function generates round keys out of the first argument, and then applies a Feistel cipher to the second argument using those round keys and a function f. Let us have a look at this function first.

The function f

The function f is implemented as follows.

def f(m, k, r):
    state = (bytes_to_long(m) ^ k).to_bytes(blen//2, byteorder = "big")
    state = [sbox[x] for x in state]
    state[2] ^= (0xf0 - (r%16)*0x10 + r*0x1)%256
    
    for i in range(8):
        state[i] ^= state[(i-1)%8]
    
    state2 = [(state[i] ^ 0xff) & state[(i+1)%8] for i in range(8)]
    
    for i in range(8):
        state[i] ^= state2[(i+1)%8]

    state[0] ^= rotr(state[0], 5) ^ rotr(state[0], 6)
    state[1] ^= rotr(state[1], 2) ^ rotr(state[1], 6)
    state[2] ^= rotr(state[2], 2) ^ rotr(state[2], 1)
    state[3] ^= rotr(state[3], 4) ^ rotr(state[3], 2)
    state[4] ^= rotr(state[4], 3) ^ rotr(state[4], 5)
    state[5] ^= rotr(state[5], 1) ^ rotr(state[5], 5)
    state[6] ^= rotr(state[6], 2) ^ rotr(state[6], 4)
    state[7] ^= rotr(state[7], 3) ^ rotr(state[7], 4)
    
    return bytes(state)

This looks kind of complicated, but luckily it wasn’t necessary to read all of this to solve the challenge. When called from block, the first argument of f will be part of the current state of the Feistel cipher, and the second argument will be a roundkey. Searching for occurrences of m and k in the function we find that they are only used to define the initial state

    state = (bytes_to_long(m) ^ k).to_bytes(blen//2, byteorder = "big")

where they are xored together. The code afterwards involves an S-box and various rotations and xors, so it seems kind of difficult to predict nice relations for the output of f applied to two related initial states – unless the initial states happened to have been equal, in which case f will of course return the same value. So it seems that the solution of this challenge will likely involve a message where the xor of the relevant part of the state of the Feistel cipher with the roundkey is the same as for the original message.

The roundkeys

The roundkeys are derived as follows:

def ks(k):
    k1 = bytes_to_long(k) % 2**(4*blen)
    k2 = bytes_to_long(k) >> (4*blen)
    
    rk1 = [((k1 << (i%(4*blen)))%(2**(4*blen))) | k1 % (2**(i%(4*blen))) for i in range(nrounds//2)]
    rk2 = [((k2 << (i%(4*blen)))%(2**(4*blen))) | k2 % (2**(i%(4*blen))) for i in range(nrounds//2)]
    
    round_keys = sum([[rk1[i], rk2[i]] for i in range(nrounds//2)], [])
    round_keys = [round_keys[i] ^ rc[i] for i in range(len(round_keys))]
    
    return round_keys

Basically all of the round keys are obtained as a rearrangement of the key bits, and then xoring with the components of rc, which is a list of known constants. This means that if we xor the input k with some d, then each of the roundkeys will be given by the original roundkeys, xored by an appropriately bit-rearranged version of d. As the permutation is different for the different rounds, the simplest nontrivial case to consider would be one were all bits of d are 1, i.e. if we flip all the bits of the original k. Then the roundkeys will all also have all their bits flipped.

In the following, let us denote by d the 16-byte block with all bits 1, so that xoring a 16-byte value with it has the effect of flipping all bits. We will also use d in contexts in which we xor with 8-byte half blocks; in this case the operation should all be considered to happen modulo 2**64.

How to obtain another preimage

We already concluded that the way we most likely will construct another preimage is by figuring out how, given a pair of 16-byte blocks (lv, rv), we can construct another such pair (lv', rv') such that block(lv', rv') ^ rv' = block(lv, rv) ^ rv. The round key used in block(lv', rv') is generated from lv', and in the previous section we arrived at it then being a good idea to try lv' = lv ^ d.

The value of rv is split up into two half-blocks for the Feistel cipher using the key derived from lv, so let us denote the initial two half blocks of rv by l_0 and r_0, and then the states in later rounds by l_i and r_i, and use similar notation l'_i and r'_i for the Feistel cipher applied to rv' with the key derived from lv'. We denote the roundkeys by k_i and k'_i. Then note that in the first round, the first two arguments for f are r_0 and k_0, and r'_0 and k'_0. As we chose lv' = lv ^ d, by the previous section we conclude that k'_i = k_i ^ d. By the section on f we likely want r_0 ^ k_0 = r'_0 ^ k'_0 to be able to predict the value of f, which will then be equal. This then implies we should choose r'_0 = r_0 ^ k_0 ^ k'_0 = r_0 ^ k_0 ^ k_0 ^ d = r_0 ^ d. With this, we will then get

l_1 = r_0
l'_1 = r'_0 = r_0 ^ d = l_1 ^ d
r_1 = l_0 ^ f(r_0, k_0, 0)
r'_1 = l'_0 ^ f(r'_0, k'_0, 0) = l'_0 ^ f(r_0 ^ d, k_0 ^ d, 0) = l'_0 ^ f(r_0, k_0, 0) = r_1 ^ (l'_0 ^ l_0)

If we repeat this kind of reasoning for the second round, again wishing to obtain the same output for the f occurring there, then we will see that we want l'_0 ^ l_0 = d. But now note that if we make that choice, i.e. choose r'_0 = r_0 ^ d and l'_0 = l_0 ^ d, then the above shows that the same kind of relations will hold in later rounds as well, i.e. it will always be r'_i = r_i ^ d and l'_i = l_i ^ d.

The upshot is that block(lv ^ d, rv ^ d) = block(lv, rv) ^ d. But this directly implies that

block(lv ^ d, rv ^ d) ^ (rv ^ d) = block(lv, rv) ^ d ^ rv ^ d = block(lv, rv) ^ rv

Hence, given a message with at least two blocks, we can produce a different message with the same hash by ensuring that the first two blocks will, after preprocessing, be given by xoring the first two preprocessed blocks of the original message with d.

Attachments

Challenge file and solve script.