Overview
Generative AI-powered federated learning framework using knowledge graphs to identify potential drug candidates for rare metabolic disorders while preserving patient data privacy across multiple medical institutions.
The Problem
Drug repurposing for rare diseases faces critical challenges: limited patient data scattered across institutions, high costs of traditional drug discovery ($2.6B average), privacy regulations preventing data sharing, and slow approval processes. Rare disease patients often wait years for potential treatments while existing drugs that could help remain undiscovered.
The Solution
Developed a novel Federated Learning framework that enables collaborative AI training across multiple medical institutions without sharing raw patient data. The system uses Graph Neural Networks (R-GCN) to build comprehensive knowledge graphs from biomedical ontologies, then applies Generative AI (VAE-based models) to predict drug-disease relationships. Differential privacy mechanisms ensure complete patient confidentiality while achieving high prediction accuracy.
Project Gallery
Technical Architecture
Multi-institutional federated learning system with privacy-preserving knowledge graph integration
Federated Learning Coordinator
Orchestrates distributed training across 3 medical institutions, managing model synchronization and aggregation without centralizing patient data.
Knowledge Graph Builder (R-GCN)
Constructs multi-relational graphs from biomedical databases (DrugBank, KEGG, Gene Ontology) to capture complex drug-disease-gene relationships.
Generative Model (VAE)
Variational Autoencoder learns latent representations of drug-disease pairs to predict novel repurposing candidates.
Privacy-Preserving Aggregator
Implements differential privacy and secure aggregation to protect individual patient records during model training.
Methodology
- Data collection from 3 medical institutions (anonymized patient records, genomic data, treatment outcomes)
- Knowledge graph construction using biomedical ontologies (DrugBank, KEGG, Gene Ontology, DisGeNET)
- Federated training with differential privacy (ε=1.0, δ=10⁻⁵)
- Drug-disease link prediction using Graph Neural Networks
- Validation through molecular docking simulations and literature review
- Clinical expert review of top candidates
Results & Impact
Key Impact
- Published in top-tier bioinformatics journal (Impact Factor: 8.7)
- Identified 3 FDA-approved drugs with potential for rare metabolic disorders
- Framework adopted by multi-institutional research consortium
- Reduced drug discovery timeline from years to months
- Enabled collaboration while maintaining HIPAA compliance
Challenges & Solutions
Heterogeneous Data Formats
Developed standardized ETL pipeline with FHIR compliance for cross-institutional compatibility
Communication Overhead
Implemented gradient compression and asynchronous updates to reduce bandwidth by 60%
Model Convergence
Adaptive learning rates and FedProx algorithm to handle non-IID data distributions
Key Implementation
Federated Training Loop
def federated_train(clients, global_model, rounds=10):
"""
Federated learning with differential privacy
"""
for round in range(rounds):
client_updates = []
# Local training at each institution
for client in clients:
local_model = train_local(
client_data=client.get_data(),
global_model=global_model,
epochs=5,
privacy_budget=1.0
)
# Add noise for differential privacy
noisy_update = add_gaussian_noise(
local_model.state_dict(),
sensitivity=0.1,
epsilon=1.0
)
client_updates.append(noisy_update)
# Secure aggregation
global_model = federated_averaging(client_updates)
# Evaluate on validation set
metrics = evaluate_model(global_model, validation_data)
print(f"Round {round}: AUC={metrics['auc']:.3f}")
return global_model
Knowledge Graph Construction
class KnowledgeGraphBuilder:
def __init__(self, ontologies):
self.graph = nx.MultiDiGraph()
self.ontologies = ontologies
def build_graph(self):
# Add drug nodes
drugs = self.load_drugbank()
for drug in drugs:
self.graph.add_node(
drug.id,
type='drug',
name=drug.name,
smiles=drug.smiles
)
# Add disease nodes
diseases = self.load_disgenet()
for disease in diseases:
self.graph.add_node(
disease.id,
type='disease',
name=disease.name
)
# Add relationships
self.add_drug_target_edges()
self.add_disease_gene_edges()
self.add_drug_disease_edges()
return self.graph