본문으로 건너뛰기
Advertisement

5.5 재귀와 메모이제이션 — 피보나치, 트리 순회, lru_cache

재귀(Recursion)는 함수가 자기 자신을 호출하는 기법으로, 분할 정복(Divide and Conquer), 트리 탐색, 수학적 정의를 코드로 직접 표현할 때 강력합니다. 메모이제이션과 결합하면 성능 문제도 해결할 수 있습니다.


재귀의 기본 구조

재귀 함수는 반드시 두 부분으로 구성됩니다:

  1. 기저 사례(Base Case): 재귀를 멈추는 조건
  2. 재귀 호출(Recursive Case): 자기 자신을 더 작은 문제로 호출
# 팩토리얼: n! = n × (n-1)!
def factorial(n: int) -> int:
# 기저 사례
if n <= 1:
return 1
# 재귀 호출
return n * factorial(n - 1)


print(factorial(5)) # 120 (5 × 4 × 3 × 2 × 1)
print(factorial(10)) # 3628800


# 재귀 호출 트레이싱
def factorial_trace(n: int, depth: int = 0) -> int:
indent = " " * depth
print(f"{indent}factorial({n}) 호출")
if n <= 1:
print(f"{indent}→ 반환: 1 (기저 사례)")
return 1
result = n * factorial_trace(n - 1, depth + 1)
print(f"{indent}→ 반환: {n} × ... = {result}")
return result


factorial_trace(4)
# factorial(4) 호출
# factorial(3) 호출
# factorial(2) 호출
# factorial(1) 호출
# → 반환: 1 (기저 사례)
# → 반환: 2 × ... = 2
# → 반환: 3 × ... = 6
# → 반환: 4 × ... = 24

콜 스택 이해

import sys

# 기본 재귀 한계
print(sys.getrecursionlimit()) # 보통 1000

# 큰 재귀는 RecursionError 발생
def count_down(n: int) -> None:
if n == 0:
return
count_down(n - 1)

try:
count_down(5000) # 기본 한계 초과
except RecursionError as e:
print(f"RecursionError: {e}")

# 한계 늘리기 (신중하게!)
sys.setrecursionlimit(10000)
count_down(5000) # 이제 성공

# 실제 스택 프레임 확인
import traceback

def recursive_func(n):
if n == 0:
print(f"현재 스택 깊이 (traceback): {len(traceback.extract_stack())}")
return
recursive_func(n - 1)

recursive_func(10)

피보나치: 순진한 재귀의 문제

# 순진한 재귀 구현 (매우 느림!)
def fib_naive(n: int) -> int:
if n <= 1:
return n
return fib_naive(n - 1) + fib_naive(n - 2)

# fib(5) = fib(4) + fib(3)
# = (fib(3) + fib(2)) + (fib(2) + fib(1))
# = ... 지수적으로 증가!

import time

start = time.time()
print(fib_naive(30)) # 832040 (느림)
print(f"시간: {time.time() - start:.3f}초")
# 약 0.3초

# fib(35) → 약 3초
# fib(40) → 약 30초 (매우 느림)

functools.lru_cache: 메모이제이션

from functools import lru_cache, cache
import time

# lru_cache 적용
@lru_cache(maxsize=None) # maxsize=None이면 무제한 캐시
def fib_cached(n: int) -> int:
if n <= 1:
return n
return fib_cached(n - 1) + fib_cached(n - 2)

start = time.time()
print(fib_cached(100)) # 354224848179261915075
print(f"시간: {time.time() - start:.6f}초") # 거의 0초!

# 캐시 정보 확인
print(fib_cached.cache_info())
# CacheInfo(hits=98, misses=101, maxsize=None, currsize=101)

# 캐시 초기화
fib_cached.cache_clear()
print(fib_cached.cache_info())
# CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)


# Python 3.9+: functools.cache (maxsize=None의 단순 버전)
@cache
def fib_fast(n: int) -> int:
if n <= 1:
return n
return fib_fast(n - 1) + fib_fast(n - 2)

print(fib_fast(200)) # 280571172992510140037611932413038677189525

꼬리 재귀와 반복으로의 전환

Python은 꼬리 재귀 최적화(TCO)를 지원하지 않습니다. 깊은 재귀는 항상 RecursionError 위험이 있으므로, 반복으로 전환하는 것이 좋습니다.

# 꼬리 재귀 버전 (Python에서는 여전히 스택 쌓임)
def factorial_tail(n: int, acc: int = 1) -> int:
if n <= 1:
return acc
return factorial_tail(n - 1, acc * n)

