본문 바로가기
알고리즘/[PRO] 삼성 SW 역량 테스트 B형

최적화) 삼성 C형 샘플 문제 : 블록 부품 맞추기

by 피로물든딸기 2021. 4. 3.
반응형

삼성 B형 전체 링크

삼성 C형 전체 링크

 

블록 부품 맞추기 문제를 더 최적화 해보자.

 

아래의 makeBlock에서 코드를 하나씩 지워보며 시간을 재보면 sort에서 비용이 많이 드는 것을 알 수 있다.

int makeBlock(int module[][4][4])
{
	register int i, sum;

	bcnt1 = bcnt2 = mcnt1 = mcnt2 = 0;
	for (i = 0; i < 30000; i++) check1[i] = check2[i] = 0;

	makeHash(module);

	sort(match1, 0, mcnt1 - 1, isMinForMatch);
	sort(block1, 0, bcnt1 - 1, isMinForBlock);
	sort(match2, 0, mcnt2 - 1, isMinForMatch);
	sort(block2, 0, bcnt2 - 1, isMinForBlock);

	sum = 0;

	for (i = 0; i < mcnt1; i++)
	{
		if (check1[match1[i].index]) continue;

		check1[match1[i].index] = 1;

		register BLOCK find = binarysearch(block1, check1, COMPLETE1 - match1[i].hash, bcnt1);
		if (find.min == -1) continue;

		sum += find.max + match1[i].min;
		check1[find.index] = 1;
	}

	for (i = 0; i < mcnt2; i++)
	{
		if (check2[match2[i].index]) continue;

		check2[match2[i].index] = 1;

		register BLOCK find = binarysearch(block2, check2, COMPLETE2 - match2[i].hash, bcnt2);
		if (find.min == -1) continue;

		sum += find.max + match2[i].min;
		check2[find.index] = 1;
	}

	return sum;
}

 

따라서, 정렬이분 탐색으로 hashing된 블럭을 찾지말고, hash table에 블럭을 저장하도록 하자.

블럭의 hashing된 값의 최대 크기는 0x2222222222222222이므로 그래도 저장하기에는 너무 큰 값이다.

따라서 PRIME으로 나눈 나머지 값으로 다시 hashing 한다.

(이 방법은 Rush Hour Puzzle에서 사용한 방법이다.)

 

블럭은 총 15만개가 나오므로, 넉넉하개 24만보다 큰 소수로 잡았다. 

hashTable에 original과 rotate된 블럭을 저장하기 때문에 block 배열은 더 이상 필요가 없다.

그리고 메모리 풀 방식을 이용해 channing 방식으로 hash table을 만든다. 

따라서 POOL도 넉넉하게 15만 정도 선언한다.

typedef unsigned long long int ull;
typedef unsigned int ui;

#define COMPLETE1 (0x1111111111111111)
#define COMPLETE2 (0x2222222222222222)
#define PRIME (240011)

typedef struct st1
{
	int index;
	int min;
	int max;
	ull hash;
}BLOCK;

BLOCK b[30300 * 5]; /* merge sort용 배열 */

BLOCK match1[30300];
BLOCK match2[30300];

int mcnt1, mcnt2;

typedef struct st2
{
	int index;
	int min;
	int max;
	ull hash;
	struct st2 *next;
}HASH;

HASH hashTable1[PRIME];
HASH hashTable2[PRIME];
HASH POOL[30300 * 5];
int pcnt;

이전의 makeHash 함수와 마찬가지로 최솟값, 최댓값을 찾아 block을 깎는다.

그리고 block에 저장하는 코드를 지우고 node를 선언하여 hashing된 블럭 % PRIME의 table에 연결한다.

