Rating: 4.5

Because how the `n` (modulus of the RSA key) generation works the `n` - 1 must be divisible by `s` where `s` is the prime found based on the user input.
For the input `000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010000000000000000000` the algorithm will generate a `1000000000000000000000000000` + `x` prime where `x` is a 16 bit number.
That's really easy to bruteforce.

Based on this prime the program generates two other numbers `q` and `p`.

`p` =`((s * a * 2) + 1) ` where `a` is in `[1,s]`

`q` =`((s * b * 2) + 1) ` where `b` is in `[1,s]` and `b` != `a`

Because of the RSA `n` = `p` * `q` and we know `s` (from the bruteforce) if the `a+b` < `s` (good chance) its really easy to find `a * b` and `a + b`.
If we find the value of `a` and `b` we have the private key and they can be found by solving the quadratic equation.

```
n = ((s * a * 2) + 1) * ((s * b * 2) + 1)
n = (4 * s * s * a * b) + (2 * s * a + 2 * s * b) + 1 / - 1
n - 1 = (4 * s * s * a * b) + (2 * s * a) + (2 * s * b) / / (2 * s)

(n - 1) / (2 * s) = (2 * s * a * b) + a + b / % s THIS ONLY WORKS IF a + b < s
((n - 1) / (2 * s)) % s = a + b

(n - 1) / (2 * s) = (2 * s * a * b) + a + b / - a+b
((n - 1) / (2 * s)) - (((n - 1) / (2 * s)) % s) = 2 * s * a * b / / (2 * s)
(((n - 1) / (2 * s)) - (((n - 1) / (2 * s)) % s)) / (2 * s) = a * b
```

