LakeCTF 2022 Writeup - Chaindle

The following writeup is a collaboration between me and Marc (@S7uXN37), who greatly contributed by rewriting the awful code I wrote during the CTF and by making it way more efficient and likely to succeed. I will be writing it from my point of view.

A few weeks ago we played LakeCTF as team Flagbot, the academic team of ETH Zurich. We managed to fully clear all the cryptography challenges, although one particular challenge happened to be particularly hostile to us. The challenge was called "chaindle" and it is essentially a AES-CBC-based Wordle (from which the challenge got its name).

The reason why chaindle turned out to be a particularly challenging problem for me was that I didn't take a serious look at the challenge until towards the end of the competition. This meant that we were in a time scramble in order to make the challenge work towards the end. Turns out that our solution usually would not work, however, one of our attempts got extremely lucky and we managed to flag just one minute and thirty seconds before the end. It goes without saying that I was almost incredulous at the fact that our solution worked, since the odds of it doing so seemed to be paper thin.

Premises aside, let's start discussing the challenge!

The Challenge

We are given a Python file containing the code that is being run remotely. The server begins by taking the alphabet of base64 characters and shuffling them to obtain a string which we will call the "answer". The player is given 256 Wordle-style queries in order to gain information about the answer. In each query, the player can submit a AES-CBC ciphertext which will be decrypted and checked against the correct answer. Note that we don't have the key for AES-CBC, which means that guesses must be made "blindly".

def get_color(answer: bytes, guess: bytes) -> list[Color]:
    assert len(answer) == len(guess), (
        f"Wrong guess length, " f"answer_len={len(answer)}, guess_len={len(guess)}"
    )

    n = len(answer)
    matched = [False] * n
    color = [Color.BLACK] * n

    for i in range(n):
        if answer[i] == guess[i]:
            matched[i] = True
            color[i] = Color.GREEN

    for i in range(n):
        if color[i] == Color.GREEN:
            continue
        for j in range(n):
            if matched[j] or answer[j] != guess[i]:
                continue
            matched[j] = True
            color[i] = Color.YELLOW

    return color

The check works as follows: after decryption, the server loops over both our guess and the correct answer in order to mark the correct letters at the correct position (the green letters). Afterwards, it makes another pass to mark letters which are contained within the answer, but which are at the incorrect position (the yellow letters). All other letters are marked as black. Note that if a letter appears twice in our guess, but only once in the answer (as it is always the case, since the answer contains only one copy of each letter in the alphabet), the rightmost letter in the guess will be also marked as black. This was an annoying factor that needed to be taken into consideration, but nothing that was impossible to account for. For the remainder of the writeup, we will abstract this server as an oracle, that takes a CBC ciphertext as inputs and outputs an array of colors, as defined in the description of the check.

After the 256 queries, either the player has correctly guessed the answer (i.e. all letters become green), after which the flag is printed, or the server will hash the correct answer and use it as a AES-ECB key to encrypt the flag. This is very good for us, since sub-par solutions (as ours will turn out to be) can still obtain the flag, by partially getting the answer and bruteforcing the remaining characters.

chaindle = Chaindle(answer)
for _ in range(256):
    guess = json.loads(input())
    result = chaindle.guess(
        bytes.fromhex(guess["iv"]),
        bytes.fromhex(guess["ciphertext"]),
    )
    print(json.dumps({"result": result}))
    if result == ALL_GREEN:
        print("Congrats!")
        break
else:
    key = sha256(answer).digest()
    cipher = AES.new(key, AES.MODE_ECB)
    FLAG = cipher.encrypt(pad(FLAG, 16))

print(FLAG)

Our Solution

Let's start by considering a single block (Block A) and the one that precedes it (Block B). Similarly to a padding oracle attack, we can modify block B in order to precisely change the contents of block A. Unlike a padding oracle attack, however, we do not only get information about the last byte in the block, but we get information about all the characters in the block at the same time! Intuitively this makes the 256 queries tight but not impossible to work with. Furthermore, this allows us to work with each byte independently. The only case in which bytes might interfere with each other is if they both decrypt to the same character, which, as described above, would make the output of the oracle black in the corresponding position. However, this is not an issue in our exploit.

