algorithm/UnionFind

[백준24542/실버1] 튜터-튜티 관계의 수 - Python

ayeongjin 2025. 2. 5. 14:53

https://www.acmicpc.net/problem/24542

 

 

스터디에 유니온 파인드로 풀어야 시간초과 안나는 문제가 나와서 유니온파인드 공부했다.

https://ayeongjin.tistory.com/45

 

유니온-파인드(Union-Find) 알고리즘

1. 유니온-파인드란?유니온-파인드(Union-Find) 알고리즘은 서로소 집합(Disjoint Set) 을 관리하는 자료구조로, 대표적으로 그래프에서 사이클 판별, 네트워크 연결성 검사, 최소 신장 트리(MST) Kruskal

ayeongjin.tistory.com

 

 

일단 이론대로 문제 설계를 해보면

1. 필요한 값

  1. 어떤 집합인지 알아볼 수 있는 parent배열 (parent배열의 값이 같으면 같은 집합)
  2. 집합의 크기를 나타내는 size배열
  3. 랭크 기반 유니온을 위한 rank배열

2. 각 연결 관계마다 집합으로 묶어주기

  • 연결한 값마다 부모 노드와 집합의 크기, 트리의 깊이 갱신

3. 모든 관계를 집합으로 묶어준 후 각 집합의 크기를 전부 곱하기

  • 각 집합마다 한명만 찬솔이의 자료를 받으면 됨
  • 누가 자료를 받을지 한명씩 뽑아서 모든 경우의 수 구하기

 

# 성공 코드 1

# python 83768 KB, 548 ms

import sys
input = sys.stdin.readline

# 경로 압축을 적용한 find 함수
def find(x):
    if parent[x] != x:               # 루트 노드가 아니라면
        parent[x] = find(parent[x])  # 루트 노드 찾으면서 경로 압축
    return parent[x]                 # 루트 노드 반환

# 유니온 연산 시, size를 업데이트하여 트리 크기 관리
def union(x, y):
    rootX = find(x) # x의 루트 노드 찾기
    rootY = find(y) # y의 루트 노드 찾기

    # 서로의 루트 노드가 다르면 (다른 집합이면) 연결해주기
    if rootX != rootY:
        if rank[rootX] < rank[rootY]:   # 트리의 높이가 더 작은 쪽(rootX)을 큰 쪽(rootY)에 붙임
            parent[rootX] = rootY
            size[rootY] += size[rootX]  # 합쳐진 집합의 크기 합치기
        elif rank[rootX] > rank[rootY]:
            parent[rootY] = rootX
            size[rootX] += size[rootY]
        else:                           # 트리의 높이가 같으면
            parent[rootY] = rootX       # 아무거나 루트 노드 갱신
            size[rootX] += size[rootY]  # 갱신한 루트 쪽으로 집합의 크기 합치기
            rank[rootX] += 1            # 갱신한 루트 쪽으로 트리 높이 추가

# 입력 처리
N, M = map(int, input().split())
relationship = [list(map(int, input().split())) for _ in range(M)]

# 초기화
parent = [i for i in range(N)]  # 각 노드는 처음에 자기 자신이 루트
rank = [1] * N  # 각 토드가 속한 트리의 높이 (초기에는 1)
size = [1] * N  # 각 노드가 속한 집합의 크기 (초기에는 자기 자신만 있으므로 1)

# 유니온 파인드 적용
for m1, m2 in relationship:
    union(m1-1, m2-1)

# 각 집합의 크기 계산
result = 1
seen = set()

for i in range(N):
    root = find(i)        # i번 노드의 최상위 루트 찾기
    if root not in seen:  # 처음보는 집합이 나오면
        seen.add(root)    # 루트 기록 후 결과 값 갱신
        result = (result * size[root]) % 1000000007

print(result)

 

 

계산을 하고 보니 rank는 필요 없었다.

 

집합의 크기를 알아야하는 문제는 size기반 유니온으로 풀이하고, 집합의 크기를 몰라도 되는 문제는 rank기반으로 트리의 깊이만 최적화하면서 풀면 될 것 같다.

 

 

# 성공 코드 2

# python 42728 KB, 368 ms

import sys
input = sys.stdin.readline

# 경로 압축을 적용한 find 함수
def find(x):
    if parent[x] != x:               # 루트 노드가 아니라면
        parent[x] = find(parent[x])  # 루트 노드 찾으면서 경로 압축
    return parent[x]                 # 루트 노드 반환

# 유니온 연산 시, size를 업데이트하여 트리 크기 관리
def union(x, y):
    rootX = find(x) # x의 루트 노드 찾기
    rootY = find(y) # y의 루트 노드 찾기

    # 서로의 루트 노드가 다르면 (다른 집합이면) 연결해주기
    if rootX != rootY:
        if size[rootX] < size[rootY]:   # 트리의 높이가 더 작은 쪽(rootX)을 큰 쪽(rootY)에 붙임
            parent[rootX] = rootY
            size[rootY] += size[rootX]  # 합쳐진 집합의 크기 합치기
        else:                           # 트리의 높이가 같으면
            parent[rootY] = rootX       # 아무거나 루트 노드 갱신
            size[rootX] += size[rootY]  # 갱신한 루트 쪽으로 집합의 크기 합치기

# 입력 처리
N, M = map(int, input().split())

# 초기화
parent = [i for i in range(N)]  # 각 노드는 처음에 자기 자신이 루트
size = [1] * N  # 각 노드가 속한 집합의 크기 (초기에는 자기 자신만 있으므로 1)

# 유니온 파인드 적용
for _ in range(M):
    m1, m2 = map(int, input().split())
    union(m1-1, m2-1)

# 각 집합의 크기 계산
result = 1

for i in range(N):
    if parent[i] == i:
        result = (result * size[i]) % 1000000007

print(result)