알고리즘/graph

[백준]2887 행성터널 #mst#크루스칼

씩씩한 IT블로그 2020. 6. 14. 16:34
반응형

1. 풀이

일반적인 mst문제이고, N개의 노드를 모두 연결하는 edge를 만들어서(nC2개) 크루스칼을 적용하면 정답은 맞지만 시간초과가 발생한다. (소스코드1)​

시간초과를 줄일 수 있는 방법은 아래의 조건을 이용해서 사용하는 edge의 개수를 줄이는것. (소스코드2)

더보기

행성은 3차원 좌표위의 한 점으로 생각하면 된다. 두 행성 A(xA, yA, zA)와 B(xB, yB, zB)를 터널로 연결할 때 드는 비용은 min(|xA-xB|, |yA-yB|, |zA-zB|)이다.

그리고 위조건을 이용하는 방법은 아래와 같다.

(1) 1차원 좌표에서 N개의 점이 있고, 이때 N개의 점으로 mst를 만드는 방법은 왼쪽부터 i번째 점을 i+1번째 점과 연결해나가는 방법이다. (n-1개)

(2) 3차원에서도 똑같은 방법을 쓰면된다. 단 좌표축이 3개니까

- x축기준으로 i번째와 i+1번째점

- y축기준을 i번째와 i+1번째점

- z축기준으로 i번째와 i+1번째점을 연결한다  => 총 3*(N-1)개

2.소스코드

(1) 소스코드1 (시간초과)

from itertools import combinations
N=int(input())
L=[]
for i in range(N):
    L.append(list(map(int,input().split())))
parent=[i for i in range(N)]

def calCost(A,B):
    return min(abs(A[0]-B[0]),abs(A[1]-B[1]),abs(A[2]-B[2]))

def find(a):
    if parent[a]==a:
        return a
    else:
        parent[a]=find(parent[a])
        return parent[a]

def union(a,b):
    a=find(a)
    b=find(b)
    parent[b]=a

edge=[]
comb=list(combinations([i for i in range(N)],2))
for a,b in comb:
    edge.append([calCost(L[a],L[b]),a,b])
edge.sort()
edgeSize=len(edge)

ans=0
for i in range(edgeSize):
    cost,a,b=edge[i]
    if find(a)!=find(b):
        ans+=cost
        union(a,b)

print(ans)

 

(2) 소스코드2 (통과)

from itertools import combinations
N=int(input())
L=[]
for i in range(N):
    x,y,z=map(int, input().split())
    L.append([x,y,z,i])
parent=[i for i in range(N)]

def returnX(A):
    return A[0]
def returnY(A):
    return A[1]
def returnZ(A):
    return A[2]

def find(a):
    if parent[a]==a:
        return a
    else:
        parent[a]=find(parent[a])
        return parent[a]

def union(a,b):
    a=find(a)
    b=find(b)
    parent[b]=a


edge=[]
L.sort(key=returnX)
for i in range(N-1):
    edge.append([L[i + 1][0] - L[i][0], L[i][3], L[i+1][3]])
L.sort(key=returnY)
for i in range(N-1):
    edge.append([L[i + 1][1] - L[i][1], L[i][3], L[i+1][3]])
L.sort(key=returnZ)
for i in range(N-1):
    edge.append([L[i + 1][2] - L[i][2], L[i][3], L[i+1][3]])

edge.sort()
edgeSize=len(edge)
ans=0
for i in range(edgeSize):
    cost,a,b=edge[i]
    if find(a)!=find(b):
        ans+=cost
        union(a,b)

print(ans)
반응형