'알고리즘'에 해당되는 글 4건

  1. 2012.03.07 Code Jam Korea 2012 본선 (C)
  2. 2012.03.07 Code Jam Korea 2012 본선 (D) (large)
  3. 2012.03.04 Code Jam Korea 2012 본선 (B)
  4. 2012.03.04 Code Jam Korea 2012 본선 (A)
  5. .
알고리즘2012. 3. 7. 19:04
문제 C. 모자 쓴 아이들

푸는데 0.3초밖에 안 걸린다. 추가로  sum of bar()도 dp로 바꿔버릴 수 있다. 사실 dp 테이블에서 삼각형 모양만 필요하니까 역삼각형 모양에 누적값을 저장해두는 식으로 활용하면 된다. 물론 dp[2001][2002]로 바꿔줘야 하지만... 그런데 이 문제에서는 그렇게 고친다고 엄청나게 빨라지지는 않는다. 그러므로 코드가 깔끔한 버전을 여기에 올림.


 
#include <stdio.h>
#include <string.h>
#include <assert.h>

int dp[2001][2001];

void foo()
{
	for (int i = 0; i <= 2000; ++i) {
		for (int j = 0; j <= i; ++j) {
			if (j > i / 2) {
				dp[i][j] = dp[i][i - j];
			} else if (j == 0) {
				dp[i][j] = 1;
			} else {
				dp[i][j] = dp[i - 1][j - 1] + dp[i - 1][j];
				if (dp[i][j] >= 32749)
					dp[i][j] -= 32749;
			}
		}
	}
}

int bar(int i, int j)
{
	if (i < 0 || i > 2000) {
		fprintf(stderr, "ERR_bar: %d\n", i);
	}
	if (j >= 0 && j <= i)
		return dp[i][j];
	return 0;
}

int main(void)
{
	foo();
	int T_;
	scanf("%d", &T_);
	for (int i_ = 1; i_ <= T_; ++i_) {
		int B, W, k, i;
		scanf("%d %d %d %d", &B, &W, &k, &i);
		assert(k <= B + W);
		if (i > k) {
			fprintf(stderr, "ERR_INPUT: %d %d\n", i, k);
		}
		const int kk = (B + W) - k;
		int ans = bar(k - i, B - i + 1) // i는 흰색
			+ bar(k - i, W - i + 1); // i는 검은색
		if (ans > 32749)
			ans -= 32749;
		if (ans) {
			int mul = 0;
			for (int m = i - 1 - kk; m <= i - 1; ++m) {
				mul += bar(i - 1, m);
			}
			mul %= 32749;
			for (int m = 1; m <= i - 1; ++m) {
				int x = 0;
				for (int p = 0; p <= kk; ++p) {
					x += bar(m - 1, p);
				}
				x %= 32749;
				mul -= bar(i - m - 1, kk - m + 1) * x;
				mul %= 32749;
			}
			//fprintf(stderr, "mul: %d\n", mul);
			ans *= mul;
			ans %= 32749;
		}
		printf("Case #%d: %d\n", i_, ans);
	}
	return 0;
}



Posted by asdfzxcv
알고리즘2012. 3. 7. 19:00
문제 D. 한강

Large 푸는데 1분 30초정도 걸린다. 그리고 메모리는 3.5기가 정도 차지한다. 메모리를 많이 먹는 건 조금 바보같이 코딩해서 그렇다. 소수를 구하는 코드를 꼬꼬마 시절 작성한 코드에서 긁어와서 좀 지저분하기도 하다.

키워드: 소수, 소인수분해, 에라토스테네스

http://algospot.com/forum/read/1322/
http://algospot.com/forum/read/1324/

#include <cstdio>
#include <cstdlib>
#include <climits>
#include <cassert>
#include <ctime>
#include <vector>
#include <cmath>
#include <algorithm>
using namespace std;

#ifndef CHAR_BIT
define CHAR_BIT 8
#endif
#define UINT_BITS 32
#define LARGEINT unsigned int
#define LARGEINT_MAX UINT_MAX
#define LINE 10
#define N_MAX 100000000000000000ll
#define PRIME_MAX 1000000007
typedef long long int lli;
vector<int> prime_list;
vector<lli> pow_table[51000000];
int pow_max[51000000];

