본문 바로가기
개발/C, C++

C, C++ - 1 비트 개수 세기 (Bit Counter)

by 피로물든딸기 2023. 7. 29.
반응형

C, C++ 전체 링크

삼성 C형 전체 링크

 

주어진 숫자에 대해 비트 단위로 1이 몇 개 인지 세는 함수를 만들어보자.

먼저 비트 단위로 출력하기를 이용하여 아래의 코드를 실행해보자.

#include <stdio.h>
#include <iostream>

using namespace std;

typedef long long int ll;

template <typename T>
void printBitNumber(T number)
{
	unsigned int bitSize = sizeof(number) * 8;
	T mask = (1ull) << (bitSize - 1);

	printf("%d", (number & mask) == mask);

	mask = (1ull) << (bitSize - 2);
	for (int i = 1; i < bitSize; i++)
	{
		printf("%d", (number & mask) == mask);
		mask >>= 1;
		if (i % 8 == 7) printf(" ");
	}
	putchar('\n');
}

int main(void)
{
	ll bit = 1234123412341234123;

	printf("bit   : "); printBitNumber(bit);

	return 0;
}

 

출력되는 비트를 보면 1234123412341234123은 1의 개수30개임을 알 수 있다.


while문을 이용한 bit counter

 

자신의 수 그 수에서 1을 뺀 수& 비트 연산하면 가장 오른쪽에 있는 1 비트를 삭제한다.

아래 코드를 실행해보자.

#include <stdio.h>
#include <iostream>

using namespace std;

typedef long long int ll;

template <typename T>
void printBitNumber(T number)
{
	unsigned int bitSize = sizeof(number) * 8;
	T mask = (1ull) << (bitSize - 1);

	printf("%d", (number & mask) == mask);

	mask = (1ull) << (bitSize - 2);
	for (int i = 1; i < bitSize; i++)
	{
		printf("%d", (number & mask) == mask);
		mask >>= 1;
		if (i % 8 == 7) printf(" ");
	}
	putchar('\n');
}

int main(void)
{
	ll bit = 1234123412341234123;
	ll temp1 = bit & (bit - 1);
	ll temp2 = temp1 & (temp1 - 1);
	ll temp3 = temp2 & (temp2 - 1);
	ll temp4 = temp3 & (temp3 - 1);
	ll temp5 = temp4 & (temp4 - 1);

	printf("bit   : "); printBitNumber(bit);
	printf("temp1 : "); printBitNumber(temp1);
	printf("temp2 : "); printBitNumber(temp2);
	printf("temp3 : "); printBitNumber(temp3);
	printf("temp4 : "); printBitNumber(temp4);
	printf("temp5 : "); printBitNumber(temp5);

	return 0;
}

 

주어진 수 bit부터 시작해서 반복하면 가장 오른쪽에 있는 숫자 1이 사라지는 것을 알 수 있다.

 

이 방법을 while문을 이용하여 반복하면 된다.

number &= (number - 1)은 오른쪽에 있는 1 비트를 제거하면서 0이 될 때까지 반복하게 된다.

 

이때 count는 bit를 세는 숫자가 0이 아닌 경우 1부터 시작해야 한다.

비트가 1개만 남은 경우에도 number &= (number - 1)은 0이 되기 때문에 counting이 되지 않는다.

 

따라서 ! 연산자두 번 사용해서 0은 0으로, 그 외의 숫자는 1로 만드는 작업이 필요하다.

template <typename T>
int getBitCount(T number)
{
	int count = !!number;
	while (number &= (number - 1)) count++;
	return count;
}

 

getBitCount로 bit를 출력하면 30이 나오는 것을 알 수 있다.

 

전체 코드는 다음과 같다.

#include <stdio.h>
#include <iostream>

using namespace std;

typedef long long int ll;

template <typename T>
void printBitNumber(T number)
{
	unsigned int bitSize = sizeof(number) * 8;
	T mask = (1ull) << (bitSize - 1);

	printf("%d", (number & mask) == mask);

	mask = (1ull) << (bitSize - 2);
	for (int i = 1; i < bitSize; i++)
	{
		printf("%d", (number & mask) == mask);
		mask >>= 1;
		if (i % 8 == 7) printf(" ");
	}
	putchar('\n');
}

template <typename T>
int getBitCount(T number)
{
	int count = !!number;
	while (number &= (number - 1)) count++;
	return count;
}

int main(void)
{
	ll bit = 1234123412341234123;
	ll temp1 = bit & (bit - 1);
	ll temp2 = temp1 & (temp1 - 1);
	ll temp3 = temp2 & (temp2 - 1);
	ll temp4 = temp3 & (temp3 - 1);
	ll temp5 = temp4 & (temp4 - 1);

	printf("bit   : "); printBitNumber(bit);
	printf("temp1 : "); printBitNumber(temp1);
	printf("temp2 : "); printBitNumber(temp2);
	printf("temp3 : "); printBitNumber(temp3);
	printf("temp4 : "); printBitNumber(temp4);
	printf("temp5 : "); printBitNumber(temp5);
	putchar('\n');
	printf("num of 1 bit : %d\n", getBitCount(1024));

	return 0;
}

