专栏文章

#3. 浅谈离散对数问题

算法·理论参与者 24已保存评论 25

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
25 条
当前快照
1 份
快照标识符
@mhz5s749
此快照首次捕获于
2025/11/15 01:55
4 个月前
此快照最后确认于
2025/11/29 05:24
3 个月前
查看原文

0 前言

今天模拟赛遇到了一个这玩意的板子,然后我只会 bsgs,于是爆零了。

1 定义

离散对数(discrete logarithm)是一个整数 xx 对于给定的 a,b,ma,b,m 满足下面的方程:
axb(modm)a^{x}\equiv b\pmod m
记作 x=logabx=\log_{a}{b}。通常情况先我们把这叫做 【阶】,index\text{index}。记作 indab\text{ind}_{a}{b}
显然离散对数不一定存在。比如:2x3(mod7)2^{x}\equiv 3\pmod 7

2 BSGS

(a,m)=1(a,m)=1 时我们可以使用大步小步(Baby Step Giant Step)算法。
注意到 (a,m)=1(a,m)=1,我们有欧拉定理 aφ(m)1a^{\varphi(m)}\equiv 1。所以 aka^k 至多有 φ(m)\varphi(m) 种取值(也就是其循环节为 φ(m)\varphi(m))。设 x=kBr,0rB1x=kB-r,0\le r\le B-1BB 是我们随便取的一个数,那么有 akBbara^{kB}\equiv ba^{r},我们预处理 a0,a1aB1a^{0},a^{1}\dots a^{B-1},枚举 kk 即可求出 xx(其实这个过程已经可以求出了)。
时间复杂度 O(B+φ(m)B)O(B+\frac{\varphi(m)}{B}),随便根号平衡一下得 O(φ(m))O(\sqrt{\varphi(m)})

3 Ex-BSGS

用于解决 (a,m)1(a,m)\neq 1 时的情况。设 d=gcd(a,m)d=\gcd(a,m),那么有 adax1bd(modmd)\frac{a}{d}a^{x-1}\equiv \frac{b}{d}\pmod {\frac{m}{d}}。于是这么一路递归除下去即可。注意判断无解,当 dd 不整除 bb 时即无解。
递归完之后就可以正常的 bsgs 了。

4 Pohlig–Hellman algorithm

尝试自己口胡ing。
这里我们不妨设模数是个大质数 PP。我们可以找出一个原根 gg,然后求 gxh(modP)g^x\equiv h\pmod P
算法思想大概就是想把 p1p-1 质因数分解为 piei\prod p_i^{e_i},然后计算 xxi(modpiei)x\equiv x_i\pmod{p_i^{e_i}}
考虑 gp11(modP)g^{p-1}\equiv 1\pmod P。所以有
(gx)p1piei=(gxi+kpiei)p1piei(gp1piei)xihp1piei(modP)(g^x)^{\frac{p-1}{p_i^{e_i}}}=(g^{x_i+kp_i^{e_i}})^{\frac{p-1}{p_i^{e_i}}}\equiv (g^{\frac{p-1}{p_i^{e_i}}})^{x_i}\equiv h^{\frac{p-1}{p_i^{e_i}}}\pmod P
所以令 gp1piei,hp1pieig^{\frac{p-1}{p_i^{e_i}}},h^{\frac{p-1}{p_i^{e_i}}} 取代原来的 g,hg,h 就可以在 pieip_i^{e_i} 范围内求 xix_i 了。
所以我们需要解决的问题变成了 gxh(modP)g^x\equiv h\pmod P,其中 x[0,piei1]x\in [0,p_i^{e_i}-1]。考虑将 xx 写成 pip_i 进制数,显然有 eie_i 位,从低到高逐位确定。即 x=x0+x1pi+x2pi2++xe1piei1x=x_0+x_1p_i+x_2p_i^2+\dots+x_{e-1}p_i^{e_i-1}。然后当我们想要求 xjx_j 的时候,就计算 (gx)p1pij+1(g^x)^{\frac{p-1}{p_i^{j+1}}},容易发现这又可以写成 gxjh(modP)g^{x_j}\equiv h\pmod P 的形式,但是这时 xjx_j 的范围就变成了 [0,pi1][0,p_i-1]。这个时候我们就可以直接 BSGS 了。
综上所述,整个算法的复杂度为 O(ei(logP+pi))O(\sum e_i(\log P+\sqrt{p_i}))。比普通 BSGS 有了不小的提升。

5 例题

