본문으로 건너뛰기
Advertisement

itertools 완전 정복 — chain, product, groupby, islice 실전 활용

itertools 모듈은 이터레이터를 조합하는 고성능 도구 모음입니다. 리스트보다 메모리 효율적이고, 함수형 프로그래밍 스타일로 데이터를 처리합니다.


무한 이터레이터

count — 무한 카운터

import itertools

# count(start=0, step=1)
for n in itertools.count(10, 2):
if n > 20:
break
print(n) # 10, 12, 14, 16, 18, 20

# 실용 예: 고유 ID 생성기
counter = itertools.count(1)
ids = [next(counter) for _ in range(5)]
print(ids) # [1, 2, 3, 4, 5]

# enumerate와 유사한 기능 직접 구현
for n, item in zip(itertools.count(1), ["a", "b", "c"]):
print(f"{n}: {item}")

cycle — 무한 반복

# cycle(iterable) — 이터러블을 무한 반복
colors = itertools.cycle(["빨강", "초록", "파랑"])
for i in range(7):
print(next(colors))
# 빨강, 초록, 파랑, 빨강, 초록, 파랑, 빨강

# 실용 예: 라운드로빈 스케줄러
def round_robin(*iterables):
"""여러 이터러블에서 번갈아가며 값 추출"""
nexts = itertools.cycle(iter(it).__next__ for it in iterables)
pending = len(iterables)
while pending:
try:
yield next(nexts)()
except StopIteration:
pending -= 1
nexts = itertools.cycle(itertools.islice(nexts, pending))

print(list(round_robin("ABC", "D", "EF")))
# ['A', 'D', 'E', 'B', 'F', 'C']

repeat — 값 반복

# repeat(object, times=None)
for x in itertools.repeat("hello", 3):
print(x) # hello, hello, hello

# map과 조합해 고정 인수 공급
result = list(map(pow, range(5), itertools.repeat(2)))
print(result) # [0, 1, 4, 9, 16] — 각 수의 제곱

종료 이터레이터

chain — 이터러블 연결

# chain(*iterables) — 여러 이터러블을 순서대로 연결
result = list(itertools.chain([1, 2], [3, 4], [5, 6]))
print(result) # [1, 2, 3, 4, 5, 6]

# chain.from_iterable — 중첩 구조 평탄화
nested = [[1, 2], [3, 4], [5, 6]]
flat = list(itertools.chain.from_iterable(nested))
print(flat) # [1, 2, 3, 4, 5, 6]

# 실용 예: 여러 파일의 줄 합치기
import glob
all_lines = itertools.chain.from_iterable(
open(f, encoding="utf-8") for f in glob.glob("*.log")
)
for line in all_lines:
process(line)

islice — 이터레이터 슬라이싱

# islice(iterable, stop)
# islice(iterable, start, stop[, step])
first_5 = list(itertools.islice(range(100), 5))
print(first_5) # [0, 1, 2, 3, 4]

# 처음 2개 건너뛰고 10개 취하기
middle = list(itertools.islice(range(100), 2, 12))
print(middle) # [2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

# 무한 이터레이터에서 N개만 취하기
first_10_evens = list(itertools.islice(
(x for x in itertools.count() if x % 2 == 0), 10
))
print(first_10_evens) # [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]

# 실용 예: 페이지네이션
def paginate(iterable, page: int, size: int):
"""이터레이터를 페이지 단위로 슬라이싱"""
return list(itertools.islice(iterable, (page - 1) * size, page * size))

zip_longest — 길이가 다른 이터러블 묶기

# zip()은 짧은 쪽에서 종료
# zip_longest는 긴 쪽 기준, 짧은 쪽은 fillvalue로 채움
a = [1, 2, 3]
b = ["a", "b"]

print(list(zip(a, b))) # [(1, 'a'), (2, 'b')]
print(list(itertools.zip_longest(a, b))) # [(1, 'a'), (2, 'b'), (3, None)]
print(list(itertools.zip_longest(a, b, fillvalue=0))) # [(1, 'a'), (2, 'b'), (3, 0)]

# 실용 예: 병렬 배치 처리
def batch(iterable, size: int, fillvalue=None):
"""이터러블을 size 크기 배치로 묶기"""
args = [iter(iterable)] * size
return itertools.zip_longest(*args, fillvalue=fillvalue)

for chunk in batch(range(10), 3):
print(chunk)
# (0, 1, 2)
# (3, 4, 5)
# (6, 7, 8)
# (9, None, None)