분할 정복을 이용한 bit counter

 

위의 getBitCount는 while문을 이용하는데, 1 비트가 적을수록 성능이 좋다.

반대로 말하면 1 비트가 많을수록 성능이 나빠진다.

만약 1이 64개라면 while이 64번 반복되어야 한다.

 

분할 정복을 이용하면 26 = 64 비트의 수를 항상 일정하게 6번만에 count 할 수 있다.

 

먼저 아래의 코드를 실행해보자.

#include <stdio.h>
#include <iostream>

using namespace std;

typedef long long int ll;

template <typename T>
void printBitNumber(T number)
{
	unsigned int bitSize = sizeof(number) * 8;
	T mask = (1ull) << (bitSize - 1);

	printf("%d", (number & mask) == mask);

	mask = (1ull) << (bitSize - 2);
	for (int i = 1; i < bitSize; i++)
	{
		printf("%d", (number & mask) == mask);
		mask >>= 1;
		if (i % 8 == 7) printf(" ");
	}
	putchar('\n');
}

int main(void)
{
	ll bit = 1234123412341234123;
	ll temp1 = bit & (bit - 1);
	ll temp2 = temp1 & (temp1 - 1);
	ll temp3 = temp2 & (temp2 - 1);
	ll temp4 = temp3 & (temp3 - 1);
	ll temp5 = temp4 & (temp4 - 1);

	ll mask1 = 0x5555555555555555;
	ll mask2 = 0x3333333333333333;
	ll mask3 = 0x0F0F0F0F0F0F0F0F;
	ll mask4 = 0x00FF00FF00FF00FF;
	ll mask5 = 0x0000FFFF0000FFFF;
	ll mask6 = 0x00000000FFFFFFFF;

	printf("bit   : "); printBitNumber(bit);
	printf("mask1 : "); printBitNumber(mask1);
	printf("mask2 : "); printBitNumber(mask2);
	printf("mask3 : "); printBitNumber(mask3);
	printf("mask4 : "); printBitNumber(mask4);
	printf("mask5 : "); printBitNumber(mask5);
	printf("mask6 : "); printBitNumber(mask6);

	return 0;
}

 

mask1 ~ mask6이 아래와 같이 규칙적인 패턴을 보인다.

 

그림으로 나타내면 다음과 같다.

 

이제 bitbit를 >> 1 연산한 값을 mask와 & 연산으로 합쳐보자.

 

(bit & mask1) + (bit >> 1) & mask1 의 연산 값은 아래와 같다.

 

위의 연산에서 4번째 칸만 다시 분석해보자.

 

4번째 칸만 보면 bit는 1 0 1 1 0 1 0 0 이다.

그 결과 나온 sum은 0 1 / 1 0 / 0 1 / 0 0 이다.

이 sum이 2칸씩 의미하는 것은 bit 칸을 2칸씩 나눌 때, 각각 1의 합이 된다.

1의 개수가 1개 / 2개 / 1개 / 0개을 의미하고 실제로도 bit를 보면 2칸씩 나눌 때 1 비트1개 / 2개 / 1개 / 0개 다.

 

mask1 단계에서 2비트씩 합쳐서 bit의 개수를 각 bit에 저장하였다.

mask2 단계에서는 mask1의 sum을 바탕으로 4개씩 bit를 저장하면 된다.

 

이제 4개씩 bit를 저장하기 위해 mask1에 의한 sum 결과를 이번에는 2칸 >> 옮긴다.

 

(bit = sum & mask2) + (bit >> 2) & mask2 의 연산 값은 아래와 같다.

 

마찬가지로 4번째 칸을 분석해보자.

원래의 비트는 1 0 1 1 0 1 0 0 이다.

2번째 mask 연산을 완료한 sum은 0 0 1 1 / 0 0 0 1 이다.

이것은 1의 개수를 4칸씩 볼 때 3개 / 1개 임을 의미하고 실제 bit에서도 4칸씩 보면 1 비트가 3개 / 1개임을 알 수 있다.

 

bit를 1칸 오른쪽으로 옮긴 후 mask1 연산을 하면 1 bit의 개수가 2 bit 단위로 저장되었고,

bit를 2칸 오른쪽으로 옮긴 후 mask2 연산을 하면 1 bit의 개수가 4 bit 단위로 저장되었다.

 

이와 같은 방식으로 mask3 ~ 6까지 4칸 / 8칸 / 16칸 / 32칸을 진행한다.

