Skip to main content
Advertisement

5.5 Recursion and Memoization — Fibonacci, Tree Traversal, lru_cache

Recursion is a technique where a function calls itself, and it is powerful for divide-and-conquer problems, tree traversal, and expressing mathematical definitions directly in code. Combined with memoization, performance issues can be solved too.


Basic Structure of Recursion

A recursive function must have two parts:

  1. Base Case: The condition that stops the recursion
  2. Recursive Case: Calling itself with a smaller problem
# Factorial: n! = n × (n-1)!
def factorial(n: int) -> int:
# Base case
if n <= 1:
return 1
# Recursive call
return n * factorial(n - 1)


print(factorial(5)) # 120 (5 × 4 × 3 × 2 × 1)
print(factorial(10)) # 3628800


# Trace the recursive calls
def factorial_trace(n: int, depth: int = 0) -> int:
indent = " " * depth
print(f"{indent}factorial({n}) called")
if n <= 1:
print(f"{indent}→ returning: 1 (base case)")
return 1
result = n * factorial_trace(n - 1, depth + 1)
print(f"{indent}→ returning: {n} × ... = {result}")
return result


factorial_trace(4)
# factorial(4) called
# factorial(3) called
# factorial(2) called
# factorial(1) called
# → returning: 1 (base case)
# → returning: 2 × ... = 2
# → returning: 3 × ... = 6
# → returning: 4 × ... = 24

Understanding the Call Stack

import sys

# Default recursion limit
print(sys.getrecursionlimit()) # Usually 1000

# Deep recursion causes RecursionError
def count_down(n: int) -> None:
if n == 0:
return
count_down(n - 1)

try:
count_down(5000) # Exceeds default limit
except RecursionError as e:
print(f"RecursionError: {e}")

# Increase the limit (use with care!)
sys.setrecursionlimit(10000)
count_down(5000) # Now succeeds

# Check current stack depth
import traceback

def recursive_func(n):
if n == 0:
print(f"Current stack depth: {len(traceback.extract_stack())}")
return
recursive_func(n - 1)

recursive_func(10)

Fibonacci: The Problem with Naive Recursion

# Naive recursive implementation (very slow!)
def fib_naive(n: int) -> int:
if n <= 1:
return n
return fib_naive(n - 1) + fib_naive(n - 2)

# fib(5) = fib(4) + fib(3)
# = (fib(3) + fib(2)) + (fib(2) + fib(1))
# = ... grows exponentially!

import time

start = time.time()
print(fib_naive(30)) # 832040 (slow)
print(f"Time: {time.time() - start:.3f}s")
# ~0.3 seconds

# fib(35) → ~3 seconds
# fib(40) → ~30 seconds (very slow)

functools.lru_cache: Memoization

from functools import lru_cache, cache
import time

# Apply lru_cache
@lru_cache(maxsize=None) # maxsize=None means unlimited cache
def fib_cached(n: int) -> int:
if n <= 1:
return n
return fib_cached(n - 1) + fib_cached(n - 2)

start = time.time()
print(fib_cached(100)) # 354224848179261915075
print(f"Time: {time.time() - start:.6f}s") # Nearly instant!

# Check cache info
print(fib_cached.cache_info())
# CacheInfo(hits=98, misses=101, maxsize=None, currsize=101)

# Clear cache
fib_cached.cache_clear()
print(fib_cached.cache_info())
# CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)


# Python 3.9+: functools.cache (simplified version of maxsize=None)
@cache
def fib_fast(n: int) -> int:
if n <= 1:
return n
return fib_fast(n - 1) + fib_fast(n - 2)

print(fib_fast(200)) # 280571172992510140037611932413038677189525

Tail Recursion and Converting to Iteration

Python does not support tail call optimization (TCO). Deep recursion always risks RecursionError, so converting to iteration is recommended.

# Tail recursive version (stack still grows in Python)
def factorial_tail(n: int, acc: int = 1) -> int:
if n <= 1:
return acc
return factorial_tail(n - 1, acc * n)

# Iterative version (recommended)
def factorial_iter(n: int) -> int:
result = 1
for i in range(2, n + 1):
result *= i
return result