都是一些板子题。
CPP
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
template <typename T>
void read(T &x) {
	T flag = 1;
	char ch = getchar();
	for (; '0' > ch || ch > '9'; ch = getchar()) if (ch == '-') flag = -1;
	for (x = 0; '0' <= ch && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
	x *= flag;
}
ll ksc(ll a, ll b, ll m) {
	return (a * b - (ll)((long double)a / m * b) * m + m) % m;
}
ll ksm(ll a, ll b, ll m) {
	ll ret = 1;
	for (; b; b >>= 1, a = a * a % m) if (b & 1) ret = ret * a % m;
	return ret;
}
ll exgcd(ll a, ll b, ll &x, ll &y) {
	if (b == 0) {
		x = 1; y = 0;
		return a;
	}
	ll d = exgcd(b, a % b, y, x);
	y -= a / b * x;
	return d;
}
ll A[100005], B[100005];
ll crt(int n) {
	ll ans = B[1], M = A[1];
	for (int i = 2; i <= n; i++) {
		ll x0, y0;
		ll now = ((B[i] - ans) % A[i] + A[i]) % A[i];
		ll d = exgcd(M, A[i], x0, y0);
		if (now % d) return -1;
		now /= d;
		ll m = A[i] / d;
		x0 = ksc(x0, now, m);
		ans = ans + x0 * M;
		M = M / d * A[i];
	}
	ans = (ans % M + M) % M;
	return ans;
}
map<ll, ll> mp;
ll bsgs(ll a, ll b, ll p) {
	ll len = sqrt(p) + 1;
	mp.clear();
	ll base = ksm(a, len, p), val = base;
	for (ll i = len; i < p; i += len) {
		if (mp.find(val) == mp.end()) mp[val] = i;
		val = val * base % p;
	}
	ll ret = 0x7f7f7f7f7f7f7f7f;
	val = b;
	for (ll i = 0; i < len; i++) {
		if (mp.find(val) != mp.end()) {
			ret = min(ret, mp[val] - i);
		}
		val = val * a % p;
	}
	return ret;
}
int main() {
	ll p, b, n;
	read(p); read(b); read(n);
	ll ans = bsgs(b, n, p);
	if (ans == 0x7f7f7f7f7f7f7f7f) puts("no solution");
	else cout << ans << "\n";
	return 0;
}
CPP
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
template <typename T>
void read(T &x) {
	T flag = 1;
	char ch = getchar();
	for (; '0' > ch || ch > '9'; ch = getchar()) if (ch == '-') flag = -1;
	for (x = 0; '0' <= ch && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
	x *= flag;
}
ll ksc(ll a, ll b, ll m) {
	return (a * b - (ll)((long double)a / m * b) * m + m) % m;
}
ll ksm(ll a, ll b, ll m) {
	ll ret = 1;
	for (; b; b >>= 1, a = a * a % m) if (b & 1) ret = ret * a % m;
	return ret;
}
ll gcd(ll a, ll b) {
	return b == 0 ? a : gcd(b, a % b);
}
ll exgcd(ll a, ll b, ll &x, ll &y) {
	if (b == 0) {
		x = 1; y = 0;
		return a;
	}
	ll d = exgcd(b, a % b, y, x);
	y -= a / b * x;
	return d;
}
ll A[100005], B[100005];
ll crt(int n) {
	ll ans = B[1], M = A[1];
	for (int i = 2; i <= n; i++) {
		ll x0, y0;
		ll now = ((B[i] - ans) % A[i] + A[i]) % A[i];
		ll d = exgcd(M, A[i], x0, y0);
		if (now % d) return -1;
		now /= d;
		ll m = A[i] / d;
		x0 = ksc(x0, now, m);
		ans = ans + x0 * M;
		M = M / d * A[i];
	}
	ans = (ans % M + M) % M;
	return ans;
}
map<ll, ll> mp;
ll bsgs(ll a, ll b, ll p) {
	ll len = sqrt(p) + 1;
	mp.clear();
	ll base = ksm(a, len, p), val = 1;
	for (ll i = 0; i < p; i += len) {
		if (mp.find(val) == mp.end()) mp[val] = i;
		val = val * base % p;
	}
	ll ret = 0x7f7f7f7f7f7f7f7f;
	val = b;
	for (ll i = 0; i < len; i++) {
		if (mp.find(val) != mp.end() && mp[val] >= i) {
			ret = min(ret, mp[val] - i);
		}
		val = val * a % p;
	}
	return ret;
}
ll exbsgs(ll a, ll b, ll p) {
	a %= p; b %= p;
	ll d, k = 0;
	while ((d = gcd(a, p)) > 1) {
		if (b % d) return -1;
		p /= d;
		k++;
	}
	ll x0, y0;
	exgcd(ksm(a, k, p), p, x0, y0);
	x0 = (x0 % p + p) % p;
	b = b * x0 % p;
	ll ans = bsgs(a, b, p);
	if (ans == 0x7f7f7f7f7f7f7f7f) return -1;
	return ans + k;
}
int main() {
	ll a, p, b;
	while (1) {
		read(a); read(p); read(b);
		if (a == 0 && p == 0 && b == 0) break;
		ll ans = exbsgs(a, b, p);
		if (ans == -1) puts("No Solution");
		else printf("%lld\n", ans);
	}
	return 0;
}
这里 p1p-1 的质因子只有 2,32,3,那么直接上 ph 算法。
CPP
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
template <typename T>
void read(T &x) {
	T sgn = 1;
	char ch = getchar();
	for (; !isdigit(ch); ch = getchar()) if (ch == '-') sgn = -1;
	for (x = 0; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0';
	x *= sgn;
}
ll exgcd(ll a, ll b, ll &x0, ll &y0) {
	if (b == 0) {
		x0 = 1, y0 = 0;
		return a;
	}
	ll d = exgcd(b, a % b, y0, x0);
	y0 = y0 - a / b * x0;
	return d;
}
ll mul(ll a, ll b, ll m) {
	return (__int128)a * b % m;
}
ll ksm(ll a, ll b, ll m) {
	ll ret = 1;
	for (; b; b >>= 1, a = mul(a, a, m)) if (b & 1) ret = mul(ret, a, m);
	return ret;
}
ll work(ll m, ll p, ll e, ll g, ll h) {
	vector<int> ans; ans.resize(e);
	ll z = ksm(p, e, m);
	g = ksm(g, (m - 1) / z, m);
	h = ksm(h, (m - 1) / z, m);
	for (int i = 0; i < e; i++) {
		ans[i] = -1;
		ll gi = ksm(g, (m - 1) / p, m), hi = ksm(h, (m - 1) / ksm(p, i + 1, m), m), cur = 0;
		for (int j = 0; j < i; j++) {
			cur = (cur - mul(ans[j], (m - 1) / ksm(p, i + 1 - j, m), m - 1) + m - 1) % (m - 1);
		}
		hi = mul(hi, ksm(g, cur, m), m);
		for (int j = 0; j < p; j++) {
			if (ksm(gi, j, m) == hi) {
				ans[i] = j;
				break;
			}
		}
		if (ans[i] == -1) return -1;
	}
	ll ret = 0, now = 1;
	for (int i = 0; i < e; i++) {
		ret = (ret + mul(ans[i], now, z)) % z;
		now = now * p;
	}
	return ret;
}
ll calc(ll g, ll h, ll p) {
	ll x = p - 1, x2, x3;
	ll e[5];
	e[2] = e[3] = 0;
	while (x % 2 == 0) e[2]++, x /= 2;
	while (x % 3 == 0) e[3]++, x /= 3;
	if (!e[2]) return work(p, 3, e[3], g, h);
	if (!e[3]) return work(p, 2, e[2], g, h);
	x2 = work(p, 2, e[2], g, h);
	x3 = work(p, 3, e[3], g, h);
	if (x2 == -1 || x3 == -1) return -1;
	ll p2 = ksm(2, e[2], p);
	ll p3 = ksm(3, e[3], p);
	x2 = (x2 % p2 + p2) % p2;
	x3 = (x3 % p3 + p3) % p3;
	ll q2, q3, t;
	exgcd(p3, p2, q2, t);
	q2 = (q2 % p2 + p2) % p2;
	exgcd(p2, p3, q3, t);
	q3 = (q3 % p3 + p3) % p3;
	ll ret = ((mul(mul(x2, p3, p - 1), q2, p - 1) + mul(mul(x3, p2, p - 1), q3, p - 1)) % (p - 1) + p - 1) % (p - 1);
	return ret;
}
void solve() {
	ll p, a, b;
	read(p); read(a); read(b);
	ll g = 2;
	while (1) {
		bool flg = true;
		if ((p - 1) % 2 == 0 && ksm(g, (p - 1) / 2, p) == 1) flg = false; 
		if ((p - 1) % 3 == 0 && ksm(g, (p - 1) / 3, p) == 1) flg = false;
		if (flg == true) break;
		g++;
	}
	a = calc(g, a, p); b = calc(g, b, p);
	if (a == -1 || b == -1) return puts("-1"), void();
	ll x, y;
	ll d = exgcd(a, p - 1, x, y);
	if (b % d != 0) puts("-1");
	else {
		x = mul(x, b / d, p - 1);
		ll s = (p - 1) / d;
		x = (x % s + s) % s;
		printf("%lld\n", x);
	}
}
int main() {
	int T; read(T); while (T--) solve();
	return 0;
}
这种问题还有一个可以优化的地方,就是当模数不变的时候,假设有 TT 组数据,那么可以通过更改块大小,将复杂度从 O(Tmod)O(T\sqrt{mod}) 变成 O(Tmod)O(\sqrt{Tmod})

评论

25 条评论,欢迎与作者交流。

正在加载评论...