그러면 각각의 결과는 8칸 / 16칸 / 32칸 / 64칸에 1 bit의 개수가 저장된다.

	ll mask3 = 0x0F0F0F0F0F0F0F0F;
	ll mask4 = 0x00FF00FF00FF00FF;
	ll mask5 = 0x0000FFFF0000FFFF;
	ll mask6 = 0x00000000FFFFFFFF;

 

모든 칸에 대해서 진행하게 되면 숫자 12341234123412341231 비트의 개수 30을 얻게 된다.

 

함수를 정리하면 다음과 같다.

int getBitCountLong(ll number)
{
	ll mask1 = 0x5555555555555555;
	ll mask2 = 0x3333333333333333;
	ll mask3 = 0x0F0F0F0F0F0F0F0F;
	ll mask4 = 0x00FF00FF00FF00FF;
	ll mask5 = 0x0000FFFF0000FFFF;
	ll mask6 = 0x00000000FFFFFFFF;

	ll count = number;
	count = (count & mask1) + ((count >>  1) & mask1);
	count = (count & mask2) + ((count >>  2) & mask2);
	count = (count & mask3) + ((count >>  4) & mask3);
	count = (count & mask4) + ((count >>  8) & mask4);
	count = (count & mask5) + ((count >> 16) & mask5);
	count = (count & mask6) + ((count >> 32) & mask6);
	
	return count;
}

 

전체 테스트 코드는 다음과 같다.

#include <stdio.h>
#include <iostream>

using namespace std;

typedef long long int ll;

template <typename T>
void printBitNumber(T number)
{
	unsigned int bitSize = sizeof(number) * 8;
	T mask = (1ull) << (bitSize - 1);

	printf("%d", (number & mask) == mask);

	mask = (1ull) << (bitSize - 2);
	for (int i = 1; i < bitSize; i++)
	{
		printf("%d", (number & mask) == mask);
		mask >>= 1;
		if (i % 8 == 7) printf(" ");
	}
	putchar('\n');
}

int getBitCountLong(ll number)
{
	ll mask1 = 0x5555555555555555;
	ll mask2 = 0x3333333333333333;
	ll mask3 = 0x0F0F0F0F0F0F0F0F;
	ll mask4 = 0x00FF00FF00FF00FF;
	ll mask5 = 0x0000FFFF0000FFFF;
	ll mask6 = 0x00000000FFFFFFFF;

	ll count = number;
	count = (count & mask1) + ((count >>  1) & mask1);
	count = (count & mask2) + ((count >>  2) & mask2);
	count = (count & mask3) + ((count >>  4) & mask3);
	count = (count & mask4) + ((count >>  8) & mask4);
	count = (count & mask5) + ((count >> 16) & mask5);
	count = (count & mask6) + ((count >> 32) & mask6);
	
	return count;
}

int main(void)
{
	ll bit = 1234123412341234123;

	ll mask1 = 0x5555555555555555;
	ll mask2 = 0x3333333333333333;
	ll mask3 = 0x0F0F0F0F0F0F0F0F;
	ll mask4 = 0x00FF00FF00FF00FF;
	ll mask5 = 0x0000FFFF0000FFFF;
	ll mask6 = 0x00000000FFFFFFFF;

	ll count = bit;

	printf("bit   : "); printBitNumber(bit);
	printf("mask1 : "); printBitNumber(mask1);
	printf("mask2 : "); printBitNumber(mask2);
	printf("mask3 : "); printBitNumber(mask3);
	printf("mask4 : "); printBitNumber(mask4);
	printf("mask5 : "); printBitNumber(mask5);
	printf("mask6 : "); printBitNumber(mask6);

	putchar('\n');

	printf("before: "); printBitNumber(count & mask1);
	printf("before: "); printBitNumber((count >> 1) & mask1);
	count = (count & mask1) + ((count >>  1) & mask1);
	printf("after : "); printBitNumber(count);
	putchar('\n');
	printf("before: "); printBitNumber(count & mask2);
	printf("before: "); printBitNumber((count >> 2) & mask2);
	count = (count & mask2) + ((count >>  2) & mask2);
	printf("after : "); printBitNumber(count);
	putchar('\n');
	printf("before: "); printBitNumber(count & mask3);
	printf("before: "); printBitNumber((count >> 4) & mask3);
	count = (count & mask3) + ((count >>  4) & mask3);
	printf("after : "); printBitNumber(count);
	putchar('\n');
	printf("before: "); printBitNumber(count & mask4);
	printf("before: "); printBitNumber((count >> 8) & mask4);
	count = (count & mask4) + ((count >>  8) & mask4);
	printf("after : "); printBitNumber(count);
	putchar('\n');
	printf("before: "); printBitNumber(count & mask5);
	printf("before: "); printBitNumber((count >> 16) & mask5);
	count = (count & mask5) + ((count >> 16) & mask5);
	printf("after : "); printBitNumber(count);
	putchar('\n');
	printf("before: "); printBitNumber(count & mask6);
	printf("before: "); printBitNumber((count >> 32) & mask6);
	count = (count & mask6) + ((count >> 32) & mask6);
	printf("after : "); printBitNumber(count);

	putchar('\n');

	printf("num of 1 bit : %d\n", getBitCountLong(bit));

	return 0;
}

