MNIST Tutorial

Complete walkthrough of federated learning with the MNIST handwritten digit dataset. This tutorial demonstrates the core concepts of federated learning using a classic computer vision problem that’s perfect for learning the fundamentals.

What you’ll accomplish:

  • Set up federated learning with multiple simulated participants

  • Train a neural network across distributed nodes without centralizing data

  • Monitor real-time training progress and convergence

  • Analyze results and compare with centralized training

  • Understand privacy preservation and communication efficiency

Prerequisites:

Tutorial Overview

This tutorial uses the MNIST dataset of handwritten digits (0-9) to demonstrate federated learning. We’ll simulate a scenario where multiple participants (hospitals, schools, or organizations) want to collaboratively train a digit recognition model without sharing their raw data.

Key Learning Objectives:

  1. Data Partitioning: How to distribute data across federated participants

  2. Federated Averaging: How model updates are aggregated across participants

  3. Privacy Preservation: How raw data never leaves its origin

  4. Communication Patterns: Efficient communication between participants and server

  5. Convergence Analysis: How federated learning converges compared to centralized training

Architecture Overview

Our MNIST federated learning system uses 4 types of tasks:

┌─────────────┐    ┌──────────────┐    ┌─────────────┐    ┌───────────────┐
│   Worker    │ -> │  Aggregator  │ -> │    Test     │ -> │   Scheduler   │
│  Training   │    │  (FedAvg)    │    │ Evaluation  │    │ Coordination  │
└─────────────┘    └──────────────┘    └─────────────┘    └───────────────┘
       ^                                                            |
       |                                                            v
       +------------------------------------------------------------+
                         (Next Round if not converged)

Task Responsibilities:

  • Worker Training: Performs local SGD on each participant’s data

  • Aggregator: Combines model updates using federated averaging

  • Test Evaluation: Tests the global model on held-out test data

  • Scheduler: Controls training rounds and checks convergence

Setting Up the Environment

1. Get the MNIST Example:

# Navigate to the MNIST example directory
cd examples/fl_pytorch_mnist

2. Install Dependencies:

pip install -r requirements.txt

The requirements include: - PyTorch for neural network training - torchvision for MNIST dataset handling - manta-core for federated learning orchestration - numpy for numerical computations

3. Prepare the MNIST Dataset:

python prepare_data.py

This script: - Downloads the MNIST dataset (if not already present) - Creates data partitions simulating different participants - Stores partitioned data in ../../temp/partitioned/mnist/

4. Verify Data Preparation:

ls ../../temp/partitioned/mnist/
# Should show: node_0/ node_1/ node_2/ ... (one per participant)

Understanding Data Partitioning

The prepare_data.py script creates realistic data distributions:

Default Partitioning (Non-IID):

# Each participant gets 2 digit classes
participant_1: digits [0, 1] (3000 samples)
participant_2: digits [2, 3] (3000 samples)
participant_3: digits [4, 5] (3000 samples)
participant_4: digits [6, 7] (3000 samples)
participant_5: digits [8, 9] (3000 samples)

This simulates realistic scenarios where: - Different hospitals specialize in different conditions - Different schools have students from different demographics - Different organizations collect different types of data

Alternative Partitioning Options:

You can modify prepare_data.py for different scenarios:

# IID Partitioning (uniform distribution)
def create_iid_partitions(dataset, num_participants):
    # Each participant gets random samples from all classes
    samples_per_participant = len(dataset) // num_participants
    return random_split(dataset, [samples_per_participant] * num_participants)

# Dirichlet Partitioning (configurable heterogeneity)
def create_dirichlet_partitions(dataset, num_participants, alpha=0.5):
    # Lower alpha = more heterogeneous, Higher alpha = more homogeneous
    return dirichlet_partition(dataset, num_participants, alpha)

Understanding the Neural Network

The MNIST model is a simple Multi-Layer Perceptron (MLP):

# modules/worker/model.py
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),  # 10 classes (digits 0-9)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

Model Characteristics: - Input: 28x28 grayscale images (784 pixels) - Hidden Layers: Two layers with 512 neurons each - Output: 10 classes (digits 0-9) - Parameters: ~670,000 total parameters - Activation: ReLU for hidden layers, no activation for output

Examining the Federated Learning Components

1. Worker Training Task (modules/worker/worker_task.py):

from manta.light.task import Task
from manta.light.utils import bytes_to_numpy, numpy_to_bytes
from .model import MLP, device

