这道题的出题人非常良心,只要是时间复杂度在
O(n3) 以内的算法就能通过,然而事实上这道题可以做到
O(1),本篇题解是
O(1) 算法。
解题思路
1. 特殊情况
我们特殊处理
m=1 的情况,因为此时中心点不是交点之一。共有
2n 个点分布在同一条直线上,除第
n 个点和第
n+1 的点间隔为
2 以外,相邻的点间隔均为
1。不妨视那条特殊的线段的长度也为
1,最后再加上路径包含这条线段的点对数量。所以,我们可以很容易地推断出答案为:
n2+i=1∑2n−1i(2n−i)
经过化简,可以得到如下的式子:
n2−3n(2n−1)(4n−1)+2n2(2n−1)
2. 一般情况
剩下的情况中,我们视
m 条直线为
2m 条端点为中心点的射线。另外,由于路径只经过圆弧和长度为整数的线段,不妨设答案为
pπ+q。
特殊处理两点在同一射线上(包括其中一点在中心点上)的情况,不难得到此部分答案
q1 为:
q1=2mi=1∑ni(n−i+1)=2m(2n(n+1)2−6n(n+1)(2n+1))=mn(n+1)2−3mn(n+1)(2n+1)
然后,考虑将其余情况分为两点在同圆上的和两点不在同圆上的。
对于
两点在同圆上的情况,可以证明最优路径一定是
两点之间的劣弧和
中心点与两点之间的两条线段这两种情况之一。如果两点之间的劣弧隔了
x−1 个点,或者说
x 段长度为
mπr 的弧,那么两种情况的代价分别是
mxπr 和
2r。不难发现哪种更优仅与
x 有关。由此,设
k 表示当
x≥k 时后者更优,否则前者更优。此处的
k 可以在
O(1) 的时间复杂度内求得其值为
⌈π2m⌉。
考虑在半径为
r 的圆上的情况,两两点对之间最优路径的长度和为
arπ+br,其中:
ar=21×2m×2i=1∑k−1mir=kr(k−1)br=21×2m×(2m−2k+1)×2r=2mr(2m−2k+1)
考虑两点不在同圆上的情况,可以证明先让处于更外侧的圆上的点向内走到另一个点所在的圆上,然后按照两点在同圆上的方法做一定是最优的。
因此,枚举两点所在的圆,一定有
2n−2r+1 种情况使得最后按照两点都在半径为
r 的圆上的情况做,那么忽略两点不在同圆的情况中外侧圆上的点走向内侧圆的代价,总代价为
pπ+q2,其中:
p=i=1∑ni(2n−2i+1)ai=i=1∑nik(2n−2i+1)(k−1)=ki=1∑n(2−2k)i2+(2n+1)(k−1)i=3nk(n+1)(2n+1)(1−k)+2nk(n+1)(2n+1)(k−1)=6nk(n+1)(2n+1)(k−1)q2=i=1∑ni(2n−2i+1)bi=i=1∑n2im(2n−2i+1)(2m−2k+1)=2mi=1∑n(4k−4m−2)i2+(2m−2k+1)(2n+1)i=2m(3n(n+1)(2n+1)(2k−2m−1)+2n(n+1)(2n+1)(2m−2k+1))=3mn(n+1)(2n+1)(2m−2k+1)
再考虑两点不在同圆的情况中外侧圆上的点走向内侧圆的代价
q3,考虑枚举内外两圆,任取不在同一射线上的点对,故其为:
q3=21×2m×(2m−1)×i=1∑nj=i+1∑n(j−i)=m(2m−1)i=1∑n(n−i)(n−i+1)=m(2m−1)i=0∑n−1i(i+1)=m(2m−1)i=1∑ni(i−1)=3m(2m−1)(n(n+1)(2n+1)−3n(n+1))=3mn(2m−1)(n+1)(n−1)
综上所述,答案
pπ+q 至此已经可以在
O(1) 的时间复杂度内计算完毕。其中
q 可以由前面的
q1+q2+q3 得到,具体为:
q=q1+q2+q3=31mn(n+1)(3(n+1)−(2n+1)+(2n+1)(2m−2k+1)+(n−1)(2m−1))=31mn(n+1)(3(n+1)+2(2n+1)(m−k)+(n−1)(2m−1))
参考代码
CPP#define _USE_MATH_DEFINES
#include <bits/stdc++.h>
using namespace std;
int main() {
long long n, m;
scanf("%lld %lld", &n, &m);
if (m == 1) {
printf("%lld\n", n * n - n * (2 * n - 1) * (4 * n - 1) / 3 + 2 * n * n * (2 * n - 1));
return 0;
}
long long k = ceil(2 * m / M_PI);
long long p = n * (n + 1) * (2 * n + 1) / 6 * k * (k - 1);
long long q = m * n * (n + 1) * (3 * (n + 1) + 2 * (2 * n + 1) * (m - k) + (n - 1) * (2 * m - 1)) / 3;
printf("%.10lf\n", p * M_PI + q);
return 0;
}