1 条题解

  • 0
    @ 2025-7-12 19:08:26

    方法思路

    利用树状数组和哈希表优化解法:

    1. 预处理所有位置i的f(0,i,a[i])值和所有位置j的f(j,n-1,a[j])值
    2. 对于每个j,我们需要知道有多少个i满足i<j且f(0,i,a[i])>f(j,n-1,a[j])
    3. 使用树状数组维护前缀中f(0,i,a[i])的分布,实现O(logn)查询和更新
    4. 从左到右枚举j,查询树状数组中有多少个值大于f(j,n-1,a[j])

    代码实现

    Java
    import java.io.*;
    import java.util.*;
    
    public class Main {
        public static void main(String[] args) throws IOException {
            BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
            int n = Integer.parseInt(br.readLine());
            int[] arr = new int[n];
            
            StringTokenizer st = new StringTokenizer(br.readLine());
            for (int i = 0; i < n; i++) {
                arr[i] = Integer.parseInt(st.nextToken());
            }
            
            System.out.println(solve(n, arr));
        }
        
        public static long solve(int n, int[] arr) {
            // 计算f(0,i,a[i])
            int[] leftCount = new int[n];
            HashMap<Integer, Integer> hLeft = new HashMap<>();
            for (int i = 0; i < n; i++) {
                int num = arr[i];
                hLeft.put(num, hLeft.getOrDefault(num, 0) + 1);
                leftCount[i] = hLeft.get(num);
            }
            
            // 计算f(j,n-1,a[j])
            int[] rightCount = new int[n];
            HashMap<Integer, Integer> hRight = new HashMap<>();
            for (int i = n - 1; i >= 0; i--) {
                int num = arr[i];
                hRight.put(num, hRight.getOrDefault(num, 0) + 1);
                rightCount[i] = hRight.get(num);
            }
            
            // 离散化 - 使用数组排序代替TreeSet
            int[] uniqueValues = new int[n];
            System.arraycopy(leftCount, 0, uniqueValues, 0, n);
            Arrays.sort(uniqueValues);
            
            int uniqueCount = 1;
            for (int i = 1; i < n; i++) {
                if (uniqueValues[i] != uniqueValues[i-1]) {
                    uniqueValues[uniqueCount++] = uniqueValues[i];
                }
            }
            
            // 构建rank映射
            HashMap<Integer, Integer> rank = new HashMap<>(uniqueCount);
            for (int i = 0; i < uniqueCount; i++) {
                rank.put(uniqueValues[i], i + 1);
            }
            
            // 树状数组
            int[] bit = new int[uniqueCount + 1];
            
            long ans = 0;
            for (int j = 0; j < n; j++) {
                if (j > 0) {
                    update(bit, rank.get(leftCount[j-1]), 1);
                }
                
                // 二分查找小于等于rightCount[j]的最大rank
                int threshold = rightCount[j];
                int idx = Arrays.binarySearch(uniqueValues, 0, uniqueCount, threshold);
                if (idx < 0) {
                    idx = -idx - 2; // 找到小于threshold的最大值的索引
                }
                
                int countLessEqual = 0;
                if (idx >= 0) {
                    countLessEqual = query(bit, rank.get(uniqueValues[idx]));
                }
                
                ans += (j - countLessEqual);
            }
            
            return ans;
        }
        
        private static void update(int[] bit, int idx, int val) {
            while (idx < bit.length) {
                bit[idx] += val;
                idx += idx & -idx;
            }
        }
        
        private static int query(int[] bit, int idx) {
            int res = 0;
            while (idx > 0) {
                res += bit[idx];
                idx -= idx & -idx;
            }
            return res;
        }
    }
    
    
    Python
    def solve(n, arr):
        # 计算f(0,i,a[i])
        left_count = [0] * n
        h_left = {}
        for i in range(n):
            num = arr[i]
            h_left[num] = h_left.get(num, 0) + 1
            left_count[i] = h_left[num]
        
        # 计算f(j,n-1,a[j])
        right_count = [0] * n
        h_right = {}
        for i in range(n-1, -1, -1):
            num = arr[i]
            h_right[num] = h_right.get(num, 0) + 1
            right_count[i] = h_right[num]
        
        # 离散化left_count,便于树状数组使用
        unique_counts = sorted(set(left_count))
        rank = {val: idx+1 for idx, val in enumerate(unique_counts)}
        max_rank = len(unique_counts)
        
        # 树状数组
        bit = [0] * (max_rank + 1)
        
        def update(idx, val):
            while idx <= max_rank:
                bit[idx] += val
                idx += idx & -idx
        
        def query(idx):
            res = 0
            while idx > 0:
                res += bit[idx]
                idx -= idx & -idx
            return res
        
        ans = 0
        for j in range(n):
            if j > 0:
                # 将之前位置的f(0,i,a[i])加入树状数组
                update(rank[left_count[j-1]], 1)
            
            # 查询大于f(j,n-1,a[j])的数量
            count_less_equal = query(max_rank) - query(rank.get(right_count[j], 0))
            ans += count_less_equal
        
        return ans
    
    n = int(input())
    arr = list(map(int, input().split()))
    print(solve(n, arr))
    
    
    C++
    #include <iostream>
    #include <vector>
    #include <unordered_map>
    #include <set>
    #include <algorithm>
    using namespace std;
    
    // 树状数组
    class BIT {
    private:
        vector<int> tree;
        int n;
        
    public:
        BIT(int size) {
            n = size;
            tree.resize(n + 1, 0);
        }
        
        void update(int idx, int val) {
            while (idx <= n) {
                tree[idx] += val;
                idx += idx & -idx;
            }
        }
        
        int query(int idx) {
            int sum = 0;
            while (idx > 0) {
                sum += tree[idx];
                idx -= idx & -idx;
            }
            return sum;
        }
    };
    
    long long solve(int n, vector<int>& arr) {
        // 计算f(0,i,a[i])
        vector<int> leftCount(n);
        unordered_map<int, int> hLeft;
        for (int i = 0; i < n; i++) {
            int num = arr[i];
            hLeft[num]++;
            leftCount[i] = hLeft[num];
        }
        
        // 计算f(j,n-1,a[j])
        vector<int> rightCount(n);
        unordered_map<int, int> hRight;
        for (int i = n - 1; i >= 0; i--) {
            int num = arr[i];
            hRight[num]++;
            rightCount[i] = hRight[num];
        }
        
        // 离散化leftCount
        set<int> uniqueCountsSet;
        for (int count : leftCount) {
            uniqueCountsSet.insert(count);
        }
        vector<int> uniqueCounts(uniqueCountsSet.begin(), uniqueCountsSet.end());
        sort(uniqueCounts.begin(), uniqueCounts.end());
        
        unordered_map<int, int> rank;
        for (int i = 0; i < uniqueCounts.size(); i++) {
            rank[uniqueCounts[i]] = i + 1;
        }
        int maxRank = uniqueCounts.size();
        
        // 初始化树状数组
        BIT bit(maxRank);
        
        long long ans = 0;
        for (int j = 0; j < n; j++) {
            if (j > 0) {
                // 将之前位置的f(0,i,a[i])加入树状数组
                bit.update(rank[leftCount[j-1]], 1);
            }
            
            // 查询树状数组中大于f(j,n-1,a[j])的数量
            int threshold = rightCount[j];
            int rankThreshold = 0;
            
            // 二分查找小于等于threshold的最大rank
            auto it = upper_bound(uniqueCounts.begin(), uniqueCounts.end(), threshold);
            if (it != uniqueCounts.begin()) {
                it--;
                rankThreshold = rank[*it];
            }
            
            int countGreater = bit.query(maxRank) - bit.query(rankThreshold);
            ans += countGreater;
        }
        
        return ans;
    }
    
    int main() {
        int n;
        cin >> n;
        
        vector<int> arr(n);
        for (int i = 0; i < n; i++) {
            cin >> arr[i];
        }
        
        cout << solve(n, arr) << endl;
        
        return 0;
    }
    
    
    • 1

    信息

    ID
    58
    时间
    1000ms
    内存
    256MiB
    难度
    5
    标签
    递交数
    1
    已通过
    1
    上传者