保姆级教程:用Python和PyTorch Geometric复现一篇GNN交通预测顶会论文(附完整代码)
2026/5/13 11:25:09 网站建设 项目流程

从论文到实践:用PyTorch Geometric实现GNN交通流量预测全流程指南

交通流量预测一直是智慧城市和智能交通系统研究的核心课题。传统的统计方法和机器学习模型在处理复杂的时空依赖性时往往力不从心,而图神经网络(GNN)因其天然的图结构建模能力,正在这一领域展现出革命性的潜力。本文将带您从零开始,完整复现一篇典型的GNN交通预测论文,使用PyTorch Geometric框架实现从数据准备到模型部署的全过程。

1. 环境准备与数据获取

1.1 搭建Python开发环境

首先需要配置适合深度学习的工作环境。推荐使用conda创建独立的Python环境:

conda create -n gnn-traffic python=3.8 conda activate gnn-traffic pip install torch torch-geometric torch-scatter torch-sparse torch-cluster -f https://data.pyg.org/whl/torch-1.10.0+cu113.html pip install pandas numpy matplotlib scikit-learn

对于GPU加速,需确保安装对应CUDA版本的PyTorch。torch-geometric的安装需要额外安装torch-scatter等依赖,版本需严格匹配。

1.2 获取交通数据集

PeMS(Performance Measurement System)是交通预测研究中最常用的公开数据集之一,包含加州高速公路传感器网络采集的交通流量、速度等数据。我们将使用PeMSD4数据集,包含旧金山湾区29个站点的3个月数据。

import os import pandas as pd # 数据下载与解压 data_url = "https://storage.googleapis.com/traffic-prediction-data/PeMSD4.zip" os.system(f"wget {data_url} && unzip PeMSD4.zip") # 加载数据 flow_data = pd.read_csv("PeMSD4/flow.csv", header=None) speed_data = pd.read_csv("PeMSD4/speed.csv", header=None)

数据集包含两个关键文件:

  • flow.csv: 每5分钟记录的交通流量(车辆数)
  • speed.csv: 对应时间点的平均车速(mph)

2. 构建交通图结构

2.1 定义图节点与特征

在GNN中,每个传感器站点将作为图的一个节点。我们需要为每个节点构建特征矩阵:

import numpy as np # 节点数量 num_nodes = 29 # 时间步长(5分钟间隔) timesteps = flow_data.shape[1] # 构建特征矩阵 (num_nodes, timesteps, 2) # 最后一个维度包含流量和速度两个特征 node_features = np.zeros((num_nodes, timesteps, 2)) node_features[:,:,0] = flow_data.values node_features[:,:,1] = speed_data.values

2.2 构建邻接矩阵

邻接矩阵定义节点间的空间关系。常用的构建方法包括:

方法类型计算公式特点
距离矩阵$A_{ij} = \exp(-\frac{d_{ij}^2}{\sigma^2})$基于地理距离,σ控制衰减速率
相关性矩阵$A_{ij} = \text{corr}(X_i, X_j)$基于历史流量模式相似性
混合矩阵$A_{ij} = \alpha A_{ij}^{\text{dist}} + (1-\alpha)A_{ij}^{\text{corr}}$结合多种信息源

以下是基于距离构建邻接矩阵的代码实现:

from sklearn.metrics.pairwise import rbf_kernel # 加载站点坐标 locations = pd.read_csv("PeMSD4/graph_sensor_locations.csv") coords = locations[["latitude", "longitude"]].values # 计算距离矩阵 dist_matrix = np.zeros((num_nodes, num_nodes)) for i in range(num_nodes): for j in range(num_nodes): dist_matrix[i,j] = haversine(coords[i], coords[j]) # 转换为邻接矩阵(RBF核) sigma = 0.1 # 控制衰减速率 adj_matrix = rbf_kernel(dist_matrix, gamma=1./(2.*sigma**2)) np.fill_diagonal(adj_matrix, 0) # 移除自连接

3. 实现GNN预测模型

3.1 设计模型架构

我们将实现一个典型的时空图神经网络(STGNN),包含空间和时间两个维度的建模:

STGNN架构: 1. 空间模块:图卷积网络(GCN)捕获站点间空间依赖 2. 时间模块:门控循环单元(GRU)处理时间序列模式 3. 预测层:全连接网络输出未来流量预测

PyTorch Geometric实现代码如下:

