在前面几章中,我们已经介绍了RSA加解密的基本原理和一些相关攻击方法。在本章中,我们将继续深入探讨RSA加解密中的预言机相关攻击。这些攻击利用了预言机(Oracle)的功能来推断出明文或私钥,从而威胁到RSA的安全性。

选择明密文相关攻击 😆

这种攻击适用于 oracle(预言机)攻击场景,一般输入明文,返回密文,并且输入的明文中不能包含 flag 还有输入密文,返回明文,并且返回的明文中不能包含flag。

首先,我们选择明文攻击,模数一定是和偶数互质的,所以我们直接选择加密 $2$, $4$, $8$,然后可以知道:

由于我们知道自己输入的明文,所以我们可以知道 $c_1$、$c_2$、$c_3$ 之间的关系:

所以,可以找到公因数:

因此我们就可以通过 $gcd$ 来计算出 $n$ 的值了,但是要考虑 $n$ 是否为质数。

然后,我们选择密文攻击,假设密文 $c = m^e \mod n$ 运用下列步骤求出 $m$ 的值:

  1. 选择任意的 $X \in \mathbb{Z}_n^*$,即 $X$ 与 $n$ 互素
  2. 计算 $Y = c \cdot X^e \mod n$
  3. 由于我们可以进行选择密文攻击,那么我们求得 $Y$ 对应的解密结果 $Z = Y^d = (c \cdot X^e)^d = c^d \cdot X = m \cdot X \mod n$
  4. 那么,由于 $X$ 与 $n$ 互素,我们很容易求得相应的逆元,进而可以得到 $m$ 的值了。

选择明密文相关攻击例题

选择明密文相关攻击例题下载

import os
from Crypto.Util.number import getPrime, bytes_to_long, long_to_bytes

class RsaCrypt:
    def __init__(self, p, q, e):
        self.n = p * q
        self.e = e
        self.d = pow(e, -1, (p-1)*(q-1))

    def enc(self, data):
        m = bytes_to_long(data)
        c = pow(m, self.e, self.n)
        return long_to_bytes(c)

    def dec(self, data):
        c = bytes_to_long(data)
        m = pow(c, self.d, self.n)
        return long_to_bytes(m)

def handle_client():
    if not os.path.exists("flag.txt"):
        with open("flag.txt", "wb") as f:
            f.write()

    p = getPrime(1024)
    q = getPrime(1024)
    e = 65537
    crypt = RsaCrypt(p, q, e)
    while 1 :
        cmd = input("cmd >").strip()
        data = input("data >").strip()

        # 补齐奇数长度的字符串以满足 bytes.fromhex 要求 (例如 '2' -> '02')
        if len(data) % 2 != 0:
            data = "0" + data

        if cmd == "enc":
            try:
                data = bytes.fromhex(data)
            except ValueError:
                print("invalid hex data")
                continue
            if b"flag" in data:
                print("data can't contain \'flag\'")
            else:
                res = crypt.enc(data)
                if res:
                    print(res.hex())
                else:
                    print("args wrong")
        elif cmd == "dec":
            try:
                data = bytes.fromhex(data)
            except ValueError:
                print("invalid hex data")
                continue
            res = crypt.dec(data)
            if res:
                if b"flag" in res:
                    print("you can't decrypt flag")
                else:
                    print(res.hex())
            else:
                print("args wrong")
        elif cmd == "get_flag":
            with open("flag.txt", "rb") as file:
                flag = file.read()
            res = crypt.enc(flag)
            print(res.hex())

if __name__ == "__main__":
    handle_client()

首先我们可以通过选择明文攻击来获取到模数 $n$ 的值了,接下来我们就可以通过选择密文攻击来获取到明文的值了。

解题代码示例:

from Crypto.Util.number import *

# get_flag
c = 0x246ed3b729330b0799ec47da886dfb2b6276bcffc66c0ede83fd7e8e1d36a82f0e51e1ba5f95253c04213c3236cf4e102cd3a1578a66ea1a1f03c2e488302da49a76d18aacab364ae1447f4d71f75658eaa3bc8fa1fc1cc9871fc7dccc8ba6f03a5d9343a1aab2404c579823e971448b7ee74bfa2ec0adb70e8fde943687e86036abe5bdf9fba8ecbfaf2bf2d6806ee6d62f62e6e344b0050840a7d167f61aa99606e5cbc51758bbb8633bc16294fcbe286ad592ea3c4d1209dea83cae7a2a99865bc13422a246e624b21da6481b8362226ece44ad12180b5631ced203a877e3e5028917824c19e377c00c5ecee6cc054cbeab9fa168a8bebe6c57684cbb3f3

