알고리즘/graph

세그먼트트리

씩씩한 IT블로그 2020. 7. 26. 17:23
반응형

수열이 있을 때 중에서 연속된 특정구간을 주고, 특정구간의

1. 최댓값([백준]2357 최솟값과 최댓값)

2. 최솟값([백준]10868 최솟값)

3. 합([백준]2042 구간합, [백준]1275 커피숍2)

4. 곱([백준]11505 구간곱 구하기)

을 여러번 구하는 문제는 세그먼트트리를 쓰면 된다. 

 

세그먼트 트리를 사용하기 위해선 보통 아래와 같은 3가지 함수가 필요하다.

1. 주어진 수열을 이용하여 처음으로 세그먼트 트리를 만드는 함수

2. 세그먼트트리의 노드하나를 수정하는 함수

3. 세그먼트트리를 이용하여 구간합(or 구간곱 or 최댓값,최솟값)을 구하는 함수. 

각 함수코드는 다음과 같다.

 

1. 세그먼트트리 초기화 함수.

def init(node,start,end):
    if start==end:
        tree[node]=L[start-1]
        return tree[node]

    mid=(start+end)//2
    tree[node]=init(node*2,start,mid)+init(node*2+1,mid+1,end)
    return tree[node]

start와 ends는 원래 수열의 인덱스를 나타내고, node는 트리의 인덱스를 나타낸다. 트리를 타고 내려가면서 두개로 계속해서 쪼개다가, 하나가 남으면 거기에 해당값을 집어넣는다.

 

2. 세그먼트트리의 노드하나를 수정하는 함수

def update(node,start,end,index,diff):
    if not start<=index<=end:
        return

    tree[node]+=diff

    if start!=end:
        mid=(start+end)//2
        update(node*2,start,mid,index,diff)
        update(node*2+1,mid+1,end,index,diff)

노드하나를 수정하는 함수이다. 세그먼트트리의 특성상, 수정된 노드의 부모노드들은 모두 수정된 값 만큼 똑같이 수정되어야 한다.

수정될 노드(index)가 포함되어있지 않은 부모노드는 볼 필요가 없으므로 중간에 return해주고, 그렇지 않은 경우 내려가면서 계속해서 수정해준다.

함수를 사용하기 전에, (1)수정되는 값을 구하는 것과, (2)원래 수열을 수정하는것을 잊지 말자!

v=b-L[a-1] # 수정되는 값 구하기
L[a-1]=b # 원래값 수정하기

 

 

* 곱을 구하는 경우 함수의 구성이 조금 다른데, 그 이유는 특정 노드를 0->x로 수정하고자 할 때 특정노드의 부모노드에 모두 x/0을 해줘야 하기 때문이다(0으로 나누기를 해야하기 때문이다)

따라서 함수의 구성을 다음과 같이 한다. (초기화 하는 함수(init)와 비슷한 알고리즘을 쓴다)

def update(node,start,end,index,val):
    if not start<=index<=end:
        return tree[node]
    elif start==end:
        tree[node]=val
        return tree[node]
    else:
        mid=(start+end)//2
        tree[node]=(update(node*2,start,mid,index,val)*update(node*2+1,mid+1,end,index,val))
        return tree[node]

 

 

3. 세그먼트트리를 이용해서 구간의 특정값 구하기.

def treeSum(node,start,end,left,right):
    if left<=start and end<=right:
        return tree[node]
    elif right<start or end<left:
        return 0
    else:
        mid=(start+end)//2
        return treeSum(node*2,start,mid,left,right)+treeSum(node*2+1,mid+1,end,left,right)

3가지 경우로 나눌 수 있다. 

(1) left<=start<=end<=right (현재 영역이 구하고자 하는 영역에 모두 포함되어 있는 경우)

해당 값을 모두 써야 하므로 해당노드를 리턴한다.

(2) left<=right<start or end<left<=right (현재 영역이 구하고자 하는 영역 밖에 있는 경우)

탐색할 필요가 없으므로 합을 구하는것이면 0, 곱을 구하는것이면 1, 최댓값을 구하는것이면 음의무한대, 최솟값을 구하는것이면 양의 무한대를 return한다 

(3) 그외의 경우 (현재 영역이 구하고자 하는 영역의 일부에 포함하는 경우)

쪼개가면서 계속해서 탐색한다.

 

* 예시 [백준]2042_구간합구하기 소스코드

import sys
import math

def treeInit(node,start,end):
    if start==end:
        tree[node]=L[start-1]
        return tree[node]

    mid=(start+end)//2

    tree[node]=treeInit(2*node,start,mid)+treeInit(2*node+1,mid+1,end)
    return tree[node]

def update(node,start,end,index,diff):
    if not (start<=index<=end):
        return

    tree[node]+=diff

    if (start!=end):
        mid=(start+end)//2
        update(node*2,start,mid,index,diff)
        update(node*2+1,mid+1,end,index,diff)

def mySum(node,start,end,left,right):
    if left>end or right<start:
        return 0
    elif left<=start and end <=right:
        return tree[node]
    mid=(start+end)//2

    return mySum(node*2,start,mid,left,right)+mySum(node*2+1,mid+1,end,left,right)




N,M,K=map(int,input().split())

# 입력값 받
L=[]
for i in range(N):
    L.append(int(sys.stdin.readline().strip()))

# 트리 초기화
treeSize=2**(math.ceil(math.log(N,2))+1)
tree=[0 for i in range(treeSize)]
treeInit(1,1,N)


for i in range(M+K):
    #print("이전",tree)
    a,b,c=map(int,input().split())
    if a==1:
        v=c-L[b-1]
        L[b-1]=c
        update(1,1,N,b,v)
    else:
        print(mySum(1,1,N,b,c))
    #print("이후",tree)

 

반응형