题目链接
Rikka with Sequence II
题解
这题单纯枚举中位数在枚举两边子集是n^2的,可是我们发现其实排序之后,如果枚举到的位置是(i,j),即中位数mid=(a[i]+a[j])/2;则a[i+1…j-1]都不用考虑,直接把两边拿出来做折半查找就OK。
时间复杂度分析:考虑规模为i(0<=i<=n-2)的折半查找的做的次数为i+1,做一次的时间复杂度为 O(2i2) 那么总时间复杂度的计算如下
后面括号里面,大概可以看成 n+n2+n4+n8+... ,也就2n的样子?所以时间复杂度分析出来是O(2^(n/2)*n)的,我也觉得非常神奇,不考虑中间的数就能去掉一个n的复杂度
然后抽象成另外一个问题,把剩下的数全部拿出来,对于每个数,设两个权值a,b,如果数小于枚举出来的中位数,则a权值为-1,大于为1,b为原来的权值减去中位数,则要求就是选出一个子集使得
怎么折半搜索?之前一直觉得需要一个map< int.set< int> >而且set还得兹磁rank操作,要不是这题带log不能过我估计我还真去写个treap(扶额),数据结构学傻了。实际上是这样的,先把所有枚举出来的子集按大小排序,然后维护2n个前缀和,表示在此之前有多少个是 ∑a==i 的集合,然后two pointer 一下应该就可以不要log了,复杂度 O(n?2n2) ,很棒啊。(所以如果再带一个排序或者二分查找的log就T了)
排序的时候需要用极其棒棒的类似归并的思想,因为全是正整数所以选了肯定不不选要的 ∑b 要大,所以相对顺序不会变,然后归并一下就排好序了,复杂度只要 O(2n2)
代码稍后补上来。
然而我代码被卡常了,本机5.8s不能再少了,我也很迷啊,这个复杂度算出来。。我自带大常数吧,等吉司机的std拿出来了我再看我常数哪里写丑了。
为了卡常我还把wys的wc课间重新看了一遍,发现没卵用。
而且这代码很迷啊,改成 ∑a 相等也能过样例,把返回的数组改成x也能过,好像自己出的n=40的数据也能过?好迷,静态差错才看出来。估计这题我是卡常卡不过去了,要不松爷您来帮我看下?(雾)vjudge上3days前的提交记录一片TLE,看了看代码发现和我其实是差不多的?(雾)
//QWsin
#include<set>
#include<cmath>
#include<vector>
#include<cstdio>
#include<cstring>
#include<assert.h>
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
const double eps=1e-9;
const int maxn=40+10;
const int maxm=1050000+10;ll ans=0;
int a[maxn];
typedef pair<int,double> Node;
#define mp make_pair
#define a first
#define b second
//两种权值a=+-1,b=a[i]-中位数 int top,SUM[maxn],vc_sz,tot;
Node t1[maxm],t2[maxm],t3[maxm],t4[maxm],vc[maxn];inline void getsort(int l,int r,Node *t1,Node *t2,Node* &res,int &L)
{Node *x=t1,*y=t2;x[top=1]=mp(0,0);for(int i=l;i<=r;++i,swap(x,y)){int A=vc[i].a,B=vc[i].b;for(int j=1,k=top+1;j<=top;j+=4,k+=4) {x[k].a=x[j].a+A;x[k].b=x[j].b+B;x[k+1].a=x[j+1].a+A;x[k+1].b=x[j+1].b+B;x[k+2].a=x[j+2].a+A;x[k+2].b=x[j+2].b+B;x[k+3].a=x[j+3].a+A;x[k+3].b=x[j+3].b+B;}int p1=1,p2=top+1,p3=0;top<<=1;// if(x[top>>1] < x[(top>>1)+1]) {++tot;swap(x,y);continue;}while(p3<top)y[++p3]= p1>top>>1||(p2<=top&&x[p1].b > x[p2].b) ? x[p2++]:x[p1++];}L=top;res=x;
}int L1,L2;//对于vc里的Node计算sigma a=0 && sigma b <=0 的答案
inline ll work()
{int sz=vc_sz;if(sz<2) return 1;Node *a1,*a2;getsort(1,sz>>1,t1,t2,a1,L1);getsort((sz>>1)+1,sz,t3,t4,a2,L2);memset(SUM,0,sizeof SUM);#define sum(i) SUM[i+20]ll ret=0;int p1=1;//反着来for(int p2=L2;p2>=1;--p2){while(p1<=L1&&a1[p1].b+a2[p2].b <= eps)++sum(a1[p1++].a);ret+=sum(-a2[p2].a);}return ret;
}
#undef a
#undef b
set<int>vis;
int main()
{int n;cin>>n; for(int i=1;i<=n;++i) {scanf("%d",a+i);assert(!vis.count(a[i]));vis.insert(a[i]);}sort(a+1,a+n+1);for(int i=1;i<=n;++i)for(int j=i;j<=n;++j){double mid=(a[i]+a[j])/2;vc_sz=0;for(int k=1;k<i;++k)vc[++vc_sz]=mp(-1,a[k]-mid);for(int k=j+1;k<=n;++k)vc[++vc_sz]=mp(1,a[k]-mid);ans+=work();}cout<<ans;return 0;
}