본문 바로가기
알고리즘/[PRO] 삼성 SW 역량 테스트 B형

BOJ 2042 : 구간 합 구하기 with 제곱근 분할법 (Sqrt Decomposition)

by 피로물든딸기 2023. 1. 15.
반응형

알고리즘 문제 전체 링크

삼성 B형 전체 링크

삼성 C형 전체 링크

 

https://www.acmicpc.net/problem/2042

참고
- 구간 합 구하기 with 제곱근 분할법 (Sqrt Decomposition)
- 구간 합 구하기 with 탑 다운 세그먼트 트리 (Top-Down Segment Tree)
- 구간 합 구하기 with 바텀 업 세그먼트 트리 (Bottom-Up Segment Tree)
- 구간 합 구하기 2 with 나중에 업데이트하기 (Top-Down Segment Tree with Lazy Propagation)

 


여러 배열이 있고 구간의 합이 쿼리로 주어지면 합을 구해서 답을 구하는 문제다.

아래와 같이 크기가 20인 배열이 있다고 가정하자.
문제에서는 배열이 1부터 시작하지만, 여기서는 편의상 0부터 시작한다.


만약 0 ~ 19번 배열까지의 합을 구하려면 매번 20개의 원소를 더해야 한다.
0 ~ 19번 정도는 괜찮지만, 문제에서 주어진 최대 N인 1,000,000의 배열을 더하고,
구간의 합을 구하는 횟수 K가 10,000이라면 1,000,000 x 10,000번의 연산이 필요하다.
따라서 Time Out이 발생하게 된다.


제곱근 분할법

연산을 줄이기 위해 미리 합을 구해두면 Time Out을 해결할 수 있다.
합을 얼마나 미리 구해두냐는 √N으로 분할한다.
예를 들어 √20 = 4.xxx이기 때문에 4나 5로 분할할 수 있다.


만약 2부터 18까지의 합을 구한다면,
2, 3번째 원소는 그냥 더하고 4 ~ 15는 분할된 합이 저장된 공간으로 2 + 10 + 2을 구하고,
16, 17, 18번 원소를 더하면 된다.

이전 방식이라면 17번 연산이 필요하지만 위 방식을 이용하면 2 + 3 + 3 = 8번만에 구할 수 있다.

만약 1,000,000의 경우라면 √1,000,000 = 1,000이 되고
최악의 연산은 1부터 999,998 원소까지 더하는 경우다.

이 경우
1 부터 999까지의 합 = 999번 연산
+ 1,000 ~ 999,899 까지의 합 = 998번 연산
+ 999,000 ~ 999,999 까지의 합 = 999번 연산

으로 999,998번 해야할 연산을 3,996번으로 줄일 수 있다.
메모리가 √N 만큼 늘어난 대신 연산 횟수는 √N 정도 줄어들게 된다.


구현

main문은 아래와 같다.
문제에서 주어진 답이 long 범위이기 때문에 long long int로 배열을 선언한다.
N = 1,000,000이기 때문에 SQRT의 크기는 1,000으로 잡았다.
그리고 문제에서 1 ~ 1,000,000번째 배열이 주어진다고 하였지만, 여기서는 편의상 0부터 시작하므로,
update나 getSum에서 index를 1씩 빼고 넘겨준다.

typedef long long int ll;

#define MAX_SQRT (1000)

int N, M, K;
ll arr[1001000];
ll sqrtSum[MAX_SQRT + 10];

int main(void)
{
	scanf("%d %d %d", &N, &M, &K);

	for (int i = 0; i < N; i++) scanf("%lld", &arr[i]);

	sqrtDecomposition();

	for (int i = 0; i < M + K; i++)
	{
		int a;

		scanf("%d", &a);

		if (a == 1)
		{
			int b;
			ll c;

			scanf("%d %lld", &b, &c);

			update(b - 1, c);
		}
		else
		{
			int b, c;
			scanf("%d %d", &b, &c);

			printf("%lld\n", getSum(b - 1, c - 1));
		}
	}

	return 0;
}


N의 크기와 상관없이 1000개 단위로 배열을 저장해둔다.
N의 크기 1000보다 작다면 그냥 다 더하는 연산도 충분히 빠르다.
1000개씩 나눠서 sqrtSum에 저장한다.
(ex. 0 ~ 999번째 배열의 합은 sqrtSum[0]에 저장된다.)

