当前位置: 代码迷 >> 综合 >> POJ 3468 A Simple Problem with Integers (树状和差分,分块,线段树延迟标记,详细的解答)
  详细解决方案

POJ 3468 A Simple Problem with Integers (树状和差分,分块,线段树延迟标记,详细的解答)

热度:45   发布时间:2023-12-13 19:53:15.0

题目意思:
Q:求L到R的和。C:L,R的数都加x

树状数组版:

看了网上很多代码,综合以下的解题步骤:

本题用到的知识点:树状数组,差分
简单介绍差分:
假设原数组为 a[1] ~ a[n]
用数组b[1] ~ b[n] 作为 a数组的差分数组
b[1] = a[1]; //第一项特殊些
b[2] = a[2] - a[1];

b[n] = a[n] - a[n - 1];

那么b的前缀和(加入前k项),刚好为数组a的第k项的值,a[k]
sum{b[1], b[2], … , b[k]} = b[1] + b[2] + … + b[k] = a[k] //简单代入即可证明

如果在数组a区间 [L, R] 的范围内,每个位置加上值x, 对于差分数组b,发生变化的只有两项,
b[L] += x, b[R + 1] -= x // b[k] = a[k] - a[k - 1]; 2 <= k <= n, 取遍每一个k,很容易得出结论

思路:
1、使用差分分解:
现在计算a数组的前k项 的和,
sum{a[1], a[2], …, a[k]} = a[1] + a[2] + … + a[k] //依次展开每个 a[i] (1 <= i <= k)
= (b[1]) + (b[1] + b[2]) + (b[1] + b[2] + b[3]) + … + (b[1] + b[2] + … + b[k]) //有k个小括号
= k * (b[1] + b[2] + … + b[k]) - 0 * b[1] - 1 * b[2] - 2 * b[3] - … - (k - 1) * b[k]
//每个括号加上一些数, 都变成 (b[1] + b[2] + … + b[k]), 然后减去那些数(一共加上了 0 项的 b[1], 1 项的 b[2], … , 依次类推)
= k * sum{b[1], b[2], … , b[k]} - sum{0 * b[1], 1 * b[2], 2 * b[3], … , (k - 1) * b[k]}
//把数组a的前k 项和 sum{a[1], a[2], …, a[k]} 分解为两部分了
// 第一部分 : 差分数组b的前k项之和 sum{b[1], b[2], … , b[k]} 的 k倍
// 第二部分 : 序列 (i - 1) * b[i] 的前k项的和 sum{0 * b[1], 1 * b[2], 2 * b[3], … , (k - 1) * b[k]}

2、 使用树状数组分析:
第一部分的和 sum{b[1], b[2], … , b[k]} ( 简化记为:sum(c[1], k) ),用 树状数组 c[1] 来记录;
第二部分的和sum{0 * b[1], 1 * b[2], 2 * b[3], … , (k - 1) * b[k]}(简化记为:sum(c[0], k)), 用 树状数组 c[0] 来记录

sum{a[1], a[2], ..., a[k]} = k * sum(c[1], k) - sum(c[0], k)那么原始数组a, [L, R]区间上同时加上x就可以表示为:(差分数组的结论)
对于c[1]来说,在L位置上加上x,在R + 1位置上加上?x; 
对于c[0]来说,在L位置上加上?x * (L ? 1),在R+1位置上加上x * R; 	
//用 i = L 和 i = R + 1 代入式子 (i - 1) * b[i]得到,因为是负号,取相反数

3、 查询结果时候,求 前k项的和公式为
long long sumk = k * getSum(k, 1) + getSum(k, 0);

4、 最后,使用long long

