[笔记] P1433

题意简述

给定 $n$ 个点,求从 $(0, 0)$ 处出发遍历完所有的点的最短距离。

分析

先写输入和输出。我们用一个结构体 node 来表示一个点。

#include <iostream>
#include <cstdio>
using namespace std;
struct node {double x, y;} a[20];
int n;
double ans = 2147483647;
int main()
{
	scanf("%d", &n);
	for (int i = 1; i <= n; i++)
		scanf("%lf%lf", &a[i].x, &a[i].y);
	printf("%.2lf\n", ans);
	return 0;
}

这个 DFS 有四个参数,分别是:层数(step)、状态(state)、已走距离(sum)和上一步的点(pre)。

状态用一个 $n$ 位的二进制数来表示,$1$ 表示没吃,$0$ 表示被吃了。

由于这道题复杂度极高,所以我们要用位运算优化。

// 全局定义
void dfs(int step, int state, double sum, node pre)
{
}
// 主函数
dfs(1, (1 << n) - 1, 0, node{0, 0});

如果递归到 $n + 1$ 层了,那说明已经遍历完所有的点了,所以要跳出并更新答案:

if (step == n + 1)
{
    ans = min(ans, sum);
    return ;
}

然后就写 DFS 里面的内容,我们对状态进行逐位 lowbit,以遍历所有的点,得到所有的情况。

而每次更新的时候,层数递增,状态要去掉最后一位 $1$,已走距离要加上从上一个位置到这一个位置的距离,再更新上一个位置。

$p$ 数组即把 lowbit 转换成第几个点。

// 全局定义
int p[1 << 16];
int lowbit(int x) {return x & -x;}
double getdis(node a, node b)
{
	return sqrt((a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y));
}
// DFS
int t = state;
while (t)
{
    int lt = lowbit(t);
    t -= lt;
    dfs(step + 1, state - lt, sum + getdis(pre, a[p[lt]]), a[p[lt]]);
}
// 主函数
for (int i = 0; i < 16; i++)
	p[1 << i] = i + 1;

接下来进行第一个剪枝:最优性剪枝。如果当前的距离已经超过最短距离,那么就不需要再递归了。

代码如下(出口后):

if (sum >= ans)
	return ;

然后就是第二个剪枝:最优性剪枝。这里如果现在的距离比从上一步到达这个状态的最小距离大,那么就不用继续搜索了。如果不比它大,那就更新这个值。

我们定义 $d_{i, j}$ 表示从 $j$ 到达状态 $i$ 的最小距离,而我们也需要给 node 加上编号这个属性,以便可以表示它是第几个点。

// 全局定义
struct node
{
	double x, y;
	int id; 
} a[20];
double d[1 << 16][20];
// DFS
if (d[state][pre.id] > 0 && d[state][pre.id] < sum)
    return ;
d[state][pre.id] = sum; 

完整代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <iostream>
#include <cstdio>
#include <cmath>
using namespace std;
struct node
{
	double x, y;
	int id; 
} a[20];
int n, p[1 << 16];
double ans = 2147483647, d[1 << 16][20];
int lowbit(int x) {return x & -x;}
double getdis(node a, node b)
{
	return sqrt((a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y));
}
void dfs(int step, int state, double sum, node pre)
{
	if (step == n + 1)
	{
		ans = min(ans, sum);
		return ;
	}
	if (sum >= ans)
		return ;
	if (d[state][pre.id] > 0 && d[state][pre.id] < sum)
		return ;
	d[state][pre.id] = sum; 
	int t = state;
	while (t)
	{
		int lt = lowbit(t);
		t -= lt;
		dfs(step + 1, state - lt, sum + getdis(pre, a[p[lt]]), a[p[lt]]);
	}
}
int main()
{
	for (int i = 0; i < 16; i++)
		p[1 << i] = i + 1;
	scanf("%d", &n);
	for (int i = 1; i <= n; i++)
	{
		scanf("%lf%lf", &a[i].x, &a[i].y);
		a[i].id = i;
	}
	dfs(1, (1 << n) - 1, 0, node{0, 0, 0});
	printf("%.2lf\n", ans);
	return 0;
}

Last modified on 2024-01-14