tags: ctf,crypto,ecdsa


N1CTF-Qualifier-2023/Crypto/e2W@rmup writeup

1. challenge

  • time cost: 1 hour 15 min, 9:30-10:45
  • score: 80PT
  • solver: 49

Welcome 2 N1CTF2023! (^ω^)

https://drive.google.com/file/d/1g-c70UHiXSAhTkBKagpPp_sbSnq9ZPDt/view?usp=sharing

import hashlib
import ecdsa
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from Crypto.Util.number import *
from secret import flag

def gen():
    curve = ecdsa.NIST256p.generator
    order = curve.order()
    d = randint(1, order-1)
    while d.bit_length() != 256:
        d = randint(1, order-1)
    pubkey = ecdsa.ecdsa.Public_key(curve, curve * d)
    privkey = ecdsa.ecdsa.Private_key(pubkey, d)
    return pubkey, privkey, d

def nonce_gen(msg, d):
    msg_bin = bin(msg)[2:].zfill(256)
    d_bin = bin(d)[2:].zfill(256)
    nonce = int(msg_bin[:128] + d_bin[:128], 2)
    return nonce

def sign(msg, privkey, d):
    msg_hash = bytes_to_long(hashlib.sha256(msg).digest())
    nonce = nonce_gen(msg_hash, d)
    sig = privkey.sign(msg_hash, nonce)
    s, r = sig.s, sig.r
    return s, r

pk, sk, d = gen()
msg = b'welcome to n1ctf2023!'
s, r = sign(msg, sk, d)
print(f's = {s}')
print(f'r = {r}')

m = pad(flag, 16)
aes = AES.new(long_to_bytes(d), mode=AES.MODE_ECB)
cipher = aes.encrypt(m)
print(f'cipher = {cipher}')

"""
s = 98064531907276862129345013436610988187051831712632166876574510656675679745081
r = 9821122129422509893435671316433203251343263825232865092134497361752993786340
cipher = b'\xf3#\xff\x17\xdf\xbb\xc0\xc6v\x1bg\xc7\x8a6\xf2\xdf~\x12\xd8]\xc5\x02Ot\x99\x9f\xf7\xf3\x98\xbc\x045\x08\xfb\xce1@e\xbcg[I\xd1\xbf\xf8\xea\n-'
"""
s = 98064531907276862129345013436610988187051831712632166876574510656675679745081
r = 9821122129422509893435671316433203251343263825232865092134497361752993786340
c = b'\xf3#\xff\x17\xdf\xbb\xc0\xc6v\x1bg\xc7\x8a6\xf2\xdf~\x12\xd8]\xc5\x02Ot\x99\x9f\xf7\xf3\x98\xbc\x045\x08\xfb\xce1@e\xbcg[I\xd1\xbf\xf8\xea\n-'

2. 分析

$$ \begin{align*} & 2^{255} <= d < 2^{256} \newline & k = h_0 \cdot 2^{128} + d_0 \newline \Rightarrow & \newline & s\cdot(h_0 \cdot 2^{128} + d_0) \equiv h+r\cdot(d_0 \cdot 2^{128} + d_1) \pmod p \newline \Rightarrow & \newline & d_0 = a\cdot d_1 + b\cdot p + c \end{align*} $$

未知数只有 $d_0, d_1$ ,模 p 的系数 u ,其中 $d_0, d_1$ 是128bit 的相对小数。

于是可以列格来求解

$$ \begin{align*} & [d_1, 1, b] \begin{bmatrix} 1&&a \newline &1&c \newline &&p \newline \end{bmatrix} \newline = & [d_1, 1, d_0] \end{align*} $$

3. LLL Test

生成测试数据测试, LLL 只有一定概率解决这个问题。

import hashlib
import ecdsa
from Crypto.Util.number import long_to_bytes, bytes_to_long
from Crypto.Util.Padding import pad, unpad
from Crypto.Cipher import AES
from tqdm.auto import tqdm

