itertools Mastery — chain, product, groupby, islice in Practice
The itertools module is a collection of high-performance tools for combining iterators. More memory-efficient than lists, and processes data in a functional programming style.
Infinite Iterators
count — Infinite Counter
import itertools
# count(start=0, step=1)
for n in itertools.count(10, 2):
if n > 20:
break
print(n) # 10, 12, 14, 16, 18, 20
# Practical use: unique ID generator
counter = itertools.count(1)
ids = [next(counter) for _ in range(5)]
print(ids) # [1, 2, 3, 4, 5]
# Directly implement enumerate-like functionality
for n, item in zip(itertools.count(1), ["a", "b", "c"]):
print(f"{n}: {item}")
cycle — Infinite Repetition
# cycle(iterable) — repeat iterable infinitely
colors = itertools.cycle(["red", "green", "blue"])
for i in range(7):
print(next(colors))
# red, green, blue, red, green, blue, red
# Practical use: round-robin scheduler
def round_robin(*iterables):
"""Extract values alternately from multiple iterables"""
nexts = itertools.cycle(iter(it).__next__ for it in iterables)
pending = len(iterables)
while pending:
try:
yield next(nexts)()
except StopIteration:
pending -= 1
nexts = itertools.cycle(itertools.islice(nexts, pending))
print(list(round_robin("ABC", "D", "EF")))
# ['A', 'D', 'E', 'B', 'F', 'C']
repeat — Repeat a Value
# repeat(object, times=None)
for x in itertools.repeat("hello", 3):
print(x) # hello, hello, hello
# Combine with map to supply a fixed argument
result = list(map(pow, range(5), itertools.repeat(2)))
print(result) # [0, 1, 4, 9, 16] — squares of each number
Terminating Iterators
chain — Connect Iterables
# chain(*iterables) — connect multiple iterables in order
result = list(itertools.chain([1, 2], [3, 4], [5, 6]))
print(result) # [1, 2, 3, 4, 5, 6]
# chain.from_iterable — flatten nested structure
nested = [[1, 2], [3, 4], [5, 6]]
flat = list(itertools.chain.from_iterable(nested))
print(flat) # [1, 2, 3, 4, 5, 6]
# Practical use: combine lines from multiple files
import glob
all_lines = itertools.chain.from_iterable(
open(f, encoding="utf-8") for f in glob.glob("*.log")
)
for line in all_lines:
process(line)
islice — Slice an Iterator
# islice(iterable, stop)
# islice(iterable, start, stop[, step])
first_5 = list(itertools.islice(range(100), 5))
print(first_5) # [0, 1, 2, 3, 4]
# Skip first 2 and take 10
middle = list(itertools.islice(range(100), 2, 12))
print(middle) # [2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
# Take N from an infinite iterator
first_10_evens = list(itertools.islice(
(x for x in itertools.count() if x % 2 == 0), 10
))
print(first_10_evens) # [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
# Practical use: pagination
def paginate(iterable, page: int, size: int):
"""Slice an iterator into pages"""
return list(itertools.islice(iterable, (page - 1) * size, page * size))
zip_longest — Zip Iterables of Different Lengths
# zip() stops at the shorter side
# zip_longest uses the longer side, fills the shorter with fillvalue
a = [1, 2, 3]
b = ["a", "b"]
print(list(zip(a, b))) # [(1, 'a'), (2, 'b')]
print(list(itertools.zip_longest(a, b))) # [(1, 'a'), (2, 'b'), (3, None)]
print(list(itertools.zip_longest(a, b, fillvalue=0))) # [(1, 'a'), (2, 'b'), (3, 0)]
# Practical use: parallel batch processing
def batch(iterable, size: int, fillvalue=None):
"""Group an iterable into size-sized batches"""
args = [iter(iterable)] * size
return itertools.zip_longest(*args, fillvalue=fillvalue)
for chunk in batch(range(10), 3):
print(chunk)
# (0, 1, 2)
# (3, 4, 5)
# (6, 7, 8)
# (9, None, None)
starmap — map with Unpacked Tuples
# starmap(function, iterable_of_tuples)
pairs = [(2, 5), (3, 2), (10, 3)]
result = list(itertools.starmap(pow, pairs))
print(result) # [32, 9, 1000] — pow(2,5), pow(3,2), pow(10,3)
# Compare with map
result2 = list(map(lambda pair: pow(*pair), pairs)) # Same but more verbose
Combinatoric Iterators
product — Cartesian Product
# product(*iterables, repeat=1)
# All combinations (ordered, with repetition)
for combo in itertools.product("AB", [1, 2]):
print(combo)
# ('A', 1) ('A', 2) ('B', 1) ('B', 2)
# repeat argument
for combo in itertools.product("AB", repeat=2):
print(combo)
# ('A', 'A') ('A', 'B') ('B', 'A') ('B', 'B')
# Practical use: hyperparameter grid search
learning_rates = [0.001, 0.01, 0.1]
batch_sizes = [32, 64, 128]
epochs = [10, 20]
for lr, bs, ep in itertools.product(learning_rates, batch_sizes, epochs):
print(f"LR={lr}, BatchSize={bs}, Epochs={ep}")
# train_model(lr=lr, batch_size=bs, epochs=ep)
permutations — Ordered Arrangements
# permutations(iterable, r=None) — ordered, no repetition
items = ["A", "B", "C"]
# All permutations
print(list(itertools.permutations(items)))
# [('A','B','C'), ('A','C','B'), ('B','A','C'), ...] — 3! = 6 items
# r-length permutations
print(list(itertools.permutations(items, 2)))
# [('A','B'), ('A','C'), ('B','A'), ('B','C'), ('C','A'), ('C','B')] — 6 items
combinations — Unordered Selections
# combinations(iterable, r) — unordered, no repetition
items = ["A", "B", "C", "D"]
print(list(itertools.combinations(items, 2)))
# [('A','B'), ('A','C'), ('A','D'), ('B','C'), ('B','D'), ('C','D')]
# combinations_with_replacement — unordered, with repetition
print(list(itertools.combinations_with_replacement("AB", 2)))
# [('A','A'), ('A','B'), ('B','B')]
# Practical use: generate A/B test pairs
versions = ["v1", "v2", "v3", "v4"]
test_pairs = list(itertools.combinations(versions, 2))
print(f"Total {len(test_pairs)} A/B test pairs")
# Total 6 A/B test pairs
groupby — Group by Consecutive Keys
# groupby(iterable, key=None)
# Warning: only groups consecutive identical keys! Sort first.
data = [
{"dept": "Engineering", "name": "Alice"},
{"dept": "Engineering", "name": "Bob"},
{"dept": "Design", "name": "Carol"},
{"dept": "Engineering", "name": "Dave"}, # Separate group if not sorted!
]
# Group by department (sort first!)
sorted_data = sorted(data, key=lambda x: x["dept"])
for dept, members in itertools.groupby(sorted_data, key=lambda x: x["dept"]):
print(f"{dept}: {[m['name'] for m in members]}")
# Design: ['Carol']
# Engineering: ['Alice', 'Bob', 'Dave']
# Practical use: aggregate log entries by date
from datetime import datetime
logs = [
{"date": "2024-03-15", "level": "ERROR", "msg": "error1"},
{"date": "2024-03-15", "level": "INFO", "msg": "info1"},
{"date": "2024-03-16", "level": "ERROR", "msg": "error2"},
{"date": "2024-03-16", "level": "ERROR", "msg": "error3"},
{"date": "2024-03-17", "level": "INFO", "msg": "info2"},
]
# Count errors by date
for date, entries in itertools.groupby(logs, key=lambda x: x["date"]):
entry_list = list(entries)
error_count = sum(1 for e in entry_list if e["level"] == "ERROR")
print(f"{date}: {len(entry_list)} total, {error_count} errors")
accumulate — Cumulative Operations
# accumulate(iterable, func=operator.add, initial=None)
import operator
# Cumulative sum
print(list(itertools.accumulate([1, 2, 3, 4, 5])))
# [1, 3, 6, 10, 15]
# Cumulative product
print(list(itertools.accumulate([1, 2, 3, 4, 5], operator.mul)))
# [1, 2, 6, 24, 120]
# Running maximum
prices = [100, 95, 110, 105, 120, 115]
running_max = list(itertools.accumulate(prices, max))
print(running_max) # [100, 100, 110, 110, 120, 120]
# initial argument (Python 3.8+)
print(list(itertools.accumulate([1, 2, 3], initial=10)))
# [10, 11, 13, 16]
Practical: Building a Data Pipeline
import itertools
import csv
from typing import Iterator, TypeVar
T = TypeVar("T")
# Pipeline utilities
def take(n: int, iterable) -> list:
"""Take n items from an iterable"""
return list(itertools.islice(iterable, n))
def drop(n: int, iterable) -> Iterator:
"""Skip n items from an iterable"""
return itertools.islice(iterable, n, None)
def window(iterable, size: int) -> Iterator[tuple]:
"""Sliding window"""
iters = itertools.tee(iterable, size)
for i, it in enumerate(iters):
next(itertools.islice(it, i, i), None) # Skip i items
return zip(*iters)
def chunk(iterable, size: int) -> Iterator[tuple]:
"""Split into fixed-size chunks"""
it = iter(iterable)
return iter(lambda: tuple(itertools.islice(it, size)), ())
# Real data analysis pipeline
def analyze_sales_data(csv_file: str):
def read_csv(filepath: str) -> Iterator[dict]:
with open(filepath, newline="", encoding="utf-8") as f:
yield from csv.DictReader(f)
def parse_amounts(rows: Iterator[dict]) -> Iterator[dict]:
for row in rows:
row["amount"] = float(row.get("amount", 0))
yield row
def filter_valid(rows: Iterator[dict]) -> Iterator[dict]:
return (row for row in rows if row["amount"] > 0)
# Compose pipeline
raw = read_csv(csv_file)
parsed = parse_amounts(raw)
valid = filter_valid(parsed)
# Group by region (sort required)
sorted_data = sorted(valid, key=lambda x: x["region"])
region_totals = {}
for region, sales in itertools.groupby(sorted_data, key=lambda x: x["region"]):
total = sum(s["amount"] for s in sales)
region_totals[region] = total
# Top 3 regions
top_3 = sorted(region_totals.items(), key=lambda x: x[1], reverse=True)[:3]
for region, total in top_3:
print(f"{region}: {total:,.2f}")
# Moving average calculation
def moving_average(data: list[float], window_size: int) -> list[float]:
windows = window(data, window_size)
return [sum(w) / window_size for w in windows]
prices = [10, 12, 11, 13, 15, 14, 16, 18, 17, 19]
ma3 = moving_average(prices, 3)
print(ma3) # [11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0]
Expert Tips
Tip 1: Clone an iterator with tee
import itertools
def split_stream(iterable, n: int = 2):
"""Clone an iterator into n independent copies"""
return itertools.tee(iterable, n)
gen = (x**2 for x in range(10))
gen1, gen2 = split_stream(gen)
print(list(gen1)) # [0, 1, 4, 9, ...]
print(list(gen2)) # [0, 1, 4, 9, ...] — independent copy
Tip 2: takewhile / dropwhile
# takewhile: stop when condition becomes False
data = [1, 3, 5, 2, 8, 4]
print(list(itertools.takewhile(lambda x: x < 5, data))) # [1, 3]
# dropwhile: skip until condition becomes False
print(list(itertools.dropwhile(lambda x: x < 5, data))) # [5, 2, 8, 4]
Tip 3: Inverse filtering with filterfalse
evens = list(itertools.filterfalse(lambda x: x % 2, range(10)))
print(evens) # [0, 2, 4, 6, 8]
Tip 4: pairwise (Python 3.10+)
# Generate consecutive pairs
pairs = list(itertools.pairwise([1, 2, 3, 4, 5]))
print(pairs) # [(1, 2), (2, 3), (3, 4), (4, 5)]
# Calculate rate of change
prices = [100, 110, 105, 120]
changes = [(b - a) / a * 100 for a, b in itertools.pairwise(prices)]
print(changes) # [10.0, -4.54..., 14.28...]