Class Decorators
Class decorators use the same @ syntax as function decorators but are applied to classes — they modify the class definition or replace it with a new class.
Class Decorator Basics
# A class decorator is a function that takes a class and returns a modified (or new) class
def add_repr(cls):
"""Decorator that automatically adds __repr__"""
def __repr__(self):
attrs = ", ".join(
f"{k}={v!r}"
for k, v in vars(self).items()
if not k.startswith("_")
)
return f"{cls.__name__}({attrs})"
cls.__repr__ = __repr__
return cls
def add_eq(cls):
"""Decorator that automatically adds __eq__"""
def __eq__(self, other):
if type(self) is not type(other):
return NotImplemented
return vars(self) == vars(other)
def __hash__(self):
return hash(tuple(sorted(vars(self).items())))
cls.__eq__ = __eq__
cls.__hash__ = __hash__
return cls
@add_repr
@add_eq
class Point:
def __init__(self, x: float, y: float):
self.x = x
self.y = y
p1 = Point(1.0, 2.0)
p2 = Point(1.0, 2.0)
p3 = Point(3.0, 4.0)
print(p1) # Point(x=1.0, y=2.0)
print(p1 == p2) # True
print(p1 == p3) # False
print(hash(p1) == hash(p2)) # True
Injecting Functionality into Classes
import time
from typing import Any
def singleton(cls):
"""Decorator that makes a class a singleton"""
instances: dict = {}
def get_instance(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
get_instance.__name__ = cls.__name__
get_instance.__doc__ = cls.__doc__
get_instance._cls = cls # Access to original class
return get_instance
@singleton
class DatabaseConnection:
"""DB connection pool (singleton)"""
def __init__(self, host: str = "localhost", port: int = 5432):
self.host = host
self.port = port
self._pool: list = []
print(f"[DB] Connection initialized: {host}:{port}")
def query(self, sql: str) -> str:
return f"[{self.host}] Executed: {sql}"
db1 = DatabaseConnection("db.example.com", 5432)
db2 = DatabaseConnection("other.host", 9999) # Ignored
print(db1 is db2) # True
print(db1.host) # db.example.com
def add_timestamps(cls):
"""Decorator that adds created/updated timestamps"""
original_init = cls.__init__
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
self.created_at = time.time()
self.updated_at = time.time()
def touch(self):
self.updated_at = time.time()
def age(self) -> float:
return time.time() - self.created_at
cls.__init__ = new_init
cls.touch = touch
cls.age = age
return cls
@add_timestamps
class Article:
def __init__(self, title: str, content: str):
self.title = title
self.content = content
article = Article("Python Decorators", "Decorators are...")
print(f"Created at: {article.created_at:.0f}")
print(f"Age: {article.age():.4f}s")
time.sleep(0.01)
article.touch()
print(f"Updated at: {article.updated_at:.0f}")
Auto-Wrapping Methods
import functools
import time
import logging
logging.basicConfig(level=logging.INFO)
def log_all_methods(cls):
"""Decorator that adds logging to all public methods of a class"""
for name, method in vars(cls).items():
if callable(method) and not name.startswith("_"):
@functools.wraps(method)
def make_logged(m):
@functools.wraps(m)
def logged(*args, **kwargs):
logging.info(f"{cls.__name__}.{m.__name__} called")
start = time.perf_counter()
result = m(*args, **kwargs)
elapsed = time.perf_counter() - start
logging.info(f"{cls.__name__}.{m.__name__} done ({elapsed:.4f}s)")
return result
return logged
setattr(cls, name, make_logged(method))
return cls
def validate_types(cls):
"""Decorator that automatically validates types based on __init__ type hints"""
import inspect
original_init = cls.__init__
hints = {}
try:
hints = original_init.__annotations__
except AttributeError:
pass
@functools.wraps(original_init)
def new_init(self, *args, **kwargs):
sig = inspect.signature(original_init)
bound = sig.bind(self, *args, **kwargs)
bound.apply_defaults()
for param_name, expected_type in hints.items():
if param_name == "return":
continue
if param_name in bound.arguments:
value = bound.arguments[param_name]
if not isinstance(value, expected_type):
raise TypeError(
f"{cls.__name__}.{param_name}: "
f"expected {expected_type.__name__}, "
f"got {type(value).__name__}"
)
original_init(self, *args, **kwargs)
cls.__init__ = new_init
return cls
@log_all_methods
class Calculator:
def add(self, a: float, b: float) -> float:
return a + b
def multiply(self, a: float, b: float) -> float:
return a * b
calc = Calculator()
print(calc.add(3, 5)) # 8
print(calc.multiply(4, 6)) # 24
@validate_types
class Person:
def __init__(self, name: str, age: int):
self.name = name
self.age = age
def __repr__(self) -> str:
return f"Person(name={self.name!r}, age={self.age})"
p = Person("Alice", 30)
print(p)
try:
Person("Bob", "thirty") # TypeError: age: expected int, got str
except TypeError as e:
print(f"Type error: {e}")
Registry Pattern
from typing import Type
def register(registry: dict):
"""Decorator factory that automatically registers classes in a registry"""
def decorator(cls):
name = cls.__name__.lower()
registry[name] = cls
print(f"Registered: {name} → {cls.__name__}")
return cls
return decorator
# Plugin registry
PLUGINS: dict[str, Type] = {}
@register(PLUGINS)
class TextPlugin:
def process(self, data: str) -> str:
return data.upper()
@register(PLUGINS)
class JsonPlugin:
def process(self, data: str) -> str:
import json
return json.dumps({"processed": data})
@register(PLUGINS)
class CompressPlugin:
def process(self, data: str) -> str:
return f"compressed({len(data)} chars)"
def get_plugin(name: str):
cls = PLUGINS.get(name.lower())
if cls is None:
raise KeyError(f"Unknown plugin: {name}. Available: {list(PLUGINS)}")
return cls()
for plugin_name in ["text", "json", "compress"]:
plugin = get_plugin(plugin_name)
print(plugin.process("Hello, World!"))
Combining with dataclass
from dataclasses import dataclass, field
import json
def to_json_mixin(cls):
"""Decorator that adds JSON serialization methods"""
def to_dict(self) -> dict:
result = {}
for f in cls.__dataclass_fields__:
value = getattr(self, f)
if hasattr(value, "to_dict"):
result[f] = value.to_dict()
elif isinstance(value, list):
result[f] = [
v.to_dict() if hasattr(v, "to_dict") else v
for v in value
]
else:
result[f] = value
return result
def to_json(self, indent: int | None = None) -> str:
return json.dumps(self.to_dict(), ensure_ascii=False, indent=indent)
@classmethod
def from_dict(klass, data: dict):
return klass(**{
k: v for k, v in data.items()
if k in klass.__dataclass_fields__
})
@classmethod
def from_json(klass, json_str: str):
return klass.from_dict(json.loads(json_str))
cls.to_dict = to_dict
cls.to_json = to_json
cls.from_dict = from_dict
cls.from_json = from_json
return cls
@to_json_mixin
@dataclass
class Product:
name: str
price: float
stock: int
tags: list[str] = field(default_factory=list)
product = Product("Python Book", 35.0, 100, ["education", "programming"])
print(product.to_json(indent=2))
json_str = product.to_json()
restored = Product.from_json(json_str)
print(restored)
Pro Tips
1. Class decorators vs inheritance
# Class decorator: better when adding the same feature to multiple classes
def serializable(cls):
import json
def to_json(self) -> str:
return json.dumps(vars(self), ensure_ascii=False)
cls.to_json = to_json
return cls
@serializable
class User:
def __init__(self, name: str):
self.name = name
@serializable
class Product:
def __init__(self, name: str, price: float):
self.name = name
self.price = price
# Inheritance: when a common interface is needed
class Serializable:
def to_json(self) -> str:
import json
return json.dumps(vars(self), ensure_ascii=False)
class Admin(Serializable):
def __init__(self, name: str, level: int):
self.name = name
self.level = level
2. Combining with __init_subclass__
class AutoRegister:
"""Base class that auto-registers subclasses"""
_subclasses: dict[str, type] = {}
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
AutoRegister._subclasses[cls.__name__] = cls
@classmethod
def create(cls, name: str) -> "AutoRegister":
klass = cls._subclasses.get(name)
if klass is None:
raise KeyError(f"Not registered: {name}")
return klass()
class PluginA(AutoRegister):
def run(self) -> str:
return "Plugin A running"
class PluginB(AutoRegister):
def run(self) -> str:
return "Plugin B running"
plugin = AutoRegister.create("PluginA")
print(plugin.run()) # Plugin A running
Summary
| Use Case | Pattern |
|---|---|
| Add methods | cls.method = new_method; return cls |
| Singleton | Store instance in closure |
| Timestamps / auditing | Wrap __init__ |
| Registry | Register class into a dict |
| Type validation | __annotations__-based __init__ wrapping |
Class decorators compose well with Python's built-in decorators like dataclass and @property, and are great for cleanly separating cross-cutting concerns.