POC:
```
#include <stdint.h>
#include <string.h>
#include <openssl/bn.h>
#include <openssl/rsa.h>
#include <openssl/pem.h>
#include <openssl/err.h>

BIGNUM *g_BN_0;
BIGNUM *g_BN_1;
BIGNUM *g_BN_3;
BIGNUM *g_BN_5;
BIGNUM *g_BN_7;
BN_CTX *g_BN_ctx;

void create_and_set(BIGNUM **bn, uint64_t val)
{
*bn = BN_new();
BN_set_word(*bn, val);
}

void init_BN()
{
create_and_set(&g_BN_0, 0LL);
create_and_set(&g_BN_1, 1LL);
create_and_set(&g_BN_3, 3LL);
create_and_set(&g_BN_5, 5LL);
create_and_set(&g_BN_7, 7LL);
g_BN_ctx = BN_CTX_new();
}

RSA * createRsaStruct(BIGNUM * p, BIGNUM * q)
{
RSA *rsa; // [rsp+10h] [rbp-60h]
BIGNUM *n; // [rsp+18h] [rbp-58h]
BIGNUM *e; // [rsp+20h] [rbp-50h]
BIGNUM *d; // [rsp+28h] [rbp-48h]
BIGNUM * v7; // [rsp+30h] [rbp-40h]
BIGNUM * v8; // [rsp+38h] [rbp-38h]
BIGNUM * v9; // [rsp+40h] [rbp-30h]
BIGNUM *v10; // [rsp+48h] [rbp-28h]
BIGNUM *v11; // [rsp+50h] [rbp-20h]
BIGNUM *v12; // [rsp+58h] [rbp-18h]
BIGNUM * v13; // [rsp+60h] [rbp-10h]
BIGNUM * v14; // [rsp+68h] [rbp-8h]

rsa = RSA_new();
n = BN_new();
e = BN_new();
d = BN_new();
v7 = BN_new();
v8 = BN_new();
v9 = BN_new();
v10 = BN_new();
v11 = BN_new();
v12 = BN_new();
v13 = BN_new();
v14 = BN_new();
BN_mul(n, p, q, g_BN_ctx);
BN_set_word(e, 0x10001LL);
BN_sub(v13, p, g_BN_1);
BN_div(0LL, v10, d, v13, g_BN_ctx);
BN_sub(v14, q, g_BN_1);
BN_div(0LL, v11, d, v14, g_BN_ctx);
BN_mod_inverse(v12, q, p, g_BN_ctx);
BN_gcd(v7, v13, v14, g_BN_ctx);
BN_mul(v8, v13, v14, g_BN_ctx);
BN_div(v9, 0LL, v8, v7, g_BN_ctx);
BN_mod_inverse(d, e, v9, g_BN_ctx);
rsa->p = p;
rsa->q = q;
printf("n: %s\n", BN_bn2hex(n));
rsa->n = n;
rsa->e = e;
rsa->d = d;
rsa->dmp1 = v10;
rsa->dmq1 = v11;
rsa->iqmp = v12;
return rsa;
}

// cames from boooring ssl
int BN_sqrt(BIGNUM *out_sqrt, const BIGNUM *in, BN_CTX *ctx) {
BIGNUM *estimate, *tmp, *delta, *last_delta, *tmp2;
int ok = 0, last_delta_valid = 0;

if (in->neg) {
return 0;
}
if (BN_is_zero(in)) {
BN_zero(out_sqrt);
return 1;
}

BN_CTX_start(ctx);
if (out_sqrt == in) {
estimate = BN_CTX_get(ctx);
} else {
estimate = out_sqrt;
}
tmp = BN_CTX_get(ctx);
last_delta = BN_CTX_get(ctx);
delta = BN_CTX_get(ctx);
if (estimate == NULL || tmp == NULL || last_delta == NULL || delta == NULL) {
goto err;
}

// We estimate that the square root of an n-bit number is 2^{n/2}.
if (!BN_lshift(estimate, BN_value_one(), BN_num_bits(in)/2)) {
goto err;
}

// This is Newton's method for finding a root of the equation |estimate|^2 -
// |in| = 0.
for (;;) {
// |estimate| = 1/2 * (|estimate| + |in|/|estimate|)
if (!BN_div(tmp, NULL, in, estimate, ctx) ||
!BN_add(tmp, tmp, estimate) ||
!BN_rshift1(estimate, tmp) ||
// |tmp| = |estimate|^2
!BN_sqr(tmp, estimate, ctx) ||
// |delta| = |in| - |tmp|
!BN_sub(delta, in, tmp)) {
goto err;
}

delta->neg = 0;
// The difference between |in| and |estimate| squared is required to always
// decrease. This ensures that the loop always terminates, but I don't have
// a proof that it always finds the square root for a given square.
if (last_delta_valid && BN_cmp(delta, last_delta) >= 0) {
break;
}

last_delta_valid = 1;

tmp2 = last_delta;
last_delta = delta;
delta = tmp2;
}

if (BN_cmp(tmp, in) != 0) {
goto err;
}

ok = 1;

err:
if (ok && out_sqrt == in && !BN_copy(out_sqrt, estimate)) {
ok = 0;
}
BN_CTX_end(ctx);
return ok;
}

RSA *haxme(BIGNUM * n)
{
// Find the prime based on our input
BIGNUM * prime = BN_new();
BN_hex2bn(&prime, "1000000000000000000000000000");
BIGNUM * mod = BN_new();
for (int i = 0; i < 0x10000; i++)
{
BN_add_word(prime, 1);

BN_mod(mod, n, prime, g_BN_ctx);
if (BN_cmp(mod, g_BN_1) != 0)
continue;

printf("prime: %s\n", BN_bn2hex(prime));
break;
}

BIGNUM *nm1 = BN_new();
BN_sub(nm1, n, g_BN_1);
BIGNUM *abp2aADDb = BN_new();
BIGNUM *rem = BN_new();
BN_div(abp2aADDb, rem, nm1, prime, g_BN_ctx);
BN_rshift1(abp2aADDb, abp2aADDb);

BIGNUM *aADDb = BN_new();
BN_mod(aADDb, abp2aADDb, prime, g_BN_ctx); // a + b

BIGNUM *aMULb = BN_new();
BN_sub(aMULb, abp2aADDb, aADDb);
BN_rshift1(aMULb, aMULb);
BN_div(aMULb, rem, aMULb, prime, g_BN_ctx); // a * b

if (BN_cmp(rem, g_BN_0) != 0)
{
printf("FAIL\n"); // a + b > p #fixme
exit(0);
}


BIGNUM *aADDbDIV2 = BN_new();
BN_rshift1(aADDbDIV2, aADDb);

BIGNUM *tmp = BN_new();
BN_mul(tmp, aADDbDIV2, aADDbDIV2, g_BN_ctx);
BN_sub(tmp, tmp, aMULb);

BN_sqrt(tmp, tmp, g_BN_ctx);

BIGNUM *a = BN_new();
BN_add(a, aADDbDIV2, tmp);
printf("a: %s\n", BN_bn2hex(a));

BIGNUM *b = BN_new();
BN_sub(b, aADDbDIV2, tmp);
printf("b: %s\n", BN_bn2hex(b));

BIGNUM * p = BN_new();
BN_mul(p, prime, a, g_BN_ctx);
BN_lshift1(p, p);
BN_add(p, p, g_BN_1);

BIGNUM * q = BN_new();
BN_mul(q, prime, b, g_BN_ctx);
BN_lshift1(q, q);
BN_add(q, q, g_BN_1);

return createRsaStruct(p, q);
}

void do_BN()
{
FILE *f = fopen("pub.pem", "rb");
RSA *rsa = NULL;
PEM_read_RSAPublicKey(f, &rsa, 0, 0);

rsa = haxme(rsa->n);

FILE *fd = fopen("q.hex", "rb");
char data[1000];
int len = fread(data, 1, 1000, fd);

char dec[1000];
memset(dec, 0, 1000);
RSA_private_decrypt(len, data, dec, rsa, 1);

printf("%s\n", dec);

int a, b;
sscanf (dec,"What's the sum of %d and %d?", &a, &b);

int64_t x = a;
x += b;

printf("%ld\n", x);
}

void free_BN()
{
BN_free(g_BN_0);
BN_free(g_BN_1);
BN_free(g_BN_3);
BN_free(g_BN_5);
BN_free(g_BN_7);
}

int main()
{
init_BN();
do_BN();
free_BN();
return 0LL;
}
```
A little wrapper written in python:
```
#!/usr/bin/env python

import socket, binascii, os

HOST = 'coooppersmith.challenges.ooo'
PORT = 5000

s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((HOST, PORT))

def readln(s):
data = ''
c = s.recv(1)
while c == None:
c = s.recv(1)
while c != '\n':
data = data + c
c = s.recv(1)
while c == None:
c = s.recv(1)
return data

s.sendall(b'000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010000000000000000000\n')

while readln(s) != '-----BEGIN RSA PUBLIC KEY-----':
pass

rsa = '-----BEGIN RSA PUBLIC KEY-----\n'

line = readln(s)
while line != '-----END RSA PUBLIC KEY-----':
rsa += line + '\n'
line = readln(s)
rsa += line + '\n'

open('pub.pem', 'wb').write(rsa)

while readln(s) != 'Question: ':
pass

question = readln(s)

open('q.hex', 'wb').write(binascii.unhexlify(question))

stream = os.popen('./keygen')
output = stream.read()
response = output.splitlines()[-1]
print response
if response == 'FAIL':
exit()

s.sendall(response + '\n')

while readln(s) != 'Your flag message:':
pass

question = readln(s)

open('q.hex', 'wb').write(binascii.unhexlify(question))

stream = os.popen('./keygen')
output = stream.read()

response = output.splitlines()[-2]
print response
```

nns2009May 21, 2020, 9:28 p.m.

(n - 1) / (2 * s) = (2 * s * a * b) + a + b / % s THIS ONLY WORKS IF a + b < s
((n - 1) / (2 * s)) % s = a + b

Why not just use:
((n - 1) / (2 * s)) % (2 * s) = a + b
?

It will work if a + b < 2 * s
which is almost always the case,
unless a = b = s (which might happen in one single case)


__loose_r__May 25, 2020, 2:45 p.m.

@nns2009: You are absolutely right. We just published how we solved this challenge. However, later we realised this mistake too, but it worked around 20% of cases, which is more than enough to get flag.:D

BTW.: a+b always less than 2s, because a and b cannot be equal.