class Solver:
    def __init__(self, s, r, c, log:bool=True):
        self.s = ZZ(s)
        self.r = ZZ(r)
        self.c = c
        self.msg = b'welcome to n1ctf2023!'

        self.log = log
        self.tqdm = None
    
    def gen_data(self):
        self.h = bytes_to_long(hashlib.sha256(self.msg).digest())
        self.h0 = int(bin(self.h)[2:].zfill(256)[:128],2)
        self.p = ZZ(ecdsa.NIST256p.generator.order())
    
    def gen_eq(self):
        PR = PolynomialRing(GF(self.p), ["d0", "d1"])
        d0, d1 = PR.gens()
        self.eq = self.s*(self.h0*(2**128)+d0) - self.h - self.r*(d0*(2**128)+d1)
        
    def resultant(self):
        g = Ideal(self.eq).groebner_basis()[0]
        self.eq = g.monomials()[0] - g

    def balance(self):
        self.bound = [2**128, 1, 2**128]
        
    def lattice(self):    
        A = Matrix([ZZ(i) for i in self.eq.coefficients()]).T
        self.L = block_matrix([
            [1, A],
            [0, self.p]
        ])
        self.balance()
        self.T = diagonal_matrix([max(self.bound) // i for i in self.bound])
        
    def get_flag(self, d):
        from Crypto.Cipher import AES
        aes = AES.new(long_to_bytes(d), mode=AES.MODE_ECB)
        flag = aes.decrypt(self.c)
        try:
            flag = unpad(flag, 16)
        except ValueError:
            return None
        else:
            if self.log:
                print("[+] flag:", flag)
            return flag
    
    def check(self, v):
        if abs(v[1]) == 1:
            v*=v[1]
            d1, d0 = v[0]%self.p, v[2]%self.p
            if d0.nbits() == 128 and d1.nbits() <= 128:
                d = d0 * 2**128 + d1
                flag = self.get_flag(d)
                return d, flag
        return None, None

    def LLL_normal(self):
        B = (self.L*self.T).LLL()/self.T
        self.tqdm.reset(total=B.nrows())
        for v in B:
            yield v
        
    def LLL_enumeration(self):        
        """
        extend search scope
        """
        from fpylll import IntegerMatrix, LLL
        from fpylll.fplll.gso import MatGSO
        from fpylll.fplll.enumeration import Enumeration, EvaluatorStrategy
        
        A = IntegerMatrix.from_matrix(self.L*self.T)
        LLL.reduction(A)
        M = MatGSO(A)
        M.update_gso()
        size = self.L.nrows()
        sol_cnt = 1000000
        enum = Enumeration(M, sol_cnt, strategy=EvaluatorStrategy.BEST_N_SOLUTIONS)
        answers = enum.enumerate(0, size, (size * max(self.bound)**2), 0, pruning=None)
        self.tqdm.reset(total=len(answers))
        for _, a in answers:
            v = IntegerMatrix.from_iterable(1, size, map(int, a))
            B = v*A
            B = matrix(B)[0]/self.T
            yield B
    
    def LLL(self, enum:bool=False):        
        def tqdm_storage(t):
            if hasattr(t, 'container') and t.leave:
                display(t.container)

        d = None
        flag = None
        LLL_func = self.LLL_enumeration if enum else self.LLL_normal
        vectors = LLL_func()
        self.tqdm = tqdm(vectors, disable=not self.log)
        for v in self.tqdm:
            d, flag = self.check(v)
            if flag is not None:
                break
        tqdm_storage(self.tqdm)
        return d, flag
    
    def solve(self, enum:bool=False):
        self.gen_data()
        self.gen_eq()
        self.resultant()
        self.lattice()
        return self.LLL(enum)

        
class Tester:
    def __init__(self):
        self.msg = b'welcome to n1ctf2023!'
        self.flag = b'flag{test}'
        self.data()
        self.enc()
        self.solver = Solver(self.s, self.r, self.c)
        
    def data(self):
        curve = ecdsa.NIST256p.generator
        self.p = curve.order()
        d = randint(1, self.p-1)
        while d.bit_length() != 256:
            d = randint(1, self.p-1)
        self.d = d
        self.pubkey = ecdsa.ecdsa.Public_key(curve, curve * d)
        self.privkey = ecdsa.ecdsa.Private_key(self.pubkey, d)
        self.sign()
                
    def nonce_gen(self, msg, d):
        msg_bin = bin(msg)[2:].zfill(256)
        d_bin = bin(d)[2:].zfill(256)
        nonce = int(msg_bin[:128] + d_bin[:128], 2)
        self.h0 = int(msg_bin[:128], 2)
        self.d0 = int(d_bin[:128],2)
        self.d1 = int(d_bin[128:],2)
        return nonce
    
    def sign(self):
        msg_hash = bytes_to_long(hashlib.sha256(self.msg).digest())
        self.h = msg_hash
        nonce = self.nonce_gen(msg_hash, self.d)
        sig = self.privkey.sign(msg_hash, nonce)
        self.s, self.r = sig.s, sig.r
        
    def enc(self):
        m = pad(self.flag, 16)
        aes = AES.new(long_to_bytes(self.d), mode=AES.MODE_ECB)
        self.c = aes.encrypt(m)
        
    def test_gen_data(self):
        self.solver.gen_data()
        assert self.solver.h == self.h
        assert self.solver.h0 == self.h0
        
    def test_gen_eq(self):
        self.solver.gen_eq()
        assert self.solver.eq(d0=self.d0, d1=self.d1) == 0
    
    def test_resultant(self):
        self.solver.resultant()
        assert self.solver.eq(d1=self.d1) == self.d0
    
    def test_lattice(self):
        self.solver.lattice()
    
    def test_LLL(self, enum:bool=False):
        d, flag = self.solver.LLL(enum=enum)
        assert d == self.d
        assert flag == self.flag
    
    def test(self, enum:bool=False, log:bool=True):
        self.solver.log = log
        self.test_gen_data()
        self.test_gen_eq()
        self.test_resultant()
        self.test_lattice()
        self.test_LLL(enum=enum)
        
    def test_benchmark(self, enum:bool=False, log:bool=False):
        count = 0
        for _ in range(100):
            try:
                self.__init__()
                self.test(enum=enum, log=log)
            except AssertionError:
                pass
            else:
                count += 1
        print(f"[+] sucess rate: {count}%")
        
Tester().test()

[out]:

0it [00:00, ?it/s]

[out]:

100%|##########| 3/3 [00:00<00:00, 833.86it/s]




---------------------------------------------------------------------------

AssertionError                            Traceback (most recent call last)

Cell In [2], line 202
    199                 count += Integer(1)
    200         print(f"[+] sucess rate: {count}%")
--> 202 Tester().test()


Cell In [2], line 188, in Tester.test(self, enum, log)
    186 self.test_resultant()
    187 self.test_lattice()
--> 188 self.test_LLL(enum=enum)


Cell In [2], line 179, in Tester.test_LLL(self, enum)
    177 def test_LLL(self, enum:bool=False):
    178     d, flag = self.solver.LLL(enum=enum)
--> 179     assert d == self.d
    180     assert flag == self.flag


AssertionError: 

经测试,只有 10% 左右的概率解决这个问题。

Tester().test_benchmark()

[out]:

[+] sucess rate: 13%

题目中的问题也解不出答案

Solver(s=s, r=r, c=c).solve()

[out]:

100%|##########| 3/3 [00:00<00:00, 1080.26it/s]

[out]:

(None, None)

4. 优化

测试说明单纯的 LLL 约束不足,有以下几种方法解决

4.1 LLL enumeration

找到更多组解

Tester().test_benchmark(enum=True)

[out]:

[+] sucess rate: 100%
Solver(s=s, r=r, c=c).solve(enum=True)

[out]:

[+] flag: b'n1ctf{Wow!__You_bre4k_my_s1gn_chal1enge!!___}'

[out]:

 30%|###       | 3/10 [00:00<00:00, 306.09it/s]

[out]:

(75767369414377063170504861698029562004628781620076152109208613361284011206099,
 b'n1ctf{Wow!__You_bre4k_my_s1gn_chal1enge!!___}')

4.2 由 d 256 bit 约束 d 的范围

$$ \begin{align*} & 2^{256} \le d \lt 2^{257} \newline \Rightarrow & \newline & d = 2^{255} + d^{'}, 2^{254} \le d^{'} < 2^{255} \newline & d_0^{'} = 2^{127} + d_0^{'}, 0 \le d_0^{'} \lt 2^{127} \newline \Rightarrow & \newline & s\cdot(h_0 \cdot 2^{128} + (2^{127}+d_0^{'})) \equiv h+r\cdot((2^{127}+d_0^{'}) \cdot 2^{128} + d_1) \pmod p \newline \end{align*} $$

class Solver2(Solver):    
    def gen_eq(self):
        PR = PolynomialRing(GF(self.p), ["d0", "d1"])
        d0, d1 = PR.gens()
        d0_ = 2**127+d0
        self.eq = self.s*(self.h0*2**128 + d0_) - self.h - self.r*(d0_ * 2**128+d1)

    def balance(self):
        self.bound = [2**128, 1, 2**127]

    def check(self, v):
        if abs(v[1]) == 1:
            v*=v[1]
            d1, d0 = v[0]%self.p, v[2]%self.p
            d0 += 2**127
            if d0.nbits() == 128 and d1.nbits() <= 128:
                d = d0 * 2**128 + d1
                flag = self.get_flag(d)
                return d, flag
        return None, None        
        
class Tester2(Tester):
    def __init__(self):
        super().__init__()
        self.solver = Solver2(self.s, self.r, self.c)
        assert self.d0.bit_length() == 128
        self.d0 -= 2**127
        assert self.d0 > 0 and self.d0.bit_length() <= 127
                
Tester2().test()

[out]:

0it [00:00, ?it/s]

[out]:

[+] flag: b'flag{test}'

[out]:

 33%|###3      | 1/3 [00:00<00:00, 506.93it/s]

概率率有提升

Tester2().test_benchmark(log=False)

[out]:

[+] sucess rate: 44%

不过已经可以解决问题

Solver2(s=s,r=r,c=c).solve()

[out]:

[+] flag: b'n1ctf{Wow!__You_bre4k_my_s1gn_chal1enge!!___}'

[out]:

  0%|          | 0/3 [00:00<?, ?it/s]

[out]:

(75767369414377063170504861698029562004628781620076152109208613361284011206099,
 b'n1ctf{Wow!__You_bre4k_my_s1gn_chal1enge!!___}')

4.3 d /= 2 进一步约束

格有一半结果是负数,可进一步约束

$$ \begin{align*} & d_0^{''} = d_0^{'} - 2^{126}, - 2^{126} \le d_0 \lt 2^{126} \newline & d_1^{'} = d_1 - 2^{127}, - 2^{127} \le d_0 \lt 2^{127} \newline \Rightarrow & \newline & s\cdot(h_0 \cdot 2^{128} + (2^{127} + 2^{126} +d_0^{''})) \equiv h+r\cdot((2^{127} + 2^{126} + d_0^{'}) \cdot 2^{128} + (2^{127} + d_1)) \pmod p \newline \end{align*} $$

class Solver3(Solver):    
    def gen_eq(self):
        PR = PolynomialRing(GF(self.p), ["d0", "d1"])
        d0, d1 = PR.gens()
        d0_ = 2**127+2**126+d0
        d1_ = 2**127 + d1
        self.eq = self.s*(self.h0*2**128 + d0_) - self.h - self.r*(d0_ * 2**128 + d1_)

    def balance(self):
        self.bound = [2**127, 1, 2**126]

    def check(self, v):
        if abs(v[1]) == 1:
            v*=v[1]
            d1, d0 = ZZ(v[0]), ZZ(v[2])
            d0 += 2**127 + 2**126
            d1 += 2**127
                        
            if d0 > 0 and d1 > 0 and d0.nbits() == 128 and d1.nbits() <= 128:
                d = d0 * 2**128 + d1
                flag = self.get_flag(d)
                return d, flag
        return None, None        
        
class Tester3(Tester):
    def __init__(self):
        super().__init__()
        self.solver = Solver3(self.s, self.r, self.c)
        assert self.d0.bit_length() == 128
        self.d0 -= 2**127 + 2**126
        assert self.d0.bit_length() <= 126
        
        self.d1 -= 2**127
        assert self.d1.bit_length() <= 127
                        
Tester3().test()

[out]:

100%|##########| 3/3 [00:00<00:00, 1413.49it/s]




---------------------------------------------------------------------------

AssertionError                            Traceback (most recent call last)

Cell In [10], line 36
     33         self.d1 -= Integer(2)**Integer(127)
     34         assert self.d1.bit_length() <= Integer(127)
---> 36 Tester3().test()


Cell In [2], line 188, in Tester.test(self, enum, log)
    186 self.test_resultant()
    187 self.test_lattice()
--> 188 self.test_LLL(enum=enum)


Cell In [2], line 179, in Tester.test_LLL(self, enum)
    177 def test_LLL(self, enum:bool=False):
    178     d, flag = self.solver.LLL(enum=enum)
--> 179     assert d == self.d
    180     assert flag == self.flag


AssertionError: 

现在概率较高了

Tester3().test_benchmark(log=False)

[out]:

[+] sucess rate: 89%
Solver3(s, r, c).solve()

[out]:

[+] flag: b'n1ctf{Wow!__You_bre4k_my_s1gn_chal1enge!!___}'

[out]:

  0%|          | 0/3 [00:00<?, ?it/s]

[out]:

(75767369414377063170504861698029562004628781620076152109208613361284011206099,
 b'n1ctf{Wow!__You_bre4k_my_s1gn_chal1enge!!___}')

4.4 爆破

如果只有这么几位的边界值,完全可以爆破来实现。

class Solver4(Solver): 
    def gen_eq(self):
        PR = PolynomialRing(GF(self.p), ["d0", "d1"])
        d0, d1 = PR.gens()
        d0_ = self.d0_known + d0
        d1_ = self.d1_known + d1
        self.eq = self.s*(self.h0*2**128 + d0_) - self.h - self.r*(d0_ * 2**128 + d1_)

    def balance(self):
        self.bound = [2**self.bits, 1, 2**self.bits]

    def check(self, v):
        if abs(v[1]) == 1:
            v*=v[1]
            d1, d0 = ZZ(v[0]), ZZ(v[2])
            d0 += self.d0_known
            d1 += self.d1_known         
            if d0 > 0 and d1 > 0 and d0.nbits() == 128 and d1.nbits() <= 128:
                d = d0 * 2**128 + d1
                flag = self.get_flag(d)
                return d, flag
        return None, None
    
    def solve_brute(self, bits=2):
        self.bits = 128 - bits
        for i in range(2**bits):
            for j in range(2**bits):
                self.d0_known = i << (128-bits)
                self.d1_known = j << (128-bits)
                d, flag = self.solve()
                if flag is not None:
                    return d, flag
                
class Tester4(Tester):
    def __init__(self):
        super().__init__()
        self.solver = Solver4(self.s, self.r, self.c)
        self.real_d0, self.real_d1 = self.d0, self.d1

    def test_solve_brute(self, bits=2, log:bool=True):
        for i in range(2**bits):
            for j in range(2**bits):
                d0_known = i << (128-bits)
                d1_known = j << (128-bits)
                self.d0 = self.real_d0 - d0_known
                self.d1 = self.real_d1 - d1_known
                self.solver.d0_known, self.solver.d1_known, self.solver.bits = d0_known, d1_known, 128-bits
                try:
                    self.test(enum=False, log=False)
                except AssertionError:
                    if log:
                        print("AssertionError")
                else:
                    self.test(enum=False, log=log)
                    return True
        return False
    
    def test_benchmark(self, bits):
        count = 0
        for _ in tqdm(range(1000), leave=False):
            self.__init__()
            if self.test_solve_brute(bits, log=False):
                count += 1
        print(f"[+] sucess rate: {count//10}%")
        
_ = Tester4().test_solve_brute(1)

[out]:

AssertionError
AssertionError

[out]:

[+] flag: b'flag{test}'

[out]:

 67%|######6   | 2/3 [00:00<00:00, 882.27it/s]
Solver4(s, r, c, log=False).solve_brute()

[out]:

(75767369414377063170504861698029562004628781620076152109208613361284011206099,
 b'n1ctf{Wow!__You_bre4k_my_s1gn_chal1enge!!___}')

不爆破

Tester4().test_benchmark(0)

[out]:

[+] sucess rate: 11%

爆破 1 bit

Tester4().test_benchmark(1)

[out]:

[+] sucess rate: 87%

爆破 2 bit

Tester4().test_benchmark(2)

[out]:

[+] sucess rate: 98%

爆破 3 bit

Tester4().test_benchmark(3)

[out]:

[+] sucess rate: 100%

本文同步发表于

  • blog: ssst0n3.github.io
  • 公众号: 石头的安全料理屋
  • 知乎专栏: 石头的安全料理屋