Rating:

# Mixed Cipher

## Introduction

```
I heard bulldozer is on this channel, be careful!
nc crypto.chal.ctf.westerns.tokyo 5643
```

We are given the following server program, which is listening at the above
address:

```py
from Crypto.PublicKey import RSA
from Crypto.Cipher import AES
from Crypto.Util.number import long_to_bytes

import random
import signal
import os
import sys

sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', 0)
privkey = RSA.generate(1024)
pubkey = privkey.publickey()
flag = open('./flag').read().strip()
aeskey = os.urandom(16)
BLOCK_SIZE = 16

def pad(s):
n = 16 - len(s)%16
return s + chr(n)*n

def unpad(s):
n = ord(s[-1])
return s[:-n]

def aes_encrypt(s):
iv = long_to_bytes(random.getrandbits(BLOCK_SIZE*8), 16)
aes = AES.new(aeskey, AES.MODE_CBC, iv)
return iv + aes.encrypt(pad(s))

def aes_decrypt(s):
iv = s[:BLOCK_SIZE]
aes = AES.new(aeskey, AES.MODE_CBC, iv)
return unpad(aes.decrypt(s[BLOCK_SIZE:]))

def bulldozer(s):
s = bytearray(s)
print('Bulldozer is coming!')
for idx in range(len(s) - 1):
s[idx] = '#'
return str(s)

def encrypt():
p = raw_input('input plain text: ').strip()
print('RSA: {}'.format(pubkey.encrypt(p, 0)[0].encode('hex')))
print('AES: {}'.format(aes_encrypt(p).encode('hex')))

def decrypt():
c = raw_input('input hexencoded cipher text: ').strip().decode('hex')
print('RSA: {}'.format(bulldozer(privkey.decrypt(c)).encode('hex')))

def print_flag():
print('here is encrypted flag :)')
p = flag
print('another bulldozer is coming!')
print(('#'*BLOCK_SIZE+aes_encrypt(p)[BLOCK_SIZE:]).encode('hex'))

def print_key():
print('here is encrypted key :)')
p = aeskey
c = pubkey.encrypt(p, 0)[0]
print(c.encode('hex'))

signal.alarm(300)
while True:
print("""Welcome to mixed cipher :)
I heard bulldozer is on this channel, be careful!
1: encrypt
2: decrypt
3: get encrypted flag
4: get encrypted key""")
n = int(raw_input())

menu = {
1: encrypt,
2: decrypt,
3: print_flag,
4: print_key,
}

if n not in menu:
print('bye :)')
exit()
menu[n]()
```

We can see that:

* the server uses both RSA and AES: RSA is used directly (without
any padding) and AES is used in CBC mode;
* the AES and RSA keys are both generated randomly when the server starts and
they are never regenerated for the remainder of the session;
* AES IVs are sent together with the encrypted message, and are not reused:
a different IV is generated everytime a message is encrypted with AES;

The server has four commands that we can use:

1. **encrypt** encrypts a message with both RSA and AES;
2. **decrypt** decrypts a message with RSA but replaces every byte of the
decrypted message except the last with `#` before sending it back to us;
3. **get encrypted flag** sends us the flag, encrypted with AES. To make things
more difficult, the entire IV is replaced with `#`;
4. **get encrypted key** sends us the AES key, encrypted with the RSA key;

To summarize, the flag is encrypted with AES-CBC with unknown IV and key. We
will need to:

* break the RSA encryption to recover the AES key;
* break the IV generation to recover the IV;

## RSA parity oracle

A [parity oracle attack](https://cryptopals.com/sets/6/challenges/46) is an
attack on RSA that can be used to recover the plaintext of an encrypted message
if, given an encrypted message, the attacker can learn whether the corresponding
plaintext is even or odd.

We can use this attack in this challenge since the server doesn't destroy the
last byte (i.e. the least significant byte) of the plaintext when we ask it to
decrypt a message with its RSA private key. We can keep multiplying the
encrypted AES key by $2^{e} \bmod n$ and every time ask the server if the
corresponding plaintext is even or odd. Each response will leak one bit of the
plaintext.

However, this attack requires knowledge of at least $e$ and $n$, respectively
the public exponent and modulus of the server's RSA key. We know $e$ since the
server is just using PyCrypto's default value ($e = 65537$) but we will need to
recover $n$. In order to do this, we can simply ask the server to encrypt a few
small numbers $i$ greater than 1 and compute the GCD of the differences between
$i^{e}$ and the corresponding ciphertexts. Since RSA encryption is done by
computing $c = m^{e} \bmod n$ the GCD will either be $n$ or $n$ times a small
number which we can factor out. In practice using 2, 3 and 4 seems to work.

## Mersenne Twister seed recovery

We still have to recover the AES IV, which is randomly generated with `iv =
long_to_bytes(random.getrandbits(BLOCK_SIZE*8), 16)` every time a message is
encrypted with AES. The pseudorandom number generator used by Python's `random`
is the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister), which
is not suitable for cryptography as its output can be predicted just by
observing enough of it.

