Tags: crypto 

Rating:


We interact with a server running
```rb
require "functions_framework"
require "digest/sha2"

fail unless ENV["FLAG"]

key = JSON.parse(File.read("pubkey.txt"))
n = key["n"].to_i
k = key["k"].to_i

EXPECTED_MESSAGE = 'SUNSHINE RHYTHM'

FunctionsFramework.http("index") do |request|
if request.request_method != "POST"
return "Bad Request"
end

data = JSON.parse(request.body.read)
cmd = data["cmd"]
if cmd == "pubkey"
return { pubkey: { n: n.to_s, k: k.to_s } }
elsif cmd == "verify"
x = data["x"].to_i
y = data["y"].to_i
msg = data["msg"].to_s
hash = ""
4.times do |i|
hash += Digest::SHA512.hexdigest(msg + i.to_s)
end
hash = hash.to_i(16) % n
signature = (x ** 2 + k * y ** 2) % n

if signature == hash
if msg == EXPECTED_MESSAGE
return { result: ENV["FLAG"] }
end
return { result: "verify success" }
else
return { result: "verify failed" }
end
else
return "invalid command"
end
end
```

Keys are generated with
```rb
require "openssl"
require "json"

p = OpenSSL::BN.generate_prime(1024)
q = OpenSSL::BN.generate_prime(1024)
k = OpenSSL::BN.generate_prime(2048, false)
n = p * q
File.write("pubkey.txt", { n: n.to_s, k: k.to_s }.to_json)
```

We have primes $p,q,k$ with $n=pq$.
The server implements the Ong-Schnorr-Schamir signature.
It consists of $x,y$ such that $x^2+ky^2=h(m)\mod n$.

This signature has been proven insecure by Pollard and Schnorr as solutions to the above congruence can be computed efficiently.
To solve the challenge, we implement the algorithm given in https://ieeexplore.ieee.org/document/1057350.
We refer to the paper for a detailed explanation. The below code references the relevant equations.