void sqrtDecomposition()
{
	int sqrtIndex = 0;
	int count = N / MAX_SQRT;
	int start, end;

	for (int i = 0; i < count; i++)
	{
		start = i * 1000;
		end = start + 1000;

		for (int k = start; k < end; k++)
			sqrtSum[sqrtIndex] += arr[k];

		sqrtIndex++;
	}

	start = count * 1000;
	end = N;

	for (int k = start; k < end; k++)
		sqrtSum[sqrtIndex] += arr[k];
}


구간의 합을 구할 때 변경되는 부분이 생긴다면 기존의 수를 sqrtSum에서 빼주고,
arr에서 값을 변경한 뒤에 다시 더한다.
배열을 0부터 시작했기 때문에 sqrtSum에 저장된 위치는 MAX_SQRT = 1000으로 나누기만 하면 된다.

void update(int index, ll value)
{
	int sqrtIndex = index / MAX_SQRT;

	sqrtSum[sqrtIndex] -= arr[index];
	arr[index] = value;
	sqrtSum[sqrtIndex] += arr[index];
}


getSum에는 왼쪽 좌표와 오른쪽 좌표가 들어온다.
sqrtIndex를 구해서 leftSqrtIndex와 rightSqrtIndex가 같다면 그냥 모두 더하면 된다.
그 외에는 sqrtSum으로 구할 수 있는 부분은 sqrtSum으로 구하고 나머지 left와 right 원소를 더하면 된다.

ll getSum(int left, int right)
{
	long long sum = 0;
	int leftSqrtIndex = left / MAX_SQRT;
	int rightSqrtIndex = right / MAX_SQRT;

	if (leftSqrtIndex == rightSqrtIndex)
	{
		for (int i = left; i <= right; i++) sum += arr[i];

		return sum;
	}

	for (int i = leftSqrtIndex + 1; i < rightSqrtIndex; i++) sum += sqrtSum[i];

	int end = (leftSqrtIndex + 1) * 1000;
	for (int i = left; i < end; i++) sum += arr[i];

	int start = rightSqrtIndex * 1000;
	for (int i = start; i <= right; i++) sum += arr[i];

	return sum;
}


전체 코드는 다음과 같다.

#include <stdio.h>

typedef long long int ll;

#define MAX_SQRT (1000)

int N, M, K;
ll arr[1001000];
ll sqrtSum[MAX_SQRT + 10];

void sqrtDecomposition()
{
	int sqrtIndex = 0;
	int count = N / MAX_SQRT;
	int start, end;

	for (int i = 0; i < count; i++)
	{
		start = i * 1000;
		end = start + 1000;

		for (int k = start; k < end; k++)
			sqrtSum[sqrtIndex] += arr[k];

		sqrtIndex++;
	}

	start = count * 1000;
	end = N;

	for (int k = start; k < end; k++)
		sqrtSum[sqrtIndex] += arr[k];
}

void update(int index, ll value)
{
	int sqrtIndex = index / MAX_SQRT;

	sqrtSum[sqrtIndex] -= arr[index];
	arr[index] = value;
	sqrtSum[sqrtIndex] += arr[index];
}

ll getSum(int left, int right)
{
	long long sum = 0;
	int leftSqrtIndex = left / MAX_SQRT;
	int rightSqrtIndex = right / MAX_SQRT;

	if (leftSqrtIndex == rightSqrtIndex)
	{
		for (int i = left; i <= right; i++) sum += arr[i];

		return sum;
	}

	for (int i = leftSqrtIndex + 1; i < rightSqrtIndex; i++) sum += sqrtSum[i];

	int end = (leftSqrtIndex + 1) * 1000;
	for (int i = left; i < end; i++) sum += arr[i];

	int start = rightSqrtIndex * 1000;
	for (int i = start; i <= right; i++) sum += arr[i];

	return sum;
}

int main(void)
{
	scanf("%d %d %d", &N, &M, &K);

	for (int i = 0; i < N; i++) scanf("%lld", &arr[i]);

	sqrtDecomposition();

	for (int i = 0; i < M + K; i++)
	{
		int a;

		scanf("%d", &a);

		if (a == 1)
		{
			int b;
			ll c;

			scanf("%d %lld", &b, &c);

			update(b - 1, c);
		}
		else
		{
			int b, c;
            
			scanf("%d %d", &b, &c);

			printf("%lld\n", getSum(b - 1, c - 1));
		}
	}

	return 0;
}
반응형

댓글