Federated Learning¶
Master federated learning with Manta through comprehensive tutorials using real datasets and production-ready algorithms. These tutorials guide you from basic concepts to advanced federated learning patterns.
What you’ll master:
MNIST Tutorial: Classic federated learning with handwritten digit classification
CIFAR-10 Tutorial: More complex federated learning with natural image classification
Custom Dataset Integration: Adapt tutorials for your own data and use cases
Advanced FL Patterns: Explore cutting-edge federated learning techniques
Why These Tutorials?¶
Real-World Relevance: These tutorials mirror actual federated learning deployments used in healthcare, finance, and IoT applications.
Production-Ready Code: All examples use production patterns you can deploy in real environments.
Comprehensive Coverage: From basic FedAvg to advanced techniques like differential privacy and secure aggregation.
Hands-On Learning: Every concept is demonstrated with working code you can modify and extend.
Tutorial Overview¶
- MNIST Tutorial
- Tutorial Overview
- Architecture Overview
- Setting Up the Environment
- Understanding Data Partitioning
- Understanding the Neural Network
- Examining the Federated Learning Components
- Running the MNIST Federated Learning
- Understanding the Results
- Analyzing Privacy Preservation
- Comparing with Centralized Training
- Troubleshooting Common Issues
- Extending the Tutorial
- Next Steps
Learning Objectives¶
By completing these tutorials, you will:
Understand Federated Learning Fundamentals: - How federated learning preserves data privacy - Communication patterns between clients and servers - Aggregation algorithms (FedAvg, FedProx, etc.) - Handling non-IID data distributions
Master Manta’s FL Implementation: - Swarm and task architecture for federated learning - Node coordination and communication patterns - Real-time monitoring and result collection - Debugging and troubleshooting FL workloads
Implement Production Patterns: - Fault tolerance and failure recovery - Communication efficiency optimization - Security and privacy considerations - Scalability patterns for large deployments
Analyze and Optimize Performance: - Training convergence analysis - Communication cost measurement - Resource utilization optimization - Comparison with centralized training
Tutorial Architecture¶
Our federated learning tutorials follow a consistent 4-component architecture:
1. Worker Nodes (Training):
class WorkerTask:
def execute(self):
# Load local data partition
local_data = self.local.get_dataset("mnist")
# Get global model from aggregator
global_model = self.world.get_global("model_params")
# Perform local training
local_update = train_local_model(global_model, local_data)
# Send update to aggregator
self.world.set_result("model_update", local_update)
2. Aggregator Node (Model Averaging):
class AggregatorTask:
def execute(self):
# Collect updates from all workers
updates = self.world.get_results("model_update")
# Perform federated averaging
global_model = federated_average(updates)
# Broadcast updated global model
self.world.set_global("model_params", global_model)
3. Test Node (Global Evaluation):
class TestTask:
def execute(self):
# Get latest global model
global_model = self.world.get_global("model_params")
# Evaluate on test dataset
test_accuracy = evaluate_model(global_model, test_data)
# Report evaluation results
self.world.set_result("test_metrics", {
"accuracy": test_accuracy,
"round": self.get_round_number()
})
4. Scheduler Node (Coordination):
class SchedulerTask:
def execute(self):
# Check convergence criteria
metrics = self.world.get_latest_result("test_metrics")
if metrics["accuracy"] > convergence_threshold:
return "CONVERGED" # Stop training
else:
return "CONTINUE" # Next round
Key Concepts Covered¶
Data Partitioning Strategies:
Our tutorials demonstrate different data distribution patterns:
# IID Partitioning (uniform distribution)
iid_partitions = partition_iid(dataset, num_clients=10)
# Non-IID Partitioning (realistic distribution)
non_iid_partitions = partition_non_iid(
dataset,
num_clients=10,
classes_per_client=2 # Each client has 2 classes
)
# Dirichlet Partitioning (configurable skew)
dirichlet_partitions = partition_dirichlet(
dataset,
num_clients=10,
alpha=0.5 # Controls level of non-IID-ness
)
Communication Optimization:
Reduce communication overhead with these techniques:
# Gradient Compression
compressed_update = compress_gradients(
local_update,
compression_ratio=0.1
)
# Periodic Communication (local epochs)
for local_epoch in range(local_epochs):
model = train_one_epoch(model, local_data)
# Only communicate after multiple local updates
self.world.set_result("model_update", model.parameters())
# Differential Privacy
noisy_update = add_gaussian_noise(
local_update,
noise_scale=privacy_budget
)
Advanced Aggregation Methods:
Beyond simple averaging:
# Weighted FedAvg (by data size)
def weighted_fedavg(updates, data_sizes):
total_size = sum(data_sizes)
weighted_params = []
for params, size in zip(updates, data_sizes):
weight = size / total_size
weighted_params.append(weight * params)
return sum(weighted_params)
# FedProx (with proximal term)
def fedprox_update(global_model, local_update, mu=0.01):
proximal_term = mu * (local_update - global_model)
return local_update - proximal_term
# Custom aggregation with outlier detection
def robust_aggregation(updates):
# Remove statistical outliers
filtered_updates = remove_outliers(updates)
return federated_average(filtered_updates)
Performance Analysis¶
Each tutorial includes comprehensive performance analysis:
Convergence Analysis:
# Track training progress
training_metrics = {
"round": round_number,
"global_loss": global_test_loss,
"global_accuracy": global_test_accuracy,
"local_losses": [worker.local_loss for worker in workers],
"communication_cost": total_bytes_transmitted,
"training_time": round_duration
}
Expected Performance Benchmarks:
Tutorial |
Final Accuracy |
Training Rounds |
Communication Cost |
---|---|---|---|
MNIST (IID) |
98.5% |
10-15 rounds |
~50MB total |
MNIST (Non-IID) |
96.8% |
15-25 rounds |
~75MB total |
CIFAR-10 (IID) |
85.2% |
50-100 rounds |
~500MB total |
CIFAR-10 (Non-IID) |
82.1% |
100-150 rounds |
~750MB total |
Comparison with Centralized Training:
Understanding the federated learning trade-offs:
Centralized MNIST Training:
- Accuracy: 99.1% (0.6% higher than FL)
- Training Time: 5 minutes (vs 12 minutes FL)
- Privacy: None (vs Full privacy preservation)
- Scalability: Single machine (vs Unlimited nodes)
Real-World Considerations¶
Data Heterogeneity Simulation:
Our tutorials simulate real-world data distributions:
# Healthcare scenario: Different hospitals have different patient populations
hospital_1_data = filter_by_demographics(mnist_data, age_range="elderly")
hospital_2_data = filter_by_demographics(mnist_data, age_range="pediatric")
# IoT scenario: Different sensors have different operating conditions
sensor_1_data = add_noise(cifar_data, noise_level="urban")
sensor_2_data = add_noise(cifar_data, noise_level="rural")
System Heterogeneity:
Handle varying computational resources:
# Adaptive batch sizes based on node capabilities
if node.memory_gb < 4:
batch_size = 16 # Smaller batches for constrained devices
else:
batch_size = 64 # Larger batches for powerful devices
# Asynchronous training for unreliable connections
async def train_with_timeout():
try:
result = await asyncio.wait_for(
train_local_model(),
timeout=max_training_time
)
return result
except asyncio.TimeoutError:
return partial_result # Return partial training if timeout
Security and Privacy:
Production-ready privacy preservation:
# Differential Privacy
def add_differential_privacy(gradients, epsilon=1.0):
noise_scale = 2.0 / epsilon # Privacy budget
noise = np.random.laplace(0, noise_scale, gradients.shape)
return gradients + noise
# Secure Aggregation (conceptual)
def secure_aggregation(encrypted_updates):
# Homomorphic encryption allows computation on encrypted data
encrypted_sum = sum(encrypted_updates)
return decrypt(encrypted_sum) # Only aggregator can decrypt sum
Best Practices from Tutorials¶
Development Workflow:
Start Simple: Begin with IID data and basic FedAvg
Add Complexity Gradually: Introduce non-IID data, then optimization
Monitor Carefully: Watch convergence and communication patterns
Profile Performance: Measure training time and communication costs
Test Robustness: Simulate node failures and network issues
Production Deployment:
Validate on Realistic Data: Use real data distributions, not synthetic
Plan for Heterogeneity: Design for varying node capabilities
Implement Fault Tolerance: Handle node disconnections gracefully
Monitor Privacy: Ensure privacy guarantees are maintained
Scale Gradually: Start with small deployments and scale up
Debugging and Troubleshooting:
Check Data Partitions: Ensure balanced and representative data splits
Monitor Convergence: Watch for oscillating or stalled training
Analyze Communication: Identify bottlenecks in model transmission
Validate Aggregation: Ensure proper weight averaging
Test Node Isolation: Verify nodes can’t access each other’s data
Next Steps¶
Ready to Start?
MNIST Beginners: MNIST Tutorial - Perfect introduction to federated learning
Computer Vision Focus: CIFAR-10 tutorial - More realistic image classification
Custom Applications: Custom dataset tutorial - Adapt patterns for your domain
After Completing Tutorials:
Advanced Patterns: Multi-cluster federated learning
Framework Integration: Use with existing FL frameworks
Production Deployment: Deploy in real environments
Each tutorial includes complete working code, detailed explanations, and troubleshooting guides to ensure your success with federated learning on Manta.