#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
#define lowbit(i) ((i)&(-i)) // lowbit(x) 表示x的二进制对应的值
const int MaxN = 100010;
long long c[2][MaxN];			//树状数组c[1] 和 c[0]
int n, q;
int A[MaxN];//getSum()函数, 返回前x个数的和
long long getSum(int x, int index)
{
    long long sum = 0;for(int i = x; i > 0; i-= lowbit(i)){
    sum += c[index][i];}return sum;
}//update()函数,将第x个数加上v
void update(int x, int v, int index)
{
    for(int i = x; i < MaxN; i += lowbit(i)){
    c[index][i] += v;}
}int main()
{
    char cmd[2] = {
    0};scanf("%d%d", &n, &q);memset(c, 0, sizeof(c));for(int i = 1; i <= n; ++i){
    scanf("%d", &A[i]);	update(i, A[i], 0);	// 数组 c[0] 进行初始化}int s, e, x;for(int i = 0; i < q; ++i){
    scanf("%s", cmd);	if(cmd[0] == 'Q'){
    scanf("%d%d", &s, &e);long long suml = (s - 1) * getSum(s - 1, 1) + getSum(s - 1, 0);long long sumr = e * getSum(e, 1) + getSum(e, 0);printf("%lld\n", sumr - suml);}else{
    scanf("%d%d%d", &s, &e, &x);	//对于c[0]来说,在L位置上加上?x * (L ? 1),在R+1位置上加上x * Rupdate(s, -x * (s - 1), 0);update(e + 1, x * e, 0);//对于c[1]来说,在L位置上加上x,在R + 1位置上加上?xupdate(s, x, 1);update(e + 1, -x, 1);}}return 0;
}/* 10 5 1 2 3 4 5 6 7 8 9 10 Q 4 4 Q 1 10 Q 2 4 C 3 6 3 Q 2 4 *//* 4 55 9 15 */

分块的做法:

1、长度为n的区间,分成若干个长度不超过 sqrt(n) 的段,sum[i] 表示第i段的区间的和,add[i]表示第i段的增量标记
第 i段的左端点 (i- 1) * sqrt(n) + 1, 右端点 min(n, i * sqrt(n))
2、对于 “C l r d”, 增加指令, 对于完整的段, 标记该区间 的增量 add[i] 增加了多少,
对于不完整的段,朴素的做法,每个值都加上 d, 同时改变 sum[i], 也就是 sum[i] += d * 区间的大小
3、对于 “Q l r”, 查询指令,同样分为完整段和不完整段处理

#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
const int MaxN = 100010;
long long a[MaxN], sum[MaxN], add[MaxN];
int L[MaxN], R[MaxN];	//每段的左右端点
int pos[MaxN];			// 每个位置属于那一段
int n, m, t;void change(int l, int r, long long d)
{
    int p = pos[l], q = pos[r];if(p == q){
    for(int i = l; i <= r; ++i){
    a[i] += d;}sum[p] += d * (r - l + 1);}else{
    for(int i = p + 1; i <= q - 1; ++i){
    add[i] += d;	}for(int i = l; i <= R[p]; ++i){
    a[i] += d;}sum[p] += d * (R[p] - l + 1);for(int i = L[q]; i <= r; ++i){
    a[i] += d;}sum[q] += d * (r - L[q] + 1);}
}long long ask(int l, int r)
{
    int p = pos[l], q = pos[r];long long ans = 0;if(p == q){
    for(int i = l; i <= r; ++i){
    ans += a[i];}ans += add[p] * (r - l + 1);}else{
    for(int i = p + 1; i <= q - 1; ++i){
    ans += sum[i] + add[i] * (R[i] - L[i] + 1);}for(int i = l; i <= R[p]; ++i){
    ans += a[i];}ans += add[p] * (R[p] - l + 1);for(int i = L[q]; i <= r; ++i){
    ans += a[i];}ans += add[q] * (r - L[q] + 1);}return ans;
}int main()
{
    scanf("%d%d", &n, &m);for(int i = 1; i <= n; ++i){
    scanf("%lld", &a[i]);}t = sqrt(n);for(int i = 1; i <= t; ++i){
    L[i] = (i - 1) * sqrt(n) + 1;R[i] = i * sqrt(n);}if(R[t] < n){
    ++t, L[t] = R[t - 1] + 1, R[t] = n;}//预处理for(int i = 1; i <= t; ++i){
    for(int j = L[i]; j <= R[i]; ++j){
    pos[j] = i;sum[i] += a[j];}}//指令while(m--){
    char op[3];int l, r, d;scanf("%s%d%d", op, &l, &r);if('C' == op[0]){
    scanf("%d", &d);change(l, r, d);}else{
    printf("%lld\n", ask(l, r));}}return 0;
}

