https://www.acmicpc.net/problem/19663
정말 어려웠던 세그트리 연습문제
처음에는 뭔가 DP일 것 같았고, 실제로 분류에도 DP가 있어서 DP로 풀이를 짜고, 세그트리로 최적화시키는 건가 싶었다.
하지만 아무리 고민해도 DP 점화식이 생각이 안났고, 세그트리로 구현하는 방법만 계속 떠올랐다.
문제 이해
수열이 주어질 때, 주어진 수열에서 인덱스 순에 맞게 3개의 수를 뽑아 만든 임의의 튜플 (x, y, z) 에 대해, x < y, z < y 를 만족하는 모든 튜플 쌍의 개수를 구하는 문제이다.
첫 번째 시도
알고리즘 분류에 있는 '정렬' 에서 힌트를 얻어 먼저 데이터를 높이순으로 정렬한다.
어떤 기준점 y가 정해졌을 때, 이 값보다 작은 값들 중, y보다 왼쪽에 있는 값의 개수, 오른쪽에 있는 값의 개수를 센 뒤,
두 값을 곱한 결과를 정답에 더해주면 된다.
파란색으로 표시한 점이 기준점이라고 하면, 빨간색 부분이 정답을 찾을 후보지다.
나이브하게 찾는다면, 빨간색 영역을 한번 훑어서 O(n)의 시간에 찾을 수 있다.
이 경우, y를 모든 높이에 대해 설정해야 하므로, 전체 시간 복잡도는 O(n²) 이다.
이를 최적화 하기 위해, y보다 작은 높이의 개수를 세그먼트 트리를 통해 빠르게 찾고자 한 것이 첫번째 시도인데,
내가 생각한 아이디어는 y를 최댓값으로 해서, y보다 최댓값이 작은 데이터의 개수를 정렬하기 전 기존의 높이 데이터에 대해 선분 쿼리를 날리는 것이다.
미리 기존 데이터에 대해 최댓값에 대한 세그먼트 트리를 만든 뒤, 최댓값이 y-1 이하인 선분들의 총 길이를 y기준 왼쪽, 오른족에 대해 나눠서 구하면 된다.
시간초과
하지만 이 아이디어는 시간초과를 받았다.
대표적인 반례가 1, 0 으로만 구성된 데이터 셋이다.
길이는 최대 30만이므로, 30만 길이의 0 1 이 반복되는 문자열이 주어진다고 생각해보자.
그러면 주어진 y 값이 1일 때, 그 왼쪽, 오른쪽에서 0의 개수를 찾아야 한다.
이때 만약 0, 1 이 번갈아서만 나온다면, 사실상 선분쿼리는 의미가 없어지고 모든 0을 하나하나 탐색하면서 찾는 것과 다를바 없어진다.
길이가 2인 임의 선분의 최댓값은 항상 1이기 때문이다.
N = int(input())
heights = [0] + list(map(int, input().split()))
tree = [-1 for _ in range(4*N+1)]
def init(node, seg_start, seg_end):
if seg_start == seg_end:
tree[node] = heights[seg_start]
return tree[node]
mid = (seg_start + seg_end) // 2
left = init(node*2, seg_start, mid)
right = init(node*2 + 1, mid+1, seg_end)
tree[node] = max(left, right)
return tree[node]
init(1, 1, N)
def query_max(query_start, query_end, node, seg_start, seg_end):
if seg_end < query_start or query_end < seg_start:
return 0
if query_start <= seg_start <= seg_end <= query_end:
return tree[node]
mid = (seg_start + seg_end) // 2
left = query_max(query_start, query_end, node*2, seg_start, mid)
right = query_max(query_start, query_end, node*2 + 1, mid+1, seg_end)
return max(left, right)
def query_count(max_value, query_start, query_end, node, seg_start, seg_end):
if seg_end < query_start or query_end < seg_start:
return 0
# 현재 보고 있는 세그먼트가 찾는 대상 세그먼트에 포함되면
if query_start <= seg_start <= seg_end <= query_end:
if tree[node] < max_value:
return seg_end - seg_start + 1
if seg_start == seg_end:
if tree[node] >= max_value:
return 0
mid = (seg_start + seg_end) // 2
left = query_count(max_value, query_start, query_end, node*2, seg_start, mid)
right = query_count(max_value, query_start, query_end, node*2 + 1, mid+1, seg_end)
return left + right
answer = 0
for i in range(2, N):
left = query_count(heights[i], 1, i-1, 1, 1, N)
right = query_count(heights[i], i+1, N, 1, 1, N)
answer += left*right
# print(left, right)
print(answer)
이 코드가 해당 코드이다.
세그트리의 업데이트 기능 없이 오직 빠른 쿼리만을 활용하고자 했던 풀이였다.
두 번째 시도
첫번째 시도에서 막힌 뒤, 아이디어가 떠오르지 않아 아이디어를 찾아 인터넷을 헤맸다.
이 문제는 푼 사람이 아주 적은 문제로 (내가 풀 당시 22명..) 한국어 검색으로는 나오지 않을 것 같았다.
그러다 이 문제가 외국의 올림피아드 기출문제라는 점에서 착안하여 혹시 영어로 검색하면 솔루션이 나올까 싶어 해당 대회를 검색했고, 아이디어 솔루션을 기록한 깃 레포를 찾을 수 있었다.
하지만 이 풀이를 봐도 이해가 되지 않았다.
세그트리를 이용하는 것은 똑같은데, 정렬은 하지 않고, 값을 하나 처리할 때마다 세그트리의 업데이트 기능으로 높이를 더해준다고 한다. 높이를 왜 더하는지 모르겠어서 이 풀이를 적용하는 것을 포기했다..
세상엔 똑똑한 사람이 많다는 것을 다시 알아간다..ㅎ
세 번째 시도
이 문제는 신촌지역 연합 알고리즘 캠프의 강의 연습문제로 풀게 된 문제이다.
강의록을 다시 보다보면 힌트가 보이지 않을까 싶었는데, 알고보니 강의록에 문제에 대한 접근 힌트가 있었다.
이 접근 방법은 높이순으로 정렬하는 것은 같은 접근이었다.
다만 그 이후에 처리 방법이 조금 달랐다.
한번 사용한 낮은 높이의 데이터는 이후의 높은 기준점에 대해서도 계속 사용된다.
어떤 배열 P를 현재 보고 있는 데이터가 처리되었다면 1, 아직 처리되지 않았다면 0을 저장하는 배열이라고 해보자.
이 배열의 인덱스는 당연히 정렬하기 전 기준으로 구성된다.
만약 현재 보고 있는 기준 점에 대해서 왼쪽에서 자신보다 낮은 산들은 배열 P에서 1로 처리되었을 것이다!
따라서 배열 P에서 자신보다 왼쪽에 있는 1의 개수, 오른쪽에 있는 1의 개수를 세서 곱하면 된다.
그러면 이 문제는 구간합을 구하는 세그트리 문제로 바뀌게 된다.
한가지 사소한 부분을 신경쓴다면, 중복된 값이 들어올 수 있다는 점이다.
따라서 업데이트는 즉시즉시 하는게 아니라, 중복된 값에 대해서는 업데이트할 예정인 값으로 배열에 저장해뒀다가,
현재 보는 기준점이 직전 기준점보다 높아졌다면 업데이트할 예정인 값들을 모두 업데이트하는 식으로 구현했다.
import sys
input = sys.stdin.readline
N = int(input())
heights = list(map(int, input().split()))
height_index = zip(heights, range(1, len(heights) + 1))
tree = [0 for _ in range(4*N+1)]
def update(update_index, update_value, node, seg_start, seg_end):
if update_index < seg_start or seg_end < update_index:
return tree[node]
if update_index <= seg_start and seg_end <= update_index:
tree[node] = update_value
return tree[node]
mid = (seg_start + seg_end) // 2
left = update(update_index, update_value, node*2, seg_start, mid)
right = update(update_index, update_value, node*2+1, mid + 1, seg_end)
tree[node] = left + right
return tree[node]
def query(query_seg_start, query_seg_end, node, seg_start, seg_end):
if query_seg_start > query_seg_end:
return 0
if query_seg_end < seg_start or seg_end < query_seg_start:
return 0
if query_seg_start <= seg_start and seg_end <= query_seg_end:
return tree[node]
mid = (seg_start + seg_end) // 2
left = query(query_seg_start, query_seg_end, node*2, seg_start, mid)
right = query(query_seg_start, query_seg_end, node*2+1, mid + 1, seg_end)
return left + right
answer = 0
last_value = -1
will_update = []
for value, index in sorted(height_index):
if value > last_value:
for update_index in will_update:
update(update_index, 1, 1, 1, N)
will_update.clear()
last_value = value
will_update.append(index)
left = query(1, index-1, 1, 1, N)
right = query(index+1, N, 1, 1, N)
answer += left*right
print(answer)
약 4초정도로 꽤 긴 시간이 걸렸다.
빠른 시간 안에 푼 사람들 코드를 보니 세그트리를 구현할 때 재귀가 아니라 반복문 방식으로 구현한 것 같았다.
이 부분은 추가로 공부해봐야겠다.
'알고리즘 (PS) > BOJ' 카테고리의 다른 글
[백준] 1135 - 뉴스 전하기 (Java) (0) | 2024.11.04 |
---|---|
[백준] 13549 - 숨바꼭질 3 (BFS 풀이) (0) | 2024.10.01 |
[백준] 11003 - 최솟값 찾기 (Python, 우선순위 큐) (0) | 2024.07.23 |
[백준] 1600 - 말이 되고픈 원숭이 (Python) (2) | 2024.07.14 |
[백준] 11780 - 플로이드 2 (Python) (2) | 2024.07.07 |