int / short / char ver

 

위에서 만든 함수는 long ver이다.

int getBitCountLong(ll number)
{
	ll mask1 = 0x5555555555555555;
	ll mask2 = 0x3333333333333333;
	ll mask3 = 0x0F0F0F0F0F0F0F0F;
	ll mask4 = 0x00FF00FF00FF00FF;
	ll mask5 = 0x0000FFFF0000FFFF;
	ll mask6 = 0x00000000FFFFFFFF;

	ll count = number;
	count = (count & mask1) + ((count >>  1) & mask1);
	count = (count & mask2) + ((count >>  2) & mask2);
	count = (count & mask3) + ((count >>  4) & mask3);
	count = (count & mask4) + ((count >>  8) & mask4);
	count = (count & mask5) + ((count >> 16) & mask5);
	count = (count & mask6) + ((count >> 32) & mask6);
	
	return count;
}

 

int, short, char 변수도 ll에서 동작하지만, 불필요한 연산이 들어가게 되므로 아래와 같이 최적화가 가능하다.

int는 32비트이므로 5번만, short는 4번, char는 3번만 하면 된다.

#include <stdio.h>
#include <iostream>

using namespace std;

typedef long long int ll;

template <typename T>
void printBitNumber(T number)
{
	unsigned int bitSize = sizeof(number) * 8;
	T mask = (1ull) << (bitSize - 1);

	printf("%d", (number & mask) == mask);

	mask = (1ull) << (bitSize - 2);
	for (int i = 1; i < bitSize; i++)
	{
		printf("%d", (number & mask) == mask);
		mask >>= 1;
		if (i % 8 == 7) printf(" ");
	}
	putchar('\n');
}

int getBitCountLong(ll number)
{
	ll mask1 = 0x5555555555555555;
	ll mask2 = 0x3333333333333333;
	ll mask3 = 0x0F0F0F0F0F0F0F0F;
	ll mask4 = 0x00FF00FF00FF00FF;
	ll mask5 = 0x0000FFFF0000FFFF;
	ll mask6 = 0x00000000FFFFFFFF;

	ll count = number;
	count = (count & mask1) + ((count >>  1) & mask1);
	count = (count & mask2) + ((count >>  2) & mask2);
	count = (count & mask3) + ((count >>  4) & mask3);
	count = (count & mask4) + ((count >>  8) & mask4);
	count = (count & mask5) + ((count >> 16) & mask5);
	count = (count & mask6) + ((count >> 32) & mask6);
	
	return count;
}

int getBitCountInt(int number)
{
	int mask1 = 0x55555555;
	int mask2 = 0x33333333;
	int mask3 = 0x0F0F0F0F;
	int mask4 = 0x00FF00FF;
	int mask5 = 0x0000FFFF;

	int count = number;
	count = (count & mask1) + ((count >>  1) & mask1);
	count = (count & mask2) + ((count >>  2) & mask2);
	count = (count & mask3) + ((count >>  4) & mask3);
	count = (count & mask4) + ((count >>  8) & mask4);
	count = (count & mask5) + ((count >> 16) & mask5);
	
	return count;
}

int getBitCountShort(short number)
{
	short mask1 = 0x5555;
	short mask2 = 0x3333;
	short mask3 = 0x0F0F;
	short mask4 = 0x00FF;

	short count = number;
	count = (count & mask1) + ((count >> 1) & mask1);
	count = (count & mask2) + ((count >> 2) & mask2);
	count = (count & mask3) + ((count >> 4) & mask3);
	count = (count & mask4) + ((count >> 8) & mask4);

	return count;
}

int getBitCountChar(char number)
{
	char mask1 = 0x55;
	char mask2 = 0x33;
	char mask3 = 0x0F;
	
	char count = number;
	count = (count & mask1) + ((count >> 1) & mask1);
	count = (count & mask2) + ((count >> 2) & mask2);
	count = (count & mask3) + ((count >> 4) & mask3);
	
	return count;
}

int main(void)
{
	ll bit = 1234123412341234123;
	int a = 12341234;
	short b = 4321;
	char c = 100;

	printBitNumber(bit);
	printf("num of 1 bit : %d\n", getBitCountLong(bit));
	putchar('\n');

	printBitNumber(a);
	printf("num of 1 bit : %d\n", getBitCountInt(a));
	putchar('\n');

	printBitNumber(b);
	printf("num of 1 bit : %d\n", getBitCountShort(b));
	putchar('\n');

	printBitNumber(c);
	printf("num of 1 bit : %d\n", getBitCountChar(c));
	putchar('\n');

	return 0;
}

 

