알고리즘

Segment Tree

컴영 2020. 3. 17. 17:20

Segment Tree 세그먼트 트리

  • 개념 : 배열의 특정 구간의 합, 혹은 특정 구간에서 최대값·최솟값 등을 효율적으로 구하기 위한 자료구조
  • 기본 구조 : 완전 이진 트리 구조 

                   가장 최상단인 루트 노드 - 전체 구간의 정보

                   가장 최하단의 리프 노드 - 배열의 그 수 자체 즉 배열의 각 요소

  •  배열의 크기가 N = 10개 일때의 세그먼트 트리를 그리면,

     


  •  세그먼트 트리는 1차원 배열로 구현한다.
  •  크기 : 입력 배열의 크기가 N일때, 리프 노드의 개수가 N이 된다. 따라서 세그먼트 트리의 높이는 [logN]이 되고, 세그먼트 트리 배열의 크기는 2^(H+1) 이 된다. 이를 코드로 나타내면

1
2
int h = (int)ceil(log2(n));
int tree_size = (1 << (h+1));  
cs


하지만, 간단하게 N * 4로 세그먼트 트리 배열의 크기를 정할 수 있다,(메모리는 더 사용하지만)

  •  어떤 노드의 인덱스 번호가 i라면, 왼쪽 자식의 번호는 2*i 이고 오른쪽 자식의 번호는 2*i +1이 된다. 이를 그림으로 보면
    

이런식으로, tree의 배열이 저장된다.


  •  세그먼트 트리의 연산은 크게 2가지로 나뉜다.

    1. update 연산 : 데이터 갱신
    2. query 연산 : 구간의 정보 반환

먼저, update연산에 대해 알아보자.

update연산은 특정 인덱스 위치에 데이터를 삽입하거나, 수정, 삭제하는 연산으로 하나의 함수로 처리할 수있다.


1
2
3
4
5
6
7
8
9
10
//배열에서 변경할 값 위치, 변경할 값, 시작노드번호, 현재 구간[nodeL,nodeR]
long long update(int pos, int val, int node, int nodeL, int nodeR) {
    //1. 변경할 값 위치를 포함하지 않는 경우
    if (pos < nodeL || pos > nodeR) return tree[node];
    //2. 리프 노드일 경우, 즉 변경할 값 위치인 경우
    if (nodeL == nodeR) return tree[node] = val;
    //3. 변경할 값 위치를 포함하는 구간일 경우 -> node의 자식들로 들어감
    int mid = (nodeL + nodeR) / 2;
    return tree[node] = update(pos, val, node * 2, nodeL, mid) + update(pos, val, node * 2 + 1, mid + 1, nodeR);
}
cs


query연산은 해당 구간의 정보를 얻기 위해 해당 구간에 포함된 서브트리들의 정보를 합하는 연산이다.

탐색범위가 [nodeL, nodeR]이고 구하려는 범위가 [L, R]일때 총 3가지 경우로 나뉠 수 있다.


1. [nodeL, nodeR]과 [L, R]이 전혀 겹치지 않는 경우 → 구하고자 하는 범위와 상관없는 경우, 반환값에서 제외

2. [L, R] 안에 [nodeL, nodeR]가 완전히 포함되는 경우 → 구하고자 하는 범위가 포함된 경우, 해당 서브 트리의 루트 노드를 반환

3. [nodeL, nodeR]과 [L, R]의 범위가 일부 겹치는 경우 → [L, R] 구간에 포함되지 않은 정보 제거 위해서 재탐색 필요한 경우


1
2
3
4
5
6
7
8
9
10
//찾고자하는 구간[L,R], 시작노드번호, 현재 노드의 구간[nodeL,nodeR]
long long query(int L, int R, int node, int nodeL, int nodeR) {
    //1. 현재 노드의 구간이 찾고자 하는 구간과 전혀 겹치지 않을 때
    if (R < nodeL || nodeR < L) return 0;
    //2. 현재 노드의 구간이 찾고자 하는 구간에 포함될 때
    if (L <= nodeL && nodeR <= R) return tree[node]; 
    //3. 현재 노드의 구간이 찾고자 하는 구간 일부에 포함될때
    int mid = (nodeL + nodeR) / 2;
    return query(L, R, node * 2, nodeL, mid) + query(L, R, node * 2 + 1, mid + 1, nodeR);
}
cs


 



위 코드들은 백준의 2042 구간 합 구하기를 푼 코드들이다.

아래에 최종 코드를 적어두겠다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#include <stdio.h>
using namespace std;
 
int n, m, k;
long long tree[4000001];
 
//배열에서 변경할 값 위치, 변경할 값, 시작노드번호, 현재 구간[nodeL,nodeR]
long long update(int pos, int val, int node, int nodeL, int nodeR) {
    //1. 변경할 값 위치를 포함하지 않는 경우
    if (pos < nodeL || pos > nodeR) return tree[node];
    //2. 리프 노드일 경우, 즉 변경할 값 위치인 경우
    if (nodeL == nodeR) return tree[node] = val;
    //3. 변경할 값 위치를 포함하는 구간을 경우 -> node의 자식들로 들어감
    int mid = (nodeL + nodeR) / 2;
    return tree[node] = update(pos, val, node * 2, nodeL, mid) + update(pos, val, node * 2 + 1, mid + 1, nodeR);
}
//찾고자하는 구간[L,R], 시작노드번호, 현재 노드의 구간[nodeL,nodeR]
long long query(int L, int R, int node, int nodeL, int nodeR) {
    //1. 현재 노드의 구간이 찾고자 하는 구간과 전혀 겹치지않을때
    if (R < nodeL || nodeR < L) return 0;
    //2. 현재 노드의 구간이 찾고자 하는 구간에 포함될때
    if (L <= nodeL && nodeR <= R) return tree[node];
    //3. 현재 노드의 구간이 찾고자 하는 구간 일부에 포함될때
    int mid = (nodeL + nodeR) / 2;
    return query(L, R, node * 2, nodeL, mid) + query(L, R, node * 2 + 1, mid + 1, nodeR);
}
int main() {
    scanf("%d %d %d"&n, &m, &k);
    for (int i = 1; i <= n; i++) {
        int temp;
        scanf("%d"&temp);
        update(i, temp, 11, n);
    }
    for (int i = 0; i < m + k; i++) {
        int a, b, c;
        scanf("%d %d %d"&a, &b, &c);
        
        if (a == 1) { //update
            update(b, c, 11, n);
        }
        else { // query
            printf("%lld\n", query(b, c, 11, n));
        }
    }
    return 0;
}
cs