Minimum Spanning Tree (MST)

최소 신장 트리

YEAHx4

YEAHx4

2026-02-23
11 mins

MST

MST는 그래프에서 모든 정점을 연결하는 트리 중 간선의 가중치 합이 가장 작은 트리를 말합니다. 트리의 특징 중 하나로 간선의 수는 항상 노드의 수보다 하나 적다는 점이 있습니다. 즉, MST는 \( N \)개의 노드와 \( M \)개의 간선으로 이루어진 그래프에서 \( N-1 \)개의 간선을 선택해 모든 노드를 연결하는 트리를 찾으면서 간선의 가중치 합이 최소가 되도록 하는 문제입니다. 이 글을 이해하기 위해서 Union-Find에 대한 이해가 필요합니다. Union-Find에 대한 설명은 이 글에서 확인할 수 있습니다.

Spanning Tree

스패닝 트리는 그래프에서 모든 정점을 포함하는 트리를 말합니다. 예를 들어, 아래의 그래프가 있을 때,

그래프

이 그래프는 노드가 5개인 그래프니 4개의 간선만 있으면 모든 노드를 연결할 수 있습니다.

MST

이 그래프가 스패닝 트리입니다. 물론 이 형태 말고도 다른 모양의 스패닝 트리도 존재합니다. MST의 목적은 이런 여러 스패닝 트리 중 간선의 가중치 합이 가장 작은 트리를 찾는 것입니다.

구현 원리

가장 간단히 생각하면 모든 스패닝 트리를 만들어보고 가중치 합의 최솟값을 찾으면 되지만, 시간이 오래 걸려 현실적으로 불가능합니다. 그래서 조금 더 효율적인 방법을 사용해야 합니다. 간선 가중치의 합이 최소가 되려면 가중치가 작은 간선을 먼저 선택하는 것이 좋습니다. 그래서 간선을 가중치 순으로 정렬한 후 가장 작은 간선부터 선택해 나가는 방법을 생각해 볼 수 있습니다. 정리하면 다음과 같습니다.

  1. 간선을 가중치 순으로 정렬한다.
  2. 가장 작은 간선부터 선택한다.
  3. 선택한 간선이 사이클을 형성하지 않는다면 트리에 포함시킨다.
  4. 모든 노드가 연결될 때까지 2-3을 반복한다

정렬은 \( O(n \log n) \) 시간에 처리할 수 있습니다. 그런데 간선이 사이클을 형성하는지 빠르게 확인할 수 있는 방법이 필요합니다. 이떄 Union-Find 자료구조가 유용하게 사용됩니다. 처음에는 모든 노드가 서로 다른 집합에 속해 있습니다. 간선을 선택하면 두 노드의 집합이 서로 합병됩니다. 그리고 두 노드가 같은 집합에 속해 있는지를 Union-Find를 사용하면 \( O(1) \)에 처리할 수 있습니다.

이 방법을 크루스칼 알고리즘(Kruskal's Algorithm)이라고 합니다. 다른 방법으로는 프림 알고리즘(Prim's Algorithm)이 있습니다. 이 글에서는 프림 알고리즘에 대해서는 설명하지 않을 예정이지만, 두 알고리즘의 시간복잡도는 비슷합니다.

구현

BOJ 1197번: 최소 스패닝 트리 문제를 풀어보면서 MST를 구현해 보겠습니다. 일단 Union-Find를 구현합니다. 최대 개수가 1만 개이기 때문에 1-indexed를 위해 10001로 선언합니다.

int parent[10001];

int find(int x) {
    if (parent[x] == x) return x;
    return parent[x] = find(parent[x]);
}

void unite(int a, int b) {
    int pa = find(a);
    int pb = find(b);
    if (pa != pb) {
        parent[pb] = pa;
    }
}

그리고 \( v, e \)를 입력받고 parents 배열을 초기화합니다. 초기에는 모든 노드가 서로 다른 집합에 속해 있으므로, parent[i] = i로 초기화합니다.

int v, e; cin >> v >> e;
for (int i = 1; i <= v; i++)
    parent[i] = i;

이제 간선 정보를 입력받아 저장하고 가중치 순으로 정렬합니다. 정렬을 위해 정렬하는 순서를 가중치, 노드1, 노드2로 설정합니다.

vector<pair<int, pair<int,int>>> edges;
while (e--) {
    int a, b, c;
    cin >> a >> b >> c;
    edges.push_back({c, {a, b}});
}

sort(edges.begin(), edges.end());

이제 간선을 하나씩 선택하면서 사이클이 형성되는지 확인하고, 사이클이 형성되지 않는다면 트리에 포함시킵니다. 그리고 두 노드가 같은 집합에 속하도록 합병합니다. 간선이 선택될 때마다 가중치를 합산합니다. 여기서는 간선의 개수가 10만개이기 때문에 모든 간선을 확인해도 괜찮습니다. 하지만, 간선은 \( v - 1 \)개를 선택해야 함을 알고 있기 때문에 중간에 종료해 주어도 괜찮습니다.

int ans = 0;
for (auto &[cost, edge] : edges) {
    auto &[a, b] = edge;

    if (find(a) != find(b)) {
        unite(a, b);
        ans += cost;
    }
}

ans를 출력하면 정답을 받을 수 있습니다. 전체 코드는 다음과 같습니다.

#include <bits/stdc++.h>

using namespace std;
void fio() { cin.tie(nullptr); cout.tie(nullptr); ios::sync_with_stdio(false); }

int parent[10001];

int find(int x) {
    if (parent[x] == x) return x;
    return parent[x] = find(parent[x]);
}

void unite(int a, int b) {
    int pa = find(a);
    int pb = find(b);
    if (pa != pb) {
        parent[pb] = pa;
    }
}

int main() {
    fio();

    int v, e; cin >> v >> e;
    for (int i = 1; i <= v; i++)
        parent[i] = i;

    vector<pair<int, pair<int,int>>> edges;
    while (e--) {
        int a, b, c;
        cin >> a >> b >> c;
        edges.push_back({c, {a, b}});
    }

    int ans = 0;
    for (auto &[cost, edge] : edges) {
        auto &[a, b] = edge;

        if (find(a) != find(b)) {
            unite(a, b);
            ans += cost;
        }
    }

    cout << ans;
}

연습문제