class Worker(Task):
    def __init__(self):
        super().__init__()
        # Load MNIST dataset from local context
        raw_data = self.local.get_binary_data("mnist")
        self.data = np.load(raw_data)
        self.model = MLP()

    def run(self):
        # Get hyperparameters from world globals
        hyperparameters = self.world.globals["hyperparameters"]

        # Get current global model weights
        weights = self.world.globals["global_model_params"]
        self.model.set_weights(bytes_to_numpy(weights))

        # Train the model
        metrics = self.train_model(hyperparameters)

        # Save metrics and model weights to world
        self.world.results.add("train_metrics", metrics)
        self.world.results.add("model_params", numpy_to_bytes(self.model.get_weights()))

    def train_model(self, hyperparameters):
        X_train, y_train = self.data["x_train"], self.data["y_train"]

        # Define loss and optimizer
        criterion = getattr(nn, hyperparameters["loss"])()
        optimizer = getattr(optim, hyperparameters["optimizer"])(
            self.model.parameters(), **hyperparameters["optimizer_params"]
        )

        metrics = {"loss": []}

        # Training loop
        for epoch in range(hyperparameters["epochs"]):
            self.model.train()
            for i in range(0, X_train.shape[0], hyperparameters["batch_size"]):
                optimizer.zero_grad()

                inputs = torch.tensor(X_train[i:i+hyperparameters["batch_size"]], dtype=torch.float32, device=device)
                targets = torch.tensor(y_train[i:i+hyperparameters["batch_size"]], dtype=torch.long, device=device)

                output = self.model(inputs)
                loss = criterion(output, targets)
                loss.backward()
                optimizer.step()

            metrics["loss"].append(loss.item())
            self.logger.info(f"Epoch {epoch + 1}/{hyperparameters['epochs']} Loss: {loss.item()}")

        return metrics

2. Aggregator Task (modules/aggregator.py):

from manta.light.task import Task
from manta.light.utils import bytes_to_numpy, numpy_to_bytes
import numpy as np

class Aggregator(Task):
    def run(self):
        # Collect model parameters from all workers
        models = bytes_to_numpy(self.world.results.select("model_params"))

        if not models:
            self.logger.warning("No model updates received")
            return

        # Perform federated averaging
        aggregated_model = self.aggregate_models(list(models.values()))

        # Update global model parameters
        self.world.globals["global_model_params"] = numpy_to_bytes(aggregated_model)

        self.logger.info(f"Aggregated {len(models)} models")

    def aggregate_models(self, models):
        """Perform federated averaging across models."""
        aggregated_model = {}

        # Average each layer across all models
        for layer in models[0]:
            aggregated_model[layer] = np.mean(
                [model[layer] for model in models], axis=0
            )

        return aggregated_model

3. Test Evaluation Task (modules/worker_test/worker_task.py):

def execute(self):
    # Get the latest global model
    global_model_bytes = self.world.get_global("global_model_params")
    model = MLP()
    model.set_weights(bytes_to_numpy(global_model_bytes))

    # Load test dataset
    test_loader = self.local.get_dataloader("mnist", train=False, batch_size=1000)

    # Evaluate model performance
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += nn.functional.cross_entropy(output, target,
                                                   reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)

    # Report evaluation results
    self.world.set_result("test_metrics", {
        "round": self.get_round_number(),
        "test_loss": test_loss,
        "test_accuracy": accuracy,
        "samples_tested": len(test_loader.dataset)
    })

4. Scheduler Task (modules/scheduler.py):

from manta.light.task import Task

class Scheduler(Task):
    def run(self):
        # Get hyperparameters
        self.hp = self.world.globals["hyperparameters"]

        # Get test metrics from previous iteration
        metrics = self.world.results.select("metrics")
        selected_nodes = self.select_nodes(metrics)

        if len(selected_nodes) == 0:
            # All nodes converged, stop the swarm
            self.logger.info("All nodes converged. Stopping swarm.")
            self.world.stop_swarm()
            return

        # Schedule next iteration for selected nodes
        self.logger.info(f"Nodes selected for next round: {selected_nodes}")
        self.world.schedule_next_iter(
            node_ids=selected_nodes,
            task_to_schedule_alias="worker"
        )

    def select_nodes(self, metrics):
        """Select nodes that haven't converged yet."""
        selected_nodes = []
        val_acc_threshold = self.hp.get("val_acc_threshold", 0.99)

        for node_id, metr in metrics.items():
            self.logger.info(f"Node {node_id} validation accuracy: {metr['val_acc']}")
            if metr["val_acc"] < val_acc_threshold:
                selected_nodes.append(node_id)

        return selected_nodes