print(factorial_tail(10)) # 3628800
print(factorial_iter(10)) # 3628800

# Fibonacci iterative version
def fib_iter(n: int) -> int:
if n <= 1:
return n
a, b = 0, 1
for _ in range(2, n + 1):
a, b = b, a + b
return b

print(fib_iter(100)) # 354224848179261915075 — computed instantly!

Classic Example: Tower of Hanoi

def hanoi(n: int, source: str, target: str, auxiliary: str, moves: list = None) -> list:
"""Returns the minimum sequence of moves to move n disks from source to target."""
if moves is None:
moves = []

if n == 1:
moves.append((source, target))
return moves

# 1. Move (n-1) disks to auxiliary
hanoi(n - 1, source, auxiliary, target, moves)
# 2. Move the largest disk to target
moves.append((source, target))
# 3. Move (n-1) disks from auxiliary to target
hanoi(n - 1, auxiliary, target, source, moves)

return moves


moves = hanoi(3, "A", "C", "B")
for i, (src, tgt) in enumerate(moves, 1):
print(f" {i}: {src}{tgt}")

print(f"Total moves: {len(moves)}") # 7 (2^3 - 1)

# Mathematical verification: minimum moves for n disks = 2^n - 1
for n in range(1, 8):
moves = hanoi(n, "A", "C", "B")
expected = 2**n - 1
assert len(moves) == expected
print(f" n={n}: {len(moves)} moves (expected: {expected})")

Tree Traversal

from dataclasses import dataclass, field
from typing import Optional


@dataclass
class TreeNode:
value: int
left: Optional["TreeNode"] = None
right: Optional["TreeNode"] = None


def insert(root: Optional[TreeNode], value: int) -> TreeNode:
"""Insert a value into a BST"""
if root is None:
return TreeNode(value)
if value < root.value:
root.left = insert(root.left, value)
elif value > root.value:
root.right = insert(root.right, value)
return root


# Build a BST
root = None
for v in [5, 3, 7, 1, 4, 6, 8]:
root = insert(root, v)


# Recursive tree traversals
def inorder(node: Optional[TreeNode]) -> list[int]:
"""In-order: Left → Root → Right (sorted order for BST)"""
if node is None:
return []
return inorder(node.left) + [node.value] + inorder(node.right)

def preorder(node: Optional[TreeNode]) -> list[int]:
"""Pre-order: Root → Left → Right"""
if node is None:
return []
return [node.value] + preorder(node.left) + preorder(node.right)

def postorder(node: Optional[TreeNode]) -> list[int]:
"""Post-order: Left → Right → Root"""
if node is None:
return []
return postorder(node.left) + postorder(node.right) + [node.value]


print("In-order:", inorder(root)) # [1, 3, 4, 5, 6, 7, 8]
print("Pre-order:", preorder(root)) # [5, 3, 1, 4, 7, 6, 8]
print("Post-order:", postorder(root)) # [1, 4, 3, 6, 8, 7, 5]


# Tree height
def tree_height(node: Optional[TreeNode]) -> int:
if node is None:
return 0
return 1 + max(tree_height(node.left), tree_height(node.right))

print("Tree height:", tree_height(root)) # 3

Binary Search (Recursive)

def binary_search(arr: list[int], target: int, left: int = 0, right: int = None) -> int:
"""Returns the index of target in sorted arr, or -1 if not found."""
if right is None:
right = len(arr) - 1

# Base case: search range is empty
if left > right:
return -1

mid = (left + right) // 2

if arr[mid] == target:
return mid
elif arr[mid] < target:
return binary_search(arr, target, mid + 1, right)
else:
return binary_search(arr, target, left, mid - 1)


sorted_arr = list(range(0, 100, 2)) # [0, 2, 4, ..., 98]
print(binary_search(sorted_arr, 42)) # 21
print(binary_search(sorted_arr, 43)) # -1

Real-world Example 1: File System Traversal

import os
from pathlib import Path


def find_files(directory: str, extension: str) -> list[str]:
"""Recursively find files with a given extension."""
result = []
path = Path(directory)

for item in path.iterdir():
if item.is_dir():
result.extend(find_files(str(item), extension))
elif item.is_file() and item.suffix == extension:
result.append(str(item))

