2) DP - BOJ 거리
위의 링크에서 문제를 꼭 읽어보세요!
예)
idx | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
---|---|---|---|---|---|---|---|---|---|
블럭 | B | O | J | B | O | J | B | O | J |
0 -> 7 -> 8과 같이 한번에 건너갈 수도 있지만,
0 -> 1 -> 2 -> 3 -> 4 -> 5 -> 6 -> 7 -> 8 ->9 처럼 한칸씩만 건널 수도 있습니다.
결과적으로 소모값은 칸수의 제곱에 비례해서 증가하니까 가급적이면 짧게 건너는게 바람직할 것입니다.
시도1 : Greedy
그러면 Greedy하게 가장 가까이에 있는 칸을 계속 건너는 방식을 생각해볼 수 있습니다.
그렇지만, 이런 방식은 최적해에는 도달하지 못합니다.
예)
idx | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
블럭 | B | J | B | O | J | O | J | O | O | J | O | B | O | O | O |
방법1 | 1 | 2 | 3 | 4 | 5 | ?? | |||||||||
방법2 | 1 | 2 | 3 | 4 | 5 | ||||||||||
최적 | 1 | 2 | 3 | 4 | 5 |
방법 1은 Greedy하게 가장 가까이에 있는 블럭을 계속해서 탐색한 방법입니다.
위 방식에 의하면, 11번째 B에서 바로 다음에 O가 있으므로 건너가는데, 그 다음에는 J가 없기 때문에 끝까지 건널 수가 없습니다.
방법 2는 방법 1을 약간 수정하여 다시 11번째 B로 돌아가서 바로 마지막 블럭이 O인 것을 확인하고 건너는 방법입니다.
하지만, 최적해는 전혀 다른 방식으로 건너고 있음을 확인할 수 있습니다.
시도2 : 완전탐색
결국 어떤 보도블록을 건너야 할 지 쉽게 알 수가 없습니다.
답이 딱히 안보이니 모든 경우를 탐색해봅시다.
논리는 간단합니다.
예를 들어서, 3번째 O에서 다음에 건널 J를 선택한다고 해봅시다.
가능한 J는 총 3가지 경우가 있으므로, 이 3가지 경우를 모두 판단하는 것입니다.
3가지 경우를 전부 판단해보고, 이 중에 최솟값을 선택하면 됩니다.
idx | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
블럭 | B | J | B | O | J | O | J | O | O | J | O | B | O | O | O |
1 | 2 | 3 | -> | ||||||||||||
1 | 2 | 3 | -> | ||||||||||||
1 | 2 | 3 | -> |
이런 논리를 재귀함수로 구성하면 다음과 같습니다.
N = int(input())
string = input()
# go(index) : index -> N-1까지 도달하기 위한 에너지의 최솟값
def go(index):
if index == N-1:
return 0
else:
curr = string[index]
ans = -1
# 1. 가능한 후보들(i)을 찾는다.
for i in range(index+1, N):
dest = string[i]
if dest == getnext(curr):
# 2. 이 후보(i)에서 N-1까지 도달할 수 있는 최솟값을 측정한다.
temp = go(i)
if temp != -1:
val = (i-index)*(i-index) + temp
# 3. 만약 도달가능하다면 최솟값인지 확인하고 갱신한다.
if ans == -1 or ans > val:
ans = val
return ans
def getnext(curr):
if curr == "B": return "O"
elif curr == "O": return "J"
return "B"
print(go(0))
이 방식은 시간초과가 발생할 것입니다.
왜냐하면 go(index)를 활용하지 않고 계속적으로 구하기 때문입니다.
예)
idx | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
블럭 | B | J | B | O | J | O | J | O | O | J | O | B | O | O | O |
1 | 2 | 3 | 4 | ||||||||||||
1 | 2 | 3 | 4 | ||||||||||||
1 | 2 | 3 | 4 |
위의 경우에서 생각해보면,
go(3) = 1 + go(4) = 1 + 49 + go(11)
or
= 9 + go(6) = 9 + 25 + go(11)
or
= 36 + go(9) = 36 + 4 + go(11)
이 되는데,
go(3)을 구하기 위해서 똑같은 go(11)을 3번이나 구해야합니다.
예시의 문자열이 적어서 망정이지 N이 1000에 근접해서 go(11) 뒤에 950개의 BOJ들이 있다고 생각하면 막막합니다.
시도3 : Dynamic Programming(DP)
그러면, 해결책은 간단합니다.
go(11)을 어딘가에 저장해놓고 쓰면 됩니다.
go(index)를 저장하고 꺼내서 사용하는 코드 몇줄만 작성하면 훨씬 빨라집니다.
N = int(input())
string = input()
d = [-1 for _ in range(N)]
# d[i] : go(i)의 값을 저장해놓는 배열
# go(i)는 최소 0이기 때문에 아직 저장하지 않았다면 -1로 설정한다.
def go(index):
if index == N - 1:
return 0
else:
# go(index)가 저장되어 있다면 활용하기
if d[index] != -1:
return d[index]
curr = string[index]
ans = -1
for i in range(index + 1, N):
dest = string[i]
if dest == getnext(curr):
temp = go(i)
if temp != -1:
val = (i - index) * (i - index) + temp
if ans == -1 or ans > val:
ans = val
# go(index)가 저장되어 있지 않다면 d에 저장하기.
ans = d[index]
return ans
def getnext(curr):
if curr == "B":
return "O"
elif curr == "O":
return "J"
return "B"
print(go(0))
위 방식에서는 같은 go(index)를 계속해서 계산하지 않기 때문에
시간복잡도는 O(N)에 불과할 것입니다.
정리
왜 이러한 방식이 가능할까요?
idx | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
블럭 | B | J | B | O | J | O | J | O | O | J | O | B | O | O | O |
1 | 2 | 3 | 4 | ||||||||||||
1 | 2 | 3 | 4 | ||||||||||||
1 | 2 | 3 | 4 |
왜냐하면 go(11)의 값이 위 3가지 경우에 대해서 전부 똑같기 때문입니다.
만약 go(11)이 이전 경로에 따라서 변하게 된다면, 부분적으로 문제를 해결할 수가 없을 것입니다.