Skip to main content
Advertisement

Class Decorators

Class decorators use the same @ syntax as function decorators but are applied to classes — they modify the class definition or replace it with a new class.


Class Decorator Basics

# A class decorator is a function that takes a class and returns a modified (or new) class
def add_repr(cls):
"""Decorator that automatically adds __repr__"""
def __repr__(self):
attrs = ", ".join(
f"{k}={v!r}"
for k, v in vars(self).items()
if not k.startswith("_")
)
return f"{cls.__name__}({attrs})"

cls.__repr__ = __repr__
return cls


def add_eq(cls):
"""Decorator that automatically adds __eq__"""
def __eq__(self, other):
if type(self) is not type(other):
return NotImplemented
return vars(self) == vars(other)

def __hash__(self):
return hash(tuple(sorted(vars(self).items())))

cls.__eq__ = __eq__
cls.__hash__ = __hash__
return cls


@add_repr
@add_eq
class Point:
def __init__(self, x: float, y: float):
self.x = x
self.y = y


p1 = Point(1.0, 2.0)
p2 = Point(1.0, 2.0)
p3 = Point(3.0, 4.0)

print(p1) # Point(x=1.0, y=2.0)
print(p1 == p2) # True
print(p1 == p3) # False
print(hash(p1) == hash(p2)) # True

Injecting Functionality into Classes

import time
from typing import Any


def singleton(cls):
"""Decorator that makes a class a singleton"""
instances: dict = {}

def get_instance(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]

get_instance.__name__ = cls.__name__
get_instance.__doc__ = cls.__doc__
get_instance._cls = cls # Access to original class
return get_instance


@singleton
class DatabaseConnection:
"""DB connection pool (singleton)"""
def __init__(self, host: str = "localhost", port: int = 5432):
self.host = host
self.port = port
self._pool: list = []
print(f"[DB] Connection initialized: {host}:{port}")

def query(self, sql: str) -> str:
return f"[{self.host}] Executed: {sql}"


db1 = DatabaseConnection("db.example.com", 5432)
db2 = DatabaseConnection("other.host", 9999) # Ignored

print(db1 is db2) # True
print(db1.host) # db.example.com


def add_timestamps(cls):
"""Decorator that adds created/updated timestamps"""
original_init = cls.__init__

def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
self.created_at = time.time()
self.updated_at = time.time()

def touch(self):
self.updated_at = time.time()

def age(self) -> float:
return time.time() - self.created_at

cls.__init__ = new_init
cls.touch = touch
cls.age = age
return cls


@add_timestamps
class Article:
def __init__(self, title: str, content: str):
self.title = title
self.content = content


article = Article("Python Decorators", "Decorators are...")
print(f"Created at: {article.created_at:.0f}")
print(f"Age: {article.age():.4f}s")

time.sleep(0.01)
article.touch()
print(f"Updated at: {article.updated_at:.0f}")

Auto-Wrapping Methods

import functools
import time
import logging

logging.basicConfig(level=logging.INFO)


def log_all_methods(cls):
"""Decorator that adds logging to all public methods of a class"""
for name, method in vars(cls).items():
if callable(method) and not name.startswith("_"):
@functools.wraps(method)
def make_logged(m):
@functools.wraps(m)
def logged(*args, **kwargs):
logging.info(f"{cls.__name__}.{m.__name__} called")
start = time.perf_counter()
result = m(*args, **kwargs)
elapsed = time.perf_counter() - start
logging.info(f"{cls.__name__}.{m.__name__} done ({elapsed:.4f}s)")
return result
return logged
setattr(cls, name, make_logged(method))
return cls


def validate_types(cls):
"""Decorator that automatically validates types based on __init__ type hints"""
import inspect

original_init = cls.__init__
hints = {}
try:
hints = original_init.__annotations__
except AttributeError:
pass

@functools.wraps(original_init)
def new_init(self, *args, **kwargs):
sig = inspect.signature(original_init)
bound = sig.bind(self, *args, **kwargs)
bound.apply_defaults()

for param_name, expected_type in hints.items():
if param_name == "return":
continue
if param_name in bound.arguments:
value = bound.arguments[param_name]
if not isinstance(value, expected_type):
raise TypeError(
f"{cls.__name__}.{param_name}: "
f"expected {expected_type.__name__}, "
f"got {type(value).__name__}"
)
original_init(self, *args, **kwargs)

cls.__init__ = new_init
return cls


@log_all_methods
class Calculator:
def add(self, a: float, b: float) -> float:
return a + b

def multiply(self, a: float, b: float) -> float:
return a * b


