1 条题解

  • 0
    @ 2025-7-9 18:26:03

    方法思路

    1. 输入处理:读取节点数量、节点权值以及树的边。
    2. 树的构建:使用邻接表表示树结构,其中每个节点存储其相邻节点列表。
    3. 深度优先搜索(DFS):从根节点开始遍历树,维护两个计数器 onetwo,分别记录当前路径中权值为1和2的节点数。
    4. 路径统计:在遍历过程中,根据当前节点的权值更新计数器,并统计满足权值和为3的路径数目。具体来说:
      • 如果当前节点权值为2,则累加之前路径中权值为1的节点数。
      • 如果当前节点权值为1,则累加之前路径中权值为2的节点数,并组合之前权值为1的节点数计算可能的路径数。
    5. 结果输出:最终输出满足条件的路径数目。

    这种方法通过动态维护路径中的权值计数,高效地统计了所有可能的路径数目,确保在较大的树结构下也能快速运行。

    代码实现

    Java
    import java.io.*;
    import java.util.*;
    
    public class Main {
        static List<Integer>[] tree;
        static int[] vals;
        static long ans = 0;
    
        public static void main(String[] args) throws IOException {
            BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
            int n = Integer.parseInt(br.readLine());
            vals = new int[n];
            StringTokenizer st = new StringTokenizer(br.readLine());
            for (int i = 0; i < n; i++) vals[i] = Integer.parseInt(st.nextToken());
    
            tree = new List[n];
            for (int i = 0; i < n; i++) tree[i] = new ArrayList<>();
    
            for (int i = 1; i < n; i++) {
                st = new StringTokenizer(br.readLine());
                int u = Integer.parseInt(st.nextToken()) - 1;
                int v = Integer.parseInt(st.nextToken()) - 1;
                tree[u].add(v);
                tree[v].add(u);
            }
    
            dfs(0, -1);
            System.out.println(ans);
        }
    
        static int[] dfs(int x, int parent) {
            int one = 0, two = 0;
            
            for (int child : tree[x]) {
                if (child != parent) {
                    int[] res = dfs(child, x);
                    if (vals[x] == 2) {
                        ans += res[0];
                    } else {
                        ans += res[1];
                        two += res[0];
                    }
                    one += res[0];
                }
            }
            
            if (vals[x] == 1) {
                ans += (long)one * (one - 1) / 2; 
            }
            
            return new int[]{vals[x] == 1 ? 1 : 0, two + (vals[x] == 2 ? 1 : 0)};
        }
    }
    
    
    Python
    import sys
    input = sys.stdin.readline
    
    def main():
        n = int(input())
        vals = list(map(int, input().split()))
        
        tree = [[] for _ in range(n)]
        for i in range(n - 1):
            u, v = map(int, input().split())
            tree[u-1].append(v-1)
            tree[v-1].append(u-1)
        
        ans = 0
        
        def dfs(x, parent):
            nonlocal ans
            one = 0
            two = 0
            
            for child in tree[x]:
                if child != parent:
                    r_one, r_two = dfs(child, x)
                    if vals[x] == 2:
                        ans += r_one
                    else:
                        ans += r_two
                        two += r_one
                    one += r_one
            
            if vals[x] == 1:
                ans += one * (one - 1) // 2
            
            return (1 if vals[x] == 1 else 0, two + (1 if vals[x] == 2 else 0))
        
        dfs(0, -1)
        print(ans)
    
    if __name__ == "__main__":
        main()
    
    
    C++
    #include <iostream>
    #include <vector>
    using namespace std;
    
    vector<vector<int>> tree;
    vector<int> vals;
    long long ans = 0;
    
    pair<int, int> dfs(int x, int parent) {
        int one = 0, two = 0;
        
        for (int child : tree[x]) {
            if (child != parent) {
                auto [r_one, r_two] = dfs(child, x);
                if (vals[x] == 2) {
                    ans += r_one;
                } else {
                    ans += r_two;
                    two += r_one;
                }
                one += r_one;
            }
        }
        
        if (vals[x] == 1) {
            ans += (long long)one * (one - 1) / 2;
        }
        
        return {vals[x] == 1 ? 1 : 0, two + (vals[x] == 2 ? 1 : 0)};
    }
    
    int main() {
        ios::sync_with_stdio(false);
        cin.tie(nullptr);
        
        int n;
        cin >> n;
        
        vals.resize(n);
        for (int i = 0; i < n; i++) {
            cin >> vals[i];
        }
        
        tree.resize(n);
        for (int i = 1; i < n; i++) {
            int u, v;
            cin >> u >> v;
            u--; v--;
            tree[u].push_back(v);
            tree[v].push_back(u);
        }
        
        dfs(0, -1);
        cout << ans << endl;
        
        return 0;
    }
    
    
    • 1

    信息

    ID
    40
    时间
    1000ms
    内存
    256MiB
    难度
    5
    标签
    递交数
    1
    已通过
    1
    上传者