# enc 2
c_2 = 0x114a75c97c0334cc5148eedb4b9f66a9f4d94c66d571ca1e2f69a6ee72c9cd2a5de3e3c8b93e8bb9328a36d2e6eec828c3315705cc1e54cf8f65a777a7aae5134f395769c38448c26572a53c901879f7245a6ea689e149bfba84035a645530f7ac998fd2c48ae8f41b3fed7e2837911565d5aa69f712fc78c20c47a6742da9804737a9d30edbfdcbe7c94d38201625c3fa39234d19b75088dfd98eae4ff145e8c689e8a54c2d717d24d4829d95d7a5de3d31e04ec298038c3c85b18600db770fd4931434f5d222b55187ff1309ef91f646fc4f2f431c86ce71cbace8ad14d82266bf5f3cb8220645156226d609a409068f31587998e0c727ec91163c06f7822c
# enc 4
c_4 = 0x4c1fb7110722eaf6ea9a51f8c0f68d2de3e90a86408afa76cef8d0bd255d2b9e1f581c56c5478726082f23e066e60be1086e6d14b23d8b401cfcb57c3eb65ee4c39a820341858cc9ba32f9e9a2aa71cbb6ef03edd6b54ba1eb8db5d327e2ad5722cfabc8311cf26607ef146138ea14d310c10859c3686abd4b55dff7a54aba1bc57e494681ca5edde24264de16c72ee8855630038299cb2702c2149d09cecb6ddccdfe813d7697745551c5ec1d9b066d4e13f082b44094b0607721b544aed6b65d5a338917babb6142772b3b72a3b8c289adcc7649adb2cbd13d5cc0032fd36bfbd2934377f86acb8d9b97b4d7821832f79af9d5a0d80df3aa54ed9e067c3ee
# enc 8
c_8 = 0x418fdbe7174db8c25395457fae06ce57f1761b4e3b73ae26150d1b951813d41646d8c93315b19c47cc65294d5ea0cebba6663fe750d741eae1d8fd19d7da87324c7df04fff9125b34e821cac0e7c1eabdf1d3224529029e0d91545db97eb314bc0040e0ff225f8a42f1b0fc1dd5a51bf05b8f70d9b6b3992ac7a290496768cdab6d136c3566a0ea9623dfa40fa29dd684629ddf5fc447682c6c3c656ec30d2cdc5f51bc4dba4e87b9323b47a296f2326ad6014f905b70b928478273478dff01b71dffd6d54d6382ddb8d4a6e636148ee3b314a442eecaea2582df53549e5fb975c679c6d56367641c1860889edf73dfc002a7f0f214e45d41082cc8b6a92c9a0

n = GCD(c_2**2 - c_4, c_2**3 - c_8)
while n % 2 == 0:
    n //= 2
print(f"n: {n}")

X = 2
Y = (c * c_2) % n
print(f"Y: {hex(Y)}")

# dec Y
Yd = 92654804595522157479766957049954882296889731182656163601277452832538918832378
m = (Yd * inverse(X, n)) % n
print(long_to_bytes(m))
# b'flag{u_know_what_the_oracle_did}'

奇偶预言机 😎

有一个 oracle,它会对一个给定的密文进行解密,并且会检查解密的明文的奇偶性,并根据奇偶性返回相应的值,比如 1 表示奇数,0 表示偶数。

这一种预言机我们可以通过构造不断放缩确定密文的范围,最终确定明文的值了。我们有 $c \equiv m^e \mod n$。

我们可以将 $c \cdot 2^e \equiv (2 \cdot m)^e \mod n$ 作为输入,那么预言机就会计算出 $2m \mod n$。

  • 如果服务器返回 $1$,说明 $2m \mod n$ 是奇数,则说明 $2m$ 大于 $n$,且减去了奇数个 $n$,又因为 $2m < 2n$,那么

  • 如果服务器返回 $0$,说明 $2m \mod n$ 是偶数,则说明 $2m$ 小于 $n$,又因为 $m < n$,那么

依据这个思路,一直放缩 $i$ 次,有

那么第 $i + 1$ 次放缩时,输入 $c \cdot 2^{(i+1)e} \mod n$,得到

根据第 $i$ 次放缩的结果,有

  • 如果服务器返回 $1$,则 $k$ 必然是奇数,设 $k = 2y + 1$,则

