[笔记] 二分答案

简化题意

首先,给出这样的三道题:

T1:

给定一个有 $N$ 个元素的集合 $A$,一个数 $K$,一个数 $L$。

请执行以下操作:

  • 用集合中所有的数除以L,取其整数部分。
  • 把所有整数部分相加。
  • 如果相加的和不小于 $K$,则输出 YES;如果相加的和小于 $K$,则输出 NO

T2:

给定一个有 $N$ 个元素的集合 $A$,一个数 $H$,一个数 $M$。

请执行以下操作:

  • 选出集合中所有大于 $H$ 的数;
  • 计算出这些数减去 $H$ 的差值;
  • 把这些差值相加。

如果相加的和不小于 $M$,则输出 YES;如果相加的和小于 $M$,则输出 NO

T3:

给定一个有 $N$ 个元素的有序集合 $A$,要从中选出 $C$ 个数,要求选中的数与相邻两数的差都不小于 $D$,问是否能选出至少 $C$ 个?

若能选出,则输出 YES;若不能选出,则输出 NO

且上述题目均有 $Q$ 组数据,所有数为不超过 $10^3$ 的正整数。

根据以上三道题的题面,容易写出代码:

T1:

while (q--)
{
    int l, sum = 0;
    scanf("%d", &l);
    for (int i = 1; i <= n; i++)
        sum += a[i] / l;
    if (sum >= k)
        printf("YES\n");
    else printf("NO\n");
}

T2:

while (q--)
{
    int h, sum = 0;
    scanf("%d", &h);
    for (int i = 1; i <= n; i++)
        if (a[i] > h)
            sum += a[i] - h;
    if (sum >= m)
        printf("YES\n");
    else printf("NO\n");
}

T3:

while (q--)
{
    int d, cnt = 1, prev = a[1];
    scanf("%d", &d);
    for (int i = 2; i <= n; i++)
    {
        if (a[i] - prev >= d)
        {
            cnt++;
            prev = a[i];
        }
    }
    if (cnt >= c)
        printf("YES\n");
    else printf("NO\n");
}

check 函数

现在,我们要把上述代码中的判断部分改为一个名为 check 的函数。可以这样修改:

T1:

bool check(int l)
{
	int sum = 0;
	for (int i = 1; i <= n; i++)
		sum += a[i] / l;
	return sum >= k;
}
// ...
while (q--)
{
    int l;
    scanf("%d", &l);
    if (check(l))
        printf("YES\n");
    else printf("NO\n");
}

T2:

bool check(int h)
{
	int sum = 0;
	for (int i = 1; i <= n; i++)
		if (a[i] > h)
			sum += a[i] - h;
	return sum >= m;
}
// ...
while (q--)
{
    int h;
    scanf("%d", &h);
    if (check(h))
        printf("YES\n");
    else printf("NO\n");
}

T3:

bool check(int d)
{
	int cnt = 1, prev = a[1];
	for (int i = 2; i <= n; i++)
	{
		if (a[i] - prev >= d)
		{
			cnt++;
			prev = a[i];
		}
	}
	return cnt >= c;
}
// ...
while (q--)
{
    int d;
    scanf("%d", &d);
    if (check(d))
        printf("YES\n");
    else printf("NO\n");
}

检查输出

我们以 T3 为例。

假设 $N = 5, A = \lbrace 1, 2, 4, 8, 9\rbrace, C = 3$,并且询问为 $1, 2, \dotsc, 9$,那么得到输出:

YES
YES
YES
NO
NO
NO
NO
NO
NO

发现前三行都是 YES,后面都是 NO。说明有一个临界点,即为 $3$。

回到原题

这三题的原题分别为 P2440 木材加工 P1873 [COCI 2011/2012 #5] EKO / 砍树P1824 进击的奶牛 P1676 [USACO05FEB] Aggressive cows G)。

这三道题都是询问上面所述的临界点,那么我们可以枚举这个临界点,但是线性枚举对于这个数据来说实在是太慢了。

二分算法

我们可以使用二分算法。二分算法,即每次折半,每次循环时要移动左端点和右端点,直到右端点跑到左端点的左边去。复杂度为 $O(\log n)$。

二分算法就是每次用一个 check 函数检查,最终查找到想要的结果。基本上就是这个模型:

int l = MIN, r = MAX, ans = 0;
while (l <= r)
{
    int mid = l + r >> 1; // 或写作 (l + r) / 2
    if (check(mid)) // 包含等于的情况
    {
        ans = mid;
        l = mid + 1;
    }
    else r = mid - 1;
}

最终代码

最后,根据题意修改数据范围、枚举范围等即可。

最终的代码如下所示:

P2440:

#include <iostream>
#include <cstdio>
#define MAXN 100001
using namespace std;
int a[MAXN], n, k;
bool check(int l)
{
	int sum = 0;
	for (int i = 1; i <= n; i++)
		sum += a[i] / l;
	return sum >= k;
}
int main()
{
	scanf("%d%d", &n, &k);
	for (int i = 1; i <= n; i++)
		scanf("%d", &a[i]);
	int l = 1, r = 1e8, ans = 0;
	while (l <= r)
	{
		int mid = l + r >> 1;
		if (check(mid))
		{
			ans = mid;
			l = mid + 1;
		}
		else r = mid - 1;
	}
	printf("%d\n", ans);
	return 0;
}

P1873:

#include <iostream>
#include <cstdio>
#define MAXN 1000001
using namespace std;
typedef long long ll; 
int a[MAXN], n, m;
bool check(int h)
{
	ll sum = 0;
	for (int i = 1; i <= n; i++)
		if (a[i] > h)
			sum += a[i] - h;
	return sum >= m;
}
int main()
{
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; i++)
		scanf("%d", &a[i]);
	int l = 1, r = 4e5, ans = 0;
	while (l <= r)
	{
		int mid = l + r >> 1;
		if (check(mid))
		{
			ans = mid;
			l = mid + 1;
		}
		else r = mid - 1;
	}
	printf("%d\n", ans);
	return 0;
}

P1824(P1676):

#include <iostream>
#include <cstdio>
#include <algorithm>
#define MAXN 100001
using namespace std;
int a[MAXN], n, c;
bool check(int d)
{
	int cnt = 1, prev = a[1];
	for (int i = 2; i <= n; i++)
	{
		if (a[i] - prev >= d)
		{
			cnt++;
			prev = a[i];
		}
	}
	return cnt >= c;
}
int main()
{
	scanf("%d%d", &n, &c);
	for (int i = 1; i <= n; i++)
		scanf("%d", &a[i]);
	sort(a + 1, a + n + 1); // 题意中是无序的
	int l = 1, r = 1e9, ans = 0;
	while (l <= r)
	{
		int mid = l + r >> 1;
		if (check(mid))
		{
			ans = mid;
			l = mid + 1;
		}
		else r = mid - 1;
	}
	printf("%d\n", ans);
	return 0;
}

Last modified on 2024-03-05