Tags: matrix crt root_of_unity 

Rating:

  • Undo the random power with modular inverse mod |GL(3, p)|
  • Take a linear combination of polynomial evaluations on 4th roots of unity to recover the sum along the diagonal
  • combine multiple moduli with CRT

Sage solution:

proof.all(False)

from pwn import remote, info, process
from Crypto.Util.number import long_to_bytes
residues, moduli = [], []
for _ in range(10):
    while True:
        info(".")
        n = 3
        while True:
            try:
                p = random_prime(2^64)
                F = GF(p)
                                # Just hope we've got a 4th root of unity mod p :)
                I = F(F(-1).sqrt())
                break
            except:
                continue

        G = GL(n, p)
        order = G.order()

        for _ in range(20):
            io = remote("litctf.live", "31782")
            # io = process(["python3", "-u", "susschal.py"])
            io.sendlineafter(b"mod: ", str(p).encode())
            io.recvuntil(b" by ")
            power = int(io.recvline())
            if gcd(power, order) != 1:
                              # Can't invert it, give up and try another one
                io.close()
                continue
            evals = []
            assert int(int(I) * int(I) % p) == int(p - 1)
            for x in [1, p - 1, I, p - I]:
                info("%d, %d, %d", power, order, gcd(power, order))
                io.sendlineafter(b": ", str(pow(power, -1, order)).encode())
                io.sendlineafter(b": ", str(x).encode())
                io.recvuntil(b" encryption is ")
                evals.append(F(io.recvline()))
            io.stream()
            break
        else:
            continue
                # Find the right linear combination
        M = [[x^i for i in range(1, 5)] for x in [1, p - 1, I, p - I]]
        coeffs = Matrix(M).solve_left(vector([1, 0, 0, 0]))
        residues.append(int(coeffs * vector(evals)))
        moduli.append(p)
        break

print(int(crt(residues, moduli)).bit_length())
print(long_to_bytes(int(crt(residues, moduli))))