union-find, disjoint set

간단 정리

 

서로소 집합을 표현할 수 있다.

트리 구조를 사용하며, 한 집합은 하나의 트리구조로 표현된다. 

무방향 그래프의 사이클을 검사할 수 있다.

시간 복잡도는 대략 (find 연산의 수) x (log_2 노드의 개수) 이다.

 

union

두 원소를 같은 집합으로 합치는 연산

원소 a, b와 둘의 루트 노드 A, B가 있을 때, A를 B의 자식으로 (혹은 반대로) 연결한다.

def union(x, y):
    x = find_parent(x)
    y = find_parent(y)

    parents[x] = y

 

find

원소의 부모를 찾는 연산

자식에서 부모 쪽으로 거슬러 올라간다.

노드의 부모가 자기 자신이 될 때까지(루트 노드) 반복한다.

경로 압축을 통해 최적화를 할 수 있다.

def find_parent(x):
    if parents[x] != x:
        parents[x] = find_parent(parents[x])
    return parents[x]

 

 


 

 

예제

 

집합의 표현

https://www.acmicpc.net/problem/1717

 

1717번: 집합의 표현

초기에 $n+1$개의 집합 $\{0\}, \{1\}, \{2\}, \dots , \{n\}$이 있다. 여기에 합집합 연산과, 두 원소가 같은 집합에 포함되어 있는지를 확인하는 연산을 수행하려고 한다. 집합을 표현하는 프로그램을 작

www.acmicpc.net

 

import sys
sys.setrecursionlimit(1_000_000)


def find_parent(x):
    if parents[x] != x:
        parents[x] = find_parent(parents[x])
    return parents[x]


def is_same_set(x, y):
    x = find_parent(x)
    y = find_parent(y)

    return "yes" if x == y else "no"


def union(x, y):
    x = find_parent(x)
    y = find_parent(y)

    parents[x] = y


input = lambda: sys.stdin.readline().rstrip()
n, m = map(int, input().split())
parents = [i for i in range(n + 1)]
for _ in range(m):
    op, a, b = map(int, input().split())
    if op == 0:
        union(a, b)
    elif op == 1:
        print(is_same_set(a, b))

 

find_parent가 재귀 함수이므로, 노드의 개수에 따라 재귀로 접근할 수 있는 횟수를 조절해야한다.

 

 


 

 

친구 네트워크

https://www.acmicpc.net/problem/4195

 

4195번: 친구 네트워크

첫째 줄에 테스트 케이스의 개수가 주어진다. 각 테스트 케이스의 첫째 줄에는 친구 관계의 수 F가 주어지며, 이 값은 100,000을 넘지 않는다. 다음 F개의 줄에는 친구 관계가 생긴 순서대로 주어진

www.acmicpc.net

 

import sys
sys.setrecursionlimit(100_000)


def add_people(x):
    global index

    if x not in people:
        people[x] = index
        index += 1


def find_parent(x):
    if parents[x] != x:
        parents[x] = find_parent(parents[x])
    return parents[x]


def union(x, y):
    x = people[x]
    y = people[y]

    x = find_parent(x)
    y = find_parent(y)

    parents[x] = y
    if x != y:
        number_of_set[y] += number_of_set[x]


input = lambda: sys.stdin.readline().rstrip()
for _ in range(int(input())):
    F = int(input())

    index = 0
    people = {}
    parents = [i for i in range(2*F+1)]
    number_of_set = [1 for i in range(2*F+1)]
    for _ in range(F):
        a, b = input().split()

        add_people(a)
        add_people(b)

        union(a, b)
        a, b = find_parent(people[a]), find_parent(people[b])
        print(max(number_of_set[a], number_of_set[b]))

 

a, b = find_parent(people[a]), find_parent(people[b]) 처럼 집합의 정체성은 항상 루트 노드에 있다는 사실을 놓치면 안된다.

 

union 연산으로 매번 집합의 원소가 추가되는 것이 아니므로 추가되는 경우와 그렇지 않은 경우를 나누어 생각한다.

 

 


 

 

도시 분할 계획

https://www.acmicpc.net/problem/1647

 

1647번: 도시 분할 계획

첫째 줄에 집의 개수 N, 길의 개수 M이 주어진다. N은 2이상 100,000이하인 정수이고, M은 1이상 1,000,000이하인 정수이다. 그 다음 줄부터 M줄에 걸쳐 길의 정보가 A B C 세 개의 정수로 주어지는데 A번

www.acmicpc.net

 

import sys
sys.setrecursionlimit(100_000)


def has_cycle(x, y):
    return x == y


def find_parent(x):
    if parents[x] != x:
        parents[x] = find_parent(parents[x])
    return parents[x]


def union(x, y, w):
    global answer, last_weight

    x = find_parent(x)
    y = find_parent(y)

    if not has_cycle(x, y):
        parents[x] = y
        answer += w
        last_weight = w


input = lambda: sys.stdin.readline().rstrip()
N, M = map(int, input().split())
roads = [list(map(int, input().split())) for _ in range(M)]

roads.sort(key=lambda x: x[2])

answer, last_weight = 0, None
parents = [i for i in range(N+1)]
for a, b, weight in roads:
    union(a, b, weight)

print(answer - last_weight)

 

트리의 부분집합 또한 트리라는 성질을 이용하여, 그리디로 접근

크루스칼은 간선의 가중치를 오름차순으로 정렬하기 때문에, 마지막에 이어지는 간선을 저장하고 총합에 빼면 된다.

'알고리즘' 카테고리의 다른 글

[알고리즘] 투 포인터  (1) 2024.03.22
[알고리즘] BFS DFS  (0) 2024.03.20
[알고리즘] 문제 풀이  (0) 2024.03.18
순열과 조합을 구현해보자  (0) 2024.01.08
이진분류, 하노이 탑  (0) 2024.01.08