线段树做法:延迟标记

1、 在区间修改的时候,如果某个节点的区间被 修改的区间完全覆盖(if(l <= tree[p].l && r >= tree[p].r) //完全覆盖), 以该节点为根的所有子树的存储信息都发生变化。如果后面的查询信息没有查询到这些相关的节点,说明目前这些节点的信息不需要更新。
也就是说,遇到完全覆盖的情况,直接返回,不过需要给这个节点打上延迟标记 add, 这个标记表示,该节点区间内的所有数加上 add 。
2、后续指令中,如果需要从节点p向下递归,再检查p是否有延迟标记。
因此建立线段树节点的时候,信息如下

struct segTree
{
    int l, r;long long sum, add;
}tree[MaxN * 4];
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
const int MaxN = 100010;
int a[MaxN];
int n, m;struct segTree
{
    int l, r;long long sum, add;
}tree[MaxN * 4];void build(int p, int l, int r)
{
    tree[p].l = l, tree[p].r = r;if(l == r){
    tree[p].sum = a[l];return;}int mid = (l + r) / 2;build(p * 2, l, mid);build(p * 2 + 1, mid + 1, r);tree[p].sum = tree[p * 2].sum + tree[p * 2 + 1].sum;
}void spread(int p)		//从节点p向下递归
{
    if(tree[p].add)	//节点p有更新标记{
    tree[p * 2].sum += tree[p].add * (tree[p * 2].r - tree[p * 2].l + 1);	//更新左孩子节点tree[p * 2 + 1].sum += tree[p].add * (tree[p * 2 + 1].r - tree[p * 2 + 1].l + 1);	//更新右孩子节点tree[p * 2].add += tree[p].add;		//给左孩子打上延迟标记tree[p * 2 + 1].add += tree[p].add;	//给右孩子打上延迟标记tree[p].add = 0;					//清除p的标记}
}void change(int p, int l, int r, int d)
{
    if(l <= tree[p].l && r >= tree[p].r)	//完全覆盖{
    tree[p].sum += (long long)d * (tree[p].r - tree[p].l + 1);	//更新节点信息tree[p].add += d;	//给节点打上延迟标记return;}spread(p);int mid = (tree[p].l + tree[p].r) / 2;if(l <= mid){
    change(p * 2, l, r, d);	}if(r > mid){
    change(p * 2 + 1, l, r, d);}tree[p].sum = tree[2 * p].sum + tree[p * 2 + 1].sum;
}long long ask(int p, int l, int r)
{
    if(l <= tree[p].l && r >= tree[p].r){
    return tree[p].sum;}spread(p);int mid = (tree[p].l + tree[p].r) / 2;long long val = 0;if(l <= mid){
    val += ask(p * 2, l, r);	}if(r > mid){
    val += ask(p * 2 + 1, l, r);}return val;
}int main()
{
    scanf("%d%d", &n, &m);for(int i = 1; i <= n; ++i){
    scanf("%d", &a[i]);}build(1, 1, n);while(m--){
    char op[3];int l, r, d;scanf("%s%d%d", op, &l, &r);if('C' == op[0]){
    scanf("%d", &d);change(1, l, r, d);}else{
    printf("%lld\n", ask(1, l, r));}}return 0;
}
  相关解决方案