tags: ctf,crypto,writeup,lattice,ecdsa


Hack.lu-Qualifier-2023/Crypto/Spooky Safebox writeup

1. challenge

  • Author: newton

Satoshi lost his private key, can you help him recover his secret?

nc flu.xxx 10030

Download challenge files

app.py

#!/usr/bin/env python3

import secrets
import os, sys, hmac
import cryptod
from proofofwork import challenge_proof_of_work

FLAG = os.environ.get("FLAG", "flag{FAKE_FLAG}") if "flag" in os.environ.get("FLAG","") else "flag{FAKE_FLAG}"
 
def main():
    print("Welcome to the Spooky Safebox!")
    if not challenge_proof_of_work():
        return
    kpriv, kpub = cryptod.make_keys()
    order = cryptod.get_order()
    encrypted_flag = cryptod.encrypt(kpub, FLAG)
    print("Here is the encrypted flag:", encrypted_flag)
    print("You've got 9 signatures, try to recover Satoshi's private key!")
    for i in range(9):
        msg_ = input("Enter a message to sign: >")
        msg = hmac.new(cryptod.int_to_bytes(kpub.point.x() * i), msg_.encode(), "sha224").hexdigest()
        checksum = 2**224 + (int(hmac.new(cryptod.int_to_bytes(kpriv.secret_multiplier) , msg_.encode(), "sha224").hexdigest(), 16) % (order-2**224))
        nonce = secrets.randbelow(2 ** 224 - 1) + 1 + checksum
        sig = kpriv.sign(int(msg, 16) % order, nonce)
        print("Signature",(cryptod.int_to_bytes(int(sig.r)) + bytes.fromhex("deadbeef") + cryptod.int_to_bytes(int(sig.s))).hex())
    
    print("Goodbye!")

if __name__ == '__main__':
    try:
        main()
    except EOFError:
        pass
    except KeyboardInterrupt:
        pass

cryptod.py

import ecdsa, ecdsa.ecdsa
from cryptography.hazmat.primitives.kdf.kbkdf import (
   CounterLocation, KBKDFHMAC, Mode
)
from cryptography.hazmat.primitives import hashes
import secrets
from Crypto.Cipher import ChaCha20_Poly1305

def get_order(): return ecdsa.NIST256p.generator.order()
def encrypt_sym(input_bytes: bytes, key:bytes):

    cipher = ChaCha20_Poly1305.new(key=key)
    ciphertext, tag = cipher.encrypt_and_digest(input_bytes)
    return ciphertext + tag + cipher.nonce

def derive_symkey(inp:bytes):
    kdf = KBKDFHMAC(
        algorithm=hashes.SHA3_256(),
        mode=Mode.CounterMode,
        length=32,
        rlen=4,
        llen=4,
        location=CounterLocation.BeforeFixed,
        label=b"safu",
        context=b"funds are safu",
        fixed=None,
    )
    return kdf.derive(inp)
    
def make_keys():
    gen = ecdsa.NIST256p.generator
    secret = secrets.randbelow(gen.order()-1) + 1
    pub_key = ecdsa.ecdsa.Public_key(gen, gen * secret)
    priv_key = ecdsa.ecdsa.Private_key(pub_key, secret)
    return priv_key, pub_key