starmap — 튜플을 언패킹하며 map

# starmap(function, iterable_of_tuples)
pairs = [(2, 5), (3, 2), (10, 3)]
result = list(itertools.starmap(pow, pairs))
print(result) # [32, 9, 1000] — pow(2,5), pow(3,2), pow(10,3)

# map과 비교
result2 = list(map(lambda pair: pow(*pair), pairs)) # 동일하지만 더 복잡

조합 이터레이터

product — 데카르트 곱

# product(*iterables, repeat=1)
# 모든 조합 (순서 있음, 중복 있음)
for combo in itertools.product("AB", [1, 2]):
print(combo)
# ('A', 1) ('A', 2) ('B', 1) ('B', 2)

# repeat 인수
for combo in itertools.product("AB", repeat=2):
print(combo)
# ('A', 'A') ('A', 'B') ('B', 'A') ('B', 'B')

# 실용 예: 하이퍼파라미터 그리드 탐색
learning_rates = [0.001, 0.01, 0.1]
batch_sizes = [32, 64, 128]
epochs = [10, 20]

for lr, bs, ep in itertools.product(learning_rates, batch_sizes, epochs):
print(f"LR={lr}, BatchSize={bs}, Epochs={ep}")
# train_model(lr=lr, batch_size=bs, epochs=ep)

permutations — 순열

# permutations(iterable, r=None) — 순서 있음, 중복 없음
items = ["A", "B", "C"]

# 모든 순열
print(list(itertools.permutations(items)))
# [('A','B','C'), ('A','C','B'), ('B','A','C'), ...] — 3! = 6개

# r 길이 순열
print(list(itertools.permutations(items, 2)))
# [('A','B'), ('A','C'), ('B','A'), ('B','C'), ('C','A'), ('C','B')] — 6개

combinations — 조합

# combinations(iterable, r) — 순서 없음, 중복 없음
items = ["A", "B", "C", "D"]
print(list(itertools.combinations(items, 2)))
# [('A','B'), ('A','C'), ('A','D'), ('B','C'), ('B','D'), ('C','D')]

# combinations_with_replacement — 순서 없음, 중복 있음
print(list(itertools.combinations_with_replacement("AB", 2)))
# [('A','A'), ('A','B'), ('B','B')]

# 실용 예: A/B 테스트 쌍 생성
versions = ["v1", "v2", "v3", "v4"]
test_pairs = list(itertools.combinations(versions, 2))
print(f"총 {len(test_pairs)}쌍의 A/B 테스트")
# 총 6쌍의 A/B 테스트

groupby — 연속된 키로 그룹화

# groupby(iterable, key=None)
# 주의: 연속된 같은 키끼리만 그룹화! 사전 정렬 필수
data = [
{"dept": "개발", "name": "홍길동"},
{"dept": "개발", "name": "김철수"},
{"dept": "디자인", "name": "이영희"},
{"dept": "개발", "name": "박민수"}, # 정렬 안 되면 별도 그룹!
]

# 부서별 그룹화 (정렬 먼저!)
sorted_data = sorted(data, key=lambda x: x["dept"])
for dept, members in itertools.groupby(sorted_data, key=lambda x: x["dept"]):
print(f"{dept}: {[m['name'] for m in members]}")
# 개발: ['홍길동', '김철수', '박민수']
# 디자인: ['이영희']
# 실용 예: 로그 날짜별 집계
from datetime import datetime

logs = [
{"date": "2024-03-15", "level": "ERROR", "msg": "오류1"},
{"date": "2024-03-15", "level": "INFO", "msg": "정보1"},
{"date": "2024-03-16", "level": "ERROR", "msg": "오류2"},
{"date": "2024-03-16", "level": "ERROR", "msg": "오류3"},
{"date": "2024-03-17", "level": "INFO", "msg": "정보2"},
]

# 날짜별 오류 수 집계
for date, entries in itertools.groupby(logs, key=lambda x: x["date"]):
entry_list = list(entries)
error_count = sum(1 for e in entry_list if e["level"] == "ERROR")
print(f"{date}: 총 {len(entry_list)}건, 오류 {error_count}건")

accumulate — 누적 연산

# accumulate(iterable, func=operator.add, initial=None)
import operator

# 누적 합
print(list(itertools.accumulate([1, 2, 3, 4, 5])))
# [1, 3, 6, 10, 15]