return result


def directory_size(path: str) -> int:
"""Recursively calculate total directory size in bytes."""
total = 0
p = Path(path)

if p.is_file():
return p.stat().st_size

for item in p.iterdir():
if item.is_dir():
total += directory_size(str(item))
else:
total += item.stat().st_size

return total


def print_tree(path: str, prefix: str = "", is_last: bool = True) -> None:
"""Print directory tree."""
p = Path(path)
connector = "└── " if is_last else "├── "
print(prefix + connector + p.name)

if p.is_dir():
items = sorted(p.iterdir())
for i, item in enumerate(items):
is_last_item = i == len(items) - 1
extension = " " if is_last else "│ "
print_tree(str(item), prefix + extension, is_last_item)

Real-world Example 2: Deep JSON Traversal

from typing import Any


def flatten_dict(data: dict, prefix: str = "", separator: str = ".") -> dict[str, Any]:
"""Flatten a nested dictionary."""
result = {}

for key, value in data.items():
new_key = f"{prefix}{separator}{key}" if prefix else key

if isinstance(value, dict):
nested = flatten_dict(value, new_key, separator)
result.update(nested)
else:
result[new_key] = value

return result


config = {
"database": {
"host": "localhost",
"port": 5432,
"credentials": {
"username": "admin",
"password": "secret"
}
},
"cache": {
"ttl": 300,
"max_size": 1000
},
"debug": True
}

flat = flatten_dict(config)
for key, value in flat.items():
print(f" {key}: {value}")
# database.host: localhost
# database.port: 5432
# database.credentials.username: admin
# database.credentials.password: secret
# cache.ttl: 300
# cache.max_size: 1000
# debug: True


def deep_get(data: dict | list, path: str, default: Any = None) -> Any:
"""Access nested values using dot notation: deep_get(data, 'a.b.c')"""
keys = path.split(".")
current = data

for key in keys:
if isinstance(current, dict):
current = current.get(key)
elif isinstance(current, list):
try:
current = current[int(key)]
except (IndexError, ValueError):
return default
else:
return default

if current is None:
return default

return current


print(deep_get(config, "database.credentials.username")) # admin
print(deep_get(config, "cache.ttl")) # 300
print(deep_get(config, "missing.key", "default")) # default

Pro Tips

1. lru_cache Requires Hashable Arguments

from functools import lru_cache

# lru_cache requires hashable arguments
@lru_cache(maxsize=128)
def process(n: int, mode: str) -> int:
return n * 2 if mode == "double" else n

# Lists are unhashable — TypeError!
# @lru_cache
# def bad(data: list): ...

# Solution: convert to tuple
@lru_cache(maxsize=128)
def process_tuple(data: tuple[int, ...]) -> int:
return sum(data)

print(process_tuple((1, 2, 3))) # 6

# Wrapper to accept lists
def process_list(data: list[int]) -> int:
return process_tuple(tuple(data))

2. When to Use Recursion vs Iteration

# Recursion is appropriate when:
# - The problem is naturally recursive (trees, graphs)
# - Depth is small (Python default limit is 1000)
# - Code clarity is the priority

# Iteration is better when:
# - Depth is deep or unpredictable
# - Performance is critical
# - Tail recursive pattern

# Convert recursion to stack + iteration
def inorder_iterative(root) -> list[int]:
"""In-order traversal without recursion"""
result = []
stack = []
current = root

while current or stack:
while current:
stack.append(current)
current = current.left
current = stack.pop()
result.append(current.value)
current = current.right

return result

3. Preventing Infinite Recursion

def safe_recursive(data, visited: set = None) -> list:
"""Safe recursion on structures with circular references"""
if visited is None:
visited = set()

obj_id = id(data)
if obj_id in visited:
return ["<circular reference>"]

visited.add(obj_id)

if isinstance(data, dict):
return [safe_recursive(v, visited) for v in data.values()]
elif isinstance(data, list):
return [safe_recursive(item, visited) for item in data]
else:
return [data]


# Test with circular reference
a = [1, 2]
b = [3, a]
a.append(b) # a → b → a (circular)

result = safe_recursive(a)
print(result) # Circular part shown as '<circular reference>'
Advertisement