The first information that we can get from the output of the oracle is which of the characters are in the base64 alphabet. Our objective is to force the entire block to become entirely made of base64 characters. This would allow us to greatly reduce the space of possibilities and, thus, the amount of guesses that we need to make in order to reveal the correct character. Since I wasn't expecting our solution to be able to make every character green, I was already aiming at doing some offline bruteforcing. Thus, our objective is to obtain as much information as possible about the answer and pray that it is enough to make bruteforcing feasible. In practice this means having information about every character in the answer, except for around 12 characters.

Our approach was fairly simple:

  1. Choose many values for block B at random
  2. For each value, check which of the corresponding bytes of block A becomes yellow
  3. For each character of the chosen block B, retain it if it generated a yellow character.

Eventually, this leads to the creation of a block B that makes the entirety of block A composed of base64 characters. We call this value a "good IV" (That's the best name I could come up with during the CTF!).

Now that we have constrained the characters, all there is left is learning which characters are in the block. To do so we can manipulate the good IV slightly, in order to gain more information. More specifically, we need to know two pieces of information:

  1. Which base64 character we induced by using the good IV
  2. Which value \(\Delta_i\) can we XOR the i-th byte of the good IV in order to obtain a green letter in the i-th byte of the block?

The two pieces of information combined will yield the correct character in the guess. I will now walk through the inefficient method that I used while rushing for a solution. Towards the end, I will discuss some improvements that would make this way more efficient.

The first piece of information is fairly easy to reliably get, since this can be obtained by XOR-ing each byte with 64 different values of \(\Delta_i\). The behaviour of a byte when being XOR-ed with each of those values uniquely determines the character. This means that we can easily get the characters that we induced with the good IV in around 64 queries. Unfortunately, the second piece of information is trickier to obtain. Clearly, every base64 character has 64 possible XOR values that map it to another base64 character. The set of these 64 possible XOR values depend on the character itself. The set of all these 64 possible XOR values across all base64 characters give us a set of 128 possible values to choose from. To interpret this result: if we don't know the character that we induced, we might have to try 128 different values in order to obtain a green. This is definitely not ideal, since this uses half of our queries. Nonetheless, we kept going for this approach, making it a bit smarter by combining the search of both pieces of information in 128 queries.

This gives us an entire block of the answer. What about the other blocks?

A quick observation is that the base64 alphabet is composed of (obviously) 64 characters, which make up 4 blocks. By including the IV, we send five blocks: C0, C1, C2, C3 and C4. The insight is that we can work independently on gaining information on the pair of blocks (C2, C4), by modifying blocks C1 and C3, and later on the pair (C1, C3), by modifying blocks C0 and C2. This enables us not to have to repeat the above process four times, but only twice. However, as we needed a few guesses at the start to obtain the good IV, we do not have enough queries to uniquely determine all characters in the second set of blocks. Thus, we will definitely have some bruteforcing to do.

We had no better option that to run the exploit a few times in the last 15 minutes of the competition. Initially I was getting around 20 characters missing, which was way above feasibility. Suddenly, fortune smiled upon me, giving me an answer in which we missed a mere 7 characters! This happened at around what I thought to be 4 minutes. Turns out that this was just the clock of my computer acting up and that I had an entire 7 minutes to bruteforce those seven characters. A child could do it in seven minutes, right?

I opened up an iPython instance and started typing as fast as I could (I'm sure that here my training with the Advent of Code helped me in getting down the code as fast as possible!). In the end, the bruteforcing worked, yielding us the flag: EPFL{A_chaindle_is_an_anagram_of_enchilada_310831124e0b5df3311a}.

What We Could Have Done Better

During the CTF: nothing. I essentially had no time left and a more complex solution would have definitely taken me longer than the time left in the competition. In this sense, I believe the usual mantra of "Keep It Simple Stupid" worked out in the end, albeit with a slight botta di culo, as we would say in Italian.

After the CTF, Marc decided to hurt himself by trying to re-interpret my awful code and re-writing it. I believe his code to be way more readable than mine, which is why I will include only that code, rather than the one I used to solve the challenge.

With more time there was a way smarter approach to be taken when trying to induce green characters. Clearly testing all 128 values would have given us a green character, but all those characters are required as long as we do not know the character that we induced. Indeed, by combining information on when the character turns yellow, we could have reduced the amount of values required by quite a lot. This should bring the total number of queries way below the 256 total, which would also allow us to not have to bruteforce anything.

#!/usr/bin/env python3
import itertools
import json
import math
import os
import string
from enum import Enum
import time

from pwn import *

class Color(Enum):
    BLACK = 0
    YELLOW = 1
    GREEN = 2

def decode_response(response: bytes) -> list[Color]:
    j = json.loads(response)["result"]
    r = []

    # I hate UTF-8 and idk how to decode this decently
    for ch in j:
        if ch.encode() == b'\xf0\x9f\x9f\xa8':
            r.append(Color.YELLOW)
        elif ch.encode() == b'\xe2\xac\x9b':
            r.append(Color.BLACK)
        else:
            r.append(Color.GREEN)
    return r

def arr_to_val(arr):
    return [x.value for x in arr]

def get_conn():
    if args.REMOTE:
        return remote('chall.polygl0ts.ch', 4400, level=logging.WARN)
    else:
        # r = remote('localhost', 9000)
        return process(['python', './chaindle.py'], level=logging.WARN)

def blockify(arr):
    return [arr[i:i+16] for i in range(0, len(arr), 16)]

def oracle(iv, ctxt, conn):
    ctxt = b''.join(ctxt)
    j = json.dumps({"iv": iv.hex(), "ciphertext": ctxt.hex()}).encode()
    conn.sendline(j)
    l = conn.recvline()
    return decode_response(l)

# Build reverse table (set of chars for which the xor with ch is good) -> ch
ALPHABET = (string.ascii_letters + string.digits + '+/').encode()

xor_vals = set(x ^ y for x in ALPHABET for y in ALPHABET)

# i in rev_idx[c]  <==>  c ^ i in ALPHABET and c in ALPHABET
rev_idx = {}
for c in ALPHABET:
    ll = set()
    for i in xor_vals:
        if bytes([c ^ i]) in ALPHABET:
            ll.add(i)
    rev_idx[c] = ll

# returns a list of all alphabet characters for which XORs with each byte in s would stay in the ALPHABET
def check_set(s: set[bytes]) -> list[bytes]:
    candidates = []
    for ch, ss in rev_idx.items():
        if s.issubset(ss):
            candidates.append((ch).to_bytes(1,"big"))
    return candidates


def attack(conn):
    query_cnt = 0

    # First try multiple random c3
    c4 = b'\x00' * 16  # Fix block4 and block2 to zeroes
    c2 = b'\x00' * 16
    good_iv = [0] * 16  # We'll fill this with good choices for block3 bytes (good = result is in alphabet)
    good_iv_2 = [0] * 16 # We'll fill this with good choices for block1 bytes

    while any(k == 0 for k in good_iv) or any(k==0 for k in good_iv_2):
        iv = os.urandom(16)
        iv_2 = os.urandom(16)
        res = oracle(os.urandom(16), [iv_2, c2] + [iv, c4], conn)  # actual IV is irrelevant because we don't look at the first block's result
        query_cnt += 1
        res_1 = res[-16:]
        for i, x in enumerate(res_1):
            if x.value >= 1 and good_iv[i] == 0:
                good_iv[i] = iv[i]

        res_2 = res[16:32]
        for i, x in enumerate(res_2):
            if x.value >= 1 and good_iv_2[i] == 0:
                good_iv_2[i] = iv_2[i]

    # Now we have [IV = rand] [B1 = good_iv_2] [B2 = 00..00] [B3 = good_iv] [B4 = 00..00]
    # and the result is YELLOW or GREEN for block4 and block2

    # Now we modify good_iv with the 64 different possible XOR values until P4 is green

    sets = [set() for _ in range(16)]  # saves all XOR differences ch where good_iv[_] ^ ch is YELLOW or GREEN
    good_xor = [-1 for _ in range(16)]  # saves the XOR difference ch where good_iv[_] ^ ch is GREEN

    sets_2 = [set() for _ in range(16)]
    good_xor_2 = [-1 for _ in range(16)]

    for ch in list(xor_vals)[:96]:
        iv = bytes([x ^ ch for x in good_iv])
        iv_2 = bytes([x ^ ch for x in good_iv_2])
        res = oracle(os.urandom(16), [iv_2, c2] + [iv, c4], conn)
        query_cnt += 1
        res_1 = res[-16:]
        for i, x in enumerate(res_1):
            if x.value >= 1:
                sets[i].add(ch)
            if x.value == 2:
                good_xor[i] = ch

        res_2 = res[16:32]
        for i, x in enumerate(res_2):
            if x.value >= 1:
                sets_2[i].add(ch)
            if x.value == 2:
                good_xor_2[i] = ch

    # Now we have [IV = rand] [B1 = good_iv_2 ^ good_xor_2] [B2 = 00..00] [B3 = good_iv ^ good_xor] [B4 = 00..00]
    # and the result is ALL_GREEN for block4 and block2

    answer = [(check_set(s), x) for s, x in zip(sets, good_xor)]
    answer_2 = [(check_set(s), x) for s, x in zip(sets_2, good_xor_2)]

    def compute_possible_plaintext(answer: list[tuple[list[bytes], bytes]]) -> list[bytes]:
        plain = []
        for ss, x in answer:
            if x != -1:
                # If we know the right choice, XOR it with all candidates to get all possible plaintext bytes
                plain.append(bytes([s[0] ^ x for s in ss]))
            else:
                # otherwise, we can't recover the plaintext character but instead we save all possible results in cand
                plain.append(ALPHABET)
        return plain

    p4 = compute_possible_plaintext(answer)
    p2 = compute_possible_plaintext(answer_2)

    # Then do the other two blocks

    c4 = b'\x00' * 16
    c2 = b'\x00' * 16
    good_iv = [0] * 16
    good_iv_2 = [0] * 16

    while any(k == 0 for k in good_iv) or any(k==0 for k in good_iv_2):
        iv = os.urandom(16)
        iv_2 = os.urandom(16)
        res = oracle(iv_2, [c2, iv] + [c4, os.urandom(16)], conn)
        query_cnt += 1
        res_1 = res[-32:-16]
        for i, x in enumerate(res_1):
            if x.value >= 1 and good_iv[i] == 0:
                good_iv[i] = iv[i]

        res_2 = res[0:16]
        for i, x in enumerate(res_2):
            if x.value >= 1 and good_iv_2[i] == 0:
                good_iv_2[i] = iv_2[i]

    # Now we modify good_iv with the 64 different possible XOR values until P4 is green

    sets = [set() for _ in range(16)]
    good_xor = [-1 for _ in range(16)]

    sets_2 = [set() for _ in range(16)]
    good_xor_2 = [-1 for _ in range(16)]

    for ch in xor_vals:
        iv = bytes([x ^ ch for x in good_iv])
        iv_2 = bytes([x ^ ch for x in good_iv_2])
        res = oracle(iv_2, [c2, iv] + [c4, os.urandom(16)], conn)
        query_cnt += 1
        res_1 = res[-32:-16]
        for i, x in enumerate(res_1):
            if x.value >= 1:
                sets[i].add(ch)
            if x.value == 2:
                good_xor[i] = ch

        res_2 = res[0:16]
        for i, x in enumerate(res_2):
            if x.value >= 1:
                sets_2[i].add(ch)
            if x.value == 2:
                good_xor_2[i] = ch

        if query_cnt == 256:
            break

    answer = [(check_set(s), x) for s, x in zip(sets, good_xor)]
    answer_2 = [(check_set(s), x) for s, x in zip(sets_2, good_xor_2)]

    p3 = compute_possible_plaintext(answer)
    p1 = compute_possible_plaintext(answer_2)

    # iterate to fixed point
    old_result = b""
    while old_result != p1+p2+p3+p4:
        old_result = p1+p2+p3+p4
        # remove impossible choices (repeat characters) from possibilities
        used_chars = []
        for x in p1+p2+p3+p4:
            if len(x) == 1:
                used_chars.append(x[0])

        total_plain = b"".join(p1+p2+p3+p4)

        # remove all used_chars from possibilities
        # we also check if some alphabet character appears only in one place in p1..p4 and then fix that position
        def update_possibilities(plain: list[bytes]) -> list[bytes]:
            new_plain = []
            for character_poss in plain:
                if len(character_poss) == 1:
                    new_plain.append(character_poss)
                    continue
                else:
                    new_poss = b""
                    for p in character_poss:
                        count = total_plain.count(bytes([p]))  # total occurences of this character
                        if count == 1:
                            new_poss = bytes([p])
                            break
                        if p not in used_chars:
                            new_poss += bytes([p])
                    new_plain.append(new_poss)
            return new_plain

        
        p1 = update_possibilities(p1)
        p2 = update_possibilities(p2)
        p3 = update_possibilities(p3)
        p4 = update_possibilities(p4)

    unknowns = 0
    exp = 1
    for p in p1+p2+p3+p4:
        exp *= len(p)
        if len(p) > 1:
            unknowns += 1


    log.debug("### End of Run ###")
    log.debug(f"We have {unknowns} unknown characters")

    log.debug(f"Brute-Force would need to exhaust up to {exp} possibilities (~ 2^{math.log(exp, 2):.2f})")
    log.debug(f"Queries used in total: {query_cnt}")

    for _ in range(256-query_cnt):
        oracle(os.urandom(16), [os.urandom(16)]*4, conn)
    cipher = conn.recvline()[2:-2].decode('unicode_escape').encode('raw_unicode_escape')

    return p1+p2+p3+p4, cipher, unknowns, round(math.log(exp, 2))

def crack(plaintext: list[bytes], ciphertext: bytes):
    from Crypto.Cipher import AES
    from Crypto.Util.Padding import unpad
    from hashlib import sha256

    found = []
    for p in plaintext:
        if len(p) == 1:
            found.append(p[0])

    missing = [x for x in ALPHABET if x not in found]
    
    prog = log.progress("Brute-Forcing")
    num_tries = 0
    time_start = time.time()
    for permutation in itertools.permutations(missing, len(missing)):
        num_tries += 1
        guess = b""
        ctr = 0
        for i in range(len(plaintext)):
            if len(plaintext[i]) == 1:
                guess += plaintext[i]
            else:
                guess += bytes([permutation[ctr]])
                if permutation[ctr] not in plaintext[i]:
                    break
                ctr += 1
        else:
            fraction_complete = num_tries/math.factorial(len(missing))
            time_passed = time.time() - time_start
            seconds_left = time_passed/fraction_complete * (1-fraction_complete)
            prog.status(f"{100*fraction_complete:.0f}% -- ETA {seconds_left//60:.0f}:{seconds_left%60:.0f} -- {guess}")
            
            key = sha256(guess).digest()
            cipher = AES.new(key, AES.MODE_ECB)
            try:
                FLAG = unpad(cipher.decrypt(ciphertext), 16)
                #log.info(FLAG)
                if b"EPFL" in FLAG:
                    prog.success(guess)
                    log.success(f"Flag is: {FLAG}")
                    quit()
            except ValueError:
                # incorrect padding
                pass
    log.failure("Solving failed")

if __name__  == "__main__":
    log.info("We will run the protocol many times. When you are ready to start brute-forcing, press CTRL+C.")

    best_plain = b""
    best_cipher = b""
    best_score = 200
    best_score2 = 200
    prog = log.progress("Executing protocol runs")
    try:
        while True:
            conn = get_conn()
            plain, cipher, score, score2 = attack(conn)
            conn.close()
            
            if score < best_score or (score == best_score and score2 < best_score2):
                best_score = score
                best_score2 = score2
                best_cipher = cipher
                best_plain = plain
                prog.status(f"Best score = {best_score} unknowns, 2^{best_score2} guesses")
            if score <= 11:
                prog.success(f"Score ({best_score}) small enough. Beginning brute force.")
                break
    except KeyboardInterrupt:
        prog.success(f"Final best score: {best_score} unknowns")
        pass
    
    print("plain:", b"".join([bytes([p[0]]) if len(p)==1 else b"?" for p in plain]))
    print("cipher:", repr(cipher))
    
    crack(best_plain, best_cipher)

Conclusions

Well, we got the flag, do we care about anything else? Not really! If this writeup may serve as a lesson I guess it would be "go fast and loose" when writing exploits and "don't overthink it". The important thing is to give yourself enough time to get lucky and bruteforce a solution ;)

Speaking of last minute, I also wrote this writeup in around one hour, and I will be submitting it around one minute before the deadline. I guess I never learn...