Skip to main content
Advertisement

5.3 First-class Functions and Lambda — Treating Functions as Values

In Python, functions are first-class objects. This means they can be assigned to variables, passed as arguments, and returned as values. This characteristic is the foundation of functional programming patterns.


First-class Functions

# 1. Assign a function to a variable
def greet(name: str) -> str:
return f"Hello, {name}!"

hello = greet # Assign the function object (not a call!)
print(hello("Alice")) # Hello, Alice!
print(hello is greet) # True — same function object


# 2. Pass a function as an argument
def apply(func, value):
return func(value)

def double(x: int) -> int:
return x * 2

def square(x: int) -> int:
return x ** 2

print(apply(double, 5)) # 10
print(apply(square, 5)) # 25
print(apply(str, 42)) # "42" — built-in functions work too


# 3. Return a function as a value
def make_adder(n: int):
def adder(x: int) -> int:
return x + n
return adder # Return the function

add5 = make_adder(5)
add10 = make_adder(10)

print(add5(3)) # 8
print(add10(3)) # 13
print(type(add5)) # <class 'function'>


# 4. Store functions in data structures
operations = {
"add": lambda x, y: x + y,
"sub": lambda x, y: x - y,
"mul": lambda x, y: x * y,
"div": lambda x, y: x / y if y != 0 else None,
}

for op_name, op_func in operations.items():
result = op_func(10, 3)
print(f" {op_name}(10, 3) = {result}")

lambda Syntax

# Syntax: lambda parameters: expression
# Unlike def, only a single expression is allowed — no return keyword

def multiply(x, y):
return x * y

multiply_lambda = lambda x, y: x * y

print(multiply(3, 4)) # 12
print(multiply_lambda(3, 4)) # 12


# lambda with default values
power = lambda base, exp=2: base ** exp
print(power(3)) # 9 (default exp=2)
print(power(3, 3)) # 27


# Lambda limitations:
# 1. Only a single expression (no if-statements, for-loops, multiline)
# 2. For complex logic, use def
# 3. Cannot write a docstring

# Conditional expression is allowed
classify = lambda x: "positive" if x > 0 else ("negative" if x < 0 else "zero")
print(classify(5)) # positive
print(classify(-3)) # negative
print(classify(0)) # zero

sorted() and lambda

students = [
{"name": "Charlie", "score": 85, "age": 22},
{"name": "Alice", "score": 92, "age": 20},
{"name": "Bob", "score": 78, "age": 25},
{"name": "Diana", "score": 92, "age": 21},
]

# Sort by score descending
by_score = sorted(students, key=lambda s: s["score"], reverse=True)
for s in by_score:
print(f" {s['name']}: {s['score']}")

# Multi-key sort: score descending, then name ascending for ties
by_score_name = sorted(
students,
key=lambda s: (-s["score"], s["name"])
)
for s in by_score_name:
print(f" {s['name']}: {s['score']}")


# String sorting
words = ["banana", "Apple", "cherry", "Date"]
# Case-insensitive sort
case_insensitive = sorted(words, key=lambda w: w.lower())
print(case_insensitive) # ['Apple', 'banana', 'cherry', 'Date']

# Sort by length
by_length = sorted(words, key=len)
print(by_length) # ['Date', 'Apple', 'banana', 'cherry']

map(): Apply a Transformation

# map(function, iterable) — applies the function to each element, returns an iterator
numbers = [1, 2, 3, 4, 5]

squared = list(map(lambda x: x**2, numbers))
print(squared) # [1, 4, 9, 16, 25]

# Using built-in functions
string_nums = ["1", "2", "3", "4", "5"]
integers = list(map(int, string_nums))
print(integers) # [1, 2, 3, 4, 5]

# Multiple iterables
a = [1, 2, 3]
b = [10, 20, 30]
sums = list(map(lambda x, y: x + y, a, b))
print(sums) # [11, 22, 33]

# map vs comprehension (comprehension is usually more readable)
result1 = list(map(lambda x: x**2, range(10)))
result2 = [x**2 for x in range(10)]
print(result1 == result2) # True

# map is concise when applying an existing function
names = ["alice", "bob", "charlie"]
capitalized = list(map(str.capitalize, names))
print(capitalized) # ['Alice', 'Bob', 'Charlie']

filter(): Conditional Filtering

numbers = [-5, -3, -1, 0, 2, 4, 6]

positives = list(filter(lambda x: x > 0, numbers))
print(positives) # [2, 4, 6]

# Pass None to remove falsy values
mixed = [0, 1, "", "hello", None, False, True, [], [1, 2]]
truthy = list(filter(None, mixed))
print(truthy) # [1, 'hello', True, [1, 2]]

# filter vs comprehension
evens1 = list(filter(lambda x: x % 2 == 0, range(20)))
evens2 = [x for x in range(20) if x % 2 == 0]
print(evens1 == evens2) # True

operator Module: Alternatives to lambda

from operator import itemgetter, attrgetter, methodcaller, add, mul
from functools import reduce

# itemgetter: access dictionary/sequence elements
students = [
{"name": "Charlie", "score": 85},
{"name": "Alice", "score": 92},
{"name": "Bob", "score": 78},
]