void makeHash(int module[][4][4])
{
	ull cut[9] = { 0,
		0x1111111111111111, 0x2222222222222222, 0x3333333333333333,
		0x4444444444444444, 0x5555555555555555, 0x6666666666666666,
		0x7777777777777777, 0x8888888888888888 };

	register int i, k, min, max, diff;
	register ull hash, phash;
	
	for (i = 0; i < 30000; i++)
	{
		register ui* ptr = (ui*)module[i][0];

		min = 10;
		max = 0;
		for (k = 0; k < 16; k++)
		{
			max = max < ptr[k] ? ptr[k] : max;
			min = min > ptr[k] ? ptr[k] : min;
		}

		diff = max - min;

		register BLOCK* match = (diff == 1) ? match1 : match2;
		register HASH* hashTable = (diff == 1) ? hashTable1 : hashTable2;

		register int& mcnt = (diff == 1) ? mcnt1 : mcnt2;

		hash = (ull)ptr[0] << 60 | (ull)ptr[1] << 56 | (ull)ptr[2] << 52 | (ull)ptr[3] << 48
			| (ull)ptr[4] << 44 | (ull)ptr[5] << 40 | (ull)ptr[6] << 36 | (ull)ptr[7] << 32
			| (ull)ptr[8] << 28 | (ull)ptr[9] << 24 | (ull)ptr[10] << 20 | (ull)ptr[11] << 16
			| (ull)ptr[12] << 12 | (ull)ptr[13] << 8 | (ull)ptr[14] << 4 | (ull)ptr[15];

		hash -= cut[min];

		/* hashTable에 저장 */
		register HASH* nd = &POOL[pcnt++];
		nd->hash = hash;
		nd->min = min;
		nd->max = max;
		nd->index = i;

		phash = hash % PRIME;
		nd->next = hashTable[phash].next;
		hashTable[phash].next = nd;

		/* rotate x 3 hash table에 저장 */
		...
        
		/* flip */
		hash = (ull)ptr[3] << 60 | (ull)ptr[2] << 56 | (ull)ptr[1] << 52 | (ull)ptr[0] << 48
			| (ull)ptr[7] << 44 | (ull)ptr[6] << 40 | (ull)ptr[5] << 36 | (ull)ptr[4] << 32
			| (ull)ptr[11] << 28 | (ull)ptr[10] << 24 | (ull)ptr[9] << 20 | (ull)ptr[8] << 16
			| (ull)ptr[15] << 12 | (ull)ptr[14] << 8 | (ull)ptr[13] << 4 | (ull)ptr[12];

		hash -= cut[min];

		match[mcnt].hash = hash;
		match[mcnt].min = min;
		match[mcnt].max = max;
		match[mcnt++].index = i;
	}
}

match할 블럭은 hashTable에 저장할 필요가 없다. 따라서 기존 코드 그대로 둔다.

 

init에서는 POOL의 메모리를 초기화하고, hashTable도 모두 NULL을 가르키게 한다.

완전한 블럭의 높이를 최대로 맞춰야 하기 때문에 여전히 match 블럭의 sorting은 필요하다.

sorting이 완료되면 찾아야할 블럭의 hash 값인 COMPLETE - match[i].hash를 hashTable에서 찾는다.

이 hash 값은 PRIME으로 나눈 나머지 값의 hashTable에 저장되어 있을 것이다.

int makeBlock(int module[][4][4])
{
	register int i, sum, maxBlock, index;
	register ull hash;

	/* init */
	mcnt1 = mcnt2 = pcnt = 0;
	for (i = 0; i < 30000; i++) check1[i] = check2[i] = 0;
	for (i = 0; i < PRIME; i++) hashTable1[i].next = hashTable2[i].next = 0;

	makeHash(module);

	sort(match1, 0, mcnt1 - 1, isMinForMatch);
	sort(match2, 0, mcnt2 - 1, isMinForMatch);

	sum = 0;

	for (i = 0; i < mcnt1; i++)
	{
		if (check1[match1[i].index]) continue;

		check1[match1[i].index] = 1;

		hash = COMPLETE1 - match1[i].hash;

		register HASH *nd = hashTable1[hash % PRIME].next;

		maxBlock = 0;
		index = -1;
		while (nd)
		{
			if (nd->hash == hash && maxBlock < nd->max && check1[nd->index] == 0)
			{
				maxBlock = nd->max;
				index = nd->index;
			}

			nd = nd->next;
		}

		if (index == -1) continue;

		sum += maxBlock + match1[i].min;
		check1[index] = 1;
	}
    
    ...
    
}

 

hashTable에 충돌이 있을 수 있으므로 실제 hash값과 비교한다.

그리고 가장 큰 값을 찾아야 하기 때문에 break하지 않고 max를 갱신하며 모든 블럭을 탐색한다.

이전에 찾았던 블럭을 제외하기 위해 check[nd->index] == 0인 경우만 찾는다.

블럭을 찾지 못하면 index = -1이기 때문에 continue해서 다음 블럭을 찾는다.

블럭을 찾았다면, sum에 높이를 저장하고, 찾은 블럭을 check 배열에 표시한다.

 

마찬가지로, match 블럭도 check에 표시를 해야, 자기 자신을 찾지 않는다.


높이가 2인 코드는 최종 코드를 참고하자.

 

Main

#include <iostream>
#include <stdlib.h>

#define MAX 30000

using namespace std;

extern int makeBlock(int module[][4][4]);

int main(void)
{
    static int module[MAX][4][4];

    srand(3); // 3 will be changed

    for (int tc = 1; tc <= 10; tc++)
    {
        for (int c = 0; c < MAX; c++)
        {
            int base = 1 + (rand() % 6);
			for (int y = 0; y < 4; y++)
			{
				for (int x = 0; x < 4; x++)
				{
					module[c][y][x] = base + (rand() % 3);
				}
			}
        }
		cout << "#" << tc << " " << makeBlock(module) << endl;
    }

	return 0;
}

 

