Federated Learning

Master federated learning with Manta through comprehensive tutorials using real datasets and production-ready algorithms. These tutorials guide you from basic concepts to advanced federated learning patterns.

What you’ll master:

  • MNIST Tutorial: Classic federated learning with handwritten digit classification

  • CIFAR-10 Tutorial: More complex federated learning with natural image classification

  • Custom Dataset Integration: Adapt tutorials for your own data and use cases

  • Advanced FL Patterns: Explore cutting-edge federated learning techniques

Why These Tutorials?

Real-World Relevance: These tutorials mirror actual federated learning deployments used in healthcare, finance, and IoT applications.

Production-Ready Code: All examples use production patterns you can deploy in real environments.

Comprehensive Coverage: From basic FedAvg to advanced techniques like differential privacy and secure aggregation.

Hands-On Learning: Every concept is demonstrated with working code you can modify and extend.

Tutorial Overview

Learning Objectives

By completing these tutorials, you will:

Understand Federated Learning Fundamentals: - How federated learning preserves data privacy - Communication patterns between clients and servers - Aggregation algorithms (FedAvg, FedProx, etc.) - Handling non-IID data distributions

Master Manta’s FL Implementation: - Swarm and task architecture for federated learning - Node coordination and communication patterns - Real-time monitoring and result collection - Debugging and troubleshooting FL workloads

Implement Production Patterns: - Fault tolerance and failure recovery - Communication efficiency optimization - Security and privacy considerations - Scalability patterns for large deployments

Analyze and Optimize Performance: - Training convergence analysis - Communication cost measurement - Resource utilization optimization - Comparison with centralized training

Tutorial Architecture

Our federated learning tutorials follow a consistent 4-component architecture:

1. Worker Nodes (Training):

class WorkerTask:
    def execute(self):
        # Load local data partition
        local_data = self.local.get_dataset("mnist")

        # Get global model from aggregator
        global_model = self.world.get_global("model_params")

        # Perform local training
        local_update = train_local_model(global_model, local_data)

        # Send update to aggregator
        self.world.set_result("model_update", local_update)

2. Aggregator Node (Model Averaging):

class AggregatorTask:
    def execute(self):
        # Collect updates from all workers
        updates = self.world.get_results("model_update")

        # Perform federated averaging
        global_model = federated_average(updates)

        # Broadcast updated global model
        self.world.set_global("model_params", global_model)

3. Test Node (Global Evaluation):

class TestTask:
    def execute(self):
        # Get latest global model
        global_model = self.world.get_global("model_params")

        # Evaluate on test dataset
        test_accuracy = evaluate_model(global_model, test_data)

        # Report evaluation results
        self.world.set_result("test_metrics", {
            "accuracy": test_accuracy,
            "round": self.get_round_number()
        })

4. Scheduler Node (Coordination):

class SchedulerTask:
    def execute(self):
        # Check convergence criteria
        metrics = self.world.get_latest_result("test_metrics")

        if metrics["accuracy"] > convergence_threshold:
            return "CONVERGED"  # Stop training
        else:
            return "CONTINUE"   # Next round

Key Concepts Covered

Data Partitioning Strategies:

Our tutorials demonstrate different data distribution patterns:

# IID Partitioning (uniform distribution)
iid_partitions = partition_iid(dataset, num_clients=10)

# Non-IID Partitioning (realistic distribution)
non_iid_partitions = partition_non_iid(
    dataset,
    num_clients=10,
    classes_per_client=2  # Each client has 2 classes
)

# Dirichlet Partitioning (configurable skew)
dirichlet_partitions = partition_dirichlet(
    dataset,
    num_clients=10,
    alpha=0.5  # Controls level of non-IID-ness
)

Communication Optimization:

Reduce communication overhead with these techniques:

# Gradient Compression
compressed_update = compress_gradients(
    local_update,
    compression_ratio=0.1
)

# Periodic Communication (local epochs)
for local_epoch in range(local_epochs):
    model = train_one_epoch(model, local_data)

# Only communicate after multiple local updates
self.world.set_result("model_update", model.parameters())

# Differential Privacy
noisy_update = add_gaussian_noise(
    local_update,
    noise_scale=privacy_budget
)

Advanced Aggregation Methods:

Beyond simple averaging:

# Weighted FedAvg (by data size)
def weighted_fedavg(updates, data_sizes):
    total_size = sum(data_sizes)
    weighted_params = []

    for params, size in zip(updates, data_sizes):
        weight = size / total_size
        weighted_params.append(weight * params)

    return sum(weighted_params)

# FedProx (with proximal term)
def fedprox_update(global_model, local_update, mu=0.01):
    proximal_term = mu * (local_update - global_model)
    return local_update - proximal_term

