알고리즘2012. 3. 7. 19:00
문제 D. 한강

Large 푸는데 1분 30초정도 걸린다. 그리고 메모리는 3.5기가 정도 차지한다. 메모리를 많이 먹는 건 조금 바보같이 코딩해서 그렇다. 소수를 구하는 코드를 꼬꼬마 시절 작성한 코드에서 긁어와서 좀 지저분하기도 하다.

키워드: 소수, 소인수분해, 에라토스테네스

http://algospot.com/forum/read/1322/
http://algospot.com/forum/read/1324/

#include <cstdio>
#include <cstdlib>
#include <climits>
#include <cassert>
#include <ctime>
#include <vector>
#include <cmath>
#include <algorithm>
using namespace std;

#ifndef CHAR_BIT
define CHAR_BIT 8
#endif
#define UINT_BITS 32
#define LARGEINT unsigned int
#define LARGEINT_MAX UINT_MAX
#define LINE 10
#define N_MAX 100000000000000000ll
#define PRIME_MAX 1000000007
typedef long long int lli;
vector<int> prime_list;
vector<lli> pow_table[51000000];
int pow_max[51000000];

#if 0
// for small data set
#define N_MAX 1000000ll
#define PRIME_MAX 1000000
#endif

unsigned int *makebitarray(LARGEINT arraysize)
{
	unsigned int *bitarray;
//	if (!arraysize) arraysize=1;
	LARGEINT q=arraysize/UINT_BITS;
	if (arraysize%UINT_BITS) q++;
	/* -----_check overflow */
	if ( (!q) || (q>LARGEINT_MAX/sizeof(unsigned int)) ) {
		bitarray=NULL;
	} else {
		bitarray=(unsigned int*) calloc(sizeof(unsigned int),q);
	}
	return bitarray;
}

int freebitarray(unsigned int *bitarray)
{
	free(bitarray);
	return 0;
}

inline unsigned int getbit(const unsigned int * const bitarray, const LARGEINT nth)
{
	LARGEINT q=nth/UINT_BITS;
	unsigned int r=nth%UINT_BITS;
	return ((*(bitarray+q))>>r)&1;
}

inline void setbit1(unsigned int * const bitarray, const LARGEINT nth)
{
	LARGEINT q=nth/UINT_BITS;
	unsigned int r=nth%UINT_BITS;
	(*(bitarray+q)|=(1<<r));
}

inline void setbit0(unsigned int * const bitarray, const LARGEINT nth)
{
	LARGEINT q=nth/UINT_BITS;
	unsigned int r=nth%UINT_BITS;
	(*(bitarray+q)&=~(1<<r));
}

int initbitarray(unsigned int *bitarray, const LARGEINT arraysize, const int TorF)
{
	LARGEINT q=arraysize/UINT_BITS;
	if (arraysize%UINT_BITS) q++;
	if (TorF) {
		while (q--) *(bitarray++)=~0;
	} else {
		while (q--) *(bitarray++)=0;
	}
	return 0;
}

void calc_prime_list()
{
	unsigned int n_of_prime=0, end, i, j, n, nsq;
	unsigned int *numbers;
	time_t a,b;

	end = PRIME_MAX; // 10^9
	if (end<3) {
		for (i=2;i<=end;i++) {
			for (j=2; (j*j<=i) && (i%j) ; j++);
			if (j*j>i) printf("%d%c",i,(++n_of_prime%10)?' ':'\n');
		}
		printf("\n\t%d prime(s)\n",n_of_prime);
		exit(0);
	}
	//a=clock();
	/* 홀수만 저장하므로 절반만 할당 */
	numbers=makebitarray(end/2);
	if (NULL==numbers) {
		printf("error: 메모리 할당이 불가능합니다.\n"
				"프로그램을 종료합니다.\n");
		fflush(stdout);
		exit(1);
	}
	initbitarray(numbers,end/2,1);
	setbit0(numbers,0); /* 1은 소수가 아님 */
	//printf("2 ");
	n_of_prime++;

	for (i=1, n=3, nsq=9; (nsq<=end) && (nsq>2*n) ; i++, nsq+=(4*n+4), n+=2) {
		if ( getbit(numbers,i) ) {
			for (j=nsq/2; (j<=end/2) ; j+=n) {
				setbit0(numbers,j);
			}
		}
	}
	fprintf(stderr, "calc_prime_list: sieving done\n");

	prime_list.push_back(2);
	i=0, n=1;
	while( i++, n+=2, (n<=end) && (n>i) ) {
		if ( getbit(numbers,i) ) {
			prime_list.push_back(n);
		}
	}

	freebitarray(numbers);
	//b=clock();
	//printf("\n\t%u prime(s)",n_of_prime);
	//printf("\n\ttime: %fs\n",(double) (b-a)/CLOCKS_PER_SEC);
	//fflush(stdout);
	fprintf(stderr, "calc_prime_list: done\n");
}