#if 0
// for small data set
#define N_MAX 1000000ll
#define PRIME_MAX 1000000
#endif

unsigned int *makebitarray(LARGEINT arraysize)
{
	unsigned int *bitarray;
//	if (!arraysize) arraysize=1;
	LARGEINT q=arraysize/UINT_BITS;
	if (arraysize%UINT_BITS) q++;
	/* -----_check overflow */
	if ( (!q) || (q>LARGEINT_MAX/sizeof(unsigned int)) ) {
		bitarray=NULL;
	} else {
		bitarray=(unsigned int*) calloc(sizeof(unsigned int),q);
	}
	return bitarray;
}

int freebitarray(unsigned int *bitarray)
{
	free(bitarray);
	return 0;
}

inline unsigned int getbit(const unsigned int * const bitarray, const LARGEINT nth)
{
	LARGEINT q=nth/UINT_BITS;
	unsigned int r=nth%UINT_BITS;
	return ((*(bitarray+q))>>r)&1;
}

inline void setbit1(unsigned int * const bitarray, const LARGEINT nth)
{
	LARGEINT q=nth/UINT_BITS;
	unsigned int r=nth%UINT_BITS;
	(*(bitarray+q)|=(1<<r));
}

inline void setbit0(unsigned int * const bitarray, const LARGEINT nth)
{
	LARGEINT q=nth/UINT_BITS;
	unsigned int r=nth%UINT_BITS;
	(*(bitarray+q)&=~(1<<r));
}

int initbitarray(unsigned int *bitarray, const LARGEINT arraysize, const int TorF)
{
	LARGEINT q=arraysize/UINT_BITS;
	if (arraysize%UINT_BITS) q++;
	if (TorF) {
		while (q--) *(bitarray++)=~0;
	} else {
		while (q--) *(bitarray++)=0;
	}
	return 0;
}

void calc_prime_list()
{
	unsigned int n_of_prime=0, end, i, j, n, nsq;
	unsigned int *numbers;
	time_t a,b;

	end = PRIME_MAX; // 10^9
	if (end<3) {
		for (i=2;i<=end;i++) {
			for (j=2; (j*j<=i) && (i%j) ; j++);
			if (j*j>i) printf("%d%c",i,(++n_of_prime%10)?' ':'\n');
		}
		printf("\n\t%d prime(s)\n",n_of_prime);
		exit(0);
	}
	//a=clock();
	/* 홀수만 저장하므로 절반만 할당 */
	numbers=makebitarray(end/2);
	if (NULL==numbers) {
		printf("error: 메모리 할당이 불가능합니다.\n"
				"프로그램을 종료합니다.\n");
		fflush(stdout);
		exit(1);
	}
	initbitarray(numbers,end/2,1);
	setbit0(numbers,0); /* 1은 소수가 아님 */
	//printf("2 ");
	n_of_prime++;

	for (i=1, n=3, nsq=9; (nsq<=end) && (nsq>2*n) ; i++, nsq+=(4*n+4), n+=2) {
		if ( getbit(numbers,i) ) {
			for (j=nsq/2; (j<=end/2) ; j+=n) {
				setbit0(numbers,j);
			}
		}
	}
	fprintf(stderr, "calc_prime_list: sieving done\n");

	prime_list.push_back(2);
	i=0, n=1;
	while( i++, n+=2, (n<=end) && (n>i) ) {
		if ( getbit(numbers,i) ) {
			prime_list.push_back(n);
		}
	}

	freebitarray(numbers);
	//b=clock();
	//printf("\n\t%u prime(s)",n_of_prime);
	//printf("\n\ttime: %fs\n",(double) (b-a)/CLOCKS_PER_SEC);
	//fflush(stdout);
	fprintf(stderr, "calc_prime_list: done\n");
}

void calc_pow_table()
{
	int imax = prime_list.size();
	//double logN = log((double) N_MAX);
	for (int i = 0; i < imax; ++i) {
		const int base = prime_list[i];
		//fprintf(stderr, "calc_pow_table: %d, %d\n", i, base);
		//double log_base = log((double) base);
		//int jmax = (int) (logN / log_base);
		pow_table[i].push_back(1);
		lli tmp = 1;
		int j = 0;
		for (j = 1; tmp <= N_MAX/base; ++j) {
			tmp *= base;
			pow_table[i].push_back(tmp);
		}
		pow_max[i] = j - 1;
	}
	fprintf(stderr, "calc_pow_table: done\n");
}