# 반복으로 전환 (권장)
def factorial_iter(n: int) -> int:
result = 1
for i in range(2, n + 1):
result *= i
return result

print(factorial_tail(10)) # 3628800
print(factorial_iter(10)) # 3628800

# 피보나치 반복 버전
def fib_iter(n: int) -> int:
if n <= 1:
return n
a, b = 0, 1
for _ in range(2, n + 1):
a, b = b, a + b
return b

print(fib_iter(100)) # 354224848179261915075 — 즉시 계산!

고전 예제: 하노이 탑

def hanoi(n: int, source: str, target: str, auxiliary: str, moves: list = None) -> list:
"""n개 디스크를 source에서 target으로 이동하는 최소 이동 순서"""
if moves is None:
moves = []

if n == 1:
moves.append((source, target))
return moves

# 1. (n-1)개를 auxiliary로 이동
hanoi(n - 1, source, auxiliary, target, moves)
# 2. 가장 큰 디스크를 target으로 이동
moves.append((source, target))
# 3. (n-1)개를 auxiliary에서 target으로 이동
hanoi(n - 1, auxiliary, target, source, moves)

return moves


moves = hanoi(3, "A", "C", "B")
for i, (src, tgt) in enumerate(moves, 1):
print(f" {i}: {src}{tgt}")

print(f"총 이동 횟수: {len(moves)}") # 7 (2^3 - 1)

# 수학적 검증: n개 디스크의 최소 이동 횟수 = 2^n - 1
for n in range(1, 8):
moves = hanoi(n, "A", "C", "B")
expected = 2**n - 1
assert len(moves) == expected
print(f" n={n}: {len(moves)}회 (예상: {expected})")

트리 순회

from dataclasses import dataclass, field
from typing import Optional


@dataclass
class TreeNode:
value: int
left: Optional["TreeNode"] = None
right: Optional["TreeNode"] = None


def insert(root: Optional[TreeNode], value: int) -> TreeNode:
"""BST에 값 삽입"""
if root is None:
return TreeNode(value)
if value < root.value:
root.left = insert(root.left, value)
elif value > root.value:
root.right = insert(root.right, value)
return root


# BST 구성
root = None
for v in [5, 3, 7, 1, 4, 6, 8]:
root = insert(root, v)


# 재귀 트리 순회
def inorder(node: Optional[TreeNode]) -> list[int]:
"""중위 순회: Left → Root → Right (BST에서 정렬 순서)"""
if node is None:
return []
return inorder(node.left) + [node.value] + inorder(node.right)

def preorder(node: Optional[TreeNode]) -> list[int]:
"""전위 순회: Root → Left → Right"""
if node is None:
return []
return [node.value] + preorder(node.left) + preorder(node.right)

def postorder(node: Optional[TreeNode]) -> list[int]:
"""후위 순회: Left → Right → Root"""
if node is None:
return []
return postorder(node.left) + postorder(node.right) + [node.value]


print("중위:", inorder(root)) # [1, 3, 4, 5, 6, 7, 8]
print("전위:", preorder(root)) # [5, 3, 1, 4, 7, 6, 8]
print("후위:", postorder(root)) # [1, 4, 3, 6, 8, 7, 5]


# 트리 높이 계산
def tree_height(node: Optional[TreeNode]) -> int:
if node is None:
return 0
return 1 + max(tree_height(node.left), tree_height(node.right))

print("트리 높이:", tree_height(root)) # 3

이진 탐색 (재귀 구현)

def binary_search(arr: list[int], target: int, left: int = 0, right: int = None) -> int:
"""정렬된 리스트에서 target의 인덱스 반환, 없으면 -1"""
if right is None:
right = len(arr) - 1

# 기저 사례: 탐색 범위가 없음
if left > right:
return -1

mid = (left + right) // 2

if arr[mid] == target:
return mid
elif arr[mid] < target:
return binary_search(arr, target, mid + 1, right)
else:
return binary_search(arr, target, left, mid - 1)


sorted_arr = list(range(0, 100, 2)) # [0, 2, 4, ..., 98]
print(binary_search(sorted_arr, 42)) # 21
print(binary_search(sorted_arr, 43)) # -1

실전 예제 1: 파일 시스템 탐색

import os
from pathlib import Path


def find_files(directory: str, extension: str) -> list[str]:
"""재귀적으로 특정 확장자 파일 찾기"""
result = []
path = Path(directory)

for item in path.iterdir():
if item.is_dir():
# 재귀: 하위 디렉토리 탐색
result.extend(find_files(str(item), extension))
elif item.is_file() and item.suffix == extension:
result.append(str(item))

return result


