본문 바로가기
알고리즘/[PRO] 삼성 SW 역량 테스트 B형

BOJ 2042 : 구간 합 구하기 with 바텀 업 세그먼트 트리 (Bottom-Up Segment Tree)

by 피로물든딸기 2023. 1. 15.
반응형

알고리즘 문제 전체 링크
삼성 B형 전체 링크

삼성 C형 전체 링크

 

https://www.acmicpc.net/problem/2042
 
참고
구간 합 구하기 with 제곱근 분할법 (Sqrt Decomposition)
구간 합 구하기 with 탑 다운 세그먼트 트리 (Top-Down Segment Tree)
구간 합 구하기 with 바텀 업 세그먼트 트리 (Bottom-Up Segment Tree)
구간 합 구하기 2 with 나중에 업데이트하기 (Top-Down Segment Tree with Lazy Propagation)
 

 
여러 배열이 있고 구간의 합이 쿼리로 주어지면 합을 구해서 답을 구하는 문제다.
 
바텀 업 세그먼트 트리를 이용하여 구간 합을 구해보자.


주어진 구간의 합을 구해야할 배열이 다음과 같다고 하자.

 
Top-Down 세그먼트 트리는 아래와 같은 트리가 만들어졌다.

 
트리에 저장된 값은 다음과 같았다.

 
Bottom-Up 세그먼트 트리는 아래와 같은 트리가 만들어진다.

 
트리에 저장된 값은 다음과 같다.

 

7개의 원소를 가지기 때문에 7보다 큰 2n은 8 (n = 3)이다.

즉, 가장 아래에 8개의 노드를 채우고 다음에 4개, 그 다음 2개, 마지막 1개를 채운다.
 
이 문제는 덧셈이므로 더해도 상관 없는 값08번째 원소(tree에는 15번째)에 채워진다.


트리의 갱신과 초기화
 
먼저 첫 번째 원소를 해당 트리에 넣어보자.
아래의 빨간 node의 index는 8이 된다. (0번째 버림 + 1 + 2 + 4 = 7 다음 수는 8)

 
이 기준이 되는 start index는 아래와 같이 구할 수 있다.

이렇게 되면 N = 7 보다 큰 (1 << n) = 2n를 만족하는 n = 3을 찾을 수 있다.

따라서 start는 1 << 3 = 8에서 1을 뺀 값 7이 되고, 첫 번째 인덱스는 start + 1이 되어 8번째 node를 찾는다.

for (n = 1; (1 << n) <= N; n++);
start = 1 << n;
start--;

 
이 index를 반으로 나누면서 index가 1이 될 때까지 거치는 모든 node에 첫 번째 원소 -1을 더한다.

 
두 번째 원소 2도 마찬가지로 관련된 node를 모두 더한다.
1 ~ 2의 합을 저장하는 node는 -1 + 2가 되어 1이 된다.

 
세 번째 원소를 갱신하면 다음과 같다.
관련된 node에 모두 3을 더했다.

 
4번째 원소까지 처리하면 다음과 같다.

 
5번째 원소를 갱신하면 다음과 같다.

 
6번째 원소를 갱신하면 다음과 같다.

 
7번째 원소를 갱신하면 다음과 같다.

 
8번째 원소는 없으므로 그대로 두면 된다. 
따라서 최종적으로 아래와 같은 값이 트리에 저장된다.

 
갱신하는 update 함수 구현은 다음과 같다.
update에서 값을 계속 더해나가므로 tree에는 값을 누적만 한다.
(Top-Down과 마찬가지로 값을 바꾸면 위의 node에서 어떤 값으로 바꿔야 할지 알 방법이 없다.)
그리고 들어오는 index는 배열의 index이므로, 기준이 되는 start index에 값을 더한 후 갱신한다.

void update(int index, ll diff)
{
	index += start;

	while (index > 1)
	{
		tree[index] += diff;
		index /= 2;
	}
}

 
이 함수를 N번 입력을 받으면서 호출하면 초기화가 완료된다.

	for (int i = 1; i <= N; i++)
	{
		ll x;

		scanf("%lld", &x);
        
		update(i, x);
	}

 
Top Down과 달리 init을 따로 만들지 않고, update 함수로 초기화를 한다.