lli get_num_div(lli N, lli *d0)
{
	*d0 = 0;
	lli sqrtN = (lli) sqrt((double) N);
	if ((sqrtN + 1)*(sqrtN + 1) <= N) {
		fprintf(stderr, "!\n");
		++sqrtN;
	}
	int D = 1;
	int imax = prime_list.size();
	for (int i = 0; i < imax; ++i) {
		int p = prime_list[i];
		if (p > sqrtN)
			break;
		int cnt = 1;
		while (!(N % p)) {
			N /= p;
			++cnt;
			if (*d0 == 0)
				*d0 = p;
		}
		D *= cnt;
		if (N == 1)
			break;
	}
	if (N != 1)
		D *= 2;
	return D;
}

bool comp(pair<int,lli> a, pair<int,lli> b)
{
	return (a.second < b.second);
}

lli foo2(const lli N, const int mi) // D == 2
{
	int low = mi;
	int high = prime_list.size() - 1;
	if (N > prime_list[high]) {
		fprintf(stderr, "error binary search\n");
		return high - low;
	}
	const int N2 = (int) N; // for faster comparison
	while (low < high) {
		int mid = low + (high - low) / 2;
		if (prime_list[mid] > N2) {
			high = mid;
		} else {
			low = mid + 1;
		}
		if (mid == low)
			break;
	}
	if (prime_list[low] > N2) {
		return (low - mi);
	} else if (prime_list[high] > N2) {
		fprintf(stderr, "warning binary search\n");
		return (high - mi);
	} else {
		fprintf(stderr, "error binary search\n");
		fprintf(stderr, "\t%lld %d\n", N, mi);
		return (high - mi);
	}
}

// D == (p + 1) && x == a ** p
lli foo_pow(const lli N, const int mi, const int p)
{
	if (mi >= prime_list.size())
		return 0ll;
	if (p > pow_max[mi] || pow_table[mi][p] > N)
		return 0ll;
	int low = mi;
	int high = prime_list.size() - 1;
	while (low < high) {
		int mid = low + (high - low) / 2;
		if (p > pow_max[mid] || pow_table[mid][p] > N) {
			high = mid;
		} else {
			low = mid + 1;
		}
		if (mid == low)
			break;
	}
	if (p > pow_max[low] || pow_table[low][p] > N) {
		return (low - mi);
	} else if (p > pow_max[high] || pow_table[high][p] > N) {
		fprintf(stderr, "warning binary search pow\n");
		return (high - mi);
	} else {
		fprintf(stderr, "error binary search pow\n");
		fprintf(stderr, "\t%lld %d\n", N, mi);
		return (high - mi);
	}
}

lli foo4(const lli N, const int mi) // D == 4
{
	if (mi >= prime_list.size()) {
		return 0ll;
	}
	if (prime_list[mi] > N) {
		return 0ll;
	}
	if (2 > pow_max[mi] || pow_table[mi][2] > N)
		return 0ll;
	lli r = 0;
	// case 1: x = a * a * a
	r += foo_pow(N, mi, 3);
	// case 2: x = a * b
	const int imax = prime_list.size();
	for (int i = mi; i < imax; ++i) {
		lli N2 = N / prime_list[i];
		if (i + 1 >= imax || N2 < prime_list[i + 1])
			break;
		r += foo2(N2, i + 1);
	}
	return r;
}

