https://www.luogu.com.cn/problem/P3368
学习树状数组:https://www.youtube.com/watch?v=v_wj_mOAlig。
首先简单介绍一下树状数组,以及为什么要使用树状数组。
假设有一个数组,我们需要对该数组执行如下操作(以下全部假设数组下标从1开始):
求前i项的和。将第i项元素加上某个给定的值。如果使用普通的数组,执行第一个操作的时间为 O ( n ) O(n) O(n),执行第二个操作的时间为 O ( 1 ) O(1) O(1)。
如果用数组保存的内容为前缀和,执行第一个操作的时间为 O ( 1 ) O(1) O(1),执行第二个操作的时间为 O ( n ) O(n) O(n)。
而使用树状数组,这两个操作的时间都为 O ( log 2 n ) O(\log_2n) O(log2n)。
关于树状数组的原理,请观看上面链接中的视频或寻找其他资料。
下面是C++的简单实现代码:
int n; int arr[500001]; inline int last_bit(int num) { return num & (-num); } // return the sum of [1, index] int sum(int index) { int res = 0; while (index) { res += arr[index]; index -= last_bit(index); } return res; } // plus value to the element of index of arr void add(int index, int value) { while (index <= n) { arr[index] += value; index += last_bit(index); } }其中 a r r arr arr的第一个元素(下标为 0 0 0)不使用, n n n为元素总数。
假设有一个数组 a r r 1 = [ 1 , 5 , 4 , 2 , 3 ] arr1=[1, 5, 4, 2, 3] arr1=[1,5,4,2,3],那么它的差分数组 a r r 2 = [ 1 , 4 , − 1 , − 2 , 1 ] arr2=[1, 4, -1, -2, 1] arr2=[1,4,−1,−2,1]。
即 a r r 2 [ i ] = a r r 1 [ i ] − a r r 1 [ i − 1 ] arr2[i] = arr1[i] - arr1[i - 1] arr2[i]=arr1[i]−arr1[i−1](其中 a r r [ 0 ] arr[0] arr[0]为 0 0 0)。
可以发现 a r r 1 [ i ] = a r r 2 [ 1 ] + a r r 2 [ 2 ] + a r r 2 [ 3 ] + ⋯ + a r r 2 [ i ] arr1[i] = arr2[1] + arr2[2] + arr2[3] + \dots + arr2[i] arr1[i]=arr2[1]+arr2[2]+arr2[3]+⋯+arr2[i]。例如 a r r 1 [ 2 ] = a r r 2 [ 1 ] + a r r 2 [ 2 ] = 1 + 4 = 5 arr1[2] = arr2[1] + arr2[2] = 1 + 4 = 5 arr1[2]=arr2[1]+arr2[2]=1+4=5。
即 a r r 1 [ i ] arr1[i] arr1[i]等于 a r r 2 arr2 arr2的前 i i i项和。
假设将区间 [ 2 , 4 ] [2, 4] [2,4]内所有数加 2 2 2,那么 a r r 1 arr1 arr1变为: [ 1 , 7 , 6 , 4 , 3 ] [1, 7, 6, 4, 3] [1,7,6,4,3],而它的差分数组 a r r 2 arr2 arr2为: [ 1 , 6 , − 1 , − 2 , − 1 ] [1, 6, -1, -2, -1] [1,6,−1,−2,−1]。
可以发现 a r r 2 arr2 arr2只有第 2 2 2个和第 5 5 5个元素的值改变了。
即 a r r [ 2 ] = a r r [ 2 ] + 2 arr[2] = arr[2] + 2 arr[2]=arr[2]+2, a r r [ 4 + 1 ] = a r r [ 4 + 1 ] − 2 arr[4 + 1] = arr[4 + 1] - 2 arr[4+1]=arr[4+1]−2。
因此,我们只需要在树状数组中维护一个差分数组,就能解决这个问题。
下面是C++的代码:
#include <iostream> int n, m; int arr[500001]; inline int last_bit(int num) { return num & (-num); } // return the sum of [1, index] int sum(int index) { int res = 0; while (index) { res += arr[index]; index -= last_bit(index); } return res; } // plus value to the element of index of arr void add(int index, int value) { while (index <= n) { arr[index] += value; index += last_bit(index); } } int main() { std::ios_base::sync_with_stdio(false); std::cin >> n >> m; int value, last_value = 0; for (int i = 1; i <= n; i++) { std::cin >> value; add(i, value - last_value); last_value = value; } int op; int x, y, k; for (int i = 0; i != m; i++) { std::cin >> op; switch (op) { case 1: std::cin >> x >> y >> k; add(x, k); add(y + 1, -k); break; case 2: std::cin >> x; std::cout << sum(x) << "\n"; break; default: break; } } }