def int_to_bytes(n: int) -> bytes:
    return n.to_bytes((n.bit_length() + 7) // 8, 'big') or b'\0'

def encrypt(kpub_dest:ecdsa.ecdsa.Public_key, msg:str):
    gen = ecdsa.NIST256p.generator
    r = secrets.randbelow(gen.order()-1) + 1
    R = gen * r
    S = kpub_dest.point * r
    key = derive_symkey(int_to_bytes(int(S.x())))
    cp = encrypt_sym(msg.encode(), key).hex() 
    return cp + "deadbeef" + R.to_bytes().hex()

2. 背景

2.1 ecdsa

回顾一下 ecdsa

参数

  • CURVE: 椭圆曲线方程
  • G: 生成元
  • n: G的秩, 素数
  • $d_{A}$ : 私钥, random(1, n-1)
  • $Q_{A}$ : 公钥, $Q_{A}=d_{A} \times G$

算法

$$ \begin{align*} & 1. e = HASH(m) \newline & 2. z = e » (2^{n.nbits()}-1) \newline & 3. k = random(1, n-1) \newline & 4. R(x_1, y_1) = k \times G \newline & 5. r = x_1 \mod n, x \neq 0 \newline & 6. s = k^{-1}(z+rd_A) \mod n, s \neq 0 \newline & 7. return (r, s) \end{align*} $$

3. 分析

3.1 flag 加密

flag 使用 ChaCha20_Poly1305 加密, 密钥使用 KBKDFHMAC 派生,看起来需要获得密钥才能解密 flag 。

密钥基于 S.x() 派生,而 S = gen * r * d = R * d ,因此如果能恢复 ecdsa 的私钥 d ,则可求解flag。

r = secrets.randbelow(gen.order()-1) + 1
R = gen * r
S = kpub_dest.point * r
key = derive_symkey(int_to_bytes(int(S.x())))
cp = encrypt_sym(msg.encode(), key).hex() 

3.2 ecdsa 不恰当的 nonce

如果 k 选取不当,则存在格基归约算法求出 k。本题中:

checksum = 2**224 + (int(hmac.new(cryptod.int_to_bytes(kpriv.secret_multiplier) , msg_.encode(), "sha224").hexdigest(), 16) % (order-2**224))

nonce = secrets.randbelow(2 ** 224 - 1) + 1 + checksum

checksum 最大可至 $2^{225}$ , nonce 最大可至 $2^{224} + 2^{225}$

import hmac, secrets, ecdsa
checksum = 2**224 + (int(hmac.new(b"\x00" , b"msg", "sha224").hexdigest(), 16) % (ecdsa.NIST256p.generator.order()-2**224))
nonce = secrets.randbelow(2 ** 224 - 1) + 1 + checksum
ZZ(checksum).nbits(), ZZ(nonce).nbits(), ZZ(nonce-checksum).nbits()

[out]:

(225, 226, 224)

可以考虑格基归约求 nonce,而后还原 d 。以下两种思路均可:

  1. 因为 nonce 226 bit, 相对于模 n 256 bit 较小,而我们有9组签名,可列格直接求 nonce
  2. 因为 9 个 nonce 中,其中都有共同部分 checksum , 因此也可以取 k1 = k1 - k0, 则数据变为8组

可能因为数据不够多,第二种思路约束不足(因为消元用掉了一组数据,而带来的大小变化仅为2bit,所以约束不足,则需要寻找更多基)。

3.3 h 未知

因为题目没有提供公钥,实际上我只知道第一组数据中的 h。需要考虑恢复公钥,进而恢复h

msg = hmac.new(cryptod.int_to_bytes(kpub.point.x() * i), msg_.encode(), "sha224").hexdigest()

第一组数据中hmac的key 为 ‘\x00’

hmac.new(b"\x00", b"msg", "sha224").hexdigest()

[out]:

'57be1bda660c3f024b7a6b3bbc05ab6a5424481fed33b21016bd2f93'

3.4 由签名恢复公钥

公钥Q的x坐标是椭圆曲线运算,要将代数公式转换到曲线上, 左右同乘G:

$$ \begin{align*} s & \equiv k^{-1}(h+d*r) \pmod n \newline \Rightarrow sG &= k^{-1}(h+d*r)G \newline skG &= hG + Qr \newline sR &= hG + Qr \newline \Rightarrow Q &= (sR-hG)*r^{-1} \end{align*} $$

则可根据第一组数据恢复公钥,进而恢复所有 h。

注意,kG 带换成点 R时,因为仅知道 R 的x坐标,因此R有两个可能值。要根据第二组签名筛除掉错误值。

也可以考虑通过格恢复 h, 但未知数过多(方程的2倍);而单个方程中 h, k 均为 224 bit 左右,不是最小基。

3.5 lattice

$k_i$ 是方程组中的较小数,所以使用格归约出向量 $K$

$$ \begin{align*} & s_i\underbrace{k_i}_{226 bit} - h_i -\underbrace{d}_{256 bit}r_i + \underbrace{t_i}_{256 bit}n = 0 \newline & \underbrace{k_i}_{226 bit} = s_i^{-1}(h_i + dr_i) + t_in \newline \Rightarrow & \newline & K = Tn + Ad + B \end{align*} $$

$$ \begin{align*} & \begin{bmatrix} t_0 & t_1& \cdots & t_8 & d & 1 \end{bmatrix} \begin{bmatrix} n& \newline &n \newline &&\ddots \newline &&&n \newline a_0 & a_1 & \cdots & a_8 & 1 \newline b_0 & b_1 & \cdots & b_8 & & 1\newline \end{bmatrix} \newline = & \begin{bmatrix} k_0 & k_1 & \cdots & k_8 & d & 1 \end{bmatrix} \end{align*} $$

4. io

from pwn import *
from tqdm.auto import tqdm

context.log_level = 'info'

def int_to_bytes(n):
    n = int(n)
    return n.to_bytes((n.bit_length() + 7) // 8, 'big') or b'\0'

def solve_pow(challenge, prefix):
    check = lambda s, challenge, prefix: hashlib.sha256((challenge + s).encode('utf-8')).hexdigest()[:len(prefix)] == prefix
    for i in range(0, 2**32):
        if check(str(i), challenge, prefix):
            return challenge + str(i)
    return -1

class IO:
    def __init__(self):
        self.conn = remote("flu.xxx", int(10030))
        self.r = []
        self.s = []
        self.cp = b''
        self.R = b''

    def proof_of_work(self):
        self.conn.recvuntil(b"Please provide a string that starts with ")
        challenge = self.conn.recvuntil(b" ").strip().decode()
        self.conn.recvuntil(b"and whose sha256 hash starts with ")
        prefix = self.conn.recvuntil(b"\n").strip().decode()
        answer = solve_pow(challenge, prefix)
        self.conn.sendlineafter(b"POW: >", answer.encode())

    def get_ciphertext(self):
        self.conn.recvuntil(b"Here is the encrypted flag: ")
        data = self.conn.recvuntil(b"\n").strip()
        cp, R = data.split(b"deadbeef")
        cp = bytes.fromhex(cp.decode())
        R = bytes.fromhex(R.decode())
        return cp, R

    def sign(self, msg):
        self.conn.sendlineafter(b"Enter a message to sign: >", msg)
        self.conn.recvuntil(b"Signature ")
        data = self.conn.recvuntil(b"\n").strip()
        r, s = data.split(b"deadbeef")
        r = int(r, 16)
        s = int(s, 16)
        return r, s

    def io(self):
        self.proof_of_work()
        self.cp, self.R = self.get_ciphertext()
        msg = b"msg"
        sign_data = [self.sign(b"msg") for _ in range(9)]
        self.conn.close()
        
        self.r = [l[0] for l in sign_data]
        self.s = [l[1] for l in sign_data]
        return self.r, self.s
    
io = IO()
io.io()
io.r, io.s, io.cp, io.R

[out]:

[x] Opening connection to flu.xxx on port 10030
[x] Opening connection to flu.xxx on port 10030: Trying 31.22.123.45
[+] Opening connection to flu.xxx on port 10030: Done
[*] Closed connection to flu.xxx port 10030

[out]:

([66776018195218013399494446979250504724682261554077795620228956487173130982407,
  114393210838116895842488336835856891681912934409026519271680553193022766720235,
  21232334096857148309057278755590109742418871193384517097597879490134339581745,
  107877821140492489493798722326358698274250384685323320766728100562366105427588,
  102495212179002103657752583001260755991244243197103500816724627193733443605955,
  85529476778740201234867988221848339745257036877970448889833080401422187732083,
  79792748475804441023545822528321341706813486183101139174592893432510198808374,
  17792792240509124689782055028944235134269166148366307092647824644914942731383,
  68516731927048496453209927566395053817976932541180090504689051332338679892015],
 [80536611067786802008272168014557974590911628552718596016081812230348652863632,
  107682597726684409302528106095784696134018121944700855780130868557886875958118,
  70648019642365910447275961156153902076938371904919048635455544712005137902090,
  92874539716167498505378614601497895622085212224923265723001224217549212722567,
  78852469261728790093495397136011535824650766075320148994737573509066481895401,
  101977509697759437790189866129685168866896774533412502924649800226264038571921,
  96370691258700369068290864081138490037131368543234717431817771468666832596059,
  74377478619125401547888185098971258007989519955304529036163727254502765743702,
  61306501505091400839917378031470378622213308722696893596402188037651109181275],
 b'\xaf\xddP\xdc\xabt\xb9j=\x88j\xdbZU\xfcV\x12\x07\xb9\xa3\x05\xfc\x85\xf4\x96\xea\xe0\x97\xb0M\x8c\xfb\xb6\x7fJ\x17\xa1\xaf\xba\x87\xd7\x99\x97\xefD\xc7\xd8Y\xb7\x02p:\xd2i\xb10\x9e\x02\xf7\x1ec\xd7C3\x90\xf0\xa8\x1d\x0cp',
 b'P\t\xd9\x85\x10\xe54\x17\xd7\xc8\xadVW\xffG\xe4b\xdf\xfc\t\xecB\x15\x8ejU\xc1)\x92\xb5aP\xa4o\x0b\x89\t^\xe2\xf8L \x8e\x91\xd1F\x87\x06[A!|_8u\xf5/!K\x0b\x91\x9dwQ')

5. 恢复 ecdsa 公钥

将 ecdsa 库的曲线转为 sagemath 对象

import ecdsa

def ecdsa2sage(curve):
    """
    convert an ecdsa curve to sagemath curve
    e.g.: ecdsa2sage(ecdsa.NIST256p)
    """
    a, b, p = curve.curve.a(), curve.curve.b(), curve.curve.p()
    E = EllipticCurve(GF(p),[a,b])
    Gx, Gy = curve.generator.x(), curve.generator.y()
    G = E(Gx, Gy)
    return E, G

E, G = ecdsa2sage(ecdsa.NIST256p)

使用第一组签名恢复公钥,因为 R 的纵坐标不确定,因此还需要第二组签名排除错误的数据。

class RecoveryEcdsaPublicKey:
    """
    recovery ecdsa public key from signature and hash
    E: ecc
    G: curve generatpr
    usage: Q = RecoveryEcdsaPublicKey(E, G).recovery_and_verify(h0, r0, s0, h1, r1, s1)
    """
    def __init__(self, E, G):
        self.E = E
        self.G = G
        self.n = E.order()
        self.Q = []

    def recovery(self, h:int, r:int, s:int) -> list:
        """
        return: public key
                there will be at most 2 public keys;
                if the signature is wrong, an empty list will be returned
        """
        h, r, s = map(ZZ, (h, r, s))
        Rs = E.lift_x(ZZ(r), all=True)
        self.Q = [inverse_mod(r, self.n)*(s*R-h*G) for R in Rs] 
        return self.Q

    def verify(self, Q, h:int, r:int, s:int) -> bool:
        """
        r, s: a different group of signature to verify public key
        return: whether the Q is a correct public key
        """
        h, r, s = map(ZZ, (h, r, s))
        u1 = h*inverse_mod(s, self.n)
        u2 = r*inverse_mod(s, self.n)
        R = u1 * G + u2 * Q
        return R[0] == r
    
    def recovery_and_verify(self, h0:int, r0:int, s0:int, h1:int=None, r1:int=None, s1:int=None) -> list:
        """
        h0, r0, s0: first group signature, to calculate public keys
        h1, r1, s1: first group signature, to filter correct public key by verify signature
        return: at most 1 public keys if h1,r1,s1 provided 
        """
        self.recovery(h0, r0, s0)
        if None not in [h1,r1,s1]:
            self.Q = [Q for Q in self.Q if self.verify(Q, h1, r1, s1)]
        return self.Q
    
    @classmethod
    def example(cls):
        import ecdsa, secrets, hashlib
        
        # sign
        _private_key = ecdsa.SigningKey.generate(curve=ecdsa.NIST256p)
        _public_key = _private_key.get_verifying_key()
        _h = int(hashlib.sha256(b'msg').hexdigest(), 16)
        _r0, _s0 = _private_key.sign_number(_h)
        _r1, _s1 = _private_key.sign_number(_h)

        # recovery
        _E = EllipticCurve(GF(ecdsa.NIST256p.curve.p()), [ecdsa.NIST256p.curve.a(), ecdsa.NIST256p.curve.b()])
        _G = _E(ecdsa.NIST256p.generator.x(), ecdsa.NIST256p.generator.y())
        _R = RecoveryEcdsaPublicKey(_E, _G)
        _Q = _R.recovery_and_verify(_h, _r0, _s0, _h, _r1, _s1)
        
        assert _Q[0][0] == ZZ(_public_key.pubkey.point.x()) and _Q[0][1] == ZZ(_public_key.pubkey.point.y())
        
RecoveryEcdsaPublicKey.example()

6. lattice

有多种解法,一些公共方法放在 BaseSolver 作为基类提供;涉及到格的在各子类中实现。

import hmac
import secrets

_hash = lambda k: int(hmac.new(k, b'msg', "sha224").hexdigest(), 16)

class BaseSolver():
    """
    recovery: public key, h
    generate symbols: d, k
    """
    def __init__(self, r:list, s:list, bits:int=226, N:int=9):
        """
        r, s: signature list
        bits: bound of k
        N:    the number of groups
        """
        self.r = [ZZ(i) for i in r]
        self.s = [ZZ(i) for i in s]
        self.bits = bits
        self.N = N
        
        self.n = E.order()
        
        self.h = []
        self.Q = None
        
        self.PR = None
        self.d = None
        self.k = None
        
        self.eq = []

    def recovery_pub(self):
        R = RecoveryEcdsaPublicKey(E, G)
        h0 = _hash(b'\x00')
        Qs = R.recovery(h0, self.r[0], self.s[0])
        self.Q = [Q for Q in Qs if R.verify(Q, _hash(int_to_bytes(Q[0])), self.r[1], self.s[1])][0]

    def recovery_h(self):
        self.h = [_hash(int_to_bytes(ZZ(self.Q[0])*i)) for i in range(self.N)]

    def gen_syms(self):
        syms = ["d"] + [f"k{i}" for i in range(self.N)]
        self.PR = PolynomialRing(GF(self.n), syms)
        self.d, *(self.k) = self.PR.gens()

    def gen_eq(self):
        """
        s*k - (h+d*r) = 0
        """
        self.eq = [self.s[i]*self.k[i] - (self.h[i]+self.d*self.r[i])  for i in range(self.N)]
        
    def base(self):
        self.recovery_pub()
        self.recovery_h()
        self.gen_syms()
        self.gen_eq()
        
class BaseTester:
    def __init__(self, bits:int=226, N:int=9):
        self.bits = bits
        self.N = N
        
        self.private_key = ecdsa.SigningKey.generate(curve=ecdsa.NIST256p)
        self.public_key = self.private_key.get_verifying_key()
        self.d = self.private_key.privkey.secret_multiplier
        self.signs()
        self.solver = BaseSolver(self.r, self.s, bits, N)

    def sign(self, i):
        h = int(hmac.new(int_to_bytes(self.public_key.pubkey.point.x() * i), b'msg', "sha224").hexdigest(), 16)
        k = secrets.randbelow(2 ** 224 - 1) + 1 + 2**224 + int(hmac.new(int_to_bytes(self.private_key.privkey.secret_multiplier) , b"msg", "sha224").hexdigest(), 16) % (ecdsa.NIST256p.generator.order()-2**224)
        r, s = self.private_key.sign_number(h, k=k)
        return r, s, h, k
    
    def signs(self):
        data = [self.sign(i) for i in range(self.N)]
        self.r, self.s, self.h, self.k = map(lambda i: [l[i] for l in data], range(4))

    def test_recovery_pub(self):
        self.solver.recovery_pub()
        assert self.solver.Q[0] == ZZ(self.public_key.pubkey.point.x())
        assert self.solver.Q[1] == ZZ(self.public_key.pubkey.point.y())

    def test_recovery_h(self):
        self.solver.recovery_h()
        assert self.solver.h == self.h
    
    def test_gen_syms(self):
        self.solver.gen_syms()
        assert str(self.solver.d) == "d"
        assert str(self.solver.k) == "[" + ", ".join([f"k{i}" for i in range(self.N)]) + "]"

    def test_gen_eq(self):
        self.solver.gen_eq()
        assert len(self.solver.eq) == self.N
        # k0, d, 1
        monomials = self.solver.eq[0].monomials()
        assert self.solver.d in monomials and self.solver.k[0] in monomials and self.solver.k[1] not in monomials
        for eq in self.solver.eq:
            assert eq(self.d, *(self.k)) == 0
        
    def test(self):
        self.test_recovery_pub()
        self.test_recovery_h()
        self.test_gen_syms()
        self.test_gen_eq()

BaseTester().test()

6.1 k 较小

$$ K = Tn + Ad + B $$

$$ \begin{align*} & \begin{bmatrix} t_0 & t_1& \cdots & t_8 & d & 1 \end{bmatrix} \begin{bmatrix} n& \newline &n \newline &&\ddots \newline &&&n \newline a_0 & a_1 & \cdots & a_8 & 1 \newline b_0 & b_1 & \cdots & b_8 & & 1\newline \end{bmatrix} \newline = & \begin{bmatrix} k_0 & k_1 & \cdots & k_8 & d & 1 \end{bmatrix} \end{align*} $$

class Solver1_small_k(BaseSolver):
    """
    small k; with d in the lattice
    """
    def __init__(self, r, s, bits:int=226, N:int=9):
        super().__init__(r, s, bits, N)
        self.base()
        self.eq_resultant = []
        self.L = None
        self.bound = []
        self.T = None
        # self.LLL = self.LLL_enumeration
    
    def resultant(self):
        """
        represent ki by d
        """
        I = Ideal(self.eq)
        t = TermOrder('wdegrevlex', tuple([1] + list(range(self.N+1, 1, -1))))
        I = I.change_ring(self.PR.change_ring(order=t))
        b = I.groebner_basis()
        self.eq_resultant = [eq.monomials()[0] - eq for eq in b]
        
    def lattice(self):
        A = matrix([ZZ(i.coefficients()[0]) for i in self.eq_resultant])
        B = matrix([ZZ(i.constant_coefficient()) for i in self.eq_resultant])
        self.L = block_matrix([
            [ZZ(self.n), 0], 
            [A.stack(B), 1]
        ])
        
    def balance(self):
        self.bound = [2**self.bits]*len(self.eq_resultant) + [2**256, 1]
        self.T = diagonal_matrix([max(self.bound) // i for i in self.bound])
    
    def check(self, v):
        if abs(v[-1]) == 1:
            v*=v[-1]
            secret = v[-2] % self.n
            if self.Q == G*secret:
                print("d = ",secret)
                return secret
                
    def LLL(self):
        B = (self.L*self.T).LLL()/self.T
        for v in B:
            secret = self.check(v)
            if secret is not None:
                return secret      

    def LLL_enumeration(self):
        from tqdm.auto import tqdm
        
        def tqdm_storage(t):
            if hasattr(t, 'container') and t.leave:
                display(t.container)
        
        """
        extend search scope
        """
        from fpylll import IntegerMatrix, LLL
        from fpylll.fplll.gso import MatGSO
        from fpylll.fplll.enumeration import Enumeration, EvaluatorStrategy
        
        A = IntegerMatrix.from_matrix(self.L*self.T)
        LLL.reduction(A)
        M = MatGSO(A)
        M.update_gso()
        size = self.L.nrows()
        sol_cnt = 10000
        enum = Enumeration(M, sol_cnt, strategy=EvaluatorStrategy.BEST_N_SOLUTIONS)
        answers = enum.enumerate(0, size, (size * max(self.bound)**2), 0, pruning=None)
        
        with tqdm(answers) as t:
            d = None
            for _, a in t:
                v = IntegerMatrix.from_iterable(1, size, map(int, a))
                B = v*A
                B = matrix(B)[0]/self.T
                secret = self.check(B)
                if secret is not None:
                    d = secret
                    break
            tqdm_storage(t)
        return d
            
    def solve(self):
        self.resultant()
        self.lattice()
        self.balance()
        return self.LLL()
        
class Tester1(BaseTester):
    def __init__(self, bits:int=226, N:int=9):
        super().__init__(bits, N)
        self.solver = Solver1_small_k(self.r, self.s, bits, N)
    
    def test_resultant(self):
        self.solver.resultant()
        assert len(self.solver.eq_resultant) == self.N
        assert self.solver.eq_resultant[0].monomials() == [self.solver.d, 1]
        for i in range(len(self.solver.eq_resultant)):
            assert self.solver.eq_resultant[i](d=self.d) == ZZ(self.k[i])
        
    def test_lattice(self):
        self.solver.lattice()
        
    def test_balance(self):
        self.solver.balance()
        assert min(self.solver.bound) > 0
        L = self.solver.L * self.solver.T
        print("lattice rows:", L.nrows())
        print("lattice determinant:", L.determinant().nbits())

    def test_LLL(self):
        d = self.solver.LLL()
        assert d == self.d
    
    def test(self):
        self.test_resultant()
        self.test_lattice()
        self.test_balance()
        self.test_LLL()
        
Tester1().test()

[out]:

lattice rows: 11
lattice determinant: 2830
d =  52720939586054020232861055204615741160315315569381962981717239861703275582615

求解远程结果

d = Solver1_small_k(io.r, io.s).solve()

[out]:

d =  3853534948223110047007972915869293470762119648642476302498539377337395966720

6.2 消元 d

因为每个方程都与 d 有关,完全可以消元 d ,而不影响结果。而且 d 是一个相对大数,放在结果向量中,可能(不确定)会影响格基归约的效果。

$$ \begin{align*} k_i &= t_in + a_id + b_i, {i=0\cdots 8} \newline & 消元 d \newline \Rightarrow & \newline k_{i} &= t_i^{'}n + a_i^{'}k_8 + b_i^{'}, {i=0\cdots 7} \newline \end{align*} $$

$$ \begin{align*} & \begin{bmatrix} t_0 & t_1& \cdots & t_7 & k_8 & 1 \end{bmatrix} \begin{bmatrix} n& \newline &n \newline &&\ddots \newline &&&n \newline a_0 & a_1 & \cdots & a_7 & 1 \newline b_0 & b_1 & \cdots & b_7 & & 1\newline \end{bmatrix} \newline = & \begin{bmatrix} k_0 & k_1 & \cdots & k_7 & k_8 & 1 \end{bmatrix} \end{align*} $$

class Solver2_small_k_resultant_d(Solver1_small_k):
    """
    small k; eliminate d
    """
    def __init__(self, r, s, bits:int=226, N:int=9):
        super().__init__(r, s, bits, N)

    def resultant(self):
        """
        eliminate d
        """
        eq = [self.eq[-1].sylvester_matrix(self.eq[i], self.d).determinant() for i in range(len(self.eq)-1)]
        I = Ideal(eq).groebner_basis()
        self.eq_resultant = [i.monomials()[0]-i for i in I]
    
    def balance(self):
        self.bound = [2**self.bits]*(len(self.eq_resultant)+1) + [1]
        self.T = diagonal_matrix([max(self.bound) // i for i in self.bound])

    def check(self, v):
        if abs(v[-1]) == 1:
            v*=v[-1]
            secret = ZZ(self.eq[0](k0=ZZ(v[0])).univariate_polynomial().roots()[0][0])
            if self.Q == G*secret:
                print("d =",secret)
                return secret

class Tester2(Tester1):
    def __init__(self, bits:int=226, N:int=9):
        super().__init__(bits, N)
        self.solver = Solver2_small_k_resultant_d(self.r, self.s, bits, N)   
        
    def test_resultant(self):
        self.solver.resultant()
        assert len(self.solver.eq_resultant) == self.N - 1
        assert self.solver.eq_resultant[0].monomials() == [self.solver.k[-1], 1]
        for eq in self.solver.eq:
            assert eq(self.d, *(self.k)) == 0

Tester2().test()

[out]:

lattice rows: 10
lattice determinant: 2274
d = 96936376497036088647218081962720612633610294433642650775741882796691505321908

求远程数据:

d = Solver2_small_k_resultant_d(io.r, io.s).solve()

[out]:

d = 3853534948223110047007972915869293470762119648642476302498539377337395966720

6.3 k 有未知的共同前缀

k 有共同前缀时,可以牺牲一个方程,取 k 之间的差值作为新的 k, 将问题重新转换为 k 较小的场景。

$$ \begin{align*} k_i &= t_in + a_id + b_i, {i=0\cdots 8} \newline & 牺牲最后一个方程,消元前缀 \newline \Rightarrow & \newline k^{'}_{i} &= k_i-k_8 \newline & = t_i^{'}n + a_i^{'}d + b_i^{'}, {i=0\cdots 7} \newline \end{align*} $$

$$ \begin{align*} & \begin{bmatrix} t_0 & t_1& \cdots & t_7 & d & 1 \end{bmatrix} \begin{bmatrix} n& \newline &n \newline &&\ddots \newline &&&n \newline a_0 & a_1 & \cdots & a_7 & 1 \newline b_0 & b_1 & \cdots & b_7 & & 1\newline \end{bmatrix} \newline = & \begin{bmatrix} k_0^{'} & k_1^{'} & \cdots & k_7^{'} & d & 1 \end{bmatrix} \end{align*} $$

class Solver3_k_with_unknown_prefix(Solver1_small_k):
    """
    k with unknown preifx
    """
    def __init__(self, r, s, bits:int=224, N:int=9):
        super().__init__(r, s, bits, N)
        self.LLL = self.LLL_enumeration
        
    def resultant(self):
        """
        represent ki by d;
        ki -= k8
        """
        super().resultant()
        self.eq_resultant = [eq - self.eq_resultant[-1] for eq in self.eq_resultant[:-1]]

class Tester3(Tester1):
    def __init__(self, bits:int=224, N:int=9):
        super().__init__(bits, N)
        self.solver = Solver3_k_with_unknown_prefix(self.r, self.s, bits, N)
        
    def test_resultant(self):
        self.solver.resultant()
        assert len(self.solver.eq_resultant) == self.N - 1
        assert self.solver.eq_resultant[0].monomials() == [self.solver.d, 1]
        for eq in self.solver.eq:
            assert eq(self.d, *(self.k)) == 0

Tester3().test()

[out]:

lattice rows: 10
lattice determinant: 2560

[out]:

d =  9082122349319352506730782077170392789105758876101622123810653216888929947915

[out]:

  0%|          | 43/10000 [00:00<00:47, 208.02it/s]

求远程数据

d = Solver3_k_with_unknown_prefix(io.r, io.s).solve()

[out]:

d =  3853534948223110047007972915869293470762119648642476302498539377337395966720

[out]:

  0%|          | 9/10000 [00:00<01:06, 150.53it/s]

6.4 k 有未知的共同前缀,消元d

在 prefix 较大,且未知时,可能有效。

$$ \begin{align*} k_i &= t_in + a_id + b_i, {i=0\cdots 8} \newline & 牺牲最后一个方程,消元前缀 \newline \Rightarrow & \newline k^{'}_{i} &= k_i-k_8 \newline & = t_i^{'}n + a_i^{'}d + b_i^{'}, {i=0\cdots 7} \newline & 再牺牲一个方程,消元d \newline \Rightarrow & \newline k^{'}_{i} &= t_i^{'}n + a_i^{'}k^{'}_7 + b_i^{'}, {i=0\cdots 6} \newline \end{align*} $$

$$ \begin{align*} & \begin{bmatrix} t_0 & t_1& \cdots & t_6 & k_7^{'} & 1 \end{bmatrix} \begin{bmatrix} n& \newline &n \newline &&\ddots \newline &&&n \newline a_0 & a_1 & \cdots & a_6 & 1 \newline b_0 & b_1 & \cdots & b_6 & & 1\newline \end{bmatrix} \newline = & \begin{bmatrix} k_0^{'} & k_1^{'} & \cdots & k_6^{'} & k_7^{'} & 1 \end{bmatrix} \end{align*} $$

class Solver4_k_with_unknown_prefix_eliminate_d(Solver1_small_k):
    """
    k with unknown preifx, eliminate d
    """
    def __init__(self, r, s, bits:int=224, N:int=9):
        super().__init__(r, s, bits, N)
        self.LLL = self.LLL_enumeration
        # ki - k8
        self.eq_difference = []
        
    def resultant(self):
        """
        represent by d;
        ki -= k8
        eliminate d;
        """
        # represent by d
        super().resultant()
        # ki -= k8
        eq = [self.eq_resultant[i] - self.eq_resultant[-1] - self.k[i] for i in range(len(self.eq_resultant)-1)]
        self.eq_difference = eq
        # eliminate d
        eq = [eq[-1].sylvester_matrix(eq[i], self.d).determinant() for i in range(len(eq)-1)]
        I = Ideal(eq).groebner_basis()
        self.eq_resultant = [i.monomials()[0]-i for i in I]

    def balance(self):
        self.bound = [2**self.bits]*(len(self.eq_resultant)+1) + [1]
        self.T = diagonal_matrix([max(self.bound) // i for i in self.bound])        
        
    def check(self, v):
        if abs(v[-1]) == 1:
            v*=v[-1]
            secret = ZZ(self.eq_difference[0](k0=ZZ(v[0])).univariate_polynomial().roots()[0][0])
            if self.Q == G*secret:
                print("d =",secret)
                return secret

class Tester4(Tester1):
    def __init__(self, bits:int=224, N:int=9):
        super().__init__(bits, N)
        self.solver = Solver4_k_with_unknown_prefix_eliminate_d(self.r, self.s, bits, N)
        
    def test_resultant(self):
        self.solver.resultant()
        assert len(self.solver.eq_resultant) == self.N - 2
        assert self.solver.eq_resultant[0].monomials() == [self.solver.k[-2], 1]
        for eq in self.solver.eq:
            assert eq(self.d, *(self.k)) == 0

Tester4().test()

[out]:

lattice rows: 9
lattice determinant: 2016

[out]:

d = 23150186784329805719311451646846086424385190132360620022332472531390776173743

[out]:

  8%|7         | 764/10000 [00:19<04:50, 31.80it/s]

求远程数据:

Solver4_k_with_unknown_prefix_eliminate_d(io.r, io.s).solve()

[out]:

d = 3853534948223110047007972915869293470762119648642476302498539377337395966720

[out]:

  0%|          | 7/10000 [00:00<02:32, 65.32it/s]

[out]:

3853534948223110047007972915869293470762119648642476302498539377337395966720

7. getflag

获取到 ecdsa 私钥 d 后,可以根据题目规则恢复 S = gen * r * d = R * d ,恢复 key, 解密 flag

from cryptography.hazmat.primitives.kdf.kbkdf import (
   CounterLocation, KBKDFHMAC, Mode
)
from cryptography.hazmat.primitives import hashes
from Crypto.Cipher import ChaCha20_Poly1305

def get_flag(secret, R, cp):
    def derive_symkey(inp:bytes):
        kdf = KBKDFHMAC(
            algorithm=hashes.SHA3_256(),
            mode=Mode.CounterMode,
            length=int(32),
            rlen=int(4),
            llen=int(4),
            location=CounterLocation.BeforeFixed,
            label=b"safu",
            context=b"funds are safu",
            fixed=None,
        )
        return kdf.derive(inp)

    def decrypt_sym(input_bytes: bytes, key:bytes):
        cipher = ChaCha20_Poly1305.new(key=key, nonce=cp[-12:])
        plaintext = cipher.decrypt_and_verify(input_bytes, cp[-28:-12])
        return plaintext

    
    R = ecdsa.NIST256p.generator.from_bytes(ecdsa.NIST256p.curve, R)
    S = R*secret
    key = derive_symkey(int_to_bytes(int(S.x())))
    flag = decrypt_sym(cp[:-28], key)
    return flag

get_flag(d, io.R, io.cp)

[out]:

b'flag{s4tosh1s_Funds_4re_safu_safeB0x_isnt}'

本文同步发表于

  • blog: ssst0n3.github.io
  • 公众号: 石头的安全料理屋
  • 知乎专栏: 石头的安全料理屋