Processing math: 100%

Tags: crypto 


We interact with a server running

require "functions_framework"
require "digest/sha2"

fail unless ENV["FLAG"]

key = JSON.parse("pubkey.txt"))
n = key["n"].to_i
k = key["k"].to_i


FunctionsFramework.http("index") do |request|
  if request.request_method != "POST"
    return "Bad Request"

  data = JSON.parse(
  cmd = data["cmd"]
  if cmd == "pubkey"
    return { pubkey: { n: n.to_s, k: k.to_s } }
  elsif cmd == "verify"
    x = data["x"].to_i
    y = data["y"].to_i
    msg = data["msg"].to_s
    hash = ""
    4.times do |i|
      hash += Digest::SHA512.hexdigest(msg + i.to_s)
    hash = hash.to_i(16) % n
    signature = (x ** 2 + k * y ** 2) % n

    if signature == hash
      if msg == EXPECTED_MESSAGE
        return { result: ENV["FLAG"] }
      return { result: "verify success" }
      return { result: "verify failed" }
    return "invalid command"

Keys are generated with

require "openssl"
require "json"

p = OpenSSL::BN.generate_prime(1024)
q = OpenSSL::BN.generate_prime(1024)
k = OpenSSL::BN.generate_prime(2048, false)
n = p * q
File.write("pubkey.txt", { n: n.to_s, k: k.to_s }.to_json)

We have primes p,q,k with n=pq. The server implements the Ong-Schnorr-Schamir signature. It consists of x,y such that x2+ky2=h(m)modn.

This signature has been proven insecure by Pollard and Schnorr as solutions to the above congruence can be computed efficiently. To solve the challenge, we implement the algorithm given in We refer to the paper for a detailed explanation. The below code references the relevant equations.

import gmpy2 as gmp
import math


def modular_sqrt(a, p):
    """ Find a quadratic residue (mod p) of 'a'. p
        must be an odd prime.

        Solve the congruence of the form:
            x^2 = a (mod p)
        And returns x. Note that p - x is also a root.

        0 is returned is no square root exists for
        these a and p.

        The Tonelli-Shanks algorithm is used (except
        for some simple cases in which the solution
        is known from an identity). This algorithm
        runs in polynomial time (unless the
        generalized Riemann hypothesis is false).
    # Simple cases
    if legendre_symbol(a, p) != 1:
        return 0
    elif a == 0:
        return 0
    elif p == 2:
        return 0
    elif p % 4 == 3:
        return pow(a, (p + 1) // 4, p)

    # Partition p-1 to s * 2^e for an odd s (i.e.
    # reduce all the powers of 2 from p-1)
    s = p - 1
    e = 0
    while s % 2 == 0:
        s //= 2
        e += 1

    # Find some 'n' with a legendre symbol n|p = -1.
    # Shouldn't take long.
    n = 2
    while legendre_symbol(n, p) != -1:
        n += 1

    # Here be dragons!
    # Read the paper "Square roots from 1; 24, 51,
    # 10 to Dan Shanks" by Ezra Brown for more
    # information

    # x is a guess of the square root that gets better
    # with each iteration.
    # b is the "fudge factor" - by how much we're off
    # with the guess. The invariant x^2 = ab (mod p)
    # is maintained throughout the loop.
    # g is used for successive powers of n to update
    # both a and b
    # r is the exponent - decreases with each update
    x = pow(a, (s + 1) // 2, p)
    b = pow(a, s, p)
    g = pow(n, s, p)
    r = e

    while True:
        t = b
        m = 0
        for m in range(r):
            if t == 1:
            t = pow(t, 2, p)

        if m == 0:
            return x

        gs = pow(g, 2 ** (r - m - 1), p)
        g = (gs * gs) % p
        x = (x * gs) % p
        b = (b * g) % p
        r = m

def legendre_symbol(a, p):
    """ Compute the Legendre symbol a|p using
        Euler's criterion. p is a prime, a is
        relatively prime to p (if p divides
        a, then a|p = 0)

        Returns 1 if a has a square root modulo
        p, -1 otherwise.
    ls = pow(a, (p - 1) // 2, p)
    return -1 if ls == p - 1 else ls

# use (2), (3)
# (2) x_1^2 + ky_1^2(x_2^2 + ky_2^2) = X^2 + kY^2
# (3) X = x_1x_2 +- ky_1y_2, Y = x_1y_2 -+ x_2y_1
def combine(k, x1, y1, x2, y2):
    return x1 * x2 + k * y1 * y2, x1 * y2 - x2 * y1

rand_state = gmp.random_state()

def solve(k, m, n):
    print((k, m, n))
    print(f"k length = {k.bit_length()}")

    # 1) assume n is composite

    # 2) Replace m with m'

    # compute m0
    while 1:
        u = gmp.mpz_random(rand_state, n)
        v = gmp.mpz_random(rand_state, n)
        m0 = m * (u ** 2 + k * v ** 2) % n
        if gmp.is_prime(m0):  # paper uses a different algorithm for sqrt and does not check primality
            x0 = modular_sqrt(-k, m0)
            if (x0 ** 2 + k) % m0 == 0:

    # (6)
    sqrt_k = gmp.isqrt(abs(k))
    mi = [m0]
    xi = [x0]
    while 1:
        mi.append((xi[-1] ** 2 + k) // mi[-1])
        if k > 0 and xi[-1] <= mi[-1] <= mi[-2] or k < 0 and abs(mi[-1]) <= sqrt_k:
        xi.append(min(xi[-1] % mi[-1], mi[-1] - (xi[-1] % mi[-1])))

    I = len(xi)
    print(f"I = {I}")
    for i in range(I):
        assert xi[i] ** 2 + k == mi[i] * mi[i + 1]

    m_ = mi[-1]

    # (8)
    s = xi[0]
    t = 1
    for i in range(1, I):
        s, t = combine(k, s, t, xi[i], 1)

    # (9)
    M =[1:]) % n
    Mi = gmp.invert(M, n)
    U = s * Mi % n
    V = t * Mi % n
    assert (U ** 2 + k * V ** 2) % n == m0 * gmp.invert(m_, n) % n

    # 3) compute x, y for m'
    if gmp.is_square(m_ % n):
        x, y = gmp.isqrt(m_ % n), 0
    elif (m_ - k) % n == 0:  # m_ == k
        x, y = 0, 1
    else:  # 4) recursion
        x_, y_ = solve(-m_, -k, n)
        y = gmp.invert(y_, n)
        x = x_ * y % n

    assert (x ** 2 + k * y ** 2 - m_) % n == 0

    # 5)
    # get solution for m0
    x_, y_ = combine(k, U, V, x, y)
    assert (x_ ** 2 + k * y_ ** 2 - m0) % n == 0

    x, y = combine(k, u, v, x_, y_)
    assert (x ** 2 + k * y ** 2 - m0 ** 2 * gmp.invert(m, n)) % n == 0

    m0i = gmp.invert(m0, n)
    x = x * m * m0i % n
    y = y * m * m0i % n
    assert (x ** 2 + k * y ** 2 - m) % n == 0

    return x, y

n = 25299128324054183472341067223932160732879350179758036557232544635970111090474692853470743347443422497121006796606102551210094872253782062717537548880909979729182337501587763866901367212812697076494080678616385493076865655574412317879297160790121009524506015912113098690685202868184636344610142590510988192306870694667596904330867479578103616304053889409982447653859514868824002960431331342963562137691362725961627846051021103954795862501700267818317148154520620016172888281127685503677751830350686839873220480306266506898497203511851305686566444690384065880667273398255172752236076702247451872387522388546088290187449
k = 31019613858513746556266176233462864650379070310554671955689986199007361221356361736128815989480106678809272137963430923820800280374078610631771089089882153619351592434728588050285853284795554255483472955286848474793299446184220594124878818081534965835159741218233013815338595300394855159744354636541274026478456851924371621879725248093305782590590080796638483359868136648681381332610536250576568502512250581068814961097404403694071264894656697723213779631364079010490113719021172301802643377777927176399460547584115127172190000090756708138720022664973312744713394243720961199400948876916817452969615149776530401604593
z = 7647621505426523107416876116179503771882358225602233882688080500019700416074079641481433709829408187069837328548192260242213288375752541771980455357725520773932150990241200106710194070936521249397968646028251975570300856774501827288072745234025035794496061930945682174098054405440546453789891131472968763963211701605509016308957961233593474796671835082111731734173903413989101376338153157881536000538617570527656809006766854781353000424633201559848473157638877908589222891590426378306548494548233753531300670238544815960843418612332546932989237593742552831285809303013106494916607063536993429613432733764529882168669

x, y = solve(gmp.mpz(k), gmp.mpz(z), gmp.mpz(n))  # use gmp for better performance
print(f"x = {x}\ny = {y}")