MNIST Tutorial¶
Complete walkthrough of federated learning with the MNIST handwritten digit dataset. This tutorial demonstrates the core concepts of federated learning using a classic computer vision problem that’s perfect for learning the fundamentals.
What you’ll accomplish:
Set up federated learning with multiple simulated participants
Train a neural network across distributed nodes without centralizing data
Monitor real-time training progress and convergence
Analyze results and compare with centralized training
Understand privacy preservation and communication efficiency
Prerequisites:
Manta platform installed and running (Getting Started)
Created and started a cluster (Step 2: Cluster Setup)
Basic understanding of neural networks and PyTorch
Tutorial Overview¶
This tutorial uses the MNIST dataset of handwritten digits (0-9) to demonstrate federated learning. We’ll simulate a scenario where multiple participants (hospitals, schools, or organizations) want to collaboratively train a digit recognition model without sharing their raw data.
Key Learning Objectives:
Data Partitioning: How to distribute data across federated participants
Federated Averaging: How model updates are aggregated across participants
Privacy Preservation: How raw data never leaves its origin
Communication Patterns: Efficient communication between participants and server
Convergence Analysis: How federated learning converges compared to centralized training
Architecture Overview¶
Our MNIST federated learning system uses 4 types of tasks:
┌─────────────┐ ┌──────────────┐ ┌─────────────┐ ┌───────────────┐
│ Worker │ -> │ Aggregator │ -> │ Test │ -> │ Scheduler │
│ Training │ │ (FedAvg) │ │ Evaluation │ │ Coordination │
└─────────────┘ └──────────────┘ └─────────────┘ └───────────────┘
^ |
| v
+------------------------------------------------------------+
(Next Round if not converged)
Task Responsibilities:
Worker Training: Performs local SGD on each participant’s data
Aggregator: Combines model updates using federated averaging
Test Evaluation: Tests the global model on held-out test data
Scheduler: Controls training rounds and checks convergence
Setting Up the Environment¶
1. Get the MNIST Example:
# Navigate to the MNIST example directory
cd examples/fl_pytorch_mnist
2. Install Dependencies:
pip install -r requirements.txt
The requirements include: - PyTorch for neural network training - torchvision for MNIST dataset handling - manta-core for federated learning orchestration - numpy for numerical computations
3. Prepare the MNIST Dataset:
python prepare_data.py
This script:
- Downloads the MNIST dataset (if not already present)
- Creates data partitions simulating different participants
- Stores partitioned data in ../../temp/partitioned/mnist/
4. Verify Data Preparation:
ls ../../temp/partitioned/mnist/
# Should show: node_0/ node_1/ node_2/ ... (one per participant)
Understanding Data Partitioning¶
The prepare_data.py
script creates realistic data distributions:
Default Partitioning (Non-IID):
# Each participant gets 2 digit classes
participant_1: digits [0, 1] (3000 samples)
participant_2: digits [2, 3] (3000 samples)
participant_3: digits [4, 5] (3000 samples)
participant_4: digits [6, 7] (3000 samples)
participant_5: digits [8, 9] (3000 samples)
This simulates realistic scenarios where: - Different hospitals specialize in different conditions - Different schools have students from different demographics - Different organizations collect different types of data
Alternative Partitioning Options:
You can modify prepare_data.py
for different scenarios:
# IID Partitioning (uniform distribution)
def create_iid_partitions(dataset, num_participants):
# Each participant gets random samples from all classes
samples_per_participant = len(dataset) // num_participants
return random_split(dataset, [samples_per_participant] * num_participants)
# Dirichlet Partitioning (configurable heterogeneity)
def create_dirichlet_partitions(dataset, num_participants, alpha=0.5):
# Lower alpha = more heterogeneous, Higher alpha = more homogeneous
return dirichlet_partition(dataset, num_participants, alpha)
Understanding the Neural Network¶
The MNIST model is a simple Multi-Layer Perceptron (MLP):
# modules/worker/model.py
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10), # 10 classes (digits 0-9)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
Model Characteristics: - Input: 28x28 grayscale images (784 pixels) - Hidden Layers: Two layers with 512 neurons each - Output: 10 classes (digits 0-9) - Parameters: ~670,000 total parameters - Activation: ReLU for hidden layers, no activation for output
Examining the Federated Learning Components¶
1. Worker Training Task (modules/worker/worker_task.py):
from manta.light.task import Task
from manta.light.utils import bytes_to_numpy, numpy_to_bytes
from .model import MLP, device
class Worker(Task):
def __init__(self):
super().__init__()
# Load MNIST dataset from local context
raw_data = self.local.get_binary_data("mnist")
self.data = np.load(raw_data)
self.model = MLP()
def run(self):
# Get hyperparameters from world globals
hyperparameters = self.world.globals["hyperparameters"]
# Get current global model weights
weights = self.world.globals["global_model_params"]
self.model.set_weights(bytes_to_numpy(weights))
# Train the model
metrics = self.train_model(hyperparameters)
# Save metrics and model weights to world
self.world.results.add("train_metrics", metrics)
self.world.results.add("model_params", numpy_to_bytes(self.model.get_weights()))
def train_model(self, hyperparameters):
X_train, y_train = self.data["x_train"], self.data["y_train"]
# Define loss and optimizer
criterion = getattr(nn, hyperparameters["loss"])()
optimizer = getattr(optim, hyperparameters["optimizer"])(
self.model.parameters(), **hyperparameters["optimizer_params"]
)
metrics = {"loss": []}
# Training loop
for epoch in range(hyperparameters["epochs"]):
self.model.train()
for i in range(0, X_train.shape[0], hyperparameters["batch_size"]):
optimizer.zero_grad()
inputs = torch.tensor(X_train[i:i+hyperparameters["batch_size"]], dtype=torch.float32, device=device)
targets = torch.tensor(y_train[i:i+hyperparameters["batch_size"]], dtype=torch.long, device=device)
output = self.model(inputs)
loss = criterion(output, targets)
loss.backward()
optimizer.step()
metrics["loss"].append(loss.item())
self.logger.info(f"Epoch {epoch + 1}/{hyperparameters['epochs']} Loss: {loss.item()}")
return metrics
2. Aggregator Task (modules/aggregator.py):
from manta.light.task import Task
from manta.light.utils import bytes_to_numpy, numpy_to_bytes
import numpy as np
class Aggregator(Task):
def run(self):
# Collect model parameters from all workers
models = bytes_to_numpy(self.world.results.select("model_params"))
if not models:
self.logger.warning("No model updates received")
return
# Perform federated averaging
aggregated_model = self.aggregate_models(list(models.values()))
# Update global model parameters
self.world.globals["global_model_params"] = numpy_to_bytes(aggregated_model)
self.logger.info(f"Aggregated {len(models)} models")
def aggregate_models(self, models):
"""Perform federated averaging across models."""
aggregated_model = {}
# Average each layer across all models
for layer in models[0]:
aggregated_model[layer] = np.mean(
[model[layer] for model in models], axis=0
)
return aggregated_model
3. Test Evaluation Task (modules/worker_test/worker_task.py):
def execute(self):
# Get the latest global model
global_model_bytes = self.world.get_global("global_model_params")
model = MLP()
model.set_weights(bytes_to_numpy(global_model_bytes))
# Load test dataset
test_loader = self.local.get_dataloader("mnist", train=False, batch_size=1000)
# Evaluate model performance
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
test_loss += nn.functional.cross_entropy(output, target,
reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
# Report evaluation results
self.world.set_result("test_metrics", {
"round": self.get_round_number(),
"test_loss": test_loss,
"test_accuracy": accuracy,
"samples_tested": len(test_loader.dataset)
})
4. Scheduler Task (modules/scheduler.py):
from manta.light.task import Task
class Scheduler(Task):
def run(self):
# Get hyperparameters
self.hp = self.world.globals["hyperparameters"]
# Get test metrics from previous iteration
metrics = self.world.results.select("metrics")
selected_nodes = self.select_nodes(metrics)
if len(selected_nodes) == 0:
# All nodes converged, stop the swarm
self.logger.info("All nodes converged. Stopping swarm.")
self.world.stop_swarm()
return
# Schedule next iteration for selected nodes
self.logger.info(f"Nodes selected for next round: {selected_nodes}")
self.world.schedule_next_iter(
node_ids=selected_nodes,
task_to_schedule_alias="worker"
)
def select_nodes(self, metrics):
"""Select nodes that haven't converged yet."""
selected_nodes = []
val_acc_threshold = self.hp.get("val_acc_threshold", 0.99)
for node_id, metr in metrics.items():
self.logger.info(f"Node {node_id} validation accuracy: {metr['val_acc']}")
if metr["val_acc"] < val_acc_threshold:
selected_nodes.append(node_id)
return selected_nodes
Running the MNIST Federated Learning¶
1. Ensure Your Cluster is Ready:
# Check that Manta platform is running
manta status
# Verify your cluster is RUNNING in the dashboard
# Get your cluster_id and user token from the dashboard
2. Connect Nodes to Your Cluster:
# Check available clusters
manta sdk cluster list
3. Deploy the MNIST Federated Learning Swarm:
# Deploy using the real deployment script
python start_swarm.py --token <your_token> --cluster_id <your_cluster_id>
Expected Output:
UserAPI availability: Service is available
Defining swarm with image 'manta_light:pytorch' and GPU=False
Swarm details: <fl_pytorch_mnist.swarm.FLSwarm object at 0x...>
Deploying swarm...
Swarm Deployment Overview:
swarm_id: swarm_a1b2c3d4e5f6
status: DEPLOYED
tasks_scheduled: 4
nodes_allocated: 2
Swarm 'swarm_a1b2c3d4e5f6' deployed successfully.
4. Monitor Real-Time Training:
The script will stream training updates:
Task Update: worker_train_001 - RUNNING - node_001
Task Update: worker_train_002 - RUNNING - node_002
Round 1 Results:
- Worker 1: Local loss 2.156, trained on 3000 samples
- Worker 2: Local loss 2.091, trained on 3000 samples
- Aggregator: Combined models from 2 workers (6000 total samples)
- Global Test Accuracy: 92.3%
- Scheduler: Continue training (threshold 99.0% not reached)
Task Update: worker_train_001 - RUNNING - node_001
Task Update: worker_train_002 - RUNNING - node_002
Round 2 Results:
- Worker 1: Local loss 1.234, improved convergence
- Worker 2: Local loss 1.189, improved convergence
- Aggregator: Updated global model
- Global Test Accuracy: 96.7%
- Scheduler: Continue training
...
Round 8 Results:
- Global Test Accuracy: 99.1%
- Scheduler: CONVERGED (threshold 99.0% reached)
- Training completed successfully!
Understanding the Results¶
Training Progression:
Typical MNIST federated learning shows this progression:
Round |
Global Accuracy |
Avg Local Loss |
Communication |
Status |
---|---|---|---|---|
1 |
92.3% |
2.123 |
2.4MB |
Starting to learn |
3 |
96.7% |
1.234 |
7.2MB |
Rapid improvement |
5 |
98.1% |
0.876 |
12.0MB |
Approaching convergence |
8 |
99.1% |
0.543 |
19.2MB |
Converged! |
Key Observations:
Rapid Initial Learning: Accuracy jumps from ~10% (random) to 90%+ in first few rounds
Gradual Convergence: Final 1-2% accuracy improvement takes several rounds
Communication Efficiency: Only ~20MB total communication for full training
Privacy Preservation: Raw MNIST images never transmitted, only model parameters
Communication Analysis:
Model Size: ~670K parameters × 4 bytes/param = 2.68MB per round
Round Communication:
- 2 workers upload model updates: 2 × 2.68MB = 5.36MB upload
- 2 workers download global model: 2 × 2.68MB = 5.36MB download
- Total per round: 10.72MB
- 8 rounds total: ~86MB communication
Compare to centralized training:
- Would need to transfer all training data: ~47MB (raw MNIST)
- Plus ongoing model checkpoints and logs
- Federated learning saves bandwidth and preserves privacy
Analyzing Privacy Preservation¶
What Stays Private:
# This data NEVER leaves each participant's node:
participant_1_data = load_mnist_partition("node_0") # Only digits 0,1
participant_2_data = load_mnist_partition("node_1") # Only digits 2,3
# Raw images, labels, intermediate activations all stay local
# Only model parameters (weights/biases) are shared
What Gets Shared:
# Only these model parameters are communicated:
shared_update = {
"model_params": compressed_weights_and_biases, # ~2.68MB
"num_samples": 3000, # Just the count, not the data
"local_loss": 1.234 # Just the final loss value
}
# No raw data, no gradients, no intermediate results
Privacy Analysis:
Input Privacy: Raw MNIST images never transmitted
Label Privacy: True labels remain on originating nodes
Gradient Privacy: Only final model parameters shared, not gradients
Inference Privacy: Difficult to reverse-engineer training data from model weights
Comparing with Centralized Training¶
Run this comparison to understand federated learning trade-offs:
# Compare federated vs centralized training
import torch
from torch.utils.data import DataLoader, ConcatDataset
# Simulate centralized training
def centralized_training():
# Combine all data (what FL avoids)
all_data = []
for i in range(num_participants):
participant_data = load_partition(f"node_{i}")
all_data.append(participant_data)
combined_dataset = ConcatDataset(all_data)
dataloader = DataLoader(combined_dataset, batch_size=64, shuffle=True)
# Train on combined data
model = MLP()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()
for epoch in range(10): # Same total training
for data, target in dataloader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
return evaluate_model(model)
Expected Comparison Results:
Approach |
Final Accuracy |
Training Time |
Privacy |
---|---|---|---|
Centralized |
99.3% |
5 minutes |
None |
Federated |
99.1% |
8 minutes |
Full |
Accuracy Loss |
-0.2% |
+60% time |
+100% privacy |
Troubleshooting Common Issues¶
Training Doesn’t Start:
# Check data preparation
ls ../../temp/partitioned/mnist/
# Should show node_0, node_1, etc. directories
# Check nodes are connected
# Verify in cluster dashboard that nodes show as "Connected"
# Check for Python errors
docker logs manager-api | grep ERROR
Poor Convergence:
# Check data distribution
import torch
for i in range(2): # For each node
data = torch.load(f"../../temp/partitioned/mnist/node_{i}/train.pt")
print(f"Node {i}: {len(data)} samples")
print(f"Classes: {torch.unique(data.targets)}")
# Should show reasonable distribution
# If one node has all data, check prepare_data.py
Slow Training:
# Check resource usage
docker stats
# Look for CPU/memory bottlenecks
# Consider reducing batch size or using GPU
Memory Issues:
# In swarm.py, reduce batch size
self.set_global("hyperparameters", {
"batch_size": 16, # Instead of 32
# ... other params
})
Communication Failures:
# Check MQTT broker
docker logs emqx
# Verify network connectivity
docker exec manager-api ping emqx
Extending the Tutorial¶
1. Try Different Data Distributions:
# Modify prepare_data.py for extreme non-IID
def create_extreme_non_iid():
# Give each participant only 1 digit class
for i in range(10):
participant_data = filter_by_class(mnist_data, class_id=i)
save_partition(participant_data, f"node_{i}")
2. Experiment with Hyperparameters:
# In swarm.py, try different settings
self.set_global("hyperparameters", {
"epochs": 2, # More local training per round
"batch_size": 64, # Larger batches
"optimizer_params": {
"lr": 0.001, # Lower learning rate
"momentum": 0.95 # Higher momentum
}
})
3. Add Custom Metrics:
# In test evaluation, add more metrics
from sklearn.metrics import precision_recall_fscore_support
def execute(self):
# ... existing evaluation code ...
# Add per-class metrics
precision, recall, f1, _ = precision_recall_fscore_support(
y_true, y_pred, average=None
)
self.world.set_result("detailed_metrics", {
"precision_per_class": precision.tolist(),
"recall_per_class": recall.tolist(),
"f1_per_class": f1.tolist()
})
4. Implement Differential Privacy:
# Add noise to model updates for enhanced privacy
def add_differential_privacy(model_params, epsilon=1.0):
noise_scale = 2.0 / epsilon
noise = np.random.laplace(0, noise_scale, model_params.shape)
return model_params + noise
# In worker task:
noisy_params = add_differential_privacy(model.get_weights())
self.world.set_result("model_update", {
"model_params": numpy_to_bytes(noisy_params),
# ... other fields
})
Next Steps¶
Congratulations! You’ve successfully completed MNIST federated learning. Next steps:
Immediate Next Steps: 1. Try CIFAR-10: CIFAR-10 tutorial - More complex image classification 2. Custom Data: Custom dataset tutorial - Use your own datasets 3. Advanced Monitoring: Real-time monitoring - Deep dive into metrics
Advanced Exploration: 4. Framework Integration: Flower integration - Use with Flower 5. Multi-Cluster: Multi-cluster deployment - Scale across multiple clusters 6. Production Deployment: Production deployment - Real-world deployment
Research Directions: 7. Algorithm Variations: Implement FedProx, FedNova, or other FL algorithms 8. Privacy Enhancements: Add secure aggregation and differential privacy 9. Robustness: Test with node failures and byzantine participants
You now understand the fundamentals of federated learning and have hands-on experience with Manta’s federated learning platform!