```py
import gmpy2 as gmp
import math

# https://ieeexplore.ieee.org/document/1057350

# https://eli.thegreenplace.net/2009/03/07/computing-modular-square-roots-in-python
def modular_sqrt(a, p):
""" Find a quadratic residue (mod p) of 'a'. p
must be an odd prime.

Solve the congruence of the form:
x^2 = a (mod p)
And returns x. Note that p - x is also a root.

0 is returned is no square root exists for
these a and p.

The Tonelli-Shanks algorithm is used (except
for some simple cases in which the solution
is known from an identity). This algorithm
runs in polynomial time (unless the
generalized Riemann hypothesis is false).
"""
# Simple cases
#
if legendre_symbol(a, p) != 1:
return 0
elif a == 0:
return 0
elif p == 2:
return 0
elif p % 4 == 3:
return pow(a, (p + 1) // 4, p)

# Partition p-1 to s * 2^e for an odd s (i.e.
# reduce all the powers of 2 from p-1)
#
s = p - 1
e = 0
while s % 2 == 0:
s //= 2
e += 1

# Find some 'n' with a legendre symbol n|p = -1.
# Shouldn't take long.
#
n = 2
while legendre_symbol(n, p) != -1:
n += 1

# Here be dragons!
# Read the paper "Square roots from 1; 24, 51,
# 10 to Dan Shanks" by Ezra Brown for more
# information
#

# x is a guess of the square root that gets better
# with each iteration.
# b is the "fudge factor" - by how much we're off
# with the guess. The invariant x^2 = ab (mod p)
# is maintained throughout the loop.
# g is used for successive powers of n to update
# both a and b
# r is the exponent - decreases with each update
#
x = pow(a, (s + 1) // 2, p)
b = pow(a, s, p)
g = pow(n, s, p)
r = e

while True:
t = b
m = 0
for m in range(r):
if t == 1:
break
t = pow(t, 2, p)

if m == 0:
return x

gs = pow(g, 2 ** (r - m - 1), p)
g = (gs * gs) % p
x = (x * gs) % p
b = (b * g) % p
r = m

def legendre_symbol(a, p):
""" Compute the Legendre symbol a|p using
Euler's criterion. p is a prime, a is
relatively prime to p (if p divides
a, then a|p = 0)

Returns 1 if a has a square root modulo
p, -1 otherwise.
"""
ls = pow(a, (p - 1) // 2, p)
return -1 if ls == p - 1 else ls

# use (2), (3)
# (2) x_1^2 + ky_1^2(x_2^2 + ky_2^2) = X^2 + kY^2
# (3) X = x_1x_2 +- ky_1y_2, Y = x_1y_2 -+ x_2y_1
def combine(k, x1, y1, x2, y2):
return x1 * x2 + k * y1 * y2, x1 * y2 - x2 * y1

rand_state = gmp.random_state()

def solve(k, m, n):
print((k, m, n))
print(f"k length = {k.bit_length()}")

# 1) assume n is composite

# 2) Replace m with m'

# compute m0
while 1:
u = gmp.mpz_random(rand_state, n)
v = gmp.mpz_random(rand_state, n)
m0 = m * (u ** 2 + k * v ** 2) % n
if gmp.is_prime(m0): # paper uses a different algorithm for sqrt and does not check primality
x0 = modular_sqrt(-k, m0)
if (x0 ** 2 + k) % m0 == 0:
break

# (6)
sqrt_k = gmp.isqrt(abs(k))
mi = [m0]
xi = [x0]
while 1:
mi.append((xi[-1] ** 2 + k) // mi[-1])
if k > 0 and xi[-1] <= mi[-1] <= mi[-2] or k < 0 and abs(mi[-1]) <= sqrt_k:
break
xi.append(min(xi[-1] % mi[-1], mi[-1] - (xi[-1] % mi[-1])))

I = len(xi)
print(f"I = {I}")
for i in range(I):
assert xi[i] ** 2 + k == mi[i] * mi[i + 1]

m_ = mi[-1]

# (8)
s = xi[0]
t = 1
for i in range(1, I):
s, t = combine(k, s, t, xi[i], 1)

# (9)
M = math.prod(mi[1:]) % n
Mi = gmp.invert(M, n)
U = s * Mi % n
V = t * Mi % n
assert (U ** 2 + k * V ** 2) % n == m0 * gmp.invert(m_, n) % n

# 3) compute x, y for m'
if gmp.is_square(m_ % n):
x, y = gmp.isqrt(m_ % n), 0
elif (m_ - k) % n == 0: # m_ == k
x, y = 0, 1
else: # 4) recursion
x_, y_ = solve(-m_, -k, n)
y = gmp.invert(y_, n)
x = x_ * y % n

assert (x ** 2 + k * y ** 2 - m_) % n == 0

# 5)
# get solution for m0
x_, y_ = combine(k, U, V, x, y)
assert (x_ ** 2 + k * y_ ** 2 - m0) % n == 0

x, y = combine(k, u, v, x_, y_)
assert (x ** 2 + k * y ** 2 - m0 ** 2 * gmp.invert(m, n)) % n == 0

m0i = gmp.invert(m0, n)
x = x * m * m0i % n
y = y * m * m0i % n
assert (x ** 2 + k * y ** 2 - m) % n == 0

return x, y

n = 25299128324054183472341067223932160732879350179758036557232544635970111090474692853470743347443422497121006796606102551210094872253782062717537548880909979729182337501587763866901367212812697076494080678616385493076865655574412317879297160790121009524506015912113098690685202868184636344610142590510988192306870694667596904330867479578103616304053889409982447653859514868824002960431331342963562137691362725961627846051021103954795862501700267818317148154520620016172888281127685503677751830350686839873220480306266506898497203511851305686566444690384065880667273398255172752236076702247451872387522388546088290187449
k = 31019613858513746556266176233462864650379070310554671955689986199007361221356361736128815989480106678809272137963430923820800280374078610631771089089882153619351592434728588050285853284795554255483472955286848474793299446184220594124878818081534965835159741218233013815338595300394855159744354636541274026478456851924371621879725248093305782590590080796638483359868136648681381332610536250576568502512250581068814961097404403694071264894656697723213779631364079010490113719021172301802643377777927176399460547584115127172190000090756708138720022664973312744713394243720961199400948876916817452969615149776530401604593
z = 7647621505426523107416876116179503771882358225602233882688080500019700416074079641481433709829408187069837328548192260242213288375752541771980455357725520773932150990241200106710194070936521249397968646028251975570300856774501827288072745234025035794496061930945682174098054405440546453789891131472968763963211701605509016308957961233593474796671835082111731734173903413989101376338153157881536000538617570527656809006766854781353000424633201559848473157638877908589222891590426378306548494548233753531300670238544815960843418612332546932989237593742552831285809303013106494916607063536993429613432733764529882168669

x, y = solve(gmp.mpz(k), gmp.mpz(z), gmp.mpz(n)) # use gmp for better performance
print(f"x = {x}\ny = {y}")
```