구간의 합
 
이제 바텀 업 세그먼트 트리에서 1 ~ 6번째 원소의 합을 구해보자.
left = 1, right = 6으로 시작한다. 실제 node의 index는 start인 7을 더해야하므로 left = 8, right = 13이 된다.

 
left의 index가 2로 나누어 떨어진다면 left index를 2로 나눈다.
그렇지 않다면, 현재 tree[left]의 값을 더하고, left에 1을 더한 후 2로 나눈다.
 
right의 index가 2로 나누어 떨어지지 않는다면 right index를 2로 나눈다.
그렇지 않다면, 현재 tree[right]의 값을 더하고, right에 1을 뺀 후, 2로 나눈다.
 
위 방식을 left < right를 만족할 때까지 반복하고,
최종적으로 left == right가 되는 경우라면 마지막으로 값을 누적한다.
 
따라서 left = 4, right = 6이 되고 아래의 상태가 된다.

 
left는 2로 나누어 떨어지므로 다시 left를 2로 나눈다.
right는 2로 나누어 떨어지므로 현재의 값에 (5 ~ 6에 저장된 값)을 누적한다.
그리고 right를 1 뺀 후 2로 나눈다.
 
left = 2가 되고, right도 5가 된 후 2로 나누었으므로 2가 된다.
즉, left = 2, right = 2가 된다.

 
left와 right의 값이 같으므로 1 ~ 4에 저장된 값을 누적하면 된다.
따라서 [1 ~ 4] + [5 ~ 6]의 값을 구하게 된다.
 
실제로 구현하면 다음과 같다.

ll getSum(int left, int right)
{
	ll ans = 0;

	left += start;
	right += start;

	while (left < right)
	{
		if (left % 2 == 0) left /= 2;
		else
		{
			ans += tree[left];
			left += 1;
			left /= 2;
		}

		if (right % 2 == 1) right /= 2;
		else
		{
			ans += tree[right];
			right -= 1;
			right /= 2;
		}
	}

	if (left == right) ans += tree[left];

	return ans;
}

 
따라서 구해야할 구간에서 최대한 구간이 포함되도록 node를 올라간다.
덧셈이 완료되는 node라면 해당 구간은 더 이상 덧셈을 할 필요가 없으므로,
left는 오른쪽으로, right는 왼쪽으로 이동한 후, 위의 node로 이동하게 된다.


구현
 
위의 내용을 바탕으로 문제를 풀면 아래와 같이 구현할 수 있다.
메모리는 주어진 N보다 큰 2의 배수의 2배가 필요하므로 (1,000,000 → 1,048,576 → 2,097,152) 가 된다.
메모리가 넉넉하다면 4를 곱한 값으로 충분하다.

#include <stdio.h>

typedef long long int ll;

int N, M, K;
ll tree[4004000];
int start;

void update(int index, ll diff)
{
	index += start;

	while (index > 1)
	{
		tree[index] += diff;
		index /= 2;
	}
}

ll getSum(int left, int right)
{
	ll ans = 0;

	left += start;
	right += start;

	while (left < right)
	{
		if (left % 2 == 0) left /= 2;
		else
		{
			ans += tree[left];
			left += 1;
			left /= 2;
		}

		if (right % 2 == 1) right /= 2;
		else
		{
			ans += tree[right];
			right -= 1;
			right /= 2;
		}
	}

	if (left == right) ans += tree[left];

	return ans;
}

int main()
{
	int n;

	scanf("%d %d %d", &N, &M, &K);

	for (n = 1; (1 << n) <= N; n++);
	start = 1 << n;
	start--;

	for (int i = 1; i <= N; i++)
	{
		ll x;

		scanf("%lld", &x);

		update(i, x);
	}

	for (int i = 0; i < M + K; i++)
	{
		int a;

		scanf("%d", &a);

		if (a == 1)
		{
			int b;
			ll c;

			scanf("%d %lld", &b, &c);

			update(b, c - tree[start + b]);
		}
		else
		{
			int b, c;

			scanf("%d %d", &b, &c);

			printf("%lld\n", getSum(b, c));
		}
	}

	return 0;
}
반응형

댓글