Running the MNIST Federated Learning

1. Ensure Your Cluster is Ready:

# Check that Manta platform is running
manta status

# Verify your cluster is RUNNING in the dashboard
# Get your cluster_id and user token from the dashboard

2. Connect Nodes to Your Cluster:

# Check available clusters
manta sdk cluster list

3. Deploy the MNIST Federated Learning Swarm:

# Deploy using the real deployment script
python start_swarm.py --token <your_token> --cluster_id <your_cluster_id>

Expected Output:

UserAPI availability: Service is available

Defining swarm with image 'manta_light:pytorch' and GPU=False
Swarm details: <fl_pytorch_mnist.swarm.FLSwarm object at 0x...>

Deploying swarm...
Swarm Deployment Overview:
  swarm_id: swarm_a1b2c3d4e5f6
  status: DEPLOYED
  tasks_scheduled: 4
  nodes_allocated: 2

Swarm 'swarm_a1b2c3d4e5f6' deployed successfully.

4. Monitor Real-Time Training:

The script will stream training updates:

Task Update: worker_train_001 - RUNNING - node_001
Task Update: worker_train_002 - RUNNING - node_002

Round 1 Results:
- Worker 1: Local loss 2.156, trained on 3000 samples
- Worker 2: Local loss 2.091, trained on 3000 samples
- Aggregator: Combined models from 2 workers (6000 total samples)
- Global Test Accuracy: 92.3%
- Scheduler: Continue training (threshold 99.0% not reached)

Task Update: worker_train_001 - RUNNING - node_001
Task Update: worker_train_002 - RUNNING - node_002

Round 2 Results:
- Worker 1: Local loss 1.234, improved convergence
- Worker 2: Local loss 1.189, improved convergence
- Aggregator: Updated global model
- Global Test Accuracy: 96.7%
- Scheduler: Continue training

...

Round 8 Results:
- Global Test Accuracy: 99.1%
- Scheduler: CONVERGED (threshold 99.0% reached)
- Training completed successfully!

Understanding the Results

Training Progression:

Typical MNIST federated learning shows this progression:

Round

Global Accuracy

Avg Local Loss

Communication

Status

1

92.3%

2.123

2.4MB

Starting to learn

3

96.7%

1.234

7.2MB

Rapid improvement

5

98.1%

0.876

12.0MB

Approaching convergence

8

99.1%

0.543

19.2MB

Converged!

Key Observations:

  1. Rapid Initial Learning: Accuracy jumps from ~10% (random) to 90%+ in first few rounds

  2. Gradual Convergence: Final 1-2% accuracy improvement takes several rounds

  3. Communication Efficiency: Only ~20MB total communication for full training

  4. Privacy Preservation: Raw MNIST images never transmitted, only model parameters

Communication Analysis:

Model Size: ~670K parameters × 4 bytes/param = 2.68MB per round

Round Communication:
- 2 workers upload model updates: 2 × 2.68MB = 5.36MB upload
- 2 workers download global model: 2 × 2.68MB = 5.36MB download
- Total per round: 10.72MB
- 8 rounds total: ~86MB communication

Compare to centralized training:
- Would need to transfer all training data: ~47MB (raw MNIST)
- Plus ongoing model checkpoints and logs
- Federated learning saves bandwidth and preserves privacy

Analyzing Privacy Preservation

What Stays Private:

# This data NEVER leaves each participant's node:
participant_1_data = load_mnist_partition("node_0")  # Only digits 0,1
participant_2_data = load_mnist_partition("node_1")  # Only digits 2,3

# Raw images, labels, intermediate activations all stay local
# Only model parameters (weights/biases) are shared

What Gets Shared:

# Only these model parameters are communicated:
shared_update = {
    "model_params": compressed_weights_and_biases,  # ~2.68MB
    "num_samples": 3000,  # Just the count, not the data
    "local_loss": 1.234   # Just the final loss value
}

# No raw data, no gradients, no intermediate results

Privacy Analysis:

  • Input Privacy: Raw MNIST images never transmitted

  • Label Privacy: True labels remain on originating nodes

  • Gradient Privacy: Only final model parameters shared, not gradients

  • Inference Privacy: Difficult to reverse-engineer training data from model weights

