Rating:

## zerolfsr-

The challenge presents us with the following Python script:

```python
import random
import signal
import socketserver
import string
from hashlib import sha256
from os import urandom
from secret import flag

def _prod(L):
p = 1
for x in L:
p *= x
return p

def _sum(L):
s = 0
for x in L:
s ^= x
return s

def n2l(x, l):
return list(map(int, '{{0:0{}b}}'.format(l).format(x)))

class Generator1:
def __init__(self, key: list):
assert len(key) == 64
self.NFSR = key[: 48]
self.LFSR = key[48: ]
self.TAP = [0, 1, 12, 15]
self.TAP2 = [[2], [5], [9], [15], [22], [26], [39], [26, 30], [5, 9], [15, 22, 26], [15, 22, 39], [9, 22, 26, 39]]
self.h_IN = [2, 4, 7, 15, 27]
self.h_OUT = [[1], [3], [0, 3], [0, 1, 2], [0, 2, 3], [0, 2, 4], [0, 1, 2, 4]]

def g(self):
x = self.NFSR
return _sum(_prod(x[i] for i in j) for j in self.TAP2)

def h(self):
x = [self.LFSR[i] for i in self.h_IN[:-1]] + [self.NFSR[self.h_IN[-1]]]
return _sum(_prod(x[i] for i in j) for j in self.h_OUT)

def f(self):
return _sum([self.NFSR[0], self.h()])

def clock(self):
o = self.f()
self.NFSR = self.NFSR[1: ] + [self.LFSR[0] ^ self.g()]
self.LFSR = self.LFSR[1: ] + [_sum(self.LFSR[i] for i in self.TAP)]
return o

class Generator2:
def __init__(self, key):
assert len(key) == 64
self.NFSR = key[: 16]
self.LFSR = key[16: ]
self.TAP = [0, 35]
self.f_IN = [0, 10, 20, 30, 40, 47]
self.f_OUT = [[0, 1, 2, 3], [0, 1, 2, 4, 5], [0, 1, 2, 5], [0, 1, 2], [0, 1, 3, 4, 5], [0, 1, 3, 5], [0, 1, 3], [0, 1, 4], [0, 1, 5], [0, 2, 3, 4, 5], [
0, 2, 3], [0, 3, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4], [1, 2, 3, 5], [1, 2], [1, 3, 5], [1, 3], [1, 4], [1], [2, 4, 5], [2, 4], [2], [3, 4], [4, 5], [4], [5]]
self.TAP2 = [[0, 3, 7], [1, 11, 13, 15], [2, 9]]
self.h_IN = [0, 2, 4, 6, 8, 13, 14]
self.h_OUT = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 4, 6], [1, 3, 4]]

def f(self):
x = [self.LFSR[i] for i in self.f_IN]
return _sum(_prod(x[i] for i in j) for j in self.f_OUT)

def h(self):
x = [self.NFSR[i] for i in self.h_IN]
return _sum(_prod(x[i] for i in j) for j in self.h_OUT)

def g(self):
x = self.NFSR
return _sum(_prod(x[i] for i in j) for j in self.TAP2)

def clock(self):
self.LFSR = self.LFSR[1: ] + [_sum(self.LFSR[i] for i in self.TAP)]
self.NFSR = self.NFSR[1: ] + [self.LFSR[1] ^ self.g()]
return self.f() ^ self.h()

class Generator3:
def __init__(self, key: list):
assert len(key) == 64
self.LFSR = key
self.TAP = [0, 55]
self.f_IN = [0, 8, 16, 24, 32, 40, 63]
self.f_OUT = [[1], [6], [0, 1, 2, 3, 4, 5], [0, 1, 2, 4, 6]]

def f(self):
x = [self.LFSR[i] for i in self.f_IN]
return _sum(_prod(x[i] for i in j) for j in self.f_OUT)

def clock(self):
self.LFSR = self.LFSR[1: ] + [_sum(self.LFSR[i] for i in self.TAP)]
return self.f()

class zer0lfsr:
def __init__(self, msk: int, t: int):
if t == 1:
self.g = Generator1(n2l(msk, 64))
elif t == 2:
self.g = Generator2(n2l(msk, 64))
else:
self.g = Generator3(n2l(msk, 64))
self.t = t

def next(self):
for i in range(self.t):
o = self.g.clock()
return o

class Task(socketserver.BaseRequestHandler):
def __init__(self, *args, **kargs):
super().__init__(*args, **kargs)

def proof_of_work(self):
random.seed(urandom(8))
proof = ''.join([random.choice(string.ascii_letters + string.digits + '!#$%&*-?') for _ in range(20)])
digest = sha256(proof.encode()).hexdigest()
self.dosend('sha256(XXXX + {}) == {}'.format(proof[4: ], digest))
self.dosend('Give me XXXX:')
x = self.request.recv(10)
x = (x.strip()).decode('utf-8')
if len(x) != 4 or sha256((x + proof[4: ]).encode()).hexdigest() != digest:
return False
return True

def dosend(self, msg):
try:
self.request.sendall(msg.encode('latin-1') + b'\n')
except:
pass

def timeout_handler(self, signum, frame):
raise TimeoutError

def handle(self):
try:
signal.signal(signal.SIGALRM, self.timeout_handler)
signal.alarm(30)
if not self.proof_of_work():
self.dosend('You must pass the PoW!')
return
signal.alarm(50)
available = [1, 2, 3]
for _ in range(2):
self.dosend('which one: ')
idx = int(self.request.recv(10).strip())
assert idx in available
available.remove(idx)
msk = random.getrandbits(64)
lfsr = zer0lfsr(msk, idx)
for i in range(5):
keystream = ''
for j in range(1000):
b = 0
for k in range(8):
b = (b << 1) + lfsr.next()
keystream += chr(b)
self.dosend('start:::' + keystream + ':::end')
hint = sha256(str(msk).encode()).hexdigest()
self.dosend('hint: ' + hint)
self.dosend('k: ')
guess = int(self.request.recv(100).strip())
if guess != msk:
self.dosend('Wrong ;(')
self.request.close()
else:
self.dosend('Good :)')
self.dosend(flag)
except TimeoutError:
self.dosend('Timeout!')
self.request.close()
except:
self.dosend('Wtf?')
self.request.close()

class ThreadedServer(socketserver.ForkingMixIn, socketserver.TCPServer):
pass

if __name__ == "__main__":
HOST, PORT = '0.0.0.0', 31337
server = ThreadedServer((HOST, PORT), Task)
server.allow_reuse_address = True
server.serve_forever()

```
### Challenge