def directory_size(path: str) -> int:
"""디렉토리 전체 크기 계산 (바이트)"""
total = 0
p = Path(path)

if p.is_file():
return p.stat().st_size

for item in p.iterdir():
if item.is_dir():
total += directory_size(str(item)) # 재귀
else:
total += item.stat().st_size

return total


def print_tree(path: str, prefix: str = "", is_last: bool = True) -> None:
"""디렉토리 트리 출력"""
p = Path(path)
connector = "└── " if is_last else "├── "
print(prefix + connector + p.name)

if p.is_dir():
items = sorted(p.iterdir())
for i, item in enumerate(items):
is_last_item = i == len(items) - 1
extension = " " if is_last else "│ "
print_tree(str(item), prefix + extension, is_last_item)


# 현재 디렉토리 구조 출력 (예시)
# print_tree(".")

실전 예제 2: JSON 깊은 탐색

from typing import Any


def flatten_dict(data: dict, prefix: str = "", separator: str = ".") -> dict[str, Any]:
"""중첩 딕셔너리를 평탄화"""
result = {}

for key, value in data.items():
new_key = f"{prefix}{separator}{key}" if prefix else key

if isinstance(value, dict):
# 재귀: 중첩 딕셔너리
nested = flatten_dict(value, new_key, separator)
result.update(nested)
else:
result[new_key] = value

return result


config = {
"database": {
"host": "localhost",
"port": 5432,
"credentials": {
"username": "admin",
"password": "secret"
}
},
"cache": {
"ttl": 300,
"max_size": 1000
},
"debug": True
}

flat = flatten_dict(config)
for key, value in flat.items():
print(f" {key}: {value}")
# database.host: localhost
# database.port: 5432
# database.credentials.username: admin
# database.credentials.password: secret
# cache.ttl: 300
# cache.max_size: 1000
# debug: True


def deep_get(data: dict | list, path: str, default: Any = None) -> Any:
"""점 표기법으로 중첩 값 접근: deep_get(data, "a.b.c")"""
keys = path.split(".")
current = data

for key in keys:
if isinstance(current, dict):
current = current.get(key)
elif isinstance(current, list):
try:
current = current[int(key)]
except (IndexError, ValueError):
return default
else:
return default

if current is None:
return default

return current


print(deep_get(config, "database.credentials.username")) # admin
print(deep_get(config, "cache.ttl")) # 300
print(deep_get(config, "missing.key", "default")) # default

고수 팁

1. lru_cache 사용 시 주의사항

from functools import lru_cache

# lru_cache는 인수가 hashable해야 함
@lru_cache(maxsize=128)
def process(n: int, mode: str) -> int:
return n * 2 if mode == "double" else n

# 리스트는 unhashable — TypeError!
# @lru_cache
# def bad(data: list): ...

# 해결책: tuple로 변환
@lru_cache(maxsize=128)
def process_tuple(data: tuple[int, ...]) -> int:
return sum(data)

print(process_tuple((1, 2, 3))) # 6

# 래퍼로 리스트를 튜플로 변환
def process_list(data: list[int]) -> int:
return process_tuple(tuple(data))

2. 재귀 vs 반복 선택 기준

# 재귀가 적합한 경우:
# - 자연스럽게 재귀적 구조인 문제 (트리, 그래프)
# - 깊이가 크지 않음 (Python 기본 한계 1000)
# - 코드 명확성이 우선

# 반복이 더 나은 경우:
# - 깊이가 깊거나 예측 불가
# - 성능이 중요
# - 꼬리 재귀 패턴

# 재귀를 스택 + 반복으로 변환
def inorder_iterative(root) -> list[int]:
"""중위 순회를 스택으로 구현 (재귀 없음)"""
result = []
stack = []
current = root

while current or stack:
while current:
stack.append(current)
current = current.left
current = stack.pop()
result.append(current.value)
current = current.right

return result

3. 무한 재귀 방지 패턴

def safe_recursive(data, visited: set = None) -> list:
"""순환 참조가 있는 구조에서 안전한 재귀"""
if visited is None:
visited = set()

obj_id = id(data)
if obj_id in visited:
return ["<circular reference>"]

visited.add(obj_id)

if isinstance(data, dict):
return [safe_recursive(v, visited) for v in data.values()]
elif isinstance(data, list):
return [safe_recursive(item, visited) for item in data]
else:
return [data]


# 순환 참조 테스트
a = [1, 2]
b = [3, a]
a.append(b) # a → b → a (순환)

result = safe_recursive(a)
print(result) # 순환 참조 부분은 '<circular reference>'로 표시
Advertisement