# Custom aggregation with outlier detection
def robust_aggregation(updates):
    # Remove statistical outliers
    filtered_updates = remove_outliers(updates)
    return federated_average(filtered_updates)

Performance Analysis

Each tutorial includes comprehensive performance analysis:

Convergence Analysis:

# Track training progress
training_metrics = {
    "round": round_number,
    "global_loss": global_test_loss,
    "global_accuracy": global_test_accuracy,
    "local_losses": [worker.local_loss for worker in workers],
    "communication_cost": total_bytes_transmitted,
    "training_time": round_duration
}

Expected Performance Benchmarks:

Tutorial

Final Accuracy

Training Rounds

Communication Cost

MNIST (IID)

98.5%

10-15 rounds

~50MB total

MNIST (Non-IID)

96.8%

15-25 rounds

~75MB total

CIFAR-10 (IID)

85.2%

50-100 rounds

~500MB total

CIFAR-10 (Non-IID)

82.1%

100-150 rounds

~750MB total

Comparison with Centralized Training:

Understanding the federated learning trade-offs:

Centralized MNIST Training:
- Accuracy: 99.1% (0.6% higher than FL)
- Training Time: 5 minutes (vs 12 minutes FL)
- Privacy: None (vs Full privacy preservation)
- Scalability: Single machine (vs Unlimited nodes)

Real-World Considerations

Data Heterogeneity Simulation:

Our tutorials simulate real-world data distributions:

# Healthcare scenario: Different hospitals have different patient populations
hospital_1_data = filter_by_demographics(mnist_data, age_range="elderly")
hospital_2_data = filter_by_demographics(mnist_data, age_range="pediatric")

# IoT scenario: Different sensors have different operating conditions
sensor_1_data = add_noise(cifar_data, noise_level="urban")
sensor_2_data = add_noise(cifar_data, noise_level="rural")

System Heterogeneity:

Handle varying computational resources:

# Adaptive batch sizes based on node capabilities
if node.memory_gb < 4:
    batch_size = 16  # Smaller batches for constrained devices
else:
    batch_size = 64  # Larger batches for powerful devices

# Asynchronous training for unreliable connections
async def train_with_timeout():
    try:
        result = await asyncio.wait_for(
            train_local_model(),
            timeout=max_training_time
        )
        return result
    except asyncio.TimeoutError:
        return partial_result  # Return partial training if timeout

Security and Privacy:

Production-ready privacy preservation:

# Differential Privacy
def add_differential_privacy(gradients, epsilon=1.0):
    noise_scale = 2.0 / epsilon  # Privacy budget
    noise = np.random.laplace(0, noise_scale, gradients.shape)
    return gradients + noise

# Secure Aggregation (conceptual)
def secure_aggregation(encrypted_updates):
    # Homomorphic encryption allows computation on encrypted data
    encrypted_sum = sum(encrypted_updates)
    return decrypt(encrypted_sum)  # Only aggregator can decrypt sum

Best Practices from Tutorials

Development Workflow:

  1. Start Simple: Begin with IID data and basic FedAvg

  2. Add Complexity Gradually: Introduce non-IID data, then optimization

  3. Monitor Carefully: Watch convergence and communication patterns

  4. Profile Performance: Measure training time and communication costs

  5. Test Robustness: Simulate node failures and network issues

Production Deployment:

  1. Validate on Realistic Data: Use real data distributions, not synthetic

  2. Plan for Heterogeneity: Design for varying node capabilities

  3. Implement Fault Tolerance: Handle node disconnections gracefully

  4. Monitor Privacy: Ensure privacy guarantees are maintained

  5. Scale Gradually: Start with small deployments and scale up

Debugging and Troubleshooting:

  1. Check Data Partitions: Ensure balanced and representative data splits

  2. Monitor Convergence: Watch for oscillating or stalled training

  3. Analyze Communication: Identify bottlenecks in model transmission

  4. Validate Aggregation: Ensure proper weight averaging

  5. Test Node Isolation: Verify nodes can’t access each other’s data

Next Steps

Ready to Start?

  1. MNIST Beginners: MNIST Tutorial - Perfect introduction to federated learning

  2. Computer Vision Focus: CIFAR-10 tutorial - More realistic image classification

  3. Custom Applications: Custom dataset tutorial - Adapt patterns for your domain

After Completing Tutorials:

  • Advanced Patterns: Multi-cluster federated learning

  • Framework Integration: Use with existing FL frameworks

  • Production Deployment: Deploy in real environments

Each tutorial includes complete working code, detailed explanations, and troubleshooting guides to ensure your success with federated learning on Manta.