import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import GCNConv from torch_geometric.utils import dense_to_sparse class STGNN(nn.Module): def __init__(self, num_nodes, input_dim, hidden_dim, output_dim, seq_len): super(STGNN, self).__init__() self.num_nodes = num_nodes self.seq_len = seq_len # 空间卷积层 self.gcn1 = GCNConv(input_dim, hidden_dim) self.gcn2 = GCNConv(hidden_dim, hidden_dim) # 时间循环层 self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True) # 预测层 self.fc = nn.Linear(hidden_dim, output_dim) def forward(self, x, edge_index, edge_weight): # x形状: (batch_size, seq_len, num_nodes, input_dim) batch_size = x.size(0) x = x.permute(0, 2, 1, 3) # (batch, nodes, seq, features) # 空间卷积 h = [] for t in range(self.seq_len): xt = x[:,:,t,:].reshape(-1, x.size(-1)) # (batch*nodes, features) xt = F.relu(self.gcn1(xt, edge_index, edge_weight)) xt = F.relu(self.gcn2(xt, edge_index, edge_weight)) h.append(xt.view(batch_size, self.num_nodes, -1)) # 堆叠时间维度 h = torch.stack(h, dim=1) # (batch, seq, nodes, hidden) # 时间建模 h = h.permute(0, 2, 1, 3) # (batch, nodes, seq, hidden) h = h.reshape(batch_size*self.num_nodes, self.seq_len, -1) _, h = self.gru(h) # 使用最后隐藏状态 h = h.squeeze(0).view(batch_size, self.num_nodes, -1) # 预测 out = self.fc(h) # (batch, nodes, output_dim) return out

3.2 数据预处理与加载

GNN需要特殊的数据加载方式,PyTorch Geometric提供了专用的DataLoader:

from torch_geometric.data import Data, Dataset from torch.utils.data import DataLoader class TrafficDataset(Dataset): def __init__(self, node_features, adj_matrix, seq_len=12, pred_len=3): self.node_features = node_features # (nodes, total_timesteps, 2) self.adj_matrix = adj_matrix self.seq_len = seq_len # 历史时间步数 self.pred_len = pred_len # 预测时间步数 self.edge_index, self.edge_weight = dense_to_sparse( torch.FloatTensor(adj_matrix)) def __len__(self): return self.node_features.shape[1] - self.seq_len - self.pred_len + 1 def __getitem__(self, idx): x = self.node_features[:, idx:idx+self.seq_len, :] # (nodes, seq, 2) y = self.node_features[:, idx+self.seq_len:idx+self.seq_len+self.pred_len, 0] # 预测流量 # 转换为PyG的Data对象 x = torch.FloatTensor(x).permute(1, 0, 2) # (seq, nodes, 2) y = torch.FloatTensor(y).permute(1, 0) # (pred_len, nodes) return x, y

4. 模型训练与调优

4.1 训练流程实现

完整的训练循环需要考虑GNN的特殊性,如邻接矩阵的处理:

def train(model, dataloader, optimizer, device): model.train() total_loss = 0 for x, y in dataloader: x = x.to(device) # (batch, seq, nodes, features) y = y.to(device) # (batch, pred_len, nodes) # 获取边缘索引和权重 edge_index = dataset.edge_index.to(device) edge_weight = dataset.edge_weight.to(device) # 前向传播 pred = model(x, edge_index, edge_weight) # (batch, nodes, pred_len) pred = pred.permute(0, 2, 1) # (batch, pred_len, nodes) # 计算损失 loss = F.mse_loss(pred, y) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() return total_loss / len(dataloader)

4.2 超参数优化策略

GNN模型对超参数敏感,建议采用以下调优策略:

  1. 学习率调度:使用ReduceLROnPlateau动态调整学习率
  2. 早停机制:验证集性能不再提升时停止训练
  3. 正则化技术
    • 图Dropout:随机丢弃部分边
    • 权重衰减:L2正则化
  4. 关键超参数范围
参数建议范围影响
GCN层数2-3层过多会导致过平滑
隐藏维度32-256影响模型容量
历史序列长度6-24(对应30-120分钟)捕获时间依赖性
RBF核σ0.05-0.5控制空间影响范围

实现学习率调度和早停的代码示例:

from torch.optim.lr_scheduler import ReduceLROnPlateau # 初始化 model = STGNN(num_nodes=29, input_dim=2, hidden_dim=64, output_dim=3, seq_len=12).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4) scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5) best_val_loss = float('inf') patience = 10 counter = 0 for epoch in range(100): train_loss = train(model, train_loader, optimizer, device) val_loss = evaluate(model, val_loader, device) scheduler.step(val_loss) # 早停逻辑 if val_loss < best_val_loss: best_val_loss = val_loss counter = 0 torch.save(model.state_dict(), 'best_model.pth') else: counter += 1 if counter >= patience: print(f"Early stopping at epoch {epoch}") break

5. 结果分析与可视化

5.1 评估指标计算

交通预测常用三种评估指标:

  1. MAE(平均绝对误差): $$ \text{MAE} = \frac{1}{n}\sum_{i=1}^n |y_i - \hat{y}_i| $$

  2. RMSE(均方根误差): $$ \text{RMSE} = \sqrt{\frac{1}{n}\sum_{i=1}^n (y_i - \hat{y}_i)^2} $$

  3. MAPE(平均绝对百分比误差): $$ \text{MAPE} = \frac{100%}{n}\sum_{i=1}^n \left|\frac{y_i - \hat{y}_i}{y_i}\right| $$

实现代码:

def compute_metrics(y_true, y_pred): mae = torch.mean(torch.abs(y_true - y_pred)) rmse = torch.sqrt(torch.mean((y_true - y_pred)**2)) mape = torch.mean(torch.abs((y_true - y_pred) / (y_true + 1e-5))) * 100 # 避免除零 return mae, rmse, mape