실행 결과는 다음과 같다.


최적화

 

mask 연산은 다음과 같이 최적화가 가능하다.

int getBitCountLong(ll number)
{
	ll mask1 = 0x5555555555555555;
	ll mask2 = 0x3333333333333333;
	ll mask3 = 0x0F0F0F0F0F0F0F0F;

	ll count = number;
	count = count - ((count >> 1) & mask1);
	count = (count & mask2) + ((count >> 2) & mask2);
	count = (count + (count >> 4)) & mask3;
	count = count + (count >> 8);
	count = count + (count >> 16);
	count = count + (count >> 32);
	
	return count & 0x0000007F;
}

 

테스트 코드를 실행해보자.

#include <stdio.h>
#include <iostream>

using namespace std;

typedef long long int ll;

template <typename T>
void printBitNumber(T number)
{
	unsigned int bitSize = sizeof(number) * 8;
	T mask = (1ull) << (bitSize - 1);

	printf("%d", (number & mask) == mask);

	mask = (1ull) << (bitSize - 2);
	for (int i = 1; i < bitSize; i++)
	{
		printf("%d", (number & mask) == mask);
		mask >>= 1;
		if (i % 8 == 7) printf(" ");
	}
	putchar('\n');
}

int getBitCountLong(ll number)
{
	ll mask1 = 0x5555555555555555;
	ll mask2 = 0x3333333333333333;
	ll mask3 = 0x0F0F0F0F0F0F0F0F;

	ll count = number;
	count = count - ((count >> 1) & mask1);
	count = (count & mask2) + ((count >> 2) & mask2);
	count = (count + (count >> 4)) & mask3;
	count = count + (count >> 8);
	count = count + (count >> 16);
	count = count + (count >> 32);
	
	return count & 0x0000007F;
}

int main(void)
{
	ll bit = 1234123412341234123;

	ll mask1 = 0x5555555555555555;
	ll mask2 = 0x3333333333333333;
	ll mask3 = 0x0F0F0F0F0F0F0F0F;
	// ll mask4 = 0x00FF00FF00FF00FF;
	// ll mask5 = 0x0000FFFF0000FFFF;
	// ll mask6 = 0x00000000FFFFFFFF;

	ll count = bit;

	printf("bit   : "); printBitNumber(bit);
	putchar('\n');

	count = count - ((count >> 1) & mask1);
	printf("after : "); printBitNumber(count);

	count = (count & mask2) + ((count >> 2) & mask2);
	printf("after : "); printBitNumber(count);

	count = (count + (count >> 4)) & mask3;
	printf("after : "); printBitNumber(count);

	count = count + (count >> 8);
	printf("after : "); printBitNumber(count);

	count = count + (count >> 16);
	printf("after : "); printBitNumber(count);

	count = count + (count >> 32);
	printf("after : "); printBitNumber(count);

	printf("num of 1 bit : %d\n", getBitCountLong(bit));

	return 0;
}

 

마지막 after에서 끝 비트 7개에 1비트의 개수가 들어 있으므로 0x0000007F를 & 연산한다.

 

int / short / char ver을 포함한 테스트 코드는 아래와 같다.

#include <stdio.h>
#include <iostream>

using namespace std;

typedef long long int ll;

template <typename T>
void printBitNumber(T number)
{
	unsigned int bitSize = sizeof(number) * 8;
	T mask = (1ull) << (bitSize - 1);

	printf("%d", (number & mask) == mask);

	mask = (1ull) << (bitSize - 2);
	for (int i = 1; i < bitSize; i++)
	{
		printf("%d", (number & mask) == mask);
		mask >>= 1;
		if (i % 8 == 7) printf(" ");
	}
	putchar('\n');
}

int getBitCountLong(ll number)
{
	ll mask1 = 0x5555555555555555;
	ll mask2 = 0x3333333333333333;
	ll mask3 = 0x0F0F0F0F0F0F0F0F;

	ll count = number;
	count = count - ((count >> 1) & mask1);
	count = (count & mask2) + ((count >> 2) & mask2);
	count = (count + (count >> 4)) & mask3;
	count = count + (count >> 8);
	count = count + (count >> 16);
	count = count + (count >> 32);

	return count & 0x0000007F;
}

int getBitCountInt(int number)
{
	int mask1 = 0x55555555;
	int mask2 = 0x33333333;
	int mask3 = 0x0F0F0F0F;
	
	int count = number;
	count = count - ((count >> 1) & mask1);
	count = (count & mask2) + ((count >> 2) & mask2);
	count = (count + (count >> 4)) & mask3;
	count = count + (count >> 8);
	count = count + (count >> 16);

	return count & 0x003F;
}

