1 条题解
-
0
方法思路
- 输入处理:读取节点数量、节点权值以及树的边。
- 树的构建:使用邻接表表示树结构,其中每个节点存储其相邻节点列表。
- 深度优先搜索(DFS):从根节点开始遍历树,维护两个计数器
one
和two
,分别记录当前路径中权值为1和2的节点数。 - 路径统计:在遍历过程中,根据当前节点的权值更新计数器,并统计满足权值和为3的路径数目。具体来说:
- 如果当前节点权值为2,则累加之前路径中权值为1的节点数。
- 如果当前节点权值为1,则累加之前路径中权值为2的节点数,并组合之前权值为1的节点数计算可能的路径数。
- 结果输出:最终输出满足条件的路径数目。
这种方法通过动态维护路径中的权值计数,高效地统计了所有可能的路径数目,确保在较大的树结构下也能快速运行。
代码实现
Java
import java.io.*; import java.util.*; public class Main { static List<Integer>[] tree; static int[] vals; static long ans = 0; public static void main(String[] args) throws IOException { BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); int n = Integer.parseInt(br.readLine()); vals = new int[n]; StringTokenizer st = new StringTokenizer(br.readLine()); for (int i = 0; i < n; i++) vals[i] = Integer.parseInt(st.nextToken()); tree = new List[n]; for (int i = 0; i < n; i++) tree[i] = new ArrayList<>(); for (int i = 1; i < n; i++) { st = new StringTokenizer(br.readLine()); int u = Integer.parseInt(st.nextToken()) - 1; int v = Integer.parseInt(st.nextToken()) - 1; tree[u].add(v); tree[v].add(u); } dfs(0, -1); System.out.println(ans); } static int[] dfs(int x, int parent) { int one = 0, two = 0; for (int child : tree[x]) { if (child != parent) { int[] res = dfs(child, x); if (vals[x] == 2) { ans += res[0]; } else { ans += res[1]; two += res[0]; } one += res[0]; } } if (vals[x] == 1) { ans += (long)one * (one - 1) / 2; } return new int[]{vals[x] == 1 ? 1 : 0, two + (vals[x] == 2 ? 1 : 0)}; } }
Python
import sys input = sys.stdin.readline def main(): n = int(input()) vals = list(map(int, input().split())) tree = [[] for _ in range(n)] for i in range(n - 1): u, v = map(int, input().split()) tree[u-1].append(v-1) tree[v-1].append(u-1) ans = 0 def dfs(x, parent): nonlocal ans one = 0 two = 0 for child in tree[x]: if child != parent: r_one, r_two = dfs(child, x) if vals[x] == 2: ans += r_one else: ans += r_two two += r_one one += r_one if vals[x] == 1: ans += one * (one - 1) // 2 return (1 if vals[x] == 1 else 0, two + (1 if vals[x] == 2 else 0)) dfs(0, -1) print(ans) if __name__ == "__main__": main()
C++
#include <iostream> #include <vector> using namespace std; vector<vector<int>> tree; vector<int> vals; long long ans = 0; pair<int, int> dfs(int x, int parent) { int one = 0, two = 0; for (int child : tree[x]) { if (child != parent) { auto [r_one, r_two] = dfs(child, x); if (vals[x] == 2) { ans += r_one; } else { ans += r_two; two += r_one; } one += r_one; } } if (vals[x] == 1) { ans += (long long)one * (one - 1) / 2; } return {vals[x] == 1 ? 1 : 0, two + (vals[x] == 2 ? 1 : 0)}; } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); int n; cin >> n; vals.resize(n); for (int i = 0; i < n; i++) { cin >> vals[i]; } tree.resize(n); for (int i = 1; i < n; i++) { int u, v; cin >> u >> v; u--; v--; tree[u].push_back(v); tree[v].push_back(u); } dfs(0, -1); cout << ans << endl; return 0; }
- 1
信息
- ID
- 40
- 时间
- 1000ms
- 内存
- 256MiB
- 难度
- 5
- 标签
- 递交数
- 1
- 已通过
- 1
- 上传者