与此同时,由于 $m$ 必然存在,所以第 $i+1$ 次得到的这个范围和第 $i$ 次得到的范围必然存在交集。所以 $y$ 必然与 $x$ 相等。

  • 服务器返回偶数,则 $k$ 必然是一个偶数,$k = 2y$,此时 $y$ 必然也与 $x$ 相等。

如此反复放缩,最终就可以确定 $m$ 的值了。

奇偶预言机例题

奇偶预言机例题下载

import socket
import threading
import os
from Crypto.Util.number import getPrime, bytes_to_long

def generate_key(bits=512):
    p = getPrime(bits)
    q = getPrime(bits)
    n = p * q
    e = 65537
    d = pow(e, -1, (p-1)*(q-1))
    return (n, e), d

def handle_client(c, pub, priv, flag_c):
    n, e = pub
    d = priv
    c.sendall(f"n = {n}\ne = {e}\nflag_c = {flag_c}\n".encode())
    while True:
        try:
            c.sendall(b"ct (hex): ")
            data = c.recv(1024).decode().strip()
            if not data: break
            ct = int(data, 16)
            pt = pow(ct, d, n)
            c.sendall(f"{pt % 2}\n".encode())
        except Exception as err:
            break
    c.close()

pub, priv = generate_key(512)
with open("flag.txt", "rb") as f:
    flag = f.read()
flag_c = pow(bytes_to_long(flag), pub[1], pub[0])

s = socket.socket()
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(('127.0.0.1', 8888))
s.listen(5)
print("Oracle Server listening on 127.0.0.1:8888")
while True:
    conn, addr = s.accept()
    threading.Thread(target=handle_client, args=(conn, pub, priv, flag_c)).start()

按照上面的思路,我们就可以通过不断放缩来确定明文的值了。

# 记得结束后杀掉服务器进程
sudo kill -9 $(sudo lsof -t -i:8888)

解题代码示例:

from pwn import *
from Crypto.Util.number import long_to_bytes
from fractions import Fraction
import sys

# 设置大整数转换限制,防止溢出或报错
sys.set_int_max_str_digits(10000)

context.log_level = 'error' # 关闭多余的 pwntools 输出
r = remote('127.0.0.1', 8888)

r.recvuntil(b'n = ')
n = int(r.recvline().strip())
r.recvuntil(b'e = ')
e = int(r.recvline().strip())
r.recvuntil(b'flag_c = ')
flag_c = int(r.recvline().strip())

print(f"n = {n}\ne = {e}\nc = {flag_c}")
print("开始执行奇偶预言机攻击...")

# 我们想要解密明文,依据题意:输入 c * 2^e mod n 来获取 2m mod n 的奇偶性
multiplier = pow(2, e, n)
ct = flag_c

low = Fraction(0)
high = Fraction(n)

# 需要的交互次数为 n 的比特长度(二分查找的次数)
bits = n.bit_length()

for i in range(1, bits + 2):
    ct = (ct * multiplier) % n

    r.recvuntil(b"ct (hex): ")
    r.sendline(hex(ct)[2:].encode())

    try:
        res = r.recvline()
        parity = int(res.strip())
    except EOFError:
        break
    except ValueError:
        print("无效返回:", res)
        break

    mid = (low + high) / 2

    if parity == 1:
        low = mid
    else:
        high = mid

    if i % 50 == 0:
        print(f"[*] 已推进 {i}/{bits} 步...")

m = int(high)
print()
print("攻击完成,解密得到的明文十六进制为:", hex(m))
print("转换为 ASCII:", long_to_bytes(m).decode(errors='replace'))

r.close()
# flag{you_have_gaind_a_deep_insight_into_the_rsa_parity_oracle!!!}

字节预言机 🥰

有一个 oracle,它会对一个给定的密文进行解密,并且会给出明文的最后一个字节。这一种可以看作上一种预言机的扩展,一样地放缩来确定明文的范围,最终确定明文的值了。

我们有 $c \equiv m^e \mod n$,构造 $c \cdot 256^e \equiv (256 \cdot m)^e \mod n$ 作为输入,那么预言机就会计算出 $256m \mod n$ 的值。

由于 $m$ 一般是小于 $n$ 的,所以 $256m \mod n = 256m - kn$,其中 $k < 256$。

而且对于两个不同的 $k_1$ 和 $k_2$,$256m - k_1n$ 和 $256m - k_2n$ 的最后一个字节是不同的。

