BOJ 2042 : 구간 합 구하기 with 다이나믹 세그먼트 트리 (Dynamic Segment Tree)
https://www.acmicpc.net/problem/2042
참고
- 구간 합 구하기 with 탑 다운 세그먼트 트리 (Top-Down Segment Tree)
이전 글에서는 tree의 크기를 넉넉하게 잡았다.
구간의 길이가 100만이기 때문에 실제로는 2배 정도의 크기만 있으면 된다.
여기서는 필요한 만큼만 적절히 잡아서 다이나믹하게 트리를 만들어보자.
배열의 index를 이용하는 방법
원래 update 코드는 아래와 같다.
void update(int left, int right, int node, int index, ll value)
{
if (index < left || right < index) return;
if (left == right)
{
tree[node] = value;
return;
}
int mid = (left + right) / 2;
int leftNodeIndex = node * 2;
int rightNodeIndex = node * 2 + 1;
update(left, mid, leftNodeIndex, index, value);
update(mid + 1, right, rightNodeIndex, index, value);
tree[node] = tree[leftNodeIndex] + tree[rightNodeIndex];
}
왼쪽과 오른쪽 노드를 이용하기 위해 node에 2를 곱하거나 곱한 후 1을 더해서 접근했다.
여기서 2를 곱하는 과정 때문에 메모리가 넉넉하게 필요하였다.
하지만 현재 노드가 왼쪽 노드나 오른쪽 노드의 위치만 제대로 저장만 한다면 node에 2를 곱할 필요가 없다.
먼저 tree를 구성하는 NODE를 아래와 같이 정의하자.
typedef struct st
{
ll value;
int left;
int right;
}NODE;
tree의 크기를 아래와 같이 정의하고 tcnt로 할당한다.
left, right가 0일 때 할당되지 않은 node로 구분하기 위해 tcnt를 2로 시작한다.
(root는 1부터 시작)
NODE tree[1001000 * 2];
int tcnt = 2;
update 코드는 아래와 같이 바뀐다.
left와 right가 존재하지 않는 경우라면 left와 right에 index를 할당한다.
if (tree[node].left == 0)
{
tree[node].left = tcnt++;
tree[node].right = tcnt++;
}
그러면 node * 2나 node * 2 + 1이 아니라 저장된 index를 보고 update를 계속 진행하면 된다.
int mid = (left + right) / 2;
int leftNodeIndex = tree[node].left;
int rightNodeIndex = tree[node].right;
update(left, mid, leftNodeIndex, index, value);
update(mid + 1, right, rightNodeIndex, index, value);
탑 다운 세그먼트 트리에서 마지막 방법과 비교하면 아래와 같이 코드가 변경된다.
query의 경우 node가 없는 경우는 0으로 return하는 코드를 추가한다.
ll getSum(int left, int right, int a, int b, int node)
{
if (node == 0) return 0;
if (b < left || right < a) return 0;
if (a <= left && right <= b) return tree[node].value;
int mid = (left + right) / 2;
return
getSum(left, mid, a, b, tree[node].left) + getSum(mid + 1, right, a, b, tree[node].right);
}
전체 코드는 다음과 같다.
#include <stdio.h>
typedef long long int ll;
typedef struct st
{
ll value;
int left;
int right;
}NODE;
int N, M, K;
NODE tree[1001000 * 2];
int tcnt = 2;
void update(int left, int right, int node, int index, ll value)
{
if (index < left || right < index) return;
if (left == right)
{
tree[node].value = value;
return;
}
if (tree[node].left == 0) tree[node].left = tcnt++;
if (tree[node].right == 0) tree[node].right = tcnt++;
int mid = (left + right) / 2;
int leftNodeIndex = tree[node].left;
int rightNodeIndex = tree[node].right;
update(left, mid, leftNodeIndex, index, value);
update(mid + 1, right, rightNodeIndex, index, value);
tree[node].value = tree[leftNodeIndex].value + tree[rightNodeIndex].value;
}
ll getSum(int left, int right, int a, int b, int node)
{
if (node == 0) return 0;
if (b < left || right < a) return 0;
if (a <= left && right <= b) return tree[node].value;
int mid = (left + right) / 2;
return
getSum(left, mid, a, b, tree[node].left) + getSum(mid + 1, right, a, b, tree[node].right);
}
int main()
{
scanf("%d %d %d", &N, &M, &K);
for (int i = 1; i <= N; i++)
{
ll input;
scanf("%lld", &input);
update(1, N, 1, i, input);
}
for (int i = 0; i < M + K; i++)
{
int a;
scanf("%d", &a);
if (a == 1)
{
int b;
ll c;
scanf("%d %lld", &b, &c);
update(1, N, 1, b, c);
}
else
{
int b, c;
scanf("%d %d", &b, &c);
printf("%lld\n", getSum(1, N, b, c, 1));
}
}
return 0;
}
링크드 리스트를 이용하는 방법
메모리 풀을 이용하였다.
코드의 원리는 같으므로 전체 코드만 참고하자.
int node가 NODE* node로 변경되고, ROOT 노드의 주소를 넘겨줘야 한다.
#include <stdio.h>
typedef long long int ll;
typedef struct st
{
ll value;
struct st* left;
struct st* right;
}NODE;
int N, M, K;
NODE ROOT;
NODE POOL[1001000 * 2];
int pcnt;
void update(int left, int right, NODE* node, int index, ll value)
{
if (index < left || right < index) return;
if (left == right)
{
node->value = value;
return;
}
if (node->left == NULL)
{
node->left = &POOL[pcnt++];
node->right = &POOL[pcnt++];
}
int mid = (left + right) / 2;
NODE* leftNode = node->left;
NODE* rightNode = node->right;
update(left, mid, leftNode, index, value);
update(mid + 1, right, rightNode, index, value);
node->value = leftNode->value + rightNode->value;
}
ll getSum(int left, int right, int a, int b, NODE* node)
{
if (node == NULL) return 0;
if (b < left || right < a) return 0;
if (a <= left && right <= b) return node->value;
int mid = (left + right) / 2;
return getSum(left, mid, a, b, node->left) + getSum(mid + 1, right, a, b, node->right);
}
int main()
{
scanf("%d %d %d", &N, &M, &K);
for (int i = 1; i <= N; i++)
{
ll input;
scanf("%lld", &input);
update(1, N, &ROOT, i, input);
}
for (int i = 0; i < M + K; i++)
{
int a;
scanf("%d", &a);
if (a == 1)
{
int b;
ll c;
scanf("%d %lld", &b, &c);
update(1, N, &ROOT, b, c);
}
else
{
int b, c;
scanf("%d %d", &b, &c);
printf("%lld\n", getSum(1, N, b, c, &ROOT));
}
}
return 0;
}