Skip to content

Federated Learning

Under the Hood 6 min read

In Short

Federated learning trains a shared ML model across many devices or institutions without any device ever sending its raw data to a central server. Local devices train on their own data and transmit only model updates, which are aggregated into a global model. It enables collaborative learning under strong privacy constraints but introduces new challenges around communication cost, data heterogeneity, and security.

01. What It Is

Federated learning is a decentralized training approach in which a global model is improved by aggregating parameter updates computed locally on each participating node, rather than by collecting all data in one place. The term was introduced by McMahan et al. at Google in 2016 (arXiv:1602.05629) and first deployed in Gboard, the Google Keyboard for Android, in 2017.

The key architectural property: raw data never leaves the device or silo where it was generated. Only model weights, gradients, or parameter deltas are transmitted, and only after local training has occurred. The central server sees no individual training record.

Federated learning divides into two primary deployment patterns:

Cross-device:
Millions of consumer devices (phones, tablets, IoT sensors) participate intermittently. Each device holds a small local dataset. Devices participate only when idle, plugged in, and on a free wireless connection. Scale is massive but individual device resources are constrained and connectivity is unreliable.

Cross-silo:
A small number of institutions (hospitals, banks, research consortia) participate, each holding a large, well-maintained local dataset. Connectivity and compute are reliable. The main concern is inter-institutional privacy and regulatory compliance rather than resource constraints.

02. Why It Matters

Privacy-preserving collaboration:
Medical institutions can jointly train a diagnostic model on patient data from multiple hospitals without any hospital sharing identifiable patient records with another. Financial institutions can train a shared fraud-detection model across transaction datasets that are legally prohibited from being combined. Google can improve Gboard's language model using what users actually type without any keystrokes leaving the phone.

Regulatory compliance:
GDPR (European Union), HIPAA (US healthcare), and equivalent frameworks restrict cross-border or inter-institutional data transfer. Federated learning can satisfy these requirements because the regulated data (the raw records) never moves.

Latency and bandwidth:
In cross-device settings, the model on the device can be continuously improved without uploading the raw sensor stream. Only a compressed update, orders of magnitude smaller than raw data, is transmitted.

03. How It Works

The standard training loop:

  1. Initialization:
    A central server initializes a global model and distributes it to participating nodes. Configuration (hyperparameters, number of local epochs, batch size) is sent alongside.

  2. Local training:
    Each node trains the global model on its local dataset for one or more epochs using standard gradient-based optimization. No raw data is transmitted.

  3. Update transmission:
    Each node sends its updated model parameters or gradients back to the server. In many implementations, a compressed or quantized delta (the difference between the local model and the received global model) is sent instead of full weights.

  4. Global aggregation:
    The server aggregates updates from all participating nodes. The standard method is Federated Averaging (FedAvg), which computes a weighted average of node updates where the weight is proportional to the number of local training samples. This is described in McMahan et al. (2016), which showed a 10-100x reduction in required communication rounds compared to naively federated SGD.

  5. Iteration:
    The updated global model is redistributed and the cycle repeats until convergence.

04. Key Terms / Methods

  • FedAvg (Federated Averaging). The foundational aggregation algorithm. Each round, a subset of nodes is sampled, trained locally, and their parameter averages (weighted by dataset size) are incorporated into the global model.
  • Differential Privacy (DP). Calibrated Gaussian or Laplacian noise is added to each node's gradient update before transmission. This provides a formal mathematical guarantee that the server cannot determine whether any particular record was included in a node's local dataset. McMahan et al. (2018, arXiv:1710.06963) showed differentially private recurrent language models can be trained with negligible accuracy loss given a sufficiently large number of participating users.
  • Secure Aggregation. A cryptographic protocol in which the server can compute the average of node updates without being able to inspect any individual node's update. Google deployed a secure aggregation protocol alongside FedAvg for Gboard, using threshold secret sharing so that no single update could be decrypted unless enough participants contributed.
  • FedProx. A variant of FedAvg that adds a proximal term to each node's local objective, limiting how far local updates can drift from the global model. Improves convergence in heterogeneous networks where nodes have different data distributions and compute capabilities.
  • Horizontal federated learning. Nodes share the same feature schema but different samples (e.g., multiple clinics each with patient records of the same variables but different patients).
  • Vertical federated learning. Nodes share overlapping samples but different feature sets (e.g., a bank and a retailer both have records for the same customers, but one holds financial features and the other holds purchase history).
  • Non-IID data. Data that is not independently and identically distributed across nodes. In cross-device settings, each user's data reflects their individual behavior, which differs systematically from other users. Non-IID data causes FedAvg to converge more slowly and to worse optima than centralized training.
  • Gradient inversion attacks. Research has shown that shared gradients can sometimes be used to reconstruct training samples (e.g., Zhu et al., 2019). This demonstrates that transmitting gradients is not inherently private and motivates differential privacy and secure aggregation as additional defenses.

05. Examples

  • Gboard. Google's keyboard uses federated learning to improve next-word prediction and query suggestion models from actual user typing, without sending keystrokes to Google's servers. This was the first large-scale production deployment of federated learning.
  • Apple Siri and keyboard. Apple uses on-device learning with differential privacy for features like emoji suggestions and QuickType, following a privacy architecture similar to federated learning.
  • Healthcare consortia. The US National Cancer Institute's Cancer Research Data Commons and similar consortia use federated analysis frameworks to query across hospital datasets without centralizing patient records.
  • NVIDIA FLARE. An open-source FL framework from NVIDIA used in clinical AI research to train radiology models across hospitals.
  • TensorFlow Federated (TFF). Google's open-source library for FL research, providing both a high-level API for FL training tasks and a low-level API for defining new aggregation algorithms.

06. Common Pitfalls / Misconceptions

Federated learning is not inherently private:
Transmitting gradients leaks information. Gradient inversion attacks can reconstruct training images or text from shared gradients, especially with large batch sizes. Differential privacy and secure aggregation are required to provide meaningful privacy guarantees, and both add computational and accuracy costs.

All nodes have equal contribution:
FedAvg weights updates by local dataset size. Nodes with large datasets dominate the global update. In cross-silo settings, this can cause the global model to be biased toward the largest institution's data distribution. Fairness-aware aggregation methods exist but are not the default.

Federated learning works on any data:
Non-IID data is the core technical challenge. When user behaviors differ substantially across nodes (e.g., medical images from a children's hospital differ systematically from those of a geriatric ward), the federated model may perform worse than a model trained only on one institution's centralized data. Techniques like FedProx and clustered federated learning partially address this.

Federated learning eliminates the need for a trusted central party:
The central server is still a point of trust and a potential attack surface. It sees aggregated updates, orchestrates the training schedule, and controls the global model. Fully decentralized peer-to-peer approaches exist but are less common in practice.

Communication overhead is solved:
McMahan et al. showed 10-100x communication reduction over naive SGD. But a modern transformer with billions of parameters still requires transmitting hundreds of megabytes per round. Cross-device deployments require careful scheduling (idle, plugged-in, unmetered connection) and gradient compression techniques to make this practical.