User Code

int makeBlock(int module[][4][4])
{
     
}

정답

typedef unsigned long long int ull;
typedef unsigned int ui;

#define COMPLETE1 (0x1111111111111111)
#define COMPLETE2 (0x2222222222222222)
#define PRIME (240011)

typedef struct st1
{
	int index;
	int min;
	int max;
	ull hash;
}BLOCK;

BLOCK b[30300 * 5]; /* merge sort용 배열 */

BLOCK match1[30300];
BLOCK match2[30300];

int mcnt1, mcnt2;

typedef struct st2
{
	int index;
	int min;
	int max;
	ull hash;
	struct st2 *next;
}HASH;

HASH hashTable1[PRIME];
HASH hashTable2[PRIME];
HASH POOL[30300 * 5];
int pcnt;

int check1[30300];
int check2[30300];

void makeHash(int module[][4][4])
{
	ull cut[9] = { 0,
		0x1111111111111111, 0x2222222222222222, 0x3333333333333333,
		0x4444444444444444, 0x5555555555555555, 0x6666666666666666,
		0x7777777777777777, 0x8888888888888888 };

	register int i, k, min, max, diff;
	register ull hash, phash;
	
	for (i = 0; i < 30000; i++)
	{
		register ui* ptr = (ui*)module[i][0];

		min = 10;
		max = 0;
		for (k = 0; k < 16; k++)
		{
			max = max < ptr[k] ? ptr[k] : max;
			min = min > ptr[k] ? ptr[k] : min;
		}

		diff = max - min;

		register BLOCK* match = (diff == 1) ? match1 : match2;
		register HASH* hashTable = (diff == 1) ? hashTable1 : hashTable2;

		register int& mcnt = (diff == 1) ? mcnt1 : mcnt2;

		hash = (ull)ptr[0] << 60 | (ull)ptr[1] << 56 | (ull)ptr[2] << 52 | (ull)ptr[3] << 48
			| (ull)ptr[4] << 44 | (ull)ptr[5] << 40 | (ull)ptr[6] << 36 | (ull)ptr[7] << 32
			| (ull)ptr[8] << 28 | (ull)ptr[9] << 24 | (ull)ptr[10] << 20 | (ull)ptr[11] << 16
			| (ull)ptr[12] << 12 | (ull)ptr[13] << 8 | (ull)ptr[14] << 4 | (ull)ptr[15];

		hash -= cut[min];

		register HASH* nd = &POOL[pcnt++];
		nd->hash = hash;
		nd->min = min;
		nd->max = max;
		nd->index = i;

		phash = hash % PRIME;
		nd->next = hashTable[phash].next;
		hashTable[phash].next = nd;

		hash = (ull)ptr[12] << 60 | (ull)ptr[8] << 56 | (ull)ptr[4] << 52 | (ull)ptr[0] << 48
			| (ull)ptr[13] << 44 | (ull)ptr[9] << 40 | (ull)ptr[5] << 36 | (ull)ptr[1] << 32
			| (ull)ptr[14] << 28 | (ull)ptr[10] << 24 | (ull)ptr[6] << 20 | (ull)ptr[2] << 16
			| (ull)ptr[15] << 12 | (ull)ptr[11] << 8 | (ull)ptr[7] << 4 | (ull)ptr[3];

		hash -= cut[min];

		nd = &POOL[pcnt++];
		nd->hash = hash;
		nd->min = min;
		nd->max = max;
		nd->index = i;

		phash = hash % PRIME;
		nd->next = hashTable[phash].next;
		hashTable[phash].next = nd;

		hash = (ull)ptr[15] << 60 | (ull)ptr[14] << 56 | (ull)ptr[13] << 52 | (ull)ptr[12] << 48
			| (ull)ptr[11] << 44 | (ull)ptr[10] << 40 | (ull)ptr[9] << 36 | (ull)ptr[8] << 32
			| (ull)ptr[7] << 28 | (ull)ptr[6] << 24 | (ull)ptr[5] << 20 | (ull)ptr[4] << 16
			| (ull)ptr[3] << 12 | (ull)ptr[2] << 8 | (ull)ptr[1] << 4 | (ull)ptr[0];

		hash -= cut[min];

		nd = &POOL[pcnt++];
		nd->hash = hash;
		nd->min = min;
		nd->max = max;
		nd->index = i;

		phash = hash % PRIME;
		nd->next = hashTable[phash].next;
		hashTable[phash].next = nd;

		hash = (ull)ptr[3] << 60 | (ull)ptr[7] << 56 | (ull)ptr[11] << 52 | (ull)ptr[15] << 48
			| (ull)ptr[2] << 44 | (ull)ptr[6] << 40 | (ull)ptr[10] << 36 | (ull)ptr[14] << 32
			| (ull)ptr[1] << 28 | (ull)ptr[5] << 24 | (ull)ptr[9] << 20 | (ull)ptr[13] << 16
			| (ull)ptr[0] << 12 | (ull)ptr[4] << 8 | (ull)ptr[8] << 4 | (ull)ptr[12];

		hash -= cut[min];

		nd = &POOL[pcnt++];
		nd->hash = hash;
		nd->min = min;
		nd->max = max;
		nd->index = i;

		phash = hash % PRIME;
		nd->next = hashTable[phash].next;
		hashTable[phash].next = nd;

		/* flip */
		hash = (ull)ptr[3] << 60 | (ull)ptr[2] << 56 | (ull)ptr[1] << 52 | (ull)ptr[0] << 48
			| (ull)ptr[7] << 44 | (ull)ptr[6] << 40 | (ull)ptr[5] << 36 | (ull)ptr[4] << 32
			| (ull)ptr[11] << 28 | (ull)ptr[10] << 24 | (ull)ptr[9] << 20 | (ull)ptr[8] << 16
			| (ull)ptr[15] << 12 | (ull)ptr[14] << 8 | (ull)ptr[13] << 4 | (ull)ptr[12];

		hash -= cut[min];

		match[mcnt].hash = hash;
		match[mcnt].min = min;
		match[mcnt].max = max;
		match[mcnt++].index = i;
	}
}