void calc_pow_table()
{
	int imax = prime_list.size();
	//double logN = log((double) N_MAX);
	for (int i = 0; i < imax; ++i) {
		const int base = prime_list[i];
		//fprintf(stderr, "calc_pow_table: %d, %d\n", i, base);
		//double log_base = log((double) base);
		//int jmax = (int) (logN / log_base);
		pow_table[i].push_back(1);
		lli tmp = 1;
		int j = 0;
		for (j = 1; tmp <= N_MAX/base; ++j) {
			tmp *= base;
			pow_table[i].push_back(tmp);
		}
		pow_max[i] = j - 1;
	}
	fprintf(stderr, "calc_pow_table: done\n");
}

lli get_num_div(lli N, lli *d0)
{
	*d0 = 0;
	lli sqrtN = (lli) sqrt((double) N);
	if ((sqrtN + 1)*(sqrtN + 1) <= N) {
		fprintf(stderr, "!\n");
		++sqrtN;
	}
	int D = 1;
	int imax = prime_list.size();
	for (int i = 0; i < imax; ++i) {
		int p = prime_list[i];
		if (p > sqrtN)
			break;
		int cnt = 1;
		while (!(N % p)) {
			N /= p;
			++cnt;
			if (*d0 == 0)
				*d0 = p;
		}
		D *= cnt;
		if (N == 1)
			break;
	}
	if (N != 1)
		D *= 2;
	return D;
}

bool comp(pair<int,lli> a, pair<int,lli> b)
{
	return (a.second < b.second);
}

lli foo2(const lli N, const int mi) // D == 2
{
	int low = mi;
	int high = prime_list.size() - 1;
	if (N > prime_list[high]) {
		fprintf(stderr, "error binary search\n");
		return high - low;
	}
	const int N2 = (int) N; // for faster comparison
	while (low < high) {
		int mid = low + (high - low) / 2;
		if (prime_list[mid] > N2) {
			high = mid;
		} else {
			low = mid + 1;
		}
		if (mid == low)
			break;
	}
	if (prime_list[low] > N2) {
		return (low - mi);
	} else if (prime_list[high] > N2) {
		fprintf(stderr, "warning binary search\n");
		return (high - mi);
	} else {
		fprintf(stderr, "error binary search\n");
		fprintf(stderr, "\t%lld %d\n", N, mi);
		return (high - mi);
	}
}

// D == (p + 1) && x == a ** p
lli foo_pow(const lli N, const int mi, const int p)
{
	if (mi >= prime_list.size())
		return 0ll;
	if (p > pow_max[mi] || pow_table[mi][p] > N)
		return 0ll;
	int low = mi;
	int high = prime_list.size() - 1;
	while (low < high) {
		int mid = low + (high - low) / 2;
		if (p > pow_max[mid] || pow_table[mid][p] > N) {
			high = mid;
		} else {
			low = mid + 1;
		}
		if (mid == low)
			break;
	}
	if (p > pow_max[low] || pow_table[low][p] > N) {
		return (low - mi);
	} else if (p > pow_max[high] || pow_table[high][p] > N) {
		fprintf(stderr, "warning binary search pow\n");
		return (high - mi);
	} else {
		fprintf(stderr, "error binary search pow\n");
		fprintf(stderr, "\t%lld %d\n", N, mi);
		return (high - mi);
	}
}