lli foo(const lli N, const int mi, const int D)
{
	//printf("!foo %lld %d %d\n", N, mi, D);
	if (N <= 0) {
		fprintf(stderr, "ERROR\n");
		return 0ll;
	}
	switch (D) {
	case 1:
		return 1ll;
		break;
	case 2:
		return foo2(N, mi);
		break;
	case 3:
		//return foo_pow(N, mi, 2);
		break;
	case 4:
		return foo4(N, mi);
		break;
	case 5:
		//return foo_pow(N, mi, 4);
		break;
	}
	if (mi >= prime_list.size()) {
		return 0ll;
	}
	if (prime_list[mi] > N) {
		return 0ll;
	}
	if (D >= 5) {
		if (3 > pow_max[mi] || pow_table[mi][3] > N)
			return 0ll;
	} else if (D >= 3) {
		if (2 > pow_max[mi] || pow_table[mi][2] > N)
			return 0ll;
	}

	lli r = 0;
	int imax = D;
	if (imax > (pow_max[mi] + 1))
		imax = pow_max[mi] + 1;
	for (int i = 1; (i + 1) <= imax; ++i) {
		if (D % (i + 1))
			continue;
		if (i >= 1 && pow_table[mi][i] <= 1)
			fprintf(stderr, "ERROR2\n");
		lli N2 = N / pow_table[mi][i];
		if (!N2)
			break;
		int D2 = D / (i + 1);
		if (D2 > 1 && N2 < prime_list[mi + 1])
			break;
		if (D2 == 1)
			r += 1;
		else
			r += foo(N2, mi + 1, D2);
	}
	return r + foo(N, mi + 1, D); // skip prime_list[mi]
}

int find_mi(const lli M)
{
	int low = 0;
	int high = prime_list.size() - 1;
	while (low < high) {
		int mid = low + (high - low) / 2;
		if (prime_list[mid] >= M) {
			high = mid;
		} else {
			low = mid + 1;
		}
		if (mid == low)
			break;
	}
	if (prime_list[low] >= M) {
		return low;
	} else if (prime_list[high] >= M) {
		return high;
	} else {
		assert(high == prime_list.size() - 1);
		return -1;
	}
}

int main(void)
{
	calc_prime_list();
	calc_pow_table();
	int T_;
	scanf("%d", &T_);
	vector< pair<int,lli> > gg;
	for (int i_ = 1; i_ <= T_; ++i_) {
		gg.clear();
		long long int ans = 0;
		lli N, M;
		scanf("%lld %lld", &N, &M);
		//fprintf(stderr, "! step 0\n");
		int mi = -1;
		lli d0;
		const lli num_div_N = get_num_div(N, &d0);
		if (num_div_N == 2) {
			//fprintf(stderr, "! num_div\n");
			goto final;
		}
		//fprintf(stderr, "! step 1\n");
		//printf("num_div_%lld: %lld\n", N, num_div_N);
		mi = find_mi(M);
		if (mi == -1) {
			//fprintf(stderr, "! mi\n");
			goto final;
		}
		//fprintf(stderr, "! step 2\n");
		ans = foo(N, mi, num_div_N);
		// subtract 1 to exclude N
		if (d0 >= M)
			--ans;
final:
		fprintf(stderr, "Case #%d: %lld\n", i_, ans);
		printf("Case #%d: %lld\n", i_, ans);
	}
	return 0;
}
Posted by asdfzxcv
알고리즘2012. 3. 4. 20:27
문제 B. 장터판

생각보다 훨씬 어렵다. 확률 계산을 잘 해야 함. 시간 복잡도는 O(NMK). Large 푸는데 약 2-3초 걸린다.
#include <stdio.h>
#include <string.h>
#include <float.h>
#include <assert.h>
#define OFFSET 4
#define MAX 100

/* -1: ?
 * 0: not a dice (buffer)
 * 1-K: face of the dice
 */
int board[MAX + 2*OFFSET][MAX + 2*OFFSET];
int KK[10];

int mycnt(int n,
	  int x1, int x2=0, int x3=0, int x4=0, int x5=0, int x6=0, int x7=0)
{
	int cnt = 0;
	switch (n) {
	case 7:
		cnt += (x7 == -1) ? 1 : 0;
	case 6:
		cnt += (x6 == -1) ? 1 : 0;
	case 5:
		cnt += (x5 == -1) ? 1 : 0;
	case 4:
		cnt += (x4 == -1) ? 1 : 0;
	case 3:
		cnt += (x3 == -1) ? 1 : 0;
	case 2:
		cnt += (x2 == -1) ? 1 : 0;
	case 1:
		cnt += (x1 == -1) ? 1 : 0;
	}
	return cnt;
}