$256m - kn$ 的最后一个字节其实就是 $-kn$ 在模 $256$ 的情况下获取的。那么其实我们可以首先枚举出0~255情况下的最后一个字节,构造一个 $k$ 和最后一个字节的映射表 $map$。

当服务器返回最后一个字节,那么我们可以根据上述构造的映射表得知 $k$,即减去了$k$ 个 $n$,即

一样地放缩 $i$ 次,有

第 $i+1$ 次放缩时,输入 $c \cdot 256^{(i+1)e} \mod n$,得到

根据第 $i$ 次放缩的结果,有

这里可以假设 $k = 256y + t$,而这里的 $t$ 就是我们可以通过映射表获得的。

与此同时,由于 $m$ 必然存在,所以第 $i+1$ 次得到的这个范围和第 $i$ 次得到的范围必然存在交集。所以 $y$ 必然与 $x$ 相等。如此反复放缩,最终就可以确定 $m$ 的值了。

字节预言机例题

字节预言机例题下载

from Crypto.Util.number import getPrime, bytes_to_long
import os

def gen_keys(bits=1024):
    p = getPrime(bits // 2)
    q = getPrime(bits // 2)
    n = p * q
    e = 65537
    d = pow(e, -1, (p-1)*(q-1))
    return (n, e), d

class Oracle:
    def __init__(self, n, d):
        self.n = n
        self.d = d

    def decrypt_last_byte(self, c):
        m = pow(c, self.d, self.n)
        return m % 256

pub, priv = gen_keys(1024)
n, e = pub
d = priv
with open("flag.txt", "rb") as f:
    flag = f.read()
m = bytes_to_long(flag)
c = pow(m, e, n)

print(f"n = {n}")
print(f"e = {e}")
print(f"c = {c}")

oracle = Oracle(n, d)

while True:
    try:
        query = int(input("c> "))
        last_byte = oracle.decrypt_last_byte(query)
        print(f"Last byte: {last_byte}")
    except:
        break

这个题的解法和前面的思路是一样的,只不过这里我们需要构造一个 $k$ 和最后一个字节的映射表来获取 $k$ 的值了。

解题代码示例:

import math
from pwn import *
from Crypto.Util.number import long_to_bytes

p = process(['python3', 'server.py'])

p.recvuntil(b"n = ")
n = int(p.recvline().strip())
p.recvuntil(b"e = ")
e = int(p.recvline().strip())
p.recvuntil(b"c = ")
c = int(p.recvline().strip())

def query(c_query):
    p.sendlineafter(b"c> ", str(c_query).encode())
    p.recvuntil(b"Last byte: ")
    return int(p.recvline().strip())

# 提前计算映射表,实际上就是求 (-n) 模 256 的逆元
# 满足 equation: (k * (-n)) % 256 = b  =>  k = (b * inv_n) % 256
inv_n = pow(-n, -1, 256)
x = 0

# 总查询次数取决于位数,一次查询泄露一个字节(8位)
total_queries = math.ceil(n.bit_length() / 8.0)

# 为减少日志输出,我们每10次打印一次进度
for i in range(1, total_queries + 1):
    c_query = (c * pow(pow(256, i, n), e, n)) % n
    b = query(c_query)

    # 我们已知 256^i * m % n = 256^i * m - K * n
    # 当前迭代的余数贡献主要来自于 (-k * n) 模 256
    # 通过预先计算好的逆元,可以直接求出减去的个数值 k
    k = (b * inv_n) % 256
    x = x * 256 + k
    if i % 10 == 0:
        print(f"[*] 已完成 {i}/{total_queries} 次查询")
print("[*] 所有查询已完成。")

# 根据所有迭代的 x 值,m 的取值范围最终被限缩到以下闭区间内
# [x * n / 256^total_queries, (x+1) * n / 256^total_queries)
lower = x * n // (256**total_queries)
upper = (x + 1) * n // (256**total_queries)
print(f"预期下界: {lower}")
print(f"预期上界: {upper}")

if lower == upper:
    m = lower
    print(f"明文 m: {m}")
    print(f"Flag : {long_to_bytes(m)}")
else:
    # 为了容灾计算的精度误差,我们在上下界附近测试真正的明文值
    for m in range(lower - 2, upper + 2):
        if pow(m, e, n) == c:
            print(f"明文 m: {m}")
            print(f"Flag : {long_to_bytes(m)}")
            break

p.close()

上一章:RSA加解密专题 - Coppersmith 相关攻击二 👈
下一章:
回到开始:关于我 👈

相关链接:

参考:
CTF Wiki - RSA 选择明密文攻击 👈