并查集

11  •  3个月前


#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

struct Edge {
    long long u, v, weight;
    bool operator<(const Edge& other) const {
        return weight < other.weight;
    }
};

long long find(vector<long long>& parent, long long i) {
    if (parent[i] == i) {
        return i;
    } else {
        parent[i] = find(parent, parent[i]); // 路径压缩
        return parent[i];
    }
}

void unionSets(vector<long long>& parent, vector<long long>& rank, long long x, long long y) {
    long long root_x = find(parent, x);
    long long root_y = find(parent, y);

    if (root_x != root_y) {
        if (rank[root_x] > rank[root_y]) {
            parent[root_y] = root_x;
        } else if (rank[root_x] < rank[root_y]) {
            parent[root_x] = root_y;
        } else {
            parent[root_y] = root_x;
            rank[root_x]++;
        }
    }
}

long long kruskal(long long n, vector<Edge>& edges) {
    vector<long long> parent(n + 1);
    vector<long long> rank(n + 1, 0);
    for (long long i = 1; i <= n; i++) {
        parent[i] = i;
    }

    sort(edges.begin(), edges.end());

    long long mst_weight = 0;
    long long mst_edges = 0;

    for (const auto& edge : edges) {
        if (find(parent, edge.u) != find(parent, edge.v)) {
            unionSets(parent, rank, edge.u, edge.v);
            mst_weight += edge.weight;
            mst_edges++;
            if (mst_edges == n - 1) {
                break;
            }
        }
    }

    if (mst_edges != n - 1) {
        return -1; // 图不连通
    } else {
        return mst_weight;
    }
}

int main() {
    long long n, m;
    cin >> n >> m;

    vector<Edge> edges(m);
    for (long long i = 0; i < m; i++) {
        cin >> edges[i].u >> edges[i].v >> edges[i].weight;
    }

    long long result = kruskal(n, edges);
    if (result == -1) {
        cout << "orz" << endl;
    } else {
        cout << result << endl;
    }

    return 0;
}

评论:

请先登录,才能进行评论