int pp(int K, int n,
	  int x1, int x2, int x3=0, int x4=0, int x5=0, int x6=0, int x7=0)
{
	int x = -1;
	int t = 1;
	//int cnt = mycnt(n, x1, x2, x3, x4, x5, x6, x7);
	switch (n) {
	case 7:
		x = (x != -1) ? (x) : (x7);
	case 6:
		x = (x != -1) ? (x) : (x6);
	case 5:
		x = (x != -1) ? (x) : (x5);
	case 4:
		x = (x != -1) ? (x) : (x4);
	case 3:
		x = (x != -1) ? (x) : (x3);
	case 2:
		x = (x != -1) ? (x) : (x2);
		x = (x != -1) ? (x) : (x1);
	}
	switch (n) {
	case 7:
		t = t && (x7 == -1 || x7 == x);
	case 6:
		t = t && (x6 == -1 || x6 == x);
	case 5:
		t = t && (x5 == -1 || x5 == x);
	case 4:
		t = t && (x4 == -1 || x4 == x);
	case 3:
		t = t && (x3 == -1 || x3 == x);
	case 2:
		t = t && (x2 == -1 || x2 == x);
		t = t && (x1 == -1 || x1 == x);
	}
	if (t && x != 0) {
		return 1;
	}
	return 0;
}

double pp4(int K, int x1, int x2, int x3, int x4, int x5, int x6, int x7)
{
	int tmp4 = 0;
	tmp4 += pp(K, 4, x1, x2, x3, x4) ? KK[mycnt(3, x5, x6, x7)] : 0;
	tmp4 += pp(K, 4, x2, x3, x4, x5) ? KK[mycnt(3, x1, x6, x7)] : 0;
	tmp4 += pp(K, 4, x3, x4, x5, x6) ? KK[mycnt(3, x1, x2, x7)] : 0;
	tmp4 += pp(K, 4, x4, x5, x6, x7) ? KK[mycnt(3, x1, x2, x3)] : 0;
	tmp4 -= pp(K, 5, x1, x2, x3, x4, x5) ? KK[mycnt(2, x6, x7)] : 0;
	tmp4 -= pp(K, 5, x2, x3, x4, x5, x6) ? KK[mycnt(2, x1, x7)] : 0;
	tmp4 -= pp(K, 5, x3, x4, x5, x6, x7) ? KK[mycnt(2, x1, x2)] : 0;
	return (double) tmp4 / KK[mycnt(7, x1, x2, x3, x4, x5, x6, x7)];
}

double pp3(int K, int x1, int x2, int x3, int x4, int x5)
{
	int tmp3 = 0;
	tmp3 += pp(K, 3, x1, x2, x3) ? KK[mycnt(2, x4, x5)] : 0;
	tmp3 += pp(K, 3, x2, x3, x4) ? KK[mycnt(2, x1, x5)] : 0;
	tmp3 += pp(K, 3, x3, x4, x5) ? KK[mycnt(2, x1, x2)] : 0;
	tmp3 -= pp(K, 4, x1, x2, x3, x4) ? KK[mycnt(1, x5)] : 0;
	tmp3 -= pp(K, 4, x2, x3, x4, x5) ? KK[mycnt(1, x1)] : 0;
	return (double) tmp3 / KK[mycnt(5, x1, x2, x3, x4, x5)];
}

double pp2(int K, int x1, int x2, int x3)
{
	int tmp2 = 0;
	tmp2 += pp(K, 2, x1, x2) ? KK[mycnt(1, x3)] : 0;
	tmp2 += pp(K, 2, x2, x3) ? KK[mycnt(1, x1)] : 0;
	tmp2 -= pp(K, 3, x1, x2, x3);
	return (double) tmp2 / KK[mycnt(3, x1, x2, x3)];
}

