알고리즘
Segment Tree
컴영
2020. 3. 17. 17:20
Segment Tree 세그먼트 트리
- 개념 : 배열의 특정 구간의 합, 혹은 특정 구간에서 최대값·최솟값 등을 효율적으로 구하기 위한 자료구조
- 기본 구조 : 완전 이진 트리 구조
가장 최상단인 루트 노드 - 전체 구간의 정보
가장 최하단의 리프 노드 - 배열의 그 수 자체 즉 배열의 각 요소
- 배열의 크기가 N = 10개 일때의 세그먼트 트리를 그리면,
- 세그먼트 트리는 1차원 배열로 구현한다.
- 크기 : 입력 배열의 크기가 N일때, 리프 노드의 개수가 N이 된다. 따라서 세그먼트 트리의 높이는 [logN]이 되고, 세그먼트 트리 배열의 크기는 2^(H+1) 이 된다. 이를 코드로 나타내면
1 2 | int h = (int)ceil(log2(n)); int tree_size = (1 << (h+1)); | cs |
하지만, 간단하게 N * 4로 세그먼트 트리 배열의 크기를 정할 수 있다,(메모리는 더 사용하지만)
- 어떤 노드의 인덱스 번호가 i라면, 왼쪽 자식의 번호는 2*i 이고 오른쪽 자식의 번호는 2*i +1이 된다. 이를 그림으로 보면
이런식으로, tree의 배열이 저장된다.
- 세그먼트 트리의 연산은 크게 2가지로 나뉜다.
- update 연산 : 데이터 갱신
- query 연산 : 구간의 정보 반환
먼저, update연산에 대해 알아보자.
update연산은 특정 인덱스 위치에 데이터를 삽입하거나, 수정, 삭제하는 연산으로 하나의 함수로 처리할 수있다.
1 2 3 4 5 6 7 8 9 10 | //배열에서 변경할 값 위치, 변경할 값, 시작노드번호, 현재 구간[nodeL,nodeR] long long update(int pos, int val, int node, int nodeL, int nodeR) { //1. 변경할 값 위치를 포함하지 않는 경우 if (pos < nodeL || pos > nodeR) return tree[node]; //2. 리프 노드일 경우, 즉 변경할 값 위치인 경우 if (nodeL == nodeR) return tree[node] = val; //3. 변경할 값 위치를 포함하는 구간일 경우 -> node의 자식들로 들어감 int mid = (nodeL + nodeR) / 2; return tree[node] = update(pos, val, node * 2, nodeL, mid) + update(pos, val, node * 2 + 1, mid + 1, nodeR); } | cs |
query연산은 해당 구간의 정보를 얻기 위해 해당 구간에 포함된 서브트리들의 정보를 합하는 연산이다.
탐색범위가 [nodeL, nodeR]이고 구하려는 범위가 [L, R]일때 총 3가지 경우로 나뉠 수 있다.
1. [nodeL, nodeR]과 [L, R]이 전혀 겹치지 않는 경우 → 구하고자 하는 범위와 상관없는 경우, 반환값에서 제외
2. [L, R] 안에 [nodeL, nodeR]가 완전히 포함되는 경우 → 구하고자 하는 범위가 포함된 경우, 해당 서브 트리의 루트 노드를 반환
3. [nodeL, nodeR]과 [L, R]의 범위가 일부 겹치는 경우 → [L, R] 구간에 포함되지 않은 정보 제거 위해서 재탐색 필요한 경우
1 2 3 4 5 6 7 8 9 10 | //찾고자하는 구간[L,R], 시작노드번호, 현재 노드의 구간[nodeL,nodeR] long long query(int L, int R, int node, int nodeL, int nodeR) { //1. 현재 노드의 구간이 찾고자 하는 구간과 전혀 겹치지 않을 때 if (R < nodeL || nodeR < L) return 0; //2. 현재 노드의 구간이 찾고자 하는 구간에 포함될 때 if (L <= nodeL && nodeR <= R) return tree[node]; //3. 현재 노드의 구간이 찾고자 하는 구간 일부에 포함될때 int mid = (nodeL + nodeR) / 2; return query(L, R, node * 2, nodeL, mid) + query(L, R, node * 2 + 1, mid + 1, nodeR); } | cs |
위 코드들은 백준의 2042 구간 합 구하기를 푼 코드들이다.
아래에 최종 코드를 적어두겠다.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | #include <stdio.h> using namespace std; int n, m, k; long long tree[4000001]; //배열에서 변경할 값 위치, 변경할 값, 시작노드번호, 현재 구간[nodeL,nodeR] long long update(int pos, int val, int node, int nodeL, int nodeR) { //1. 변경할 값 위치를 포함하지 않는 경우 if (pos < nodeL || pos > nodeR) return tree[node]; //2. 리프 노드일 경우, 즉 변경할 값 위치인 경우 if (nodeL == nodeR) return tree[node] = val; //3. 변경할 값 위치를 포함하는 구간을 경우 -> node의 자식들로 들어감 int mid = (nodeL + nodeR) / 2; return tree[node] = update(pos, val, node * 2, nodeL, mid) + update(pos, val, node * 2 + 1, mid + 1, nodeR); } //찾고자하는 구간[L,R], 시작노드번호, 현재 노드의 구간[nodeL,nodeR] long long query(int L, int R, int node, int nodeL, int nodeR) { //1. 현재 노드의 구간이 찾고자 하는 구간과 전혀 겹치지않을때 if (R < nodeL || nodeR < L) return 0; //2. 현재 노드의 구간이 찾고자 하는 구간에 포함될때 if (L <= nodeL && nodeR <= R) return tree[node]; //3. 현재 노드의 구간이 찾고자 하는 구간 일부에 포함될때 int mid = (nodeL + nodeR) / 2; return query(L, R, node * 2, nodeL, mid) + query(L, R, node * 2 + 1, mid + 1, nodeR); } int main() { scanf("%d %d %d", &n, &m, &k); for (int i = 1; i <= n; i++) { int temp; scanf("%d", &temp); update(i, temp, 1, 1, n); } for (int i = 0; i < m + k; i++) { int a, b, c; scanf("%d %d %d", &a, &b, &c); if (a == 1) { //update update(b, c, 1, 1, n); } else { // query printf("%lld\n", query(b, c, 1, 1, n)); } } return 0; } | cs |