The server first asks for a proof of work. Then you need to specify an index between 1 and 3. This index determines the Generator that will be used as your LFSR. The server will then send you 5 times 1000 bytes of the LFSR (so 40000 LFSR output bits). It will also return a hash of the key you need to guess as a hint. If you guess the key correctly, you have to choose another index that you did not yet use. The whole thing repeats and if you got both keys correct, you are given the flag.

### Solution Script

We solve the challenge using the Z3 SMT solver. First, the script solves the proof of work challenge using brute-force. It then converts the character encoded bytes into a list of bits. We use only a small amount of the 40k bits (200 seemed to be enough) as constraints for Z3. Using all of the bits would take too much time to solve. I chose Generator3 and Generator1 as targets, because they worked reasonably fast with Z3. I had to modify the function ```n2l(x, l)``` from the task script since the format string did not work with the BitVec from Z3. ```n2l(x, l)``` converts an integer into a bit list.

```python
from pwn import *
from hashlib import sha256
from itertools import product
import re
import z3
import bitarray
import random

# this was modified to be able to use z3
def n2l(x, l):
if isinstance(x, int):
return list(map(int, '{{0:0{}b}}'.format(l).format(x)))
return [z3.Extract(i, i, x) for i in reversed(range(l))]

host = args.HOST or '111.186.59.28'
port = int(args.PORT or 31337)

def start():
return connect(host, port)

def solve_proof_of_work(proof_of_work_line) :
hashable_suffix = re.search('sha256\(XXXX \+ (.*)\) ==', proof_of_work_line).group(1)
hash = re.search('== (.*)', proof_of_work_line).group(1)
alphabet = (string.ascii_letters + string.digits + '!#$%&*-?')
for hashable_prefix in product(alphabet, repeat=4) :
current_hash_in_hex = sha256((''.join(hashable_prefix) + hashable_suffix).encode()).hexdigest()
if current_hash_in_hex == hash :
return ''.join(hashable_prefix)

def bytes_to_bit_list(bytes) :
ba = bitarray.bitarray()
ba.frombytes(bytes)
return ba.tolist()

def test_if_keystreams_matches_output_bits(keystreams, output_bits) :
keystreams_test = []
test_index = 0
for i in range(5):
tmp_keystream = b""
for j in range(1000):
b = 0
for k in range(8):
b = (b << 1) + lfsr_output_bits[test_index]
test_index += 1
tmp_keystream += b.to_bytes(1, 'big')
keystreams_test.append(tmp_keystream)
return keystreams_test == keystreams_test

def solve_with_z3(lfsr_output_bits, generator_index, prefix_bits_to_check) :
msk = z3.BitVec('msk', 64)
lfsr = zer0lfsr(msk, generator_index)

solver = z3.Solver()

for i in range(prefix_bits_to_check):
solver.add(lfsr_output_bits[i] == lfsr.next())

log.info("solver.check(): %s", solver.check())

model = solver.model()

log.info("model: %s", model)

return model.evaluate(msk).as_long()

io = start()

proof_of_work_line = io.recvline(keepends=False).decode("utf-8")
io.recvline()

proof = solve_proof_of_work(proof_of_work_line)
io.sendline(proof)

prefix_bits_to_check = 200
generator_indices = [3, 1]
for generator_index in generator_indices :
io.recvline() # which one:
io.sendline(str(generator_index)) # choose Generator[generator_index]

keystreams = []
for i in range(5) :
start = io.recvn(8)
keystream = io.recvn(1000)
end = io.recvn(7) #including the next line byte

keystreams.append(keystream)


hint_line = io.recvline(keepends=False).decode("utf-8")
hint = re.search('hint: (.*)', hint_line).group(1)

io.recvline() #k:

lfsr_output_bits = []
for i in range(5) :
lfsr_output_bits += bytes_to_bit_list(keystreams[i])

assert test_if_keystreams_matches_output_bits(keystreams, lfsr_output_bits)

solved_msk = solve_with_z3(lfsr_output_bits, generator_index, prefix_bits_to_check)
log.info("solved_msk: %s", solved_msk)

hint_compare = sha256(str(solved_msk).encode()).hexdigest()
log.info("hint_compare: %s", hint_compare)
log.info("hint: %s", hint)

io.sendline(str(solved_msk))

maybe_good = io.recvline(keepends=False) # Good :)
log.info("maybe_good: %s", maybe_good)

flag = io.recvline(keepends=False)
log.info("flag: %s", flag)
io.shutdown()
```