double po(int i0, int j0, int S4, int S3, int S2, int K) // wrong
{
	double point = 0.0;
	int kl = 1;
	int kh = K;
	int div = K;
	int gg = 1;
	if (board[i0][j0] != -1) {
		kl = kh = board[i0][j0];
		div = 1;
		gg = 0;
	}
	for (int k = kl; k <= kh; ++k) {
		board[i0][j0] = k;
		double p4, p3, p2;
		double tmp4, tmp3, tmp2;
		p4 = p3 = p2 = 1.0;
		int i, j;
		i = i0;
		j = j0;
		// row
		tmp4 = pp4(K, board[i][j-3], board[i][j-2], board[i][j-1], board[i][j], board[i][j+1], board[i][j+2], board[i][j+3]);
		p4 *= 1.0 - tmp4;
		tmp3 = pp3(K, board[i][j-2], board[i][j-1], board[i][j], board[i][j+1], board[i][j+2]);
		p3 *= 1.0 - tmp3;
		tmp2 = pp2(K, board[i][j-1], board[i][j], board[i][j+1]);
		p2 *= 1.0 - tmp2;
		// column
		tmp4 = pp4(K, board[i-3][j], board[i-2][j], board[i-1][j], board[i][j], board[i+1][j], board[i+2][j], board[i+3][j]);
		p4 *= 1.0 - tmp4;
		tmp3 = pp3(K, board[i-2][j], board[i-1][j], board[i][j], board[i+1][j], board[i+2][j]);
		p3 *= 1.0 -tmp3;
		tmp2 = pp2(K, board[i-1][j], board[i][j], board[i+1][j]);
		p2 *= 1.0 - tmp2;
		// diagonal 1
		tmp4 = pp4(K, board[i-3][j-3], board[i-2][j-2], board[i-1][j-1], board[i][j], board[i+1][j+1], board[i+2][j+2], board[i+3][j+3]);
		p4 *= 1.0 - tmp4;
		tmp3 = pp3(K, board[i-2][j-2], board[i-1][j-1], board[i][j], board[i+1][j+1], board[i+2][j+2]);
		p3 *= 1.0 -tmp3;
		tmp2 = pp2(K, board[i-1][j-1], board[i][j], board[i+1][j+1]);
		p2 *= 1.0 - tmp2;
		// diagonal 2
		tmp4 = pp4(K, board[i-3][j+3], board[i-2][j+2], board[i-1][j+1], board[i][j], board[i+1][j-1], board[i+2][j-2], board[i+3][j-3]);
		p4 *= 1.0 - tmp4;
		tmp3 = pp3(K, board[i-2][j+2], board[i-1][j+1], board[i][j], board[i+1][j-1], board[i+2][j-2]);
		p3 *= 1.0 - tmp3;
		tmp2 = pp2(K, board[i-1][j+1], board[i][j], board[i+1][j-1]);
		p2 *= 1.0 - tmp2;
		//
		p4 = 1.0 - p4;
		p3 = 1.0 - p3;
		p2 = 1.0 - p2;
		p3 -= p4;
		p2 -= (p3 + p4);
		assert(p4 + DBL_EPSILON >= 0);
		assert(p3 + DBL_EPSILON >= 0);
		assert(p2 + DBL_EPSILON >= 0);
		point += (p4 * S4 + p3 * S3 + p2 * S2) / div;
		//printf("! %d, %d: %f, %f, %f\n", i0, j0, p4, p3, p2);
	}
	if (gg) {
		board[i0][j0] = -1;
	}
	//printf("!--------------------\n");
	return point;
}

int main(void)
{
	int T_;
	scanf("%d", &T_);
	for (int i_ = 1; i_ <= T_; ++i_) {
		memset(board, 0, sizeof(board));
		int N, M, K, S4, S3, S2;
		scanf("%d %d %d %d %d %d", &N, &M, &K, &S4, &S3, &S2);
		scanf("%*c");
		for (int i = OFFSET; i < N + OFFSET; ++i) {
			for (int j = OFFSET; j < M + OFFSET; ++j) {
				char c;
				scanf("%c", &c);
				board[i][j] = (c == '?') ? (-1) : (c - '0');
			}
			scanf("%*c");
		}
		KK[0] = 1;
		for (int i = 1; i < 10; ++i)
			KK[i] = KK[i-1] * K;
		double ans = 0.0;
		for (int i = OFFSET; i < N + OFFSET; ++i) {
			for (int j = OFFSET; j < M + OFFSET; ++j) {
				ans += po(i, j, S4, S3, S2, K);
			}
		}
		printf("Case #%d: %.7f\n", i_, ans);
	}
	return 0;
}
Posted by asdfzxcv
알고리즘2012. 3. 4. 10:00
문제 A. 생존자

삽질만 하다가 틀린 문제다. ㅠㅠㅠㅠ

