https://www.acmicpc.net/problem/2213
트리를 구성하는 정점에 가중치가 있을 때, 인접하지 않은 정점들로만 구성된 정점의 부분집합에 대해 가중치의 합이 최대가 되는 부분 집합을 구하는 문제이다. 그에 더해 이 문제는 이 가중치의 합의 최댓값에 더해, 그 값이 나오도록 하는 부분 집합의 원소까지 구해야 한다.
이 문제는 트리 DP 로 유명하다.
트리는 그 형태가 이미 재귀적이기 때문에 DP 로 표현하기가 매우 좋다.
이 문제의 DP 테이블을 다음과 같이 정의해보자.
DP[i][j] = i 번째 노드를 루트로 하는 트리에 대해, 이 노드를 정점에 포함 (j = 0) 또는 포함하지 않을 때 (j = 1) 의 가중치 합의 최댓값
이제 점화식을 세워보면
만약 i 번째 노드를 계산에 포함한다면, 이 노드의 직속 자식 노드들은 모두 사용할 수 없다.
따라서 DP[i][0] = sum (DP[k][1]) + v[i] (k는 모든 자식 노드의 정점 번호) 이다.
만약 i 번째 노드를 계산에 포함시키지 않는다면, 이 노드의 직속 자식 노드들은 모두 포함하거나 포함하지 않을 수 있다.
각각의 경우에 대해 최댓값을 취하면 되므로
DP[i][1] = sum( max(DP[k][0], DP[k][1]) ) (k는 모든 자식 노드들의 정점 번호) 이다.
이 과정에서 DP 값을 중복해서 구하는 경우가 발생하므로 메모이제이션을 통해 최적화를 해주면 된다.
이 문제의 답을 구할 때는 가중치 합의 최대 뿐만 아니라, 그 최댓값을 구성하는 정점들의 정보도 알아야 한다.
역추적 방법을 고민하기 위해 리프노드에서부터 위로 올라가는 방법을 떠올려보자.
리프노드 2개를 자식으로 갖는 노드가 있다고 해보자.
이 노드에 대해 DP[i][0] 을 구하는 경우, 리프노드 2개를 모두 사용하지 않고 자신만을 사용하므로, 자신의 정점 번호만 집합에 추가하면 된다.
이 노드에 대해 DP[i][1] 을 구하는 경우, 각 리프노드 2개를 포함할 지 말지 결정해서, 최댓값이 포함하는 경우라면 리프노드 정점 번호를 포함하고, 최댓값이 포함하지 않는 경우라면 리프노드 정점 번호를 포함하지 않도록 해서 그 정보를 부모 노드로 올려보낸 뒤, 부모 노드는 그 정보를 합쳐서 그대로 자신의 정점 집합으로 쓰면 된다. 자기 자신읜 포함하지 않기 때문이다.
이 과정을 위로 갈 수록 재귀적으로 구현하면 된다.
import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**5)
USE, NOT_USE = 0, 1
def solve(node, used):
if dp[node][used][0] != 0:
return dp[node][used]
if used == USE:
ans = weights[node]
combi = [node]
for conn_node in graph[node]:
if visit[conn_node]:
continue
visit[conn_node] = True
_ans, _combi = solve(conn_node, NOT_USE)
ans += _ans
combi += _combi
else:
ans = 0
combi = []
for conn_node in graph[node]:
if visit[conn_node]:
continue
visit[conn_node] = True
_ans1, _combi1 = solve(conn_node, USE)
visit[conn_node] = True
_ans2, _combi2 = solve(conn_node, NOT_USE)
if _ans1 > _ans2:
ans += _ans1
combi += _combi1
else:
ans += _ans2
combi += _combi2
dp[node][used] = (ans, tuple(combi))
visit[node] = False
return dp[node][used]
n = int(input())
graph = [[] for _ in range(n+1)]
weights = [0] + list(map(int, input().split()))
visit = [False]*(n+1)
for _ in range(n-1):
s, e = map(int, input().split())
graph[s].append(e)
graph[e].append(s)
dp = [[(0, tuple()), (0, tuple())] for _ in range(n+1)]
visit[1] = True
answer1, combi1 = solve(1, USE)
visit[1] = True
answer2, combi2 = solve(1, NOT_USE)
#
# print(answer1, combi1)
# print(answer2, combi2)
if answer1 > answer2:
print(answer1)
print(*sorted(combi1))
else:
print(answer2)
print(*sorted(combi2))
코드 구현은 위와 같다.
solve 함수는 정점 번호와, 사용 여부를 받아서 그 상태의 최대 가중치 합의 값과 그 때의 노드 구성을 반환하는 함수이다.
'알고리즘 (PS) > BOJ' 카테고리의 다른 글
[백준] 1949 - 우수 마을 (Java) (0) | 2024.11.06 |
---|---|
[백준] 2533 - 사회망 서비스(SNS) (Python) (0) | 2024.11.05 |
[백준] 1135 - 뉴스 전하기 (Java) (0) | 2024.11.04 |
[백준] 13549 - 숨바꼭질 3 (BFS 풀이) (0) | 2024.10.01 |
[백준] 19663 - Mountains (0) | 2024.07.26 |