int isMinForMatch(BLOCK a, BLOCK b)
{
	return (a.max > b.max);
}

void merge(BLOCK* block, int start, int end, int(*comp)(BLOCK, BLOCK))
{
	register int mid, i, j, k;

	mid = (start + end) >> 1;
	i = start;
	j = mid + 1;
	k = 0;

	while (i <= mid && j <= end)
	{
		if (comp(block[i], block[j])) b[k++] = block[i++];
		else b[k++] = block[j++];
	}

	while (i <= mid) b[k++] = block[i++];
	while (j <= end) b[k++] = block[j++];

	for (i = start; i <= end; i++)
		block[i] = b[i - start];

}

void sort(BLOCK* block, int start, int end, int(*comp)(BLOCK a, BLOCK b))
{
	register int mid;
	if (start >= end) return;

	mid = (start + end) >> 1;

	sort(block, start, mid, comp);
	sort(block, mid + 1, end, comp);
	merge(block, start, end, comp);
}

int makeBlock(int module[][4][4])
{
	register int i, sum, maxBlock, index;
	register ull hash;
	
	mcnt1 = mcnt2 = pcnt = 0;
	for (i = 0; i < 30000; i++) check1[i] = check2[i] = 0;
	for (i = 0; i < PRIME; i++) hashTable1[i].next = hashTable2[i].next = 0;

	makeHash(module);

	sort(match1, 0, mcnt1 - 1, isMinForMatch);
	sort(match2, 0, mcnt2 - 1, isMinForMatch);

	sum = 0;

	for (i = 0; i < mcnt1; i++)
	{
		if (check1[match1[i].index]) continue;

		check1[match1[i].index] = 1;

		hash = COMPLETE1 - match1[i].hash;

		register HASH *nd = hashTable1[hash % PRIME].next;

		maxBlock = 0;
		index = -1;
		while (nd)
		{
			if (nd->hash == hash && maxBlock < nd->max && check1[nd->index] == 0)
			{
				maxBlock = nd->max;
				index = nd->index;
			}

			nd = nd->next;
		}

		if (index == -1) continue;

		sum += maxBlock + match1[i].min;
		check1[index] = 1;
	}

	for (i = 0; i < mcnt2; i++)
	{
		if (check2[match2[i].index]) continue;

		check2[match2[i].index] = 1;

		hash = COMPLETE2 - match2[i].hash;

		register HASH *nd = hashTable2[hash % PRIME].next;

		maxBlock = 0;
		index = -1;
		while (nd)
		{
			if (nd->hash == hash && maxBlock < nd->max && check2[nd->index] == 0)
			{
				maxBlock = nd->max;
				index = nd->index;
			}

			nd = nd->next;
		}

		if (index == -1) continue;

		sum += maxBlock + match2[i].min;
		check2[index] = 1;
	}

	return sum;
}

정렬에 NlogN 비용이, 이분 탐색에 logN * 30000의 비용이 들었지만,

hashTable은 hash값을 찾는데 O(1)이면 되고, 충돌이 적기 때문에 O(30000 * 충돌 횟수)의 비용이 들게 된다.

따라서 hash table을 이용하면 실행시간이 절반 가까이 줄어든다.

반응형

댓글