SOLID Principles in Python
SOLID is a set of five core principles for object-oriented design that guide you toward maintainable, extensible code. Each letter stands for one principle.
- S: Single Responsibility Principle
- O: Open/Closed Principle
- L: Liskov Substitution Principle
- I: Interface Segregation Principle
- D: Dependency Inversion Principle
S — Single Responsibility Principle
"A class should have only one reason to change."
When a class handles too many responsibilities, changing one aspect risks breaking another.
Violation Example
# SRP violation: user data management + email sending + DB saving + report generation
class User:
def __init__(self, name: str, email: str):
self.name = name
self.email = email
def send_welcome_email(self):
"""Email sending — not the User class's responsibility"""
print(f"Sending welcome email to {self.email}.")
# SMTP config, template rendering, etc. all crammed in here...
def save_to_db(self, connection):
"""DB saving — not the User class's responsibility"""
connection.execute(
"INSERT INTO users (name, email) VALUES (?, ?)",
(self.name, self.email)
)
def generate_report(self) -> str:
"""Report generation — not the User class's responsibility"""
return f"User report: {self.name} ({self.email})"
Improved Code
# SRP compliant: each class has a single responsibility
class User:
"""Handles user data only"""
def __init__(self, name: str, email: str):
self.name = name
self.email = email
def __repr__(self) -> str:
return f"User(name={self.name!r}, email={self.email!r})"
class EmailService:
"""Handles email sending only"""
def __init__(self, smtp_host: str, smtp_port: int = 587):
self.smtp_host = smtp_host
self.smtp_port = smtp_port
def send_welcome(self, user: User) -> None:
print(f"[{self.smtp_host}:{self.smtp_port}] Sending welcome email to {user.email}")
def send_notification(self, user: User, message: str) -> None:
print(f"[Notification] {user.email} → {message}")
class UserRepository:
"""Handles user data persistence only"""
def __init__(self):
self._store: dict[str, User] = {}
def save(self, user: User) -> None:
self._store[user.email] = user
print(f"User saved: {user.name}")
def find_by_email(self, email: str) -> User | None:
return self._store.get(email)
def find_all(self) -> list[User]:
return list(self._store.values())
class UserReportGenerator:
"""Handles report generation only"""
def generate(self, user: User) -> str:
return f"User report: {user.name} ({user.email})"
def generate_summary(self, users: list[User]) -> str:
lines = [f"=== User Summary ({len(users)} users) ==="]
for u in users:
lines.append(f" - {u.name}: {u.email}")
return "\n".join(lines)
# Usage example
user = User("Alice", "alice@example.com")
email_svc = EmailService("smtp.example.com")
repo = UserRepository()
reporter = UserReportGenerator()
repo.save(user)
email_svc.send_welcome(user)
print(reporter.generate(user))
O — Open/Closed Principle
"Software entities (classes, modules, functions) should be open for extension but closed for modification."
When adding new functionality, extend by adding new classes rather than modifying existing code.
Violation Example
# OCP violation: must modify this function every time a new discount type is added
def calculate_discount(price: float, discount_type: str) -> float:
if discount_type == "percentage":
return price * 0.9
elif discount_type == "fixed":
return price - 10
elif discount_type == "vip":
return price * 0.7
# Every new discount type requires adding another elif → OCP violation!
return price
Improved Code
from abc import ABC, abstractmethod
class DiscountStrategy(ABC):
"""Abstract base class for discount strategies — extension point"""
@abstractmethod
def apply(self, price: float) -> float:
"""Return price after applying discount"""
...
@abstractmethod
def describe(self) -> str:
"""Describe the discount"""
...
class NoDiscount(DiscountStrategy):
def apply(self, price: float) -> float:
return price
def describe(self) -> str:
return "No discount"
class PercentageDiscount(DiscountStrategy):
def __init__(self, percent: float):
if not (0 < percent <= 100):
raise ValueError(f"Percent must be between 0 and 100: {percent}")
self.percent = percent
def apply(self, price: float) -> float:
return price * (1 - self.percent / 100)
def describe(self) -> str:
return f"{self.percent:.0f}% discount"
class FixedDiscount(DiscountStrategy):
def __init__(self, amount: float):
self.amount = amount
def apply(self, price: float) -> float:
return max(0.0, price - self.amount)
def describe(self) -> str:
return f"${self.amount:,.2f} off"
class VIPDiscount(DiscountStrategy):
"""VIP customer only — added without modifying existing code"""
def apply(self, price: float) -> float:
return price * 0.7
def describe(self) -> str:
return "VIP 30% discount"
class SeasonalDiscount(DiscountStrategy):
"""Seasonal discount — added without modifying existing code"""
def __init__(self, percent: float, season: str):
self.percent = percent
self.season = season
def apply(self, price: float) -> float:
return price * (1 - self.percent / 100)
def describe(self) -> str:
return f"{self.season} season {self.percent:.0f}% special"
class PriceCalculator:
"""OCP compliant: no modification needed when new discount strategies are added"""
def __init__(self, strategy: DiscountStrategy):
self.strategy = strategy
def calculate(self, price: float) -> float:
return self.strategy.apply(price)
def receipt(self, price: float) -> str:
discounted = self.calculate(price)
saved = price - discounted
return (f"Original: ${price:,.2f} | {self.strategy.describe()} | "
f"Final: ${discounted:,.2f} (saved: ${saved:,.2f})")
# Using various discount strategies
original_price = 100.0
strategies = [
NoDiscount(),
PercentageDiscount(10),
FixedDiscount(15.0),
VIPDiscount(),
SeasonalDiscount(25, "Summer"),
]
for strategy in strategies:
calc = PriceCalculator(strategy)
print(calc.receipt(original_price))
L — Liskov Substitution Principle
"Subtypes must be substitutable for their base types."
Child classes must be usable in every place where the parent class is expected, without breaking behavior. In other words, a child class must not violate the parent's contract (preconditions and postconditions).
Violation Example
class Rectangle:
def __init__(self, width: float, height: float):
self._width = width
self._height = height
@property
def width(self) -> float:
return self._width
@width.setter
def width(self, value: float) -> None:
self._width = value
@property
def height(self) -> float:
return self._height
@height.setter
def height(self, value: float) -> None:
self._height = value
def area(self) -> float:
return self._width * self._height
class Square(Rectangle):
"""LSP violation: Square cannot substitute Rectangle"""
@Rectangle.width.setter
def width(self, value: float) -> None:
self._width = value
self._height = value # Square so height changes too!
@Rectangle.height.setter
def height(self, value: float) -> None:
self._width = value # Square so width changes too!
self._height = value
def resize_rectangle(rect: Rectangle, width: float, height: float) -> None:
"""Behavior changes when Square is used where Rectangle is expected"""
rect.width = width
rect.height = height
expected_area = width * height
actual_area = rect.area()
print(f"Expected area: {expected_area}, Actual area: {actual_area}")
assert actual_area == expected_area, "LSP violation!"
rect = Rectangle(4, 5)
resize_rectangle(rect, 6, 3) # OK: expected area 18, actual area 18
square = Square(4, 4)
# resize_rectangle(square, 6, 3) # LSP violation: actual area 9 (3×3), expected 18
Improved Code
from abc import ABC, abstractmethod
import math
class Shape(ABC):
"""Correct class hierarchy for LSP compliance"""
@abstractmethod
def area(self) -> float:
...
@abstractmethod
def perimeter(self) -> float:
...
def describe(self) -> str:
return (f"{self.__class__.__name__}: "
f"area={self.area():.2f}, perimeter={self.perimeter():.2f}")
class Rectangle(Shape):
def __init__(self, width: float, height: float):
if width <= 0 or height <= 0:
raise ValueError("Width and height must be positive.")
self._width = width
self._height = height
@property
def width(self) -> float:
return self._width
@property
def height(self) -> float:
return self._height
def area(self) -> float:
return self._width * self._height
def perimeter(self) -> float:
return 2 * (self._width + self._height)
def with_size(self, width: float, height: float) -> "Rectangle":
"""Return a new Rectangle with changed dimensions (immutable design)"""
return Rectangle(width, height)
class Square(Shape):
"""Square inherits from Shape directly, not from Rectangle"""
def __init__(self, side: float):
if side <= 0:
raise ValueError("Side length must be positive.")
self._side = side
@property
def side(self) -> float:
return self._side
def area(self) -> float:
return self._side ** 2
def perimeter(self) -> float:
return 4 * self._side
def with_side(self, side: float) -> "Square":
return Square(side)
class Circle(Shape):
def __init__(self, radius: float):
if radius <= 0:
raise ValueError("Radius must be positive.")
self._radius = radius
def area(self) -> float:
return math.pi * self._radius ** 2
def perimeter(self) -> float:
return 2 * math.pi * self._radius
def print_shape_info(shape: Shape) -> None:
"""Any Shape subclass can safely substitute — LSP compliant"""
print(shape.describe())
shapes: list[Shape] = [
Rectangle(6, 4),
Square(5),
Circle(3),
]
for shape in shapes:
print_shape_info(shape) # All Shape subclasses substitute without issues
I — Interface Segregation Principle
"Clients should not be forced to depend on interfaces they do not use."
Many small interfaces are better than one large interface.
Violation Example
from abc import ABC, abstractmethod
# ISP violation: interface demands too much
class WorkerInterface(ABC):
@abstractmethod
def work(self): ...
@abstractmethod
def eat(self): ...
@abstractmethod
def sleep(self): ...
@abstractmethod
def code(self): ...
@abstractmethod
def manage_team(self): ...
@abstractmethod
def write_report(self): ...
class Robot(WorkerInterface):
def work(self):
return "Robot is working."
def eat(self):
# Robots don't eat but forced to implement — ISP violation!
raise NotImplementedError("Robots don't eat.")
def sleep(self):
# Robots don't sleep but forced to implement — ISP violation!
raise NotImplementedError("Robots don't sleep.")
def code(self): ...
def manage_team(self): ...
def write_report(self): ...
Improved Code
from abc import ABC, abstractmethod
from typing import Protocol
# ISP compliant: interfaces split into small units (using Protocol)
class Workable(Protocol):
def work(self) -> str: ...
class Eatable(Protocol):
def eat(self) -> str: ...
class Sleepable(Protocol):
def sleep(self) -> str: ...
class Codeable(Protocol):
def code(self, language: str) -> str: ...
class Manageable(Protocol):
def manage_team(self, team_size: int) -> str: ...
class Developer:
"""Implements only the needed interfaces"""
def __init__(self, name: str):
self.name = name
def work(self) -> str:
return f"{self.name} is developing..."
def eat(self) -> str:
return f"{self.name} is eating..."
def sleep(self) -> str:
return f"{self.name} is sleeping..."
def code(self, language: str) -> str:
return f"{self.name} is coding in {language}..."
class Manager:
"""Manager implements only team management interfaces"""
def __init__(self, name: str):
self.name = name
def work(self) -> str:
return f"{self.name} is coordinating tasks..."
def eat(self) -> str:
return f"{self.name} is at a lunch meeting..."
def sleep(self) -> str:
return f"{self.name} is sleeping..."
def manage_team(self, team_size: int) -> str:
return f"{self.name} is managing a team of {team_size}..."
class Robot:
"""Robot implements only what it needs — no forced eat/sleep"""
def __init__(self, model: str):
self.model = model
def work(self) -> str:
return f"{self.model} is running..."
def code(self, language: str) -> str:
return f"{self.model} is auto-generating {language} code..."
def assign_work(worker: Workable) -> None:
"""Requires only the Workable interface"""
print(worker.work())
def assign_coding_task(coder: Codeable, language: str) -> None:
"""Requires only the Codeable interface"""
print(coder.code(language))
dev = Developer("Alice")
manager = Manager("Bob")
robot = Robot("GPT-Bot")
for worker in [dev, manager, robot]:
assign_work(worker) # All three have work(), so OK
assign_coding_task(dev, "Python")
assign_coding_task(robot, "TypeScript")
# assign_coding_task(manager, "Java") # Manager has no code() → type error
D — Dependency Inversion Principle
"High-level modules should not depend on low-level modules. Both should depend on abstractions."
Depending on abstractions (interfaces) rather than concrete implementations makes code more flexible and easier to test.
Violation Example
# DIP violation: high-level module (OrderService) depends directly on low-level module (MySQLDatabase)
class MySQLDatabase:
def save(self, data: dict) -> None:
print(f"Saving to MySQL: {data}")
class EmailNotifier:
def notify(self, message: str) -> None:
print(f"Sending email: {message}")
class OrderService:
def __init__(self):
# Directly instantiating concrete classes — DIP violation!
self.db = MySQLDatabase()
self.notifier = EmailNotifier()
def place_order(self, order: dict) -> None:
self.db.save(order)
self.notifier.notify(f"Order placed: {order}")
# Replacing MySQLDatabase or EmailNotifier requires modifying this class
Improved Code
from abc import ABC, abstractmethod
from typing import Any
# Abstraction layer — interface definitions
class Database(ABC):
@abstractmethod
def save(self, data: dict[str, Any]) -> None: ...
@abstractmethod
def find(self, query: dict[str, Any]) -> list[dict[str, Any]]: ...
class Notifier(ABC):
@abstractmethod
def send(self, recipient: str, message: str) -> None: ...
# Low-level implementations
class MySQLDatabase(Database):
def __init__(self, host: str, port: int = 3306):
self.host = host
self.port = port
self._data: list[dict] = []
def save(self, data: dict[str, Any]) -> None:
self._data.append(data)
print(f"[MySQL {self.host}:{self.port}] Saved: {data}")
def find(self, query: dict[str, Any]) -> list[dict[str, Any]]:
return [d for d in self._data
if all(d.get(k) == v for k, v in query.items())]
class MongoDatabase(Database):
"""Replacing MySQL with MongoDB — no changes to OrderService"""
def __init__(self, uri: str):
self.uri = uri
self._collections: dict[str, list] = {}
def save(self, data: dict[str, Any]) -> None:
collection = data.get("_collection", "default")
self._collections.setdefault(collection, []).append(data)
print(f"[MongoDB {self.uri}] Saved to {collection}: {data}")
def find(self, query: dict[str, Any]) -> list[dict[str, Any]]:
results = []
for docs in self._collections.values():
results.extend(d for d in docs
if all(d.get(k) == v for k, v in query.items()))
return results
class InMemoryDatabase(Database):
"""In-memory DB for testing"""
def __init__(self):
self._store: list[dict] = []
def save(self, data: dict[str, Any]) -> None:
self._store.append(data.copy())
def find(self, query: dict[str, Any]) -> list[dict[str, Any]]:
return [d for d in self._store
if all(d.get(k) == v for k, v in query.items())]
class EmailNotifier(Notifier):
def send(self, recipient: str, message: str) -> None:
print(f"[Email → {recipient}] {message}")
class SMSNotifier(Notifier):
def send(self, recipient: str, message: str) -> None:
print(f"[SMS → {recipient}] {message}")
class SlackNotifier(Notifier):
def __init__(self, channel: str):
self.channel = channel
def send(self, recipient: str, message: str) -> None:
print(f"[Slack #{self.channel}] @{recipient}: {message}")
# High-level module — depends only on abstractions
class OrderService:
"""DIP compliant: depends only on abstractions (Database, Notifier)"""
def __init__(self, db: Database, notifier: Notifier):
self._db = db # Dependency Injection
self._notifier = notifier
def place_order(self, customer: str, items: list[str], total: float) -> str:
order = {
"customer": customer,
"items": items,
"total": total,
"status": "confirmed",
}
self._db.save(order)
self._notifier.send(customer, f"Order confirmed! Total: ${total:,.2f}")
return f"Order placed: {customer} — {', '.join(items)}"
def get_customer_orders(self, customer: str) -> list[dict]:
return self._db.find({"customer": customer})
# Production environment
prod_service = OrderService(
db=MySQLDatabase("db.example.com"),
notifier=EmailNotifier(),
)
prod_service.place_order("Alice", ["Python Book", "Mouse"], 55.0)
# Test environment — without real DB/email
test_service = OrderService(
db=InMemoryDatabase(),
notifier=SlackNotifier("alerts"),
)
test_service.place_order("TestUser", ["Product A"], 10.0)
# DB is MySQL, notifications via SMS — replacing without modifying OrderService
mixed_service = OrderService(
db=MySQLDatabase("mysql.local"),
notifier=SMSNotifier(),
)
mixed_service.place_order("Bob", ["Chair"], 120.0)
Real-World Example: Applying All SOLID Principles
from abc import ABC, abstractmethod
from typing import Protocol
import json
# --- Interfaces (I: kept small) ---
class Storable(Protocol):
def save(self, key: str, data: str) -> None: ...
def load(self, key: str) -> str | None: ...
class Serializable(Protocol):
def serialize(self, obj: dict) -> str: ...
def deserialize(self, data: str) -> dict: ...
class Validatable(Protocol):
def validate(self, data: dict) -> list[str]: ...
# --- Low-level implementations (D: depend on abstractions) ---
class FileStorage:
"""File storage implementation"""
def __init__(self, base_dir: str = "/tmp"):
self.base_dir = base_dir
self._store: dict[str, str] = {} # Simulation
def save(self, key: str, data: str) -> None:
self._store[key] = data
print(f"[File] Saved to {self.base_dir}/{key}")
def load(self, key: str) -> str | None:
return self._store.get(key)
class JSONSerializer:
"""JSON serialization"""
def serialize(self, obj: dict) -> str:
return json.dumps(obj, ensure_ascii=False)
def deserialize(self, data: str) -> dict:
return json.loads(data)
class ProductValidator:
"""Product validation (S: only handles validation)"""
def validate(self, data: dict) -> list[str]:
errors = []
if not data.get("name"):
errors.append("Product name is required.")
if not isinstance(data.get("price"), (int, float)) or data["price"] <= 0:
errors.append("Price must be a positive number.")
if not isinstance(data.get("stock"), int) or data["stock"] < 0:
errors.append("Stock must be a non-negative integer.")
return errors
# --- High-level module (O: open for extension, D: depends on abstractions) ---
class ProductService:
"""Product service adhering to SOLID principles"""
def __init__(
self,
storage: Storable,
serializer: Serializable,
validator: Validatable,
):
self._storage = storage
self._serializer = serializer
self._validator = validator
def create_product(self, product_data: dict) -> dict:
# Validate
errors = self._validator.validate(product_data)
if errors:
raise ValueError(f"Validation failed: {errors}")
# Save
key = f"product:{product_data['name']}"
serialized = self._serializer.serialize(product_data)
self._storage.save(key, serialized)
print(f"Product registered: {product_data['name']}")
return product_data
def get_product(self, name: str) -> dict | None:
key = f"product:{name}"
data = self._storage.load(key)
if data is None:
return None
return self._serializer.deserialize(data)
# Wiring (Dependency Injection)
service = ProductService(
storage=FileStorage("/data/products"),
serializer=JSONSerializer(),
validator=ProductValidator(),
)
try:
service.create_product({"name": "Python Book", "price": 35.0, "stock": 100})
service.create_product({"name": "Keyboard", "price": 150.0, "stock": 50})
except ValueError as e:
print(f"Error: {e}")
product = service.get_product("Python Book")
print(f"Retrieved: {product}")
# Test with invalid data
try:
service.create_product({"name": "", "price": -100, "stock": -1})
except ValueError as e:
print(f"Expected error: {e}")
Pro Tips
1. SOLID is a guideline, not a rule
# You don't always need to strictly follow every principle
# Excessive abstraction in small scripts or simple code is harmful
# Bad: unnecessary abstraction for simple config loading
class ConfigLoader(ABC):
@abstractmethod
def load(self) -> dict: ...
class FileConfigLoader(ConfigLoader):
def load(self) -> dict:
return {"host": "localhost"}
# Good: a simple function is enough for simple cases
def load_config() -> dict:
return {"host": "localhost"}
2. Dependency injection container pattern for DIP
class Container:
"""Simple dependency injection container"""
def __init__(self):
self._bindings: dict = {}
def bind(self, interface, implementation) -> None:
self._bindings[interface] = implementation
def make(self, interface):
if interface not in self._bindings:
raise KeyError(f"Unregistered interface: {interface}")
impl = self._bindings[interface]
return impl() if callable(impl) else impl
# Register
container = Container()
container.bind("db", InMemoryDatabase)
container.bind("notifier", lambda: SlackNotifier("general"))
# Use
db = container.make("db")
notifier = container.make("notifier")
3. Structural subtyping with Protocol
from typing import Protocol, runtime_checkable
@runtime_checkable
class Drawable(Protocol):
def draw(self) -> str: ...
def resize(self, factor: float) -> None: ...
class Circle:
def __init__(self, r: float):
self.r = r
def draw(self) -> str:
return f"Circle(r={self.r})"
def resize(self, factor: float) -> None:
self.r *= factor
# Protocol compliance can be checked without explicit inheritance
c = Circle(5)
print(isinstance(c, Drawable)) # True (runtime_checkable)
Summary
| Principle | Key Concept | Effect |
|---|---|---|
| SRP | Single responsibility | Minimizes scope of change |
| OCP | Abstraction & extension | Add features without modifying existing code |
| LSP | Correct inheritance | Guarantees safe polymorphism |
| ISP | Small interfaces | Eliminates unnecessary dependencies |
| DIP | Dependency injection | Testable and flexibly replaceable |
The ultimate goal of SOLID principles is to reduce the cost of change. They help you design systems where new requirements can be met with minimal code modifications.