이것도  DP로 풀면 되긴 하는데, 정렬 방법을 잘못 생각해서 틀렸다. 직관적으로 생각할 때 deadline이 앞에 있는 음식부터 먹어야 할 것 같지 않은가? 그래서 S로 먼저 정렬하고, P로 정렬했다. (P로 나중에 정렬했으니, P에 더 가중치를 두고 정렬한 것이 된다.) 그런데 답이 틀렸다. 정렬 방법을 바꿔봤는데도 틀렸다. 그래서 왜 틀렸나 고민을 해보다가 단순히 P로 정렬하면 안 되는 경우를 찾긴 했다. 그래서 gg치고 딴 거 했다.

답을 보고 하는 소리지만, 결론적으로 (P + S)로 정렬하면 된다. (P + S)는 해당 음식을 최대한 늦게 먹고 배부른 상태에서 그 상태가 끝나는 시간일 것이다. 이 값이 상대적으로 더 작은 음식을 먼저 먹는 것은 경우에 따라 가능하지만, 더 큰 음식을 먼저 먹고 더 작은 음식을 나중에 먹는 것은 절대 불가능하다. 그리고 이 값이 같으면 어느 것을 먼저 먹든 생존 시간에 영향을 주지 않는다. (엄밀한 증명은 좀 귀찮겠지만, 대충 따져보니 이런 것 같다.) 이렇게 답을 적어 놓으면 간단해 보이지만, 문제 풀 당시에 생각하는 건 잘 안 됐다. (원래 그런 거지만)

이렇게 음식을 정렬해두면, 그 다음엔 간단하게 풀 수 있다. 타임라인을 그어보면 P_MAX + S_MAX + 1이 최대일 것이다. 타임라인을 그은 다음, 각 시각에 대해 생존한 상태이고, 막 음식을 먹어야할 상황이면 1로 표시한다. 음식을 순서대로 먹을지 말지 적용하면서 타임라인을 업데이트하면 된다. 시간 복잡도는 O((P + S)N)이 된다. Large를 푸는데 40초정도 걸린다.

사실 small data set은 brute force하게 풀어도 답이 나온다. 통계를 보면 small은 126명 중 105명이 풀었고, large는 20명이 풀었다. (다른 사람들도 아마 정렬하는 걸 생각 못해서 틀렸을 것 같다.)


#include <cstdio>
#include <cstring>
#include <cassert>
#include <vector>
#include <algorithm>
using namespace std;
#define ITEM_MAX 1000
#define P_MAX 100000
#define S_MAX 1000
#define TIME_MAX (P_MAX + S_MAX + 1)

typedef pair<int,int> mypair;
vector<mypair> item;
int dp[TIME_MAX + 1];

bool comp3(mypair a, mypair b)
{
	return (a.first + a.second < b.first + b.second);
}

void foo(int i, int j) // wrong
{
	assert(dp[j]);
	if (i > item.size())
		return;

	if (item[i].first < j) {
		foo(i + 1, j);
	} else {
		// eat item i
		int x = j + item[i].second;
		if (x <= TIME_MAX) {
			dp[x] = 1;
			foo(i + 1, x);
		}
		// don't eat
		foo(i + 1, j);
	}
}

int bar()
{
	int large = 0;
	for (int i = 0; i < item.size(); ++i) {
		for (int j = item[i].first; j >= 0; --j) {
			if (dp[j]) {
				int x = j + item[i].second;
				dp[x] = 1;
				if (x > large)
					large = x;
			}
		}
	}
	return large;
}

int main(void)
{
	int T_;
	scanf("%d", &T_);
	for (int i_ = 1; i_ <= T_; ++i_) {
		item.clear();
		memset(dp, 0, sizeof(dp));
		dp[0] = 1;
		int ans = 0;
		int N;
		scanf("%d", &N);
		for (int i = 0; i<N; ++i) {
			int a, b;
			scanf("%d %d", &a, &b);
			assert(a <= P_MAX);
			assert(b <= S_MAX);
			item.push_back(make_pair(a,b));
		}
		sort(item.begin(), item.end(), comp3);
		//foo(0, 0);
		ans = bar();
		for (int i = 0; i<N; ++i) {
		//	printf("item%d: %d %d\n", i, item[i].first, item[i].second);
		}
		printf("Case #%d: %d\n", i_, ans);
	}
	return 0;
}
Posted by asdfzxcv