Comparing with Centralized Training

Run this comparison to understand federated learning trade-offs:

# Compare federated vs centralized training
import torch
from torch.utils.data import DataLoader, ConcatDataset

# Simulate centralized training
def centralized_training():
    # Combine all data (what FL avoids)
    all_data = []
    for i in range(num_participants):
        participant_data = load_partition(f"node_{i}")
        all_data.append(participant_data)

    combined_dataset = ConcatDataset(all_data)
    dataloader = DataLoader(combined_dataset, batch_size=64, shuffle=True)

    # Train on combined data
    model = MLP()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(10):  # Same total training
        for data, target in dataloader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

    return evaluate_model(model)

Expected Comparison Results:

Approach

Final Accuracy

Training Time

Privacy

Centralized

99.3%

5 minutes

None

Federated

99.1%

8 minutes

Full

Accuracy Loss

-0.2%

+60% time

+100% privacy

Troubleshooting Common Issues

Training Doesn’t Start:

# Check data preparation
ls ../../temp/partitioned/mnist/
# Should show node_0, node_1, etc. directories

# Check nodes are connected
# Verify in cluster dashboard that nodes show as "Connected"

# Check for Python errors
docker logs manager-api | grep ERROR

Poor Convergence:

# Check data distribution
import torch

for i in range(2):  # For each node
    data = torch.load(f"../../temp/partitioned/mnist/node_{i}/train.pt")
    print(f"Node {i}: {len(data)} samples")
    print(f"Classes: {torch.unique(data.targets)}")

# Should show reasonable distribution
# If one node has all data, check prepare_data.py

Slow Training:

# Check resource usage
docker stats

# Look for CPU/memory bottlenecks
# Consider reducing batch size or using GPU

Memory Issues:

# In swarm.py, reduce batch size
self.set_global("hyperparameters", {
    "batch_size": 16,  # Instead of 32
    # ... other params
})

Communication Failures:

# Check MQTT broker
docker logs emqx

# Verify network connectivity
docker exec manager-api ping emqx

Extending the Tutorial

1. Try Different Data Distributions:

# Modify prepare_data.py for extreme non-IID
def create_extreme_non_iid():
    # Give each participant only 1 digit class
    for i in range(10):
        participant_data = filter_by_class(mnist_data, class_id=i)
        save_partition(participant_data, f"node_{i}")

2. Experiment with Hyperparameters:

# In swarm.py, try different settings
self.set_global("hyperparameters", {
    "epochs": 2,           # More local training per round
    "batch_size": 64,      # Larger batches
    "optimizer_params": {
        "lr": 0.001,       # Lower learning rate
        "momentum": 0.95   # Higher momentum
    }
})

3. Add Custom Metrics:

# In test evaluation, add more metrics
from sklearn.metrics import precision_recall_fscore_support

def execute(self):
    # ... existing evaluation code ...

    # Add per-class metrics
    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average=None
    )

    self.world.set_result("detailed_metrics", {
        "precision_per_class": precision.tolist(),
        "recall_per_class": recall.tolist(),
        "f1_per_class": f1.tolist()
    })

4. Implement Differential Privacy:

# Add noise to model updates for enhanced privacy
def add_differential_privacy(model_params, epsilon=1.0):
    noise_scale = 2.0 / epsilon
    noise = np.random.laplace(0, noise_scale, model_params.shape)
    return model_params + noise

# In worker task:
noisy_params = add_differential_privacy(model.get_weights())
self.world.set_result("model_update", {
    "model_params": numpy_to_bytes(noisy_params),
    # ... other fields
})

Next Steps

Congratulations! You’ve successfully completed MNIST federated learning. Next steps:

Immediate Next Steps: 1. Try CIFAR-10: CIFAR-10 tutorial - More complex image classification 2. Custom Data: Custom dataset tutorial - Use your own datasets 3. Advanced Monitoring: Real-time monitoring - Deep dive into metrics

Advanced Exploration: 4. Framework Integration: Flower integration - Use with Flower 5. Multi-Cluster: Multi-cluster deployment - Scale across multiple clusters 6. Production Deployment: Production deployment - Real-world deployment

Research Directions: 7. Algorithm Variations: Implement FedProx, FedNova, or other FL algorithms 8. Privacy Enhancements: Add secure aggregation and differential privacy 9. Robustness: Test with node failures and byzantine participants

You now understand the fundamentals of federated learning and have hands-on experience with Manta’s federated learning platform!