树状数组与线段树经典问题的python实现

it2025-09-16  5


title: 树状数组与线段树经典问题的python实现 date: 2020-03-26 22:13:26 categories: 算法 tags: [python, 树状数组与线段树]

树状数组

作用:单点修改,区间求和

时间复杂度:修改和查询的复杂度都是O(logN)

要点:

l o w b i t ( x ) = x & ( − x ) = 2 k lowbit(x)=x\&(-x) =2^{k} lowbit(x)=x&(x)=2k k:x在二进制位下面末尾连续0的个数

原理:

利用的负数的存储特性,负数是以补码存储的,对于整数运算 x&(-x)有

1. 当x为0时,即 0 & 0,结果为0; 2. 当x为奇数时,最后一个比特位为1,取反加1没有进位,故x和-x除最后一位外**前面的位正好相反,且最后一位均为1**。结果为1。5&(-5)=(101)&(011)=1 3. 当x为偶数时,取反时,末尾连续k个0均为变成1,加1时,往前进一位,正好是2^k次方。

则定**C[x]=sum(x-lowbit(x),x](**sum(l,r]表示数组a区间(l,r]的区间和)

区间求和有C[x]定义很明了。

对于单点增加:在图上树的过程,假设增加a[i],我们只需要修改其父亲和祖宗节点,例如增加a[5],我们需要修改C[5],C[6],C[8],C[16],可以证明x的父亲节点有且仅有一个为x+lowbit(x)

动态求连续区间和

给定 n 个数组成的一个数列,规定有两种操作,一是修改某个元素,二是求子数列 [a,b][a,b] 的连续和。

输入格式

第一行包含两个整数n 和 m,分别表示数的个数和操作次数。

第二行包含 n 个整数,表示完整数列。

接下来 m 行,每行包含三个整数 k,k,a,b (k=0,表示求子数列[a,b]的和;k=1,表示第 a 个数加 b)。

数列从 11 开始计数。

输出格式

输出若干行数字,表示 k=0 时,对应的子数列 [a,b] 的连续和。

数据范围

1≤n≤100000 1≤m≤100000, 1≤a≤b≤n

输入样例:

10 5 1 2 3 4 5 6 7 8 9 10 1 1 5 0 1 3 0 4 8 1 7 5 0 4 8

输出样例:

11 30 35

挑战模式

n,m=map(int,input().split()) a=[0 for i in range(0,n+1)] tr=[0 for i in range(0,n+1)] def lowbit(x): return x&(-x) def query(x): res=0 while x: res+=tr[x] x-=lowbit(x) return res def add(x,val): while x<=n: tr[x]+=val x+=lowbit(x) a=list(map(int,input().split())) for i in range(0,n): add(i+1,a[i]) for t in range(0,m): k,l,r=map(int,input().split()) if k==0: print(query(r)-query(l-1)) else: add(l,r)

线段树

本质是二叉树,分治思想。

u的左儿子u<<1(2u)和右儿子u<<1|1(2u+1)

数列区间最大值

输入一串数字,给你 M 个询问,每次询问就给你两个数字 X,Y,要求你说出 X到 Y 这段区间内的最大数。

输入格式

第一行两个整数 N,M 表示数字的个数和要询问的次数;

接下来一行为 N 个数;

接下来 MM 行,每行都有两个整数 X,Y。

输出格式

输出共 M行,每行输出一个数。

数据范围

1≤N≤105, 1≤M≤106, 1≤X≤Y≤N, 数列中的数字均不超过2^31−1

输入样例:

10 2 3 2 4 5 6 8 1 2 9 7 1 4 3 8

输出样例:

5 8 class Node: def __init__(self,l=0,r=0,maxx=0): self.l=l self.r=r self.maxx=maxx n,m=map(int,input().split()) tr=[Node() for i in range(0,100005*4+100)] def pushup(u): tr[u].maxx=max(tr[u<<1].maxx, tr[u<<1|1].maxx) def build(u,l,r): if l == r: tr[u] = Node(l, r, w[r-1]) else: tr[u]=Node(l,r) mid=(l+r) >> 1 build(u << 1, l, mid) build(u << 1 | 1, mid+1, r) pushup(u) def query_max(u, l, r): if tr[u].l >= l and tr[u].r <= r: return tr[u].maxx mid=(tr[u].l + tr[u].r) >> 1 maxx=-1e18 if l <= mid : maxx=max(query_max(u << 1, l, r), maxx) if r > mid : maxx=max(query_max(u << 1 | 1, l, r), maxx) return maxx w=list(map(int,input().split())) build(1, 1, n) for i in range(0,m): x,y=map(int,input().split()) print(query_max(1, x, y))
最新回复(0)