lli foo4(const lli N, const int mi) // D == 4
{
	if (mi >= prime_list.size()) {
		return 0ll;
	}
	if (prime_list[mi] > N) {
		return 0ll;
	}
	if (2 > pow_max[mi] || pow_table[mi][2] > N)
		return 0ll;
	lli r = 0;
	// case 1: x = a * a * a
	r += foo_pow(N, mi, 3);
	// case 2: x = a * b
	const int imax = prime_list.size();
	for (int i = mi; i < imax; ++i) {
		lli N2 = N / prime_list[i];
		if (i + 1 >= imax || N2 < prime_list[i + 1])
			break;
		r += foo2(N2, i + 1);
	}
	return r;
}

lli foo(const lli N, const int mi, const int D)
{
	//printf("!foo %lld %d %d\n", N, mi, D);
	if (N <= 0) {
		fprintf(stderr, "ERROR\n");
		return 0ll;
	}
	switch (D) {
	case 1:
		return 1ll;
		break;
	case 2:
		return foo2(N, mi);
		break;
	case 3:
		//return foo_pow(N, mi, 2);
		break;
	case 4:
		return foo4(N, mi);
		break;
	case 5:
		//return foo_pow(N, mi, 4);
		break;
	}
	if (mi >= prime_list.size()) {
		return 0ll;
	}
	if (prime_list[mi] > N) {
		return 0ll;
	}
	if (D >= 5) {
		if (3 > pow_max[mi] || pow_table[mi][3] > N)
			return 0ll;
	} else if (D >= 3) {
		if (2 > pow_max[mi] || pow_table[mi][2] > N)
			return 0ll;
	}

	lli r = 0;
	int imax = D;
	if (imax > (pow_max[mi] + 1))
		imax = pow_max[mi] + 1;
	for (int i = 1; (i + 1) <= imax; ++i) {
		if (D % (i + 1))
			continue;
		if (i >= 1 && pow_table[mi][i] <= 1)
			fprintf(stderr, "ERROR2\n");
		lli N2 = N / pow_table[mi][i];
		if (!N2)
			break;
		int D2 = D / (i + 1);
		if (D2 > 1 && N2 < prime_list[mi + 1])
			break;
		if (D2 == 1)
			r += 1;
		else
			r += foo(N2, mi + 1, D2);
	}
	return r + foo(N, mi + 1, D); // skip prime_list[mi]
}

int find_mi(const lli M)
{
	int low = 0;
	int high = prime_list.size() - 1;
	while (low < high) {
		int mid = low + (high - low) / 2;
		if (prime_list[mid] >= M) {
			high = mid;
		} else {
			low = mid + 1;
		}
		if (mid == low)
			break;
	}
	if (prime_list[low] >= M) {
		return low;
	} else if (prime_list[high] >= M) {
		return high;
	} else {
		assert(high == prime_list.size() - 1);
		return -1;
	}
}

int main(void)
{
	calc_prime_list();
	calc_pow_table();
	int T_;
	scanf("%d", &T_);
	vector< pair<int,lli> > gg;
	for (int i_ = 1; i_ <= T_; ++i_) {
		gg.clear();
		long long int ans = 0;
		lli N, M;
		scanf("%lld %lld", &N, &M);
		//fprintf(stderr, "! step 0\n");
		int mi = -1;
		lli d0;
		const lli num_div_N = get_num_div(N, &d0);
		if (num_div_N == 2) {
			//fprintf(stderr, "! num_div\n");
			goto final;
		}
		//fprintf(stderr, "! step 1\n");
		//printf("num_div_%lld: %lld\n", N, num_div_N);
		mi = find_mi(M);
		if (mi == -1) {
			//fprintf(stderr, "! mi\n");
			goto final;
		}
		//fprintf(stderr, "! step 2\n");
		ans = foo(N, mi, num_div_N);
		// subtract 1 to exclude N
		if (d0 >= M)
			--ans;
final:
		fprintf(stderr, "Case #%d: %lld\n", i_, ans);
		printf("Case #%d: %lld\n", i_, ans);
	}
	return 0;
}
Posted by asdfzxcv