借此写一下二维树状数组 1.单点修改,区间查询 click here 区间查询,因为a[i][j]表示从(i,j)到(0,0)的所有值之和,因此需要减去多计算的部分
#include<stdio.h> #include<string.h> #include<algorithm> using namespace std; const int maxn=4100; int n,m; typedef long long ll; ll a[maxn][maxn]; int lowbit(int x) { return x&(-x); } void update(int x,int y,ll val) { for(int i=x;i<=n;i+=lowbit(i)) for(int j=y;j<=m;j+=lowbit(j)) a[i][j]+=val; } ll getsum(int x,int y) { ll res=0; for(int i=x;i>0;i-=lowbit(i)) for(int j=y;j>0;j-=lowbit(j)) res+=a[i][j]; return res; } int main(){ scanf("%d%d",&n,&m); memset(a,0,sizeof a); int t; while(~scanf("%d",&t)) { int x1,y1,x2,y2; ll val; if(t==1) { scanf("%d%d%lld",&x1,&y1,&val); update(x1,y1,val); } else { scanf("%d%d%d%d",&x1,&y1,&x2,&y2); printf("%lld\n",getsum(x2,y2)-getsum(x1-1,y2)-getsum(x2,y1-1)+getsum(x1-1,y1-1)); } } return 0; }2.区间修改,单点查询 click here 区间修改与区间查询是一样的道理
#include<stdio.h> #include<string.h> #include<algorithm> using namespace std; int n,m; typedef long long ll; const int maxn=4100; ll a[maxn][maxn]; int lowbit(int x) { return x&(-x); } void update(int x1,int y1,int val) { for(int x=x1;x<=n;x+=lowbit(x)) for(int y=y1;y<=m;y+=lowbit(y)) a[x][y]+=val; } ll getsum(int x,int y) { ll res=0; for(int i=x;i>0;i-=lowbit(i)) for(int j=y;j>0;j-=lowbit(j)) res+=a[i][j]; return res; } int main() { scanf("%d%d",&n,&m); memset(a,0,sizeof a); int t; while(~scanf("%d",&t)) { int x1,y1,x2,y2,val; if(t==1) { scanf("%d%d%d%d%d",&x1,&y1,&x2,&y2,&val); update(x1,y1,val); update(x1,y2+1,-val); update(x2+1,y1,-val); update(x2+1,y2+1,val); } else { scanf("%d%d",&x1,&y1); printf("%lld\n",getsum(x1,y1)); } } return 0; }3.区间修改,区间查询 click here
#include<stdio.h> #include<string.h> #include<algorithm> using namespace std; const int maxn=4100; typedef long long ll; ll a[maxn][maxn]; ll s1[maxn][maxn],s2[maxn][maxn],s3[maxn][maxn],s4[maxn][maxn]; ll n,m; int lowbit(int x) { return x&(-x); } void update(int x,int y,int val) { for(int i=x;i<=n;i+=lowbit(i)) for(int j=y;j<=m;j+=lowbit(j)) { s1[i][j]+=val; s2[i][j]+=val*x; s3[i][j]+=val*y; s4[i][j]+=val*x*y; } } ll getsum(int x,int y) { ll res=0; for(int i=x;i>0;i-=lowbit(i)) for(int j=y;j>0;j-=lowbit(j)) { res+=s1[i][j]*(x+1)*(y+1)-s2[i][j]*(y+1)-s3[i][j]*(x+1)+s4[i][j]; } return res; }int main() { scanf("%lld%lld",&n,&m); ll t; while(~scanf("%lld",&t)) { ll x1,y1,x2,y2,val; if(t==1) { scanf("%lld%lld%lld%lld%lld",&x1,&y1,&x2,&y2,&val); update(x1,y1,val); update(x1,y2+1,-val); update(x2+1,y1,-val); update(x2+1,y2+1,val); } else { scanf("%d%d%d%d",&x1,&y1,&x2,&y2); ll res=getsum(x2,y2)-getsum(x1-1,y2)-getsum(x2,y1-1)+getsum(x1-1,y1-1); printf("%lld\n",res); } } return 0; }