当前位置: 代码迷 >> 综合 >> 【HYSBZ】【带权二分】【斜率优化】4518 征途
  详细解决方案

【HYSBZ】【带权二分】【斜率优化】4518 征途

热度:8   发布时间:2023-11-21 06:44:48.0

HYSBZ 4518 征途

◇题目传送门◆

题目大意

给定NNN个数的序列AAA,要求将这个序列划分成MMM个连续的段,求这MMM段的方差乘上M2M^2M2的值。

分析

设最终划分的MMM段的每段的总和为xix_ixi?

先来推一发式子:M2S2=M2×1M∑i=1M(xi?xˉ)2=M×∑i=1M(xi2?2xˉxi+xˉ2)=M×(∑i=1Mxi2?2xˉ∑i=1Mxi+Mxˉ2)=∑i=1Mxi2?(∑i=1Mxi)2\begin{aligned}M^2S^2&=M^2\times \frac{1}{M}\sum_{i=1}^{M}{(x_i-\bar{x})^2}\\&=M\times\sum_{i=1}^{M}({x_i}^2-2\bar{x}x_i+\bar{x}^2)\\&=M\times(\sum_{i=1}^{M}{x_i}^2-2\bar{x}\sum_{i=1}^{M}x_i+M\bar{x}^2)\\&=\sum_{i=1}^M{x_i}^2-\left(\sum_{i=1}^Mx_i\right)^2\end{aligned}M2S2?=M2×M1?i=1M?(xi??xˉ)2=M×i=1M?(xi?2?2xˉxi?+xˉ2)=M×(i=1M?xi?2?2xˉi=1M?xi?+Mxˉ2)=i=1M?xi?2?(i=1M?xi?)2?

不难发现后面的那个是定值,我们只需要最小化前面的那个就可以了。

然而恰好选出MMM段有点麻烦。而我们发现MMM越大,答案会变小。感性理解一下就是个下凸函数,所以我们可以考虑带权二分。

sssxxx的前缀和。

我们将每一段的代价都加上一个值vvv,则状态转移方程就变为f(i)=min?j=1i{f(j)+(sj?si)2+v}f(i)=\min_{j=1}^{i}\{f(j)+(s_j-s_i)^2+v\}f(i)=minj=1i?{ f(j)+(sj??si?)2+v}

看到这个转移方程我们不难发现可以进行斜率优化:

对于决策点k,jk,jk,j,若kkk优于jjj,即f(k)+(si?sk)2≤f(j)+(si?sj)2f(k)+(s_i-s_k)^2\le f(j)+(s_i-s_j)^2f(k)+(si??sk?)2f(j)+(si??sj?)2,那么拆掉平方项并移项,得到2si≥f(k)?f(j)+sk2?sj2sk?sj2s_i\ge \frac{f(k)-f(j)+{s_k}^2-{s_j}^2}{s_k-s_j}2si?sk??sj?f(k)?f(j)+sk?2?sj?2?

发现这个斜率的横坐标是单调递增的,所以直接单调队列维护即可。

参考代码

#include <queue>
#include <cstdio>
#include <algorithm>
using namespace std;typedef long long ll;
const int Maxn = 3000;int N, M;
int A[Maxn + 5];
ll sum[Maxn + 5], f[Maxn + 5];
int g[Maxn + 5];inline double slope(int x, int y) {
    return 1.0 * (f[y] - f[x] + sum[y] * sum[y] - sum[x] * sum[x]) / (sum[y] - sum[x]);
}
bool check(ll val) {
    static int q[Maxn + 5], head, tail;q[head = tail = 0] = 0;for(int i = 1; i <= N; i++) {
    while(head < tail && slope(q[head], q[head + 1]) < 2 * sum[i])head++;int j = q[head];f[i] = f[j] + val + (sum[i] - sum[j]) * (sum[i] - sum[j]);g[i] = g[j] + 1;while(head < tail && slope(q[tail - 1], q[tail]) > slope(q[tail], i))tail--;q[++tail] = i;}return g[N] > M;
}int main() {
    
#ifdef LOACLfreopen("in.txt", "r", stdin);freopen("out.txt", "w", stdout);
#endifscanf("%d %d", &N, &M);for(int i = 1; i <= N; i++)scanf("%d", &A[i]);for(int i = 1; i <= N; i++)sum[i] = sum[i - 1] + A[i];ll lb = 0, ub = sum[N] * sum[N];ll ans = 0;while(lb <= ub) {
    int mid = (lb + ub) >> 1;if(check(mid)) lb = mid + 1;else ub = mid - 1, ans = M * (f[N] - mid * M) - sum[N] * sum[N];}printf("%lld\n", ans);return 0;
}