5.2 预测结果可视化

使用Matplotlib绘制真实值与预测值的对比:

import matplotlib.pyplot as plt def plot_predictions(model, dataloader, node_idx=0, timesteps=24): model.eval() x, y_true = next(iter(dataloader)) with torch.no_grad(): y_pred = model(x.to(device), dataset.edge_index.to(device), dataset.edge_weight.to(device)) # 选择特定节点的预测结果 y_true = y_true[:, :, node_idx].cpu().numpy() # (batch, pred_len) y_pred = y_pred[:, node_idx, :].cpu().numpy().T # (pred_len, batch) # 绘制前timesteps个时间点 plt.figure(figsize=(12, 6)) plt.plot(y_true[:timesteps].flatten(), label="True Flow") plt.plot(y_pred[:timesteps].flatten(), label="Predicted Flow") plt.xlabel("Time (5-min intervals)") plt.ylabel("Traffic Flow") plt.title(f"Traffic Flow Prediction at Node {node_idx}") plt.legend() plt.grid() plt.show()

6. 进阶优化技巧

6.1 动态图卷积改进

静态邻接矩阵无法反映交通关系的时变性。我们可以实现动态图卷积:

class DynamicGCNConv(nn.Module): def __init__(self, input_dim, output_dim): super().__init__() self.weight = nn.Parameter(torch.Tensor(input_dim, output_dim)) self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.weight) def forward(self, x, adj): # x: (batch, nodes, features) # adj: (batch, nodes, nodes) 动态邻接矩阵 support = torch.matmul(x, self.weight) output = torch.matmul(adj, support) return output

6.2 多任务学习框架

同时预测流量和速度可以提升模型泛化能力:

class MultiTaskSTGNN(nn.Module): def __init__(self, num_nodes, input_dim, hidden_dim, seq_len): super().__init__() # 共享的时空编码器 self.encoder = STGNNEncoder(num_nodes, input_dim, hidden_dim, seq_len) # 任务特定头 self.flow_head = nn.Linear(hidden_dim, 1) self.speed_head = nn.Linear(hidden_dim, 1) def forward(self, x, edge_index, edge_weight): h = self.encoder(x, edge_index, edge_weight) flow = self.flow_head(h) speed = self.speed_head(h) return flow, speed

6.3 部署优化建议

将训练好的模型投入实际应用时需考虑:

  1. 模型轻量化

    • 知识蒸馏:用大模型训练小模型
    • 量化:FP16或INT8量化减少内存占用
  2. 增量学习

    def incremental_update(model, new_data, lr=0.001, steps=100): optimizer = torch.optim.SGD(model.parameters(), lr=lr) for _ in range(steps): loss = train_step(model, new_data, optimizer) if loss < 0.001: break return model
  3. 边缘计算部署

    • 使用TorchScript导出模型
    • 在边缘设备上使用ONNX Runtime推理

7. 常见问题与解决方案

在实际复现过程中,可能会遇到以下典型问题:

问题1:内存不足

  • 现象:训练时GPU内存溢出
  • 解决方案:
    • 减小batch size
    • 使用torch.utils.checkpoint进行梯度检查点
    • 简化模型结构

问题2:过拟合

  • 现象:训练损失下降但验证损失上升
  • 解决方案:
    • 增加图Dropout
    class GraphDropout(nn.Module): def __init__(self, p=0.5): super().__init__() self.p = p def forward(self, edge_index, edge_weight): if self.training: mask = torch.rand(edge_weight.size()) > self.p return edge_index, edge_weight * mask.float() return edge_index, edge_weight
    • 添加更多的训练数据
    • 使用更严格的L2正则化

问题3:预测结果滞后

  • 现象:预测曲线与真实曲线形状相似但存在时移
  • 解决方案:
    • 增加历史序列长度
    • 在损失函数中加入差分惩罚项
    def time_aware_loss(y_true, y_pred, alpha=0.1): mse = F.mse_loss(y_true, y_pred) diff_loss = F.mse_loss(y_pred[:,1:]-y_pred[:,:-1], y_true[:,1:]-y_true[:,:-1]) return mse + alpha * diff_loss

问题4:边缘权重不稳定

  • 现象:模型对邻接矩阵非常敏感
  • 解决方案:
    • 使用注意力机制动态学习边缘权重
    class EdgeLearner(nn.Module): def __init__(self, node_dim): super().__init__() self.attn = nn.Linear(2*node_dim, 1) def forward(self, x): # x: (nodes, features) nodes = x.size(0) x_i = x.unsqueeze(1).expand(-1, nodes, -1) x_j = x.unsqueeze(0).expand(nodes, -1, -1) pair = torch.cat([x_i, x_j], dim=-1) weights = torch.sigmoid(self.attn(pair)).squeeze(-1) return weights

在实际项目中,我们通常需要多次迭代优化才能获得理想效果。建议从简单模型开始,逐步增加复杂度,同时使用版本控制工具记录每次实验的配置和结果。

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询