int getBitCountShort(short number)
{
	short mask1 = 0x5555;
	short mask2 = 0x3333;
	short mask3 = 0x0F0F;
	
	short count = number;
	count = count - ((count >> 1) & mask1);
	count = (count & mask2) + ((count >> 2) & mask2);
	count = (count + (count >> 4)) & mask3;
	count = count + (count >> 8);

	return count & 0x1F;
}

int getBitCountChar(char number)
{
	char mask1 = 0x55;
	char mask2 = 0x33;
	char mask3 = 0x0F;
	
	char count = number;
	count = count - ((count >> 1) & mask1);
	count = (count & mask2) + ((count >> 2) & mask2);
	count = (count + (count >> 4)) & mask3;
	
	return count;
}

int main(void)
{
	ll bit = 1234123412341234123;
	int a = 12341234;
	short b = 4321;
	char c = 100;

	printBitNumber(bit);
	printf("num of 1 bit : %d\n", getBitCountLong(bit));
	putchar('\n');

	printBitNumber(a);
	printf("num of 1 bit : %d\n", getBitCountInt(a));
	putchar('\n');

	printBitNumber(b);
	printf("num of 1 bit : %d\n", getBitCountShort(b));
	putchar('\n');

	printBitNumber(c);
	printf("num of 1 bit : %d\n", getBitCountChar(c));
	putchar('\n');

	return 0;
}

while vs 분할 정복 vs 분할 정복 최적화 속도 비교

 

<time.h>를 이용해 각 카운터를 시간을 측정해보자.

아래의 코드는 long / int / short / char에 따라 1 비트의 수를 세는 함수의 성능을 측정하였다.

#include <stdio.h>
#include <time.h>
#include <iostream>

using namespace std;

typedef long long int ll;

template <typename T>
void printBitNumber(T number)
{
	unsigned int bitSize = sizeof(number) * 8;
	T mask = (1ull) << (bitSize - 1);

	printf("%d", (number & mask) == mask);

	mask = (1ull) << (bitSize - 2);
	for (int i = 1; i < bitSize; i++)
	{
		printf("%d", (number & mask) == mask);
		mask >>= 1;
		if (i % 8 == 7) printf(" ");
	}
	putchar('\n');
}

template <typename T>
int getBitCount(T number)
{
	int count = !!number;
	while (number &= (number - 1)) count++;
	return count;
}

int getBitCountLong1(ll number)
{
	ll mask1 = 0x5555555555555555;
	ll mask2 = 0x3333333333333333;
	ll mask3 = 0x0F0F0F0F0F0F0F0F;
	ll mask4 = 0x00FF00FF00FF00FF;
	ll mask5 = 0x0000FFFF0000FFFF;
	ll mask6 = 0x00000000FFFFFFFF;

	ll count = number;
	count = (count & mask1) + ((count >>  1) & mask1);
	count = (count & mask2) + ((count >>  2) & mask2);
	count = (count & mask3) + ((count >>  4) & mask3);
	count = (count & mask4) + ((count >>  8) & mask4);
	count = (count & mask5) + ((count >> 16) & mask5);
	count = (count & mask6) + ((count >> 32) & mask6);

	return count;
}

int getBitCountInt1(int number)
{
	int mask1 = 0x55555555;
	int mask2 = 0x33333333;
	int mask3 = 0x0F0F0F0F;
	int mask4 = 0x00FF00FF;
	int mask5 = 0x0000FFFF;

	int count = number;
	count = (count & mask1) + ((count >>  1) & mask1);
	count = (count & mask2) + ((count >>  2) & mask2);
	count = (count & mask3) + ((count >>  4) & mask3);
	count = (count & mask4) + ((count >>  8) & mask4);
	count = (count & mask5) + ((count >> 16) & mask5);

	return count;
}

int getBitCountShort1(short number)
{
	short mask1 = 0x5555;
	short mask2 = 0x3333;
	short mask3 = 0x0F0F;
	short mask4 = 0x00FF;

	short count = number;
	count = (count & mask1) + ((count >> 1) & mask1);
	count = (count & mask2) + ((count >> 2) & mask2);
	count = (count & mask3) + ((count >> 4) & mask3);
	count = (count & mask4) + ((count >> 8) & mask4);

	return count;
}

int getBitCountChar1(char number)
{
	char mask1 = 0x55;
	char mask2 = 0x33;
	char mask3 = 0x0F;

	char count = number;
	count = (count & mask1) + ((count >> 1) & mask1);
	count = (count & mask2) + ((count >> 2) & mask2);
	count = (count & mask3) + ((count >> 4) & mask3);

	return count;
}