# 누적 곱
print(list(itertools.accumulate([1, 2, 3, 4, 5], operator.mul)))
# [1, 2, 6, 24, 120]

# 누적 최대값 (러닝 최대)
prices = [100, 95, 110, 105, 120, 115]
running_max = list(itertools.accumulate(prices, max))
print(running_max) # [100, 100, 110, 110, 120, 120]

# initial 인수 (Python 3.8+)
print(list(itertools.accumulate([1, 2, 3], initial=10)))
# [10, 11, 13, 16]

실전: 데이터 파이프라인 구축

import itertools
import csv
from typing import Iterator, TypeVar

T = TypeVar("T")

# 파이프라인 유틸리티
def take(n: int, iterable) -> list:
"""이터러블에서 n개만 취하기"""
return list(itertools.islice(iterable, n))

def drop(n: int, iterable) -> Iterator:
"""이터러블에서 n개 건너뛰기"""
return itertools.islice(iterable, n, None)

def window(iterable, size: int) -> Iterator[tuple]:
"""슬라이딩 윈도우"""
iters = itertools.tee(iterable, size)
for i, it in enumerate(iters):
next(itertools.islice(it, i, i), None) # i번 건너뜀
return zip(*iters)

def chunk(iterable, size: int) -> Iterator[tuple]:
"""고정 크기 청크로 분할"""
it = iter(iterable)
return iter(lambda: tuple(itertools.islice(it, size)), ())


# 실제 데이터 분석 파이프라인
def analyze_sales_data(csv_file: str):
def read_csv(filepath: str) -> Iterator[dict]:
with open(filepath, newline="", encoding="utf-8") as f:
yield from csv.DictReader(f)

def parse_amounts(rows: Iterator[dict]) -> Iterator[dict]:
for row in rows:
row["amount"] = float(row.get("amount", 0))
yield row

def filter_valid(rows: Iterator[dict]) -> Iterator[dict]:
return (row for row in rows if row["amount"] > 0)

# 파이프라인 조합
raw = read_csv(csv_file)
parsed = parse_amounts(raw)
valid = filter_valid(parsed)

# 지역별 그룹화 (정렬 필요)
sorted_data = sorted(valid, key=lambda x: x["region"])
region_totals = {}

for region, sales in itertools.groupby(sorted_data, key=lambda x: x["region"]):
total = sum(s["amount"] for s in sales)
region_totals[region] = total

# 상위 3개 지역
top_3 = sorted(region_totals.items(), key=lambda x: x[1], reverse=True)[:3]
for region, total in top_3:
print(f"{region}: {total:,.0f}원")


# 이동 평균 계산
def moving_average(data: list[float], window_size: int) -> list[float]:
windows = window(data, window_size)
return [sum(w) / window_size for w in windows]

prices = [10, 12, 11, 13, 15, 14, 16, 18, 17, 19]
ma3 = moving_average(prices, 3)
print(ma3) # [11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0]

고수 팁

팁 1: tee로 이터레이터 복제

import itertools

def split_stream(iterable, n: int = 2):
"""이터레이터를 n개로 복제 (각각 독립적으로 사용 가능)"""
return itertools.tee(iterable, n)

gen = (x**2 for x in range(10))
gen1, gen2 = split_stream(gen)
print(list(gen1)) # [0, 1, 4, 9, ...]
print(list(gen2)) # [0, 1, 4, 9, ...] — 독립적인 복사본

팁 2: takewhile / dropwhile

# takewhile: 조건이 False가 되면 중단
data = [1, 3, 5, 2, 8, 4]
print(list(itertools.takewhile(lambda x: x < 5, data))) # [1, 3]

# dropwhile: 조건이 False가 될 때까지 건너뜀
print(list(itertools.dropwhile(lambda x: x < 5, data))) # [5, 2, 8, 4]

팁 3: filterfalse로 조건의 역 필터링

evens = list(itertools.filterfalse(lambda x: x % 2, range(10)))
print(evens) # [0, 2, 4, 6, 8]

팁 4: pairwise (Python 3.10+)

# 연속된 쌍 생성
pairs = list(itertools.pairwise([1, 2, 3, 4, 5]))
print(pairs) # [(1, 2), (2, 3), (3, 4), (4, 5)]

# 변화율 계산
prices = [100, 110, 105, 120]
changes = [(b - a) / a * 100 for a, b in itertools.pairwise(prices)]
print(changes) # [10.0, -4.54..., 14.28...]
Advertisement