ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Segment Tree
    알고리즘 2020. 3. 17. 17:20

    Segment Tree 세그먼트 트리

    • 개념 : 배열의 특정 구간의 합, 혹은 특정 구간에서 최대값·최솟값 등을 효율적으로 구하기 위한 자료구조
    • 기본 구조 : 완전 이진 트리 구조 

                       가장 최상단인 루트 노드 - 전체 구간의 정보

                       가장 최하단의 리프 노드 - 배열의 그 수 자체 즉 배열의 각 요소

    •  배열의 크기가 N = 10개 일때의 세그먼트 트리를 그리면,

         


    •  세그먼트 트리는 1차원 배열로 구현한다.
    •  크기 : 입력 배열의 크기가 N일때, 리프 노드의 개수가 N이 된다. 따라서 세그먼트 트리의 높이는 [logN]이 되고, 세그먼트 트리 배열의 크기는 2^(H+1) 이 된다. 이를 코드로 나타내면

    1
    2
    int h = (int)ceil(log2(n));
    int tree_size = (1 << (h+1));  
    cs


    하지만, 간단하게 N * 4로 세그먼트 트리 배열의 크기를 정할 수 있다,(메모리는 더 사용하지만)

    •  어떤 노드의 인덱스 번호가 i라면, 왼쪽 자식의 번호는 2*i 이고 오른쪽 자식의 번호는 2*i +1이 된다. 이를 그림으로 보면
        

    이런식으로, tree의 배열이 저장된다.


    •  세그먼트 트리의 연산은 크게 2가지로 나뉜다.

      1. update 연산 : 데이터 갱신
      2. query 연산 : 구간의 정보 반환

    먼저, update연산에 대해 알아보자.

    update연산은 특정 인덱스 위치에 데이터를 삽입하거나, 수정, 삭제하는 연산으로 하나의 함수로 처리할 수있다.


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    //배열에서 변경할 값 위치, 변경할 값, 시작노드번호, 현재 구간[nodeL,nodeR]
    long long update(int pos, int val, int node, int nodeL, int nodeR) {
        //1. 변경할 값 위치를 포함하지 않는 경우
        if (pos < nodeL || pos > nodeR) return tree[node];
        //2. 리프 노드일 경우, 즉 변경할 값 위치인 경우
        if (nodeL == nodeR) return tree[node] = val;
        //3. 변경할 값 위치를 포함하는 구간일 경우 -> node의 자식들로 들어감
        int mid = (nodeL + nodeR) / 2;
        return tree[node] = update(pos, val, node * 2, nodeL, mid) + update(pos, val, node * 2 + 1, mid + 1, nodeR);
    }
    cs


    query연산은 해당 구간의 정보를 얻기 위해 해당 구간에 포함된 서브트리들의 정보를 합하는 연산이다.

    탐색범위가 [nodeL, nodeR]이고 구하려는 범위가 [L, R]일때 총 3가지 경우로 나뉠 수 있다.


    1. [nodeL, nodeR]과 [L, R]이 전혀 겹치지 않는 경우 → 구하고자 하는 범위와 상관없는 경우, 반환값에서 제외

    2. [L, R] 안에 [nodeL, nodeR]가 완전히 포함되는 경우 → 구하고자 하는 범위가 포함된 경우, 해당 서브 트리의 루트 노드를 반환

    3. [nodeL, nodeR]과 [L, R]의 범위가 일부 겹치는 경우 → [L, R] 구간에 포함되지 않은 정보 제거 위해서 재탐색 필요한 경우


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    //찾고자하는 구간[L,R], 시작노드번호, 현재 노드의 구간[nodeL,nodeR]
    long long query(int L, int R, int node, int nodeL, int nodeR) {
        //1. 현재 노드의 구간이 찾고자 하는 구간과 전혀 겹치지 않을 때
        if (R < nodeL || nodeR < L) return 0;
        //2. 현재 노드의 구간이 찾고자 하는 구간에 포함될 때
        if (L <= nodeL && nodeR <= R) return tree[node]; 
        //3. 현재 노드의 구간이 찾고자 하는 구간 일부에 포함될때
        int mid = (nodeL + nodeR) / 2;
        return query(L, R, node * 2, nodeL, mid) + query(L, R, node * 2 + 1, mid + 1, nodeR);
    }
    cs


     



    위 코드들은 백준의 2042 구간 합 구하기를 푼 코드들이다.

    아래에 최종 코드를 적어두겠다.


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    #include <stdio.h>
    using namespace std;
     
    int n, m, k;
    long long tree[4000001];
     
    //배열에서 변경할 값 위치, 변경할 값, 시작노드번호, 현재 구간[nodeL,nodeR]
    long long update(int pos, int val, int node, int nodeL, int nodeR) {
        //1. 변경할 값 위치를 포함하지 않는 경우
        if (pos < nodeL || pos > nodeR) return tree[node];
        //2. 리프 노드일 경우, 즉 변경할 값 위치인 경우
        if (nodeL == nodeR) return tree[node] = val;
        //3. 변경할 값 위치를 포함하는 구간을 경우 -> node의 자식들로 들어감
        int mid = (nodeL + nodeR) / 2;
        return tree[node] = update(pos, val, node * 2, nodeL, mid) + update(pos, val, node * 2 + 1, mid + 1, nodeR);
    }
    //찾고자하는 구간[L,R], 시작노드번호, 현재 노드의 구간[nodeL,nodeR]
    long long query(int L, int R, int node, int nodeL, int nodeR) {
        //1. 현재 노드의 구간이 찾고자 하는 구간과 전혀 겹치지않을때
        if (R < nodeL || nodeR < L) return 0;
        //2. 현재 노드의 구간이 찾고자 하는 구간에 포함될때
        if (L <= nodeL && nodeR <= R) return tree[node];
        //3. 현재 노드의 구간이 찾고자 하는 구간 일부에 포함될때
        int mid = (nodeL + nodeR) / 2;
        return query(L, R, node * 2, nodeL, mid) + query(L, R, node * 2 + 1, mid + 1, nodeR);
    }
    int main() {
        scanf("%d %d %d"&n, &m, &k);
        for (int i = 1; i <= n; i++) {
            int temp;
            scanf("%d"&temp);
            update(i, temp, 11, n);
        }
        for (int i = 0; i < m + k; i++) {
            int a, b, c;
            scanf("%d %d %d"&a, &b, &c);
            
            if (a == 1) { //update
                update(b, c, 11, n);
            }
            else { // query
                printf("%lld\n", query(b, c, 11, n));
            }
        }
        return 0;
    }
    cs


    '알고리즘' 카테고리의 다른 글

    brute force  (0) 2020.03.26
    Lazy propagation  (0) 2020.03.17
    Greedy Algorithm  (0) 2020.03.17
    Floyd-Warshall  (0) 2020.03.17
    Dijkstra  (0) 2020.03.17

    댓글

Designed by Tistory.