代数推导保平安。
思路
首先将
p 分解,得到一堆奇质数。根据 CRT,如果在每个数下的解的组数分别为
t1,t2,⋯,则最后的答案为
i∏ti。
下面只讨论
p 为奇质数的情况(事实上,
p=2 也可以用组合意义分讨出来)。出于自己的习惯,将题面中的
x 记为
r。
写成式子,答案为:
a=0∑p−1b=0∑p−1[a2+b2≡r(modp)]
考虑单位根反演,得到:
a=0∑p−1b=0∑p−1p1i=0∑p−1(ωpa2+b2−r)i
整理得:
p1i=0∑p−1ωp−ri(x=0∑p−1ωpx2i)2
里面的式子比较碍事,考虑化简。枚举数变成枚举平方,一个数
t=x2 对应
1+(pt) 个
x,因此有:
x=0∑p−1ωpx2i=t=0∑p−1(1+(pt))ωpti
带回原式得到答案为:
p1i=0∑p−1ωp−ri(t=0∑p−1(1+(pt))ωpti)2
拆开括号:
p1i=0∑p−1ωp−ri(t=0∑p−1(ωpi)t+t=0∑p−1(pt)ωpti)2
我们发现
t=0∑p−1(ωpi)t 是单位根反演的形式,可以还原成
p×[p∣i]。这一项仅在
i=0 时有值,把它单拎出来,得到上式等于:
p1(p+t=0∑p−1(pt))2+p1i=1∑p−1ωp−ri(t=0∑p−1(pt)ωpti)2
再根据除了
0 以外,
p 的二次剩余与二次非剩余均有
2p−1 个,得到
t=0∑p−1(pt)=0,因此上式等于:
p+p1i=1∑p−1ωp−ri(t=0∑p−1(pt)ωpti)2
现在上式满足
(p,i)=1,因此
(pi)2=1 且
it 与
t 共同遍历模
p 完系。因此我们有:
t=0∑p−1(pt)ωpti=t=0∑p−1(pt)(pi)2ωpti=(pi)t=0∑p−1(pit)ωpit=(pi)t=0∑p−1(pt)ωpt
上面用到了勒让德符号的完全积性。带回原式得到答案为:
p+p1i=1∑p−1ωp−ri((pi)t=0∑p−1(pt)ωpt)2
又有
(pi)2=1,得到:
p+p1(i=1∑p−1ωp−ri)(t=0∑p−1(pt)ωpt)2
r=0 是否成立会影响第一个括号的值。分两类。当
r=0 时:
p+pp−1(t=0∑p−1(pt)ωpt)2
否则考虑给第一个括号补上
i=0,逆用单位根反演再展开,得到
r=0 时的结果:
p−p1(t=0∑p−1(pt)ωpt)2
上面的式子均有共同的一项:
t=0∑p−1(pt)ωpt。问题转化为求该式。
回到一开始的
t=1∑p−1ωpt2,其等于:
t=1∑p−1(1+(pt))ωpt
去括号:
t=1∑p−1ωpt+t=1∑p−1(pt)ωpt
代入
t=1∑p−1ωpt=0,第二个循环提前到从
0 开始,得到:
t=0∑p−1(pt)ωpt
因此
t=0∑p−1(pt)ωpt=t=1∑p−1ωpt2。
等式右边被称为二次高斯和。根据二次互反律证明的相关知识,我们有:
\sqrt p & p\equiv 1\pmod 4\\
i\sqrt p & p\equiv 3\pmod 4
\end{matrix}\right.$$
其中 $i$ 为虚数单位。读者自证不难,具体思路大概为因式分解其平方后确定正负号。
带回到原式得到答案 $A$。
当 $r=0$ 时,$A=\left\{\begin{matrix}
2p-1 & p\equiv 1\pmod 4\\
1 & p\equiv 3\pmod 4\end{matrix}\right.$。
当 $r\neq0$ 时,$A=\left\{\begin{matrix}
p-1 & p\equiv 1\pmod 4\\
p+1 & p\equiv 3\pmod 4\end{matrix}\right.$。
解决!
### 代码
```cpp
#include<bits/stdc++.h>
#define F(i,a,b) for(int i(a),i##i##end(b);i<=i##i##end;++i)
#define R(i,a,b) for(int i(a),i##i##end(b);i>=i##i##end;--i)
#define ll long long
using namespace std;
int n,p,r;
int v[10000001];
vector<int>prime;
inline void euler(){
F(i,2,10000000){
if(!v[i]) v[i]=i,prime.push_back(i);
for(int j:prime){
ll t(1ll*i*j);
if(t>=10000001) break;
v[t]=j;
if(v[i]==j) break;
}
}
return;
}
int main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>n;
for(euler();n;--n){
cin>>p>>r;
ll ans(1);
while(p!=1){
int t=v[p];
p/=t;
if(r%t==0) ans=ans*((t&3)==1?(t<<1)-1:1);
else ans=ans*((t&3)==1?t-1:t+1);
}
cout<<ans<<"\n";
}
return 0;
}
```