Instead of implementing our own Mersenne Twister cracker, we used [an existing
one](https://github.com/tna0y/Python-random-module-cracker).

All we have to do is ask the server to encrypt some messages and feed the IVs to
the cracker.

## Putting it all together

Now that we have recovered both the key and the IV we have everything we need
to decrypt the key.

```py
from pwn import *
from fractions import gcd
from Crypto.Cipher import AES
from Crypto.Util.number import bytes_to_long, long_to_bytes

# https://github.com/tna0y/Python-random-module-cracker
from randcrack import RandCrack

e = 65537
host = 'crypto.chal.ctf.westerns.tokyo'
port = 5643
rsa_re = re.compile('RSA: ([a-z0-9]+)')
aes_re = re.compile('AES: ([a-z0-9]+)')

def discard_prompt(t):
for _ in xrange(5):
t.recvline()

# Ask the server to encrypt a message
def server_encrypt(t, msg):
discard_prompt(t)
t.sendline('1')
t.sendline(msg)
t.recvline()

rsaline = t.recvline()
rsa = rsa_re.search(rsaline).groups(1)[0].decode('hex')

aesline = t.recvline()
aes = aes_re.search(aesline).groups(1)[0].decode('hex')

assert len(aes) % 16 == 0

return rsa, aes

# Ask the server to decrypt a message
def server_decrypt(t, msg):
discard_prompt(t)
t.sendline('2')
t.sendline(msg.encode('hex'))

t.recvline()
t.recvline()
rsaline = t.recvline()
return rsa_re.search(rsaline).groups(1)[0].decode('hex')

# Get the encrypted flag from the server
def get_enc_flag(t):
discard_prompt(t)
t.sendline('3')
t.recvline()
t.recvline()
t.recvline()

flagline = t.recvline().strip()

enc = flagline.decode('hex')[16:]
assert len(enc) % 16 == 0

return enc

# Get the encrypted AES key from the server
def get_enc_key(t):
discard_prompt(t)
t.sendline('4')
t.recvline()
t.recvline()
keyline = t.recvline().strip()
return keyline.decode('hex')

# Recover the RSA modulus: msgs = list of (plaintext, ciphertext) tuples
def recover_modulus(msgs):
msg_gcd = reduce(gcd, [(p ** e) - c for p, c in msgs[:5] if p > 1])

assert msg_gcd > 1

return msg_gcd

# Recover the Mersenne Twister RNG's seed from the AES IVs
def recover_seed(ivs):
cracker = RandCrack()

for iv in ivs:
tmp = iv

# The cracker takes 32 bits at a time
while tmp > 0:
cracker.submit(tmp % (1 << 32))
tmp = tmp >> 32

return cracker

# Ask the server to encrypt some messages and use the responses to recover the
# RSA modulus and RNG seed
def recover_modulus_and_seed(t):
rsa_pairs = []
ivs = []

for i in range(157):
# The server will treat newlines as the end of the message
if i == ord('\n'):
continue

rsa, aes = server_encrypt(t, struct.pack('B', i))

rsa_pairs.append((i, bytes_to_long(rsa)))
ivs.append(bytes_to_long(aes[:16]))

return recover_modulus(rsa_pairs), recover_seed(ivs)

# Recover the AES key encrypted with RSA with modulus n
def recover_aes_key(t, n, key):
# The AES key is at most 2^128 so we don't need to use m as our upper bound
lb = 0
ub = n / (2 ** 890)
key = (pow(2 ** 890, e, n) * bytes_to_long(key)) % n

while lb < ub:
key = (pow(2, e, n) * key) % n
d = bytes_to_long(server_decrypt(t, long_to_bytes(key)))

if d % 2 == 0:
ub = (lb + ub) / 2
else:
lb = (lb + ub) / 2

assert lb == ub

return long_to_bytes(lb)

def unpad(s):
n = ord(s[-1])
return s[:-n]

def main():
t = remote(host, port)
#t = remote('localhost', port)

p = log.progress('Recovering RNG seed and RSA modulus')
n, cracker = recover_modulus_and_seed(t)
iv = long_to_bytes(cracker.predict_getrandbits(128))
p.success('Seed and modulus recovered!')

enc_flag = get_enc_flag(t)

p = log.progress('Recovering AES key')
aes_key = recover_aes_key(t, n, get_enc_key(t))
p.success('AES key recovered!')

# The last byte of the key is not recovered correctly for whatever reason
# but we can just bruteforce it
for i in range(256):
aes = AES.new(aes_key[:-1] + struct.pack('B', i), AES.MODE_CBC, iv)
try:
dec = unpad(aes.decrypt(enc_flag))
if 'TWCTF' in dec:
log.success('Flag: {}'.format(dec))
exit(0)
except:
pass

if __name__ == '__main__':
main()
```

```
$ python solve.py
[+] Opening connection to crypto.chal.ctf.westerns.tokyo on port 5643: Done
[+] Recovering RNG seed and RSA modulus: Seed and modulus recovered!
[+] Recovering AES key: AES key recovered!
[+] Flag: TWCTF{L#B_de#r#pti#n_ora#le_c9630b129769330c9498858830f306d9}
[*] Closed connection to crypto.chal.ctf.westerns.tokyo port 5643
```

Original writeup (https://github.com/ctf-epfl/writeups/blob/master/twctf18/mixed-cipher/mixed-cipher.ipynb).