이전 글 에서 언급된 문제와 유사한 문제가 오일러 프로젝트에 있어서 같은 방법으로 풀어보고자 한다.
숫자 1406357289는 0-9팬디지털인데, 부분열에 관련된 재미있는 성질을 가지고 있습니다.
d1을 첫째자리수, d2를 둘째자리수...라고 했을 때 다음과 같은 재미있는 사실을 발견할 수 있습니다.
- d2d3d4 = 406 : 2로 나눠짐
- d3d4d5 = 063 : 3으로 나눠짐
- d4d5d6 = 635 : 5로 나눠짐
- d5d6d7 = 357 : 7로 나눠짐
- d6d7d8 = 572 : 11로 나눠짐
- d7d8d9 = 728 : 13으로 나눠짐
- d8d9d10 = 289 : 17로 나눠짐
위와 같은 성질을 갖는 0~9 팬디지털을 모두 찾아서 그 합을 구하면 얼마입니까?
(출처 : 오일러프로젝트 43 - http://euler.synap.co.kr/prob_detail.php?id=43)
이전과 같은 방식으로 앞에서부터 숫자를 붙여나간다고 생각할 때 4자리가되는 시점부터 해당 수의 끝 세자리는 (n-3)번째 소수로 나눠진다.
팬디지털 숫자이기 때문에 0-9의 숫자들의 순열을 만들어서 점검해보려해도 이는 10!가지나 되기 때문에 그다지 좋은 아이디어가 아니고, 앞에서부터 조건을 만족하는지를 평가하면서 답을 구성할 수 없는 조합을 빠르게 치워나가는 것이 중요하다.
재귀함수로 이 문제를 풀어보자면, 첫번째 자리를 찾는 깊이를 0이라 할 때 레벨 9까지 진행하고, 레벨 3부터는 2, 3, 5... 의 소수로 나눠지는 조건을 만족해야함을 알 수 있다.
따라서 재귀함수로 구현된 풀이는 다음과 같아야 한다.
primes = (1,1,1,2,3,5,7,11,13,17)
def check(n:int, level:int, s:set) -> [int]:
if level > 9:
return [n]
result = []
### 맨 앞자리에 0이 오는 것을 방지하기 위해 범위는 레벨에 따라
### 구분한다.
r = range(10) if level > 0 else range(1, 10)
for x in (x for x in r if x not in s):
m = n * 10 + x
if (m % 1000) % primes[level] != 0:
continue
s.add(x)
temp = check(m, level+1, s)
if temp:
result += temp
return result
%time print(sum(check(0, 0, set())))
## 16695334890
## Wall time: 83ms
다른 풀이
참고로 이 문제는 리스트 축약(list comprehension)으로도 풀 수 있다. 그다지 예쁜 풀이는 아니지만, 각 단계별로 조건을 만들어서 검사하는 것이다. 이런식으로 엄청나게 깊이 중첩된 리스트 축약을 사용하는 것도 가능하다는 것을 보여주는 예시라 보면 되겠다.
심지어 이 풀이는 함수 호출에 드는 비용이 없기 때문에 더 빠르기 까지 하다.
from functools import reduce
def solve():
nums = [(d1, d2, d3, d4, d5, d6, d7, d8, d9, d10)\
for d1 in range(9)\
## d2는 d1과 같지 않음
for d2 in range(10) if d2 != d1 \
## d3는 d2, d1과 같지 않음
for d3 in range(10) if d3 not in (d1, d2) \
## d4는 짝수만 가능
for d4 in range(0, 10, 2) \
if d4 not in (d1, d2, d3) \
## d5는 d1~d4와 다른 숫자이면서
for d5 in range(10) \
if d5 not in (d1, d2, d3, d4) \
### d3d4d5가 3의 배수일 것
and (d3 * 100 + d4 * 10 + d5) % 3 == 0\
## d6는 5의 배수의 끝자리이므로 0혹은 5만 가능
for d6 in (0, 5) \
if d6 not in (d1, d2, d3, d4, d5)\
## 이후 같은 식으로 조건에 맞는 중첩되지 않는 숫자만 고름
for d7 in range(10) \
if d7 not in (d1, d2, d3, d4, d5, d6)\
and (d5 * 100 + d6 * 10 + d7) % 7 == 0\
for d8 in range(10)\
if d8 not in (d1, d2, d3, d4, d5, d6, d7)\
and (d6 * 100 + d7 * 10 + d8) % 11 == 0\
for d9 in range(10)\
if d9 not in (d1, d2, d3, d4, d5, d6, d7, d8)\
and (d7 * 100 + d8 * 10 + d9) % 13 == 0\
for d10 in range(10)\
if d10 not in (d1, d2, d3, d4, d5, d6, d7, d8, d9)\
and (d8 * 100 + d9 * 10 + d10) % 17 == 0\
]
## 모여진 결과는 정수튜플이므로 하나의 정수로 합치고
## 그 결과를 합산
result = sum(reduce(lambda x,y: x*10+y, d) for d in nums)
print(result)
%time solve()
## 16695334890
## Wall time: 68 ms