
문제
N줄에 0 이상 9 이하의 숫자가 세 개씩 적혀 있다. 내려가기 게임을 하고 있는데, 이 게임은 첫 줄에서 시작해서 마지막 줄에서 끝나게 되는 놀이이다.
먼저 처음에 적혀 있는 세 개의 숫자 중에서 하나를 골라서 시작하게 된다. 그리고 다음 줄로 내려가는데, 다음 줄로 내려갈 때에는 다음과 같은 제약 조건이 있다. 바로 아래의 수로 넘어가거나, 아니면 바로 아래의 수와 붙어 있는 수로만 이동할 수 있다는 것이다. 이 제약 조건을 그림으로 나타내어 보면 다음과 같다.

별표는 현재 위치이고, 그 아랫 줄의 파란 동그라미는 원룡이가 다음 줄로 내려갈 수 있는 위치이며, 빨간 가위표는 원룡이가 내려갈 수 없는 위치가 된다. 숫자표가 주어져 있을 때, 얻을 수 있는 최대 점수, 최소 점수를 구하는 프로그램을 작성하시오. 점수는 원룡이가 위치한 곳의 수의 합이다.
입력
첫째 줄에 N(1 ≤ N ≤ 100,000)이 주어진다. 다음 N개의 줄에는 숫자가 세 개씩 주어진다. 숫자는 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 중의 하나가 된다.
출력
첫째 줄에 얻을 수 있는 최대 점수와 최소 점수를 띄어서 출력한다.
문제풀이
되게 쉽다 생각하고, 바로 DP로 풀이했다가 시간초과가 아닌 '메모리 초과'가 발생했다.
메모리 초과가 발생하는건 너무 오랜만이라... 그리고, 코드도 충분히 짧다고 생각했었다.
오답 코드
import sys
N = int(sys.stdin.readline().strip())
table = [list(map(int,sys.stdin.readline().split())) for _ in range(N)]
dp_max = [[0] * 3 for _ in range(N)]
dp_min = [[0] * 3 for _ in range(N)]
dp_max[0][0] = table[0][0]
dp_max[0][1] = table[0][1]
dp_max[0][2] = table[0][2]
dp_min[0][0] = table[0][0]
dp_min[0][1] = table[0][1]
dp_min[0][2] = table[0][2]
for i in range(1,N):
dp_max[i][0] = table[i][0] + max(dp_max[i-1][0], dp_max[i-1][1])
dp_max[i][1] = table[i][1] + max(dp_max[i-1][0], dp_max[i-1][1], dp_max[i-1][2])
dp_max[i][2] = table[i][2] + max(dp_max[i-1][1], dp_max[i-1][2])
for i in range(1,N):
dp_min[i][0] = table[i][0] + min(dp_min[i-1][0], dp_min[i-1][1])
dp_min[i][1] = table[i][1] + min(dp_min[i-1][0], dp_min[i-1][1], dp_min[i-1][2])
dp_min[i][2] = table[i][2] + min(dp_min[i-1][1], dp_min[i-1][2])
print(max(dp_max[N-1]), end=" ")
print(min(dp_min[N-1]), end=" ")왜 틀렸는지를 도저히 모르겠어서, GPT에게 도움을 받았다.
python의 리스트는 포인터이며, 어쩌고 저쩌고...
정말 필요한 만큼의 변수만 선언해서 쓰라는 내용의 답변을 받았다.
그렇기에, DP를 2차원 배열이 아닌, 1차원 배열로 선언해서 풀기로 마음을 먹었다.
정답 코드
import sys
N = int(sys.stdin.readline().strip())
dp = [0 for _ in range(6)]
dp[0], dp[2], dp[4] = map(int,sys.stdin.readline().split())
dp[1] = dp[0]
dp[3] = dp[2]
dp[5] = dp[4]
max_result = 0
min_result = 1e3
for i in range(1,N):
one, two , three = map(int, sys.stdin.readline().split())
one_max = one + max(dp[0], dp[2])
one_min = one + min(dp[1], dp[3])
two_max = two + max(dp[0], dp[2], dp[4])
two_min = two + min(dp[1], dp[3], dp[5])
three_max = three + max(dp[2], dp[4])
three_min = three + min(dp[3], dp[5])
dp[0] = one_max
dp[1] = one_min
dp[2] = two_max
dp[3] = two_min
dp[4] = three_max
dp[5] = three_min
print(max(dp), min(dp))