BOJ 13547 : 수열과 쿼리 5
https://www.acmicpc.net/problem/13547
참고
- BOJ 2042 : 구간 합 구하기 with 제곱근 분할법 (Sqrt Decomposition)
원소의 개수 세기
구간에 업데이트가 없다면, 모스(Mo's) 알고리즘을 이용해 쿼리를 처리할 수 있다.
만약에 [left, right] 구간에 대한 답을 알고 있다면, [left, right ± 1]이나 [left ± 1, right]의 쿼리는 쉽게 처리할 수 있다.
즉, 이미 답을 구한 구간을 최대한 효율적으로 활용하는 방법이다.
위 문제는 회전 초밥에서 슬라이딩 윈도우를 사용했던 방법으로 원소의 개수를 셀 수 있다.
원소가 1 ~ 1,000,000이므로 테이블을 만든다.
int table[1000000 + 100];
left나 right를 한 칸씩 옮겨가면서 원소 A[i]에 있는 값을 센다.
원소가 추가될 때는 ++table[A[i]]를 했을 때 1이면 처음으로 추가되는 원소이므로 답이 1 증가한다.
if (++table[A[i]] == 1) ans++;
반대로 원소가 빠질 때는 --table[A[i]]가 0이 되는 경우다.
if (--table[A[i]] == 0) ans--;
그 외의 값은 모두 중복으로 원소를 세고 있는 것이므로 정답을 증가하거나 감소하지 않는다.
오프라인 쿼리
위의 방식대로 존재하는 서로 다른 수의 개수를
이미 답을 구한 구간을 최대한 효율적으로 활용하는 방법으로 구하더라도 아직은 비효율적이다.
N = 30인 경우, 쿼리의 순서가 [1, 5] → [28, 29] → [2, 6]로 주어진다고 하자.
[1, 5] 쿼리를 처리한 후, [28, 29] 쿼리를 처리하면 left와 right를 많이 이동시켜야 한다.
그리고 다시 [2, 6] 쿼리를 처리하면 left와 right를 많이 이동시킨다.
하지만 만약, [1, 5] → [2, 6] → [28, 29]로 쿼리가 들어온다면 성능이 많이 나아진다.
[1, 5]에서 답을 구했다면 left++을 해서 [2, 5]의 답을 구하고, right++을 해서 [2, 6]을 구하면 되기 때문이다.
따라서 쿼리를 순서대로 구하지 말고, 정렬해서 효율적으로 답을 구한 뒤, 답만 순서대로 출력한다.
이때 제곱근 분할법으로 쿼리를 정렬한다.
그러면 쿼리를 아래와 같이 정렬할 수 있다.
ex. N = 30이고 √N = 5 (= sqrt)라고 하자.
구간의 쿼리를 [left, right]라고 하자.
1. left / sqrt 가 작은 순으로 정렬한다.
2. left / sqrt 가 같은 경우는 right가 작은 순서대로 정렬한다.
예를 들면 아래와 같이 정렬된다.
쿼리 0 ~ 3에서 left는 모두 같은 √N에 포함된다.
같은 √N에 포함되는 left는 최대 √N번만 움직이게 된다.
그리고 right는 수열의 크기 N만큼만 움직인다.
쿼리 3 → 쿼리 4로 넘어가는 경우는 최악일 때, 2√N의 비용이 든다.
그리고 분할된 √N은 다시 돌아가지 않는다.
분할된 곳에서는 left는 최대 √N씩 움직이고, right는 N까지 증가하면서 움직인다.
따라서 비용을 최소화할 수 있다.
답을 구할 때는 각 쿼리에 대한 index를 저장한 다음 answer 배열에 저장하고 순서대로 출력하면 된다.
구현
query는 다음과 같이 정의한다.
쿼리를 정렬하기 때문에 index를 기억할 필요가 있다.
typedef struct st
{
int index;
int left;
int right;
}QUERY;
QUERY query[MAX];
정렬은 merge sort를 이용하였다.
원소의 수가 N = 100000이므로 √N = 약 316이다.
left / sqrt가 작은 순으로, 같은 경우 right가 작은 순으로 정렬한다.
#define MAX (100000 + 100)
#define SQRT (316)
QUERY b[MAX];
int isMin(QUERY a, QUERY b)
{
if (a.left / SQRT < b.left / SQRT) return 1;
else if (a.left / SQRT == b.left / SQRT && a.right < b.right) return 1;
return 0;
}
void merge(QUERY* a, int start, int end)
{
int mid, i, j, k;
mid = (start + end) >> 1;
i = start;
j = mid + 1;
k = 0;
while (i <= mid && j <= end)
{
if (isMin(a[i], a[j])) b[k++] = a[i++];
else b[k++] = a[j++];
}
while (i <= mid) b[k++] = a[i++];
while (j <= end) b[k++] = a[j++];
for (i = start; i <= end; i++)
a[i] = b[i - start];
}
void sort(QUERY* a, int start, int end)
{
int mid;
if (start >= end) return;
mid = (start + end) >> 1;
sort(a, start, mid);
sort(a, mid + 1, end);
merge(a, start, end);
}
입력을 받은 후 쿼리를 정렬한다.
scanf("%d", &N);
for (int i = 1; i <= N; i++) scanf("%d", &A[i]);
scanf("%d", &M);
for (int i = 0; i < M; i++)
{
int left, right;
scanf("%d %d", &left, &right);
query[i].index = i;
query[i].left = left;
query[i].right = right;
}
sort(query, 0, M - 1);
left와 right는 정렬된 첫 번째 쿼리의 left로 정의한다.
그리고 이 원소는 최초로 들어오는 첫 원소이므로 답 ans = 1로 설정하고 table에도 해당 원소의 수를 증가시킨다.
int left, right, ans;
left = right = query[0].left;
table[A[left]]++;
ans = 1;
이제 쿼리를 움직여가면서 답을 구한다.
예를 들어 현재 left가 쿼리의 left보다 작으면 left가 증가하고 이것은 원소를 빼는 과정이 된다.
query가 더 작다면 left는 감소하고 원소는 더하는 과정이 된다.
for (int i = 0; i < M; i++)
{
while (left < query[i].left)
{
left++;
if (--table[A[left - 1]] == 0) ans--;
}
while (left > query[i].left)
{
left--;
if (++table[A[left]] == 1) ans++;
}
while (right < query[i].right)
{
right++;
if (++table[A[right]] == 1) ans++;
}
while (right > query[i].right)
{
right--;
if (--table[A[right + 1]] == 0) ans--;
}
answer[query[i].index] = ans;
}
left는 증가시켰으나 빼는 원소는 left - 1이 되므로 코드가 아래와 같이 되는 것에 주의하자.
while (left < query[i].left)
{
left++;
if (--table[A[left - 1]] == 0) ans--;
}
반대로 left를 감소시키는 경우는 원소를 더하는 경우이므로 현재의 left 그대로 더해주면 된다.
while (left > query[i].left)
{
left--;
if (++table[A[left]] == 1) ans++;
}
right는 left와 반대로 하면 된다.
마지막에는 해당 쿼리의 index에 맞는 answer에 답을 저장한다.
answer[query[i].index] = ans;
위의 코드를 다음과 같이 줄일 수도 있다.
for (int i = 0; i < M; i++)
{
while (left < query[i].left)
if (--table[A[left++]] == 0) ans--;
while (left > query[i].left)
if (++table[A[--left]] == 1) ans++;
while (right < query[i].right)
if (++table[A[++right]] == 1) ans++;
while (right > query[i].right)
if (--table[A[right--]] == 0) ans--;
answer[query[i].index] = ans;
}
전체 코드는 다음과 같다.
#include <stdio.h>
#define MAX (100000 + 100)
#define SQRT (316)
int N, M;
int A[MAX];
int answer[MAX];
int table[1000000 + 100];
typedef struct st
{
int index;
int left;
int right;
}QUERY;
QUERY query[MAX];
QUERY b[MAX];
int isMin(QUERY a, QUERY b)
{
if (a.left / SQRT < b.left / SQRT) return 1;
else if (a.left / SQRT == b.left / SQRT && a.right < b.right) return 1;
return 0;
}
void merge(QUERY* a, int start, int end)
{
int mid, i, j, k;
mid = (start + end) >> 1;
i = start;
j = mid + 1;
k = 0;
while (i <= mid && j <= end)
{
if (isMin(a[i], a[j])) b[k++] = a[i++];
else b[k++] = a[j++];
}
while (i <= mid) b[k++] = a[i++];
while (j <= end) b[k++] = a[j++];
for (i = start; i <= end; i++)
a[i] = b[i - start];
}
void sort(QUERY* a, int start, int end)
{
int mid;
if (start >= end) return;
mid = (start + end) >> 1;
sort(a, start, mid);
sort(a, mid + 1, end);
merge(a, start, end);
}
int main()
{
scanf("%d", &N);
for (int i = 1; i <= N; i++) scanf("%d", &A[i]);
scanf("%d", &M);
for (int i = 0; i < M; i++)
{
int left, right;
scanf("%d %d", &left, &right);
query[i].index = i;
query[i].left = left;
query[i].right = right;
}
sort(query, 0, M - 1);
int left, right, ans;
left = right = query[0].left;
table[A[left]]++;
ans = 1;
for (int i = 0; i < M; i++)
{
while (left < query[i].left)
if (--table[A[left++]] == 0) ans--;
while (left > query[i].left)
if (++table[A[--left]] == 1) ans++;
while (right < query[i].right)
if (++table[A[++right]] == 1) ans++;
while (right > query[i].right)
if (--table[A[right--]] == 0) ans--;
answer[query[i].index] = ans;
}
for (int i = 0; i < M; i++) printf("%d\n", answer[i]);
return 0;
}