Processing math: 100%

Tags: crypto 

Rating:

We interact with a server running

require "functions_framework"
require "digest/sha2"

fail unless ENV["FLAG"]

key = JSON.parse(File.read("pubkey.txt"))
n = key["n"].to_i
k = key["k"].to_i

EXPECTED_MESSAGE = 'SUNSHINE RHYTHM'

FunctionsFramework.http("index") do |request|
  if request.request_method != "POST"
    return "Bad Request"
  end

  data = JSON.parse(request.body.read)
  cmd = data["cmd"]
  if cmd == "pubkey"
    return { pubkey: { n: n.to_s, k: k.to_s } }
  elsif cmd == "verify"
    x = data["x"].to_i
    y = data["y"].to_i
    msg = data["msg"].to_s
    hash = ""
    4.times do |i|
      hash += Digest::SHA512.hexdigest(msg + i.to_s)
    end
    hash = hash.to_i(16) % n
    signature = (x ** 2 + k * y ** 2) % n


    if signature == hash
      if msg == EXPECTED_MESSAGE
        return { result: ENV["FLAG"] }
      end
      return { result: "verify success" }
    else
      return { result: "verify failed" }
    end
  else
    return "invalid command"
  end
end

Keys are generated with

require "openssl"
require "json"

p = OpenSSL::BN.generate_prime(1024)
q = OpenSSL::BN.generate_prime(1024)
k = OpenSSL::BN.generate_prime(2048, false)
n = p * q
File.write("pubkey.txt", { n: n.to_s, k: k.to_s }.to_json)

We have primes p,q,k with n=pq. The server implements the Ong-Schnorr-Schamir signature. It consists of x,y such that x2+ky2=h(m)modn.

This signature has been proven insecure by Pollard and Schnorr as solutions to the above congruence can be computed efficiently. To solve the challenge, we implement the algorithm given in https://ieeexplore.ieee.org/document/1057350. We refer to the paper for a detailed explanation. The below code references the relevant equations.

import gmpy2 as gmp
import math

# https://ieeexplore.ieee.org/document/1057350