calc = Calculator()
print(calc.add(3, 5)) # 8
print(calc.multiply(4, 6)) # 24


@validate_types
class Person:
def __init__(self, name: str, age: int):
self.name = name
self.age = age

def __repr__(self) -> str:
return f"Person(name={self.name!r}, age={self.age})"


p = Person("Alice", 30)
print(p)

try:
Person("Bob", "thirty") # TypeError: age: expected int, got str
except TypeError as e:
print(f"Type error: {e}")

Registry Pattern

from typing import Type


def register(registry: dict):
"""Decorator factory that automatically registers classes in a registry"""
def decorator(cls):
name = cls.__name__.lower()
registry[name] = cls
print(f"Registered: {name}{cls.__name__}")
return cls
return decorator


# Plugin registry
PLUGINS: dict[str, Type] = {}


@register(PLUGINS)
class TextPlugin:
def process(self, data: str) -> str:
return data.upper()


@register(PLUGINS)
class JsonPlugin:
def process(self, data: str) -> str:
import json
return json.dumps({"processed": data})


@register(PLUGINS)
class CompressPlugin:
def process(self, data: str) -> str:
return f"compressed({len(data)} chars)"


def get_plugin(name: str):
cls = PLUGINS.get(name.lower())
if cls is None:
raise KeyError(f"Unknown plugin: {name}. Available: {list(PLUGINS)}")
return cls()


for plugin_name in ["text", "json", "compress"]:
plugin = get_plugin(plugin_name)
print(plugin.process("Hello, World!"))

Combining with dataclass

from dataclasses import dataclass, field
import json


def to_json_mixin(cls):
"""Decorator that adds JSON serialization methods"""
def to_dict(self) -> dict:
result = {}
for f in cls.__dataclass_fields__:
value = getattr(self, f)
if hasattr(value, "to_dict"):
result[f] = value.to_dict()
elif isinstance(value, list):
result[f] = [
v.to_dict() if hasattr(v, "to_dict") else v
for v in value
]
else:
result[f] = value
return result

def to_json(self, indent: int | None = None) -> str:
return json.dumps(self.to_dict(), ensure_ascii=False, indent=indent)

@classmethod
def from_dict(klass, data: dict):
return klass(**{
k: v for k, v in data.items()
if k in klass.__dataclass_fields__
})

@classmethod
def from_json(klass, json_str: str):
return klass.from_dict(json.loads(json_str))

cls.to_dict = to_dict
cls.to_json = to_json
cls.from_dict = from_dict
cls.from_json = from_json
return cls


@to_json_mixin
@dataclass
class Product:
name: str
price: float
stock: int
tags: list[str] = field(default_factory=list)


product = Product("Python Book", 35.0, 100, ["education", "programming"])
print(product.to_json(indent=2))

json_str = product.to_json()
restored = Product.from_json(json_str)
print(restored)

Pro Tips

1. Class decorators vs inheritance

# Class decorator: better when adding the same feature to multiple classes
def serializable(cls):
import json

def to_json(self) -> str:
return json.dumps(vars(self), ensure_ascii=False)

cls.to_json = to_json
return cls


@serializable
class User:
def __init__(self, name: str):
self.name = name


@serializable
class Product:
def __init__(self, name: str, price: float):
self.name = name
self.price = price


# Inheritance: when a common interface is needed
class Serializable:
def to_json(self) -> str:
import json
return json.dumps(vars(self), ensure_ascii=False)


class Admin(Serializable):
def __init__(self, name: str, level: int):
self.name = name
self.level = level

2. Combining with __init_subclass__

class AutoRegister:
"""Base class that auto-registers subclasses"""
_subclasses: dict[str, type] = {}

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
AutoRegister._subclasses[cls.__name__] = cls

@classmethod
def create(cls, name: str) -> "AutoRegister":
klass = cls._subclasses.get(name)
if klass is None:
raise KeyError(f"Not registered: {name}")
return klass()


class PluginA(AutoRegister):
def run(self) -> str:
return "Plugin A running"


class PluginB(AutoRegister):
def run(self) -> str:
return "Plugin B running"


plugin = AutoRegister.create("PluginA")
print(plugin.run()) # Plugin A running

Summary

Use CasePattern
Add methodscls.method = new_method; return cls
SingletonStore instance in closure
Timestamps / auditingWrap __init__
RegistryRegister class into a dict
Type validation__annotations__-based __init__ wrapping

Class decorators compose well with Python's built-in decorators like dataclass and @property, and are great for cleanly separating cross-cutting concerns.

Advertisement