int getBitCountLong2(ll number)
{
	ll mask1 = 0x5555555555555555;
	ll mask2 = 0x3333333333333333;
	ll mask3 = 0x0F0F0F0F0F0F0F0F;
	ll mask4 = 0x00FF00FF00FF00FF;
	ll mask5 = 0x0000FFFF0000FFFF;
	ll mask6 = 0x00000000FFFFFFFF;

	ll count = number;
	count = (count & mask1) + ((count >>  1) & mask1);
	count = (count & mask2) + ((count >>  2) & mask2);
	count = (count & mask3) + ((count >>  4) & mask3);
	count = (count & mask4) + ((count >>  8) & mask4);
	count = (count & mask5) + ((count >> 16) & mask5);
	count = (count & mask6) + ((count >> 32) & mask6);

	return count;
}

int getBitCountInt2(int number)
{
	int mask1 = 0x55555555;
	int mask2 = 0x33333333;
	int mask3 = 0x0F0F0F0F;
	int mask4 = 0x00FF00FF;
	int mask5 = 0x0000FFFF;

	int count = number;
	count = (count & mask1) + ((count >>  1) & mask1);
	count = (count & mask2) + ((count >>  2) & mask2);
	count = (count & mask3) + ((count >>  4) & mask3);
	count = (count & mask4) + ((count >>  8) & mask4);
	count = (count & mask5) + ((count >> 16) & mask5);

	return count;
}

int getBitCountShort2(short number)
{
	short mask1 = 0x5555;
	short mask2 = 0x3333;
	short mask3 = 0x0F0F;
	short mask4 = 0x00FF;

	short count = number;
	count = (count & mask1) + ((count >> 1) & mask1);
	count = (count & mask2) + ((count >> 2) & mask2);
	count = (count & mask3) + ((count >> 4) & mask3);
	count = (count & mask4) + ((count >> 8) & mask4);

	return count;
}

int getBitCountChar2(char number)
{
	char mask1 = 0x55;
	char mask2 = 0x33;
	char mask3 = 0x0F;

	char count = number;
	count = (count & mask1) + ((count >> 1) & mask1);
	count = (count & mask2) + ((count >> 2) & mask2);
	count = (count & mask3) + ((count >> 4) & mask3);

	return count;
}