# lambda version
by_score_lambda = sorted(students, key=lambda s: s["score"])
# operator version (faster)
by_score_op = sorted(students, key=itemgetter("score"))

print([s["name"] for s in by_score_op]) # ['Bob', 'Charlie', 'Alice']

# Multi-key sort
by_score_name = sorted(students, key=itemgetter("score", "name"))


# attrgetter: access object attributes
from dataclasses import dataclass

@dataclass
class Product:
name: str
price: float
category: str

products = [
Product("Laptop", 1500000, "electronics"),
Product("Phone", 800000, "electronics"),
Product("Shirt", 30000, "clothing"),
]

by_price = sorted(products, key=attrgetter("price"))
print([p.name for p in by_price]) # ['Shirt', 'Phone', 'Laptop']

by_cat_price = sorted(products, key=attrgetter("category", "price"))


# methodcaller: call a method
words = ["hello", "WORLD", "Python"]
lowered = list(map(methodcaller("lower"), words))
print(lowered) # ['hello', 'world', 'python']

replaced = list(map(methodcaller("replace", "l", "r"), words[:2]))
print(replaced) # ['herro', 'WORLD']


# Arithmetic operators
numbers = [1, 2, 3, 4, 5]
total = reduce(add, numbers) # Same as sum(numbers)
product = reduce(mul, numbers) # 1*2*3*4*5 = 120
print(total, product)

functools.reduce()

from functools import reduce
from operator import add, mul

numbers = [1, 2, 3, 4, 5]

total = reduce(add, numbers)
print(total) # 15

product = reduce(mul, numbers)
print(product) # 120

# With initial value
total_with_init = reduce(add, numbers, 100)
print(total_with_init) # 115

# Custom operation: maximum (same as max())
maximum = reduce(lambda a, b: a if a > b else b, numbers)
print(maximum) # 5

# Merge dictionaries
dicts = [{"a": 1}, {"b": 2}, {"c": 3}]
merged = reduce(lambda acc, d: {**acc, **d}, dicts)
print(merged) # {'a': 1, 'b': 2, 'c': 3}

# Empty iterable: TypeError without initial value
try:
reduce(add, [])
except TypeError as e:
print(f"Error: {e}")

result = reduce(add, [], 0)
print(result) # 0

Functions Returning Functions (Factory Pattern)

def make_validator(min_val: float, max_val: float, name: str = "value"):
"""Creates a range-checking function"""
def validate(x: float) -> float:
if not min_val <= x <= max_val:
raise ValueError(f"{name} must be between {min_val} and {max_val} (got: {x})")
return x
validate.__doc__ = f"Range check for {name}: [{min_val}, {max_val}]"
return validate


validate_age = make_validator(0, 150, "age")
validate_score = make_validator(0, 100, "score")
validate_temp = make_validator(-273.15, 1000, "temperature")

print(validate_age(25)) # 25
print(validate_score(87)) # 87

try:
validate_age(200)
except ValueError as e:
print(f"Validation failed: {e}")


# Formatter factory
def make_formatter(prefix: str = "", suffix: str = "", width: int = 0):
def format_value(value) -> str:
text = f"{prefix}{value}{suffix}"
return text.rjust(width) if width else text
return format_value


currency = make_formatter(prefix="$", width=15)
percentage = make_formatter(suffix="%", width=8)

for val in [1000, 50000, 1234567]:
print(currency(f"{val:,}"))

for pct in [10, 85.5, 100]:
print(percentage(pct))

Real-world Example: Composable Sort Keys

from typing import Callable, TypeVar

T = TypeVar("T")


def compose(*funcs: Callable) -> Callable:
"""Compose multiple functions (applied right to left)"""
def composed(x):
result = x
for f in reversed(funcs):
result = f(result)
return result
return composed


# Sort pipeline
data = [
{"name": "Alice", "score": 92, "grade": "A"},
{"name": "Bob", "score": 78, "grade": "C"},
{"name": "Charlie", "score": 92, "grade": "A"},
{"name": "Diana", "score": 85, "grade": "B"},
]

# Score descending, name ascending on tie
result = sorted(data, key=lambda x: (-x["score"], x["name"]))
for d in result:
print(f" {d['name']}: {d['score']}")

Pro Tips

1. Use partial Instead of lambda for Partial Application

from functools import partial

def power(base: float, exponent: float) -> float:
return base ** exponent

# lambda way
square_lambda = lambda x: power(x, 2)

# partial way (clearer, better tool support)
square = partial(power, exponent=2)
cube = partial(power, exponent=3)

print(square(5)) # 25.0
print(cube(5)) # 125.0

# partial preserves the original function's docstring
print(square.__doc__) # power's docstring is kept

2. Callback Pattern

from typing import Callable

def process_with_callback(
data: list[int],
on_success: Callable[[list[int]], None],
on_error: Callable[[Exception], None] | None = None
) -> None:
try:
result = [x * 2 for x in data]
on_success(result)
except Exception as e:
if on_error:
on_error(e)
else:
raise


process_with_callback(
[1, 2, 3],
on_success=lambda result: print(f"Success: {result}"),
on_error=lambda e: print(f"Error: {e}")
)
# Success: [2, 4, 6]
Advertisement