# https://eli.thegreenplace.net/2009/03/07/computing-modular-square-roots-in-python
def modular_sqrt(a, p):
    """ Find a quadratic residue (mod p) of 'a'. p
        must be an odd prime.

        Solve the congruence of the form:
            x^2 = a (mod p)
        And returns x. Note that p - x is also a root.

        0 is returned is no square root exists for
        these a and p.

        The Tonelli-Shanks algorithm is used (except
        for some simple cases in which the solution
        is known from an identity). This algorithm
        runs in polynomial time (unless the
        generalized Riemann hypothesis is false).
    """
    # Simple cases
    #
    if legendre_symbol(a, p) != 1:
        return 0
    elif a == 0:
        return 0
    elif p == 2:
        return 0
    elif p % 4 == 3:
        return pow(a, (p + 1) // 4, p)

    # Partition p-1 to s * 2^e for an odd s (i.e.
    # reduce all the powers of 2 from p-1)
    #
    s = p - 1
    e = 0
    while s % 2 == 0:
        s //= 2
        e += 1

    # Find some 'n' with a legendre symbol n|p = -1.
    # Shouldn't take long.
    #
    n = 2
    while legendre_symbol(n, p) != -1:
        n += 1

    # Here be dragons!
    # Read the paper "Square roots from 1; 24, 51,
    # 10 to Dan Shanks" by Ezra Brown for more
    # information
    #

    # x is a guess of the square root that gets better
    # with each iteration.
    # b is the "fudge factor" - by how much we're off
    # with the guess. The invariant x^2 = ab (mod p)
    # is maintained throughout the loop.
    # g is used for successive powers of n to update
    # both a and b
    # r is the exponent - decreases with each update
    #
    x = pow(a, (s + 1) // 2, p)
    b = pow(a, s, p)
    g = pow(n, s, p)
    r = e

    while True:
        t = b
        m = 0
        for m in range(r):
            if t == 1:
                break
            t = pow(t, 2, p)

        if m == 0:
            return x

        gs = pow(g, 2 ** (r - m - 1), p)
        g = (gs * gs) % p
        x = (x * gs) % p
        b = (b * g) % p
        r = m


def legendre_symbol(a, p):
    """ Compute the Legendre symbol a|p using
        Euler's criterion. p is a prime, a is
        relatively prime to p (if p divides
        a, then a|p = 0)

        Returns 1 if a has a square root modulo
        p, -1 otherwise.
    """
    ls = pow(a, (p - 1) // 2, p)
    return -1 if ls == p - 1 else ls


# use (2), (3)
# (2) x_1^2 + ky_1^2(x_2^2 + ky_2^2) = X^2 + kY^2
# (3) X = x_1x_2 +- ky_1y_2, Y = x_1y_2 -+ x_2y_1
def combine(k, x1, y1, x2, y2):
    return x1 * x2 + k * y1 * y2, x1 * y2 - x2 * y1


rand_state = gmp.random_state()

def solve(k, m, n):
    print((k, m, n))
    print(f"k length = {k.bit_length()}")

    # 1) assume n is composite

    # 2) Replace m with m'

    # compute m0
    while 1:
        u = gmp.mpz_random(rand_state, n)
        v = gmp.mpz_random(rand_state, n)
        m0 = m * (u ** 2 + k * v ** 2) % n
        if gmp.is_prime(m0):  # paper uses a different algorithm for sqrt and does not check primality
            x0 = modular_sqrt(-k, m0)
            if (x0 ** 2 + k) % m0 == 0:
                break

    # (6)
    sqrt_k = gmp.isqrt(abs(k))
    mi = [m0]
    xi = [x0]
    while 1:
        mi.append((xi[-1] ** 2 + k) // mi[-1])
        if k > 0 and xi[-1] <= mi[-1] <= mi[-2] or k < 0 and abs(mi[-1]) <= sqrt_k:
            break
        xi.append(min(xi[-1] % mi[-1], mi[-1] - (xi[-1] % mi[-1])))

    I = len(xi)
    print(f"I = {I}")
    for i in range(I):
        assert xi[i] ** 2 + k == mi[i] * mi[i + 1]

    m_ = mi[-1]

    # (8)
    s = xi[0]
    t = 1
    for i in range(1, I):
        s, t = combine(k, s, t, xi[i], 1)

    # (9)
    M = math.prod(mi[1:]) % n
    Mi = gmp.invert(M, n)
    U = s * Mi % n
    V = t * Mi % n
    assert (U ** 2 + k * V ** 2) % n == m0 * gmp.invert(m_, n) % n

    # 3) compute x, y for m'
    if gmp.is_square(m_ % n):
        x, y = gmp.isqrt(m_ % n), 0
    elif (m_ - k) % n == 0:  # m_ == k
        x, y = 0, 1
    else:  # 4) recursion
        x_, y_ = solve(-m_, -k, n)
        y = gmp.invert(y_, n)
        x = x_ * y % n

    assert (x ** 2 + k * y ** 2 - m_) % n == 0

    # 5)
    # get solution for m0
    x_, y_ = combine(k, U, V, x, y)
    assert (x_ ** 2 + k * y_ ** 2 - m0) % n == 0

    x, y = combine(k, u, v, x_, y_)
    assert (x ** 2 + k * y ** 2 - m0 ** 2 * gmp.invert(m, n)) % n == 0

    m0i = gmp.invert(m0, n)
    x = x * m * m0i % n
    y = y * m * m0i % n
    assert (x ** 2 + k * y ** 2 - m) % n == 0

    return x, y


n = 25299128324054183472341067223932160732879350179758036557232544635970111090474692853470743347443422497121006796606102551210094872253782062717537548880909979729182337501587763866901367212812697076494080678616385493076865655574412317879297160790121009524506015912113098690685202868184636344610142590510988192306870694667596904330867479578103616304053889409982447653859514868824002960431331342963562137691362725961627846051021103954795862501700267818317148154520620016172888281127685503677751830350686839873220480306266506898497203511851305686566444690384065880667273398255172752236076702247451872387522388546088290187449
k = 31019613858513746556266176233462864650379070310554671955689986199007361221356361736128815989480106678809272137963430923820800280374078610631771089089882153619351592434728588050285853284795554255483472955286848474793299446184220594124878818081534965835159741218233013815338595300394855159744354636541274026478456851924371621879725248093305782590590080796638483359868136648681381332610536250576568502512250581068814961097404403694071264894656697723213779631364079010490113719021172301802643377777927176399460547584115127172190000090756708138720022664973312744713394243720961199400948876916817452969615149776530401604593
z = 7647621505426523107416876116179503771882358225602233882688080500019700416074079641481433709829408187069837328548192260242213288375752541771980455357725520773932150990241200106710194070936521249397968646028251975570300856774501827288072745234025035794496061930945682174098054405440546453789891131472968763963211701605509016308957961233593474796671835082111731734173903413989101376338153157881536000538617570527656809006766854781353000424633201559848473157638877908589222891590426378306548494548233753531300670238544815960843418612332546932989237593742552831285809303013106494916607063536993429613432733764529882168669

x, y = solve(gmp.mpz(k), gmp.mpz(z), gmp.mpz(n))  # use gmp for better performance
print(f"x = {x}\ny = {y}")