int main(void)
{
	{ // long
		int TIME_WHILE, TIME_DIVIDE1, TIME_DIVIDE2;

		TIME_WHILE = TIME_DIVIDE1 = TIME_DIVIDE2 = 0;

		// Error Check
		for (ll i = 0xFFFFFFFFFF000000; i < 0xFFFFFFFFFFFFFFFF; i++)
		{
			int count_while = getBitCount(i);
			int count_divide1 = getBitCountLong1(i);
			int count_divide2 = getBitCountLong2(i);
			if ((count_while != count_divide1) || (count_divide1 != count_divide2)) printf("Counter Error!!!\n");
		}

		{ // while ver 
			clock_t start = clock();
			for (ll i = 0xFFFFFFFFFF000000; i < 0xFFFFFFFFFFFFFFFF; i++)
				int count = getBitCount(i);

			TIME_WHILE += ((int)clock() - start) / (CLOCKS_PER_SEC / 1000);
		}
		
		{ // divide ver1
			clock_t start = clock();
			for (ll i = 0xFFFFFFFFFF000000; i < 0xFFFFFFFFFFFFFFFF; i++)
				int count = getBitCountLong1(i);

			TIME_DIVIDE1 += ((int)clock() - start) / (CLOCKS_PER_SEC / 1000);
		}

		{ // divide ver2
			clock_t start = clock();
			for (ll i = 0xFFFFFFFFFF000000; i < 0xFFFFFFFFFFFFFFFF; i++)
				int count = getBitCountLong2(i);

			TIME_DIVIDE2 += ((int)clock() - start) / (CLOCKS_PER_SEC / 1000);
		}

		printf("Long WHILE  ver  : %d ms\n", TIME_WHILE);
		printf("Long DIVIDE ver1 : %d ms\n", TIME_DIVIDE1);
		printf("Long DIVIDE ver2 : %d ms\n\n", TIME_DIVIDE2);
	}

	{ // int
		int TIME_WHILE, TIME_DIVIDE1, TIME_DIVIDE2;

		TIME_WHILE = TIME_DIVIDE1 = TIME_DIVIDE2 = 0;

		// Error Check
		for (int i = 0xFF000000; i < 0xFFFFFFFF; i++)
		{
			int count_while = getBitCount(i);
			int count_divide1 = getBitCountInt1(i);
			int count_divide2 = getBitCountInt2(i);
			if ((count_while != count_divide1) || (count_divide1 != count_divide2)) printf("Counter Error!!!\n");
		}

		{ // while ver 
			clock_t start = clock();
			for (int i = 0xFF000000; i < 0xFFFFFFFF; i++)
				int count = getBitCount(i);

			TIME_WHILE += ((int)clock() - start) / (CLOCKS_PER_SEC / 1000);
		}

		{ // divide ver1
			clock_t start = clock();
			for (int i = 0xFF000000; i < 0xFFFFFFFF; i++)
				int count = getBitCountInt1(i);

			TIME_DIVIDE1 += ((int)clock() - start) / (CLOCKS_PER_SEC / 1000);
		}

		{ // divide ver2
			clock_t start = clock();
			for (int i = 0xFF000000; i < 0xFFFFFFFF; i++)
				int count = getBitCountInt2(i);

			TIME_DIVIDE2 += ((int)clock() - start) / (CLOCKS_PER_SEC / 1000);
		}

		printf("Int WHILE  ver  : %d ms\n", TIME_WHILE);
		printf("Int DIVIDE ver1 : %d ms\n", TIME_DIVIDE1);
		printf("Int DIVIDE ver2 : %d ms\n\n", TIME_DIVIDE2);
	}

	{ // short
		int TIME_WHILE, TIME_DIVIDE1, TIME_DIVIDE2;

		TIME_WHILE = TIME_DIVIDE1 = TIME_DIVIDE2 = 0;

		// Error Check
		for (int i = 0; i < 32767; i++)
		{
			int count_while = getBitCount((short)i);
			int count_divide1 = getBitCountShort1((short)i);
			int count_divide2 = getBitCountShort2((short)i);
			if ((count_while != count_divide1) || (count_divide1 != count_divide2)) printf("Counter Error!!!\n");
		}

		{ // while ver 
			clock_t start = clock();
			for (int tc = 0; tc < 100; tc++)
			{
				for (int i = 0; i < 32767; i++)
					int count = getBitCount((short)i);
			}

			TIME_WHILE += ((int)clock() - start) / (CLOCKS_PER_SEC / 1000);
		}

		{ // divide ver1
			clock_t start = clock();
			for (int tc = 0; tc < 100; tc++)
			{
				for (int i = 0; i < 32767; i++)
					int count = getBitCountShort1((short)i);
			}

			TIME_DIVIDE1 += ((int)clock() - start) / (CLOCKS_PER_SEC / 1000);
		}

		{ // divide ver2
			clock_t start = clock();
			for (int tc = 0; tc < 100; tc++)
			{
				for (int i = 0; i < 32767; i++)
					int count = getBitCountShort2((short)i);
			}
			
			TIME_DIVIDE2 += ((int)clock() - start) / (CLOCKS_PER_SEC / 1000);
		}

		printf("Short WHILE  ver  : %d ms\n", TIME_WHILE);
		printf("Short DIVIDE ver1 : %d ms\n", TIME_DIVIDE1);
		printf("Short DIVIDE ver2 : %d ms\n\n", TIME_DIVIDE2);
	}

	{ // char
		int TIME_WHILE, TIME_DIVIDE1, TIME_DIVIDE2;

		TIME_WHILE = TIME_DIVIDE1 = TIME_DIVIDE2 = 0;

		// Error Check
		for (int i = 0; i < 255; i++)
		{
			int count_while = getBitCount((char)i);
			int count_divide1 = getBitCountChar1((char)i);
			int count_divide2 = getBitCountChar2((char)i);
			if ((count_while != count_divide1) || (count_divide1 != count_divide2)) printf("Counter Error!!!\n");
		}

		{ // while ver 
			clock_t start = clock();
			for (int tc = 0; tc < 10000; tc++)
			{
				for (int i = 0; i < 255; i++)
					int count = getBitCount((char)i);
			}

			TIME_WHILE += ((int)clock() - start) / (CLOCKS_PER_SEC / 1000);
		}

		{ // divide ver1
			clock_t start = clock();
			for (int tc = 0; tc < 10000; tc++)
			{
				for (int i = 0; i < 255; i++)
					int count = getBitCountChar1((char)i);
			}

			TIME_DIVIDE1 += ((int)clock() - start) / (CLOCKS_PER_SEC / 1000);
		}

		{ // divide ver2
			clock_t start = clock();
			for (int tc = 0; tc < 10000; tc++)
			{
				for (int i = 0; i < 255; i++)
					int count = getBitCountChar2((char)i);

			}

			TIME_DIVIDE2 += ((int)clock() - start) / (CLOCKS_PER_SEC / 1000);
		}

		printf("Char WHILE  ver  : %d ms\n", TIME_WHILE);
		printf("Char DIVIDE ver1 : %d ms\n", TIME_DIVIDE1);
		printf("Char DIVIDE ver2 : %d ms\n\n", TIME_DIVIDE2);
	}

	return 0;
}

 

모든 케이스에서 최적화가 반드시 유리한 것은 아니다.

또한 char의 경우 1이 적기 때문에 while이 유리할 수도 있다.

반응형

댓글