BOJ 7453 : 합이 0인 네 정수
https://www.acmicpc.net/problem/7453
참고
합이 0이 되는지 확인하기 위해 4000 x 4000 x 4000 x 4000의 경우의 수를 모두 확인하는 것은 비용이 너무 많이 든다.
따라서 A와 B의 합인 AB 배열 (16,000,000)과 C와 D의 합인 CD 배열(16,000,000)을 만든다.
그리고 AB 배열의 값이 x라면 CD에서 -x가 되는 경우를 찾으면 된다.
이 경우 이분 탐색으로 log(16,000,000) ≒ 24번 만에 답을 찾을 수 있게 된다.
물론 이분 탐색을 위해 배열을 머지 소트로 정렬하였다.
구현
입력을 받은 후, AB, CD 배열을 만든다.
idx = 0;
for (int i = 0; i < N; i++)
{
for (int k = 0; k < N; k++)
{
AB[idx] = A[i] + B[k];
CD[idx++] = C[i] + D[k];
}
}
그리고 정렬한다.
sort(AB, 0, idx - 1);
sort(CD, 0, idx - 1);
답을 찾는 방법은 다음과 같다.
AB[]의 원소의 값이 x라면 CD에서 -x가 있는지 이분 탐색으로 찾는다.
값을 찾은 경우에는 -x가 여러 개 있을 수 있으므로 배열의 앞, 뒤로 같은 값인지 센다.
값을 찾지 못하면 0을 리턴하면 된다.
ll binarysearch(int data)
{
int l, r, m;
ll ret;
ret = 0;
l = 0, r = idx - 1;
while (l <= r)
{
m = (l + r) / 2;
if (CD[m] == data)
{
int tmp = m - 1;
while (CD[tmp] == data && tmp) ret++, tmp--;
while (CD[m] == data && m <= idx) ret++, m++;
return ret;
}
else if (CD[m] < data) l = m + 1;
else r = m - 1;
}
return 0;
}
AB[c]의 원소의 값이 x이고, AB[c + 1]도 x일 수 있다.
AB[c]에 대응하는 값 -x의 개수가 count라면 AB [c + 1]은 count 만큼 다시 값에 더해주면 된다.
즉, 현재 값(tmp)을 계속 기억해뒀다가, 값이 바뀌면 이분 탐색을 한다.
그래서 최초의 tmp는 AB[0]과 다른 값이 되기 위해서 AB[0]에 1을 빼고 시작한다.
ll ans, count;
int tmp = AB[0] - 1;
ans = count = 0;
for (int i = 0; i < idx; i++)
{
if (tmp != AB[i])
{
count = binarysearch(-AB[i]);
ans += count;
tmp = AB[i];
}
else
ans += count;
}
참고로 ans는 long 타입이다.
배열이 모두 0인 경우만 생각해봐도 경우의 수가 매우 많은 것을 알 수 있다.
전체 코드는 다음과 같다.
#include <stdio.h>
typedef long long int ll;
#define MAX (4010)
int N;
int A[MAX], B[MAX], C[MAX], D[MAX];
int AB[MAX * MAX], CD[MAX * MAX];
int idx;
int b[MAX * MAX];
void merge(int* a, int start, int end)
{
int mid, i, j, k;
mid = (start + end) >> 1;
i = start;
j = mid + 1;
k = 0;
while (i <= mid && j <= end)
{
if (a[i] <= a[j]) b[k++] = a[i++];
else b[k++] = a[j++];
}
while (i <= mid) b[k++] = a[i++];
while (j <= end) b[k++] = a[j++];
for (i = start; i <= end; i++)
a[i] = b[i - start];
}
void sort(int* a, int start, int end)
{
int mid;
if (start >= end) return;
mid = (start + end) >> 1;
sort(a, start, mid);
sort(a, mid + 1, end);
merge(a, start, end);
}
ll binarysearch(int data)
{
int l, r, m;
ll ret;
ret = 0;
l = 0, r = idx - 1;
while (l <= r)
{
m = (l + r) / 2;
if (CD[m] == data)
{
int tmp = m - 1;
while (CD[tmp] == data && tmp) ret++, tmp--;
while (CD[m] == data && m <= idx) ret++, m++;
return ret;
}
else if (CD[m] < data) l = m + 1;
else r = m - 1;
}
return 0;
}
int main()
{
scanf("%d", &N);
for (int i = 0; i < N; i++)
scanf("%d %d %d %d", &A[i], &B[i], &C[i], &D[i]);
idx = 0;
for (int i = 0; i < N; i++)
{
for (int k = 0; k < N; k++)
{
AB[idx] = A[i] + B[k];
CD[idx++] = C[i] + D[k];
}
}
sort(AB, 0, idx - 1);
sort(CD, 0, idx - 1);
ll ans, count;
int tmp = AB[0] - 1;
ans = count = 0;
for (int i = 0; i < idx; i++)
{
if (tmp != AB[i])
{
count = binarysearch(-AB[i]);
ans += count;
tmp = AB[i];
}
else
ans += count;
}
printf("%lld\n", ans);
return 0;
}