导读(Introduction)
欢迎来到 Apache Airflow 源码深度解析系列的第34课。
在上一课中,我们构建了一个完整的企业级 ETL 平台,涵盖了多层数据仓库、多团队协作和监控告警。本课将目光转向另一个高价值场景——机器学习管道编排(ML Pipeline Orchestration)。
机器学习工程与传统数据工程有本质区别:数据管道追求的是确定性和幂等性——同样的输入永远产生同样的输出;而 ML 管道充满了不确定性和实验性——同一份数据配合不同超参数,可能产生完全不同的模型质量。这种差异深刻影响着管道的设计方式。
Airflow 3.x 提供了一系列强大的原语来应对 ML 场景的独特挑战:
- Dynamic Task Mapping(动态任务映射)——在运行时动态创建任意数量的并行任务,完美匹配超参搜索的大规模并行实验需求
- TaskGroup——将训练、评估、部署等阶段组织为逻辑分组,保持 DAG 可读性
- XCom——在任务间传递模型指标、最优参数等轻量数据
- Setup/Teardown——管理 GPU 计算资源的生命周期,确保无论实验成功与否,昂贵资源都能被正确释放
- Asset 驱动调度——当训练数据更新时自动触发模型重训练
通过本课的实战,你将掌握如何将 Airflow 打造为一个完整的MLOps 编排平台。
学习目标(Learning Objectives)
完成本课学习后,你将能够:
- 理解 ML Pipeline 与数据管道的核心区别——明确实验性、非确定性对管道设计的影响
- 设计模型训练/评估/部署的工作流架构——合理划分 ML 管道的各个阶段
- 掌握 Dynamic Task Mapping 在超参搜索中的应用——利用
.expand()实现大规模并行实验 - 运用 TaskGroup 组织复杂 ML 工作流——保持管道的结构清晰性和可维护性
- 通过 XCom 传递和比较模型指标——实现自动模型选择和指标追踪
- 利用 Setup/Teardown 管理 GPU 资源——确保昂贵计算资源的生命周期安全
- 实现 Asset 驱动的自动模型重训练——数据变更自动触发 ML Pipeline
正文内容(Main Content)
1. ML Pipeline 与数据管道的本质区别
1.1 传统数据管道的特征
传统 ETL/数据管道具有以下特性:
| 特征 | 表现 |
|---|---|
| 确定性 | 同一输入 → 同一输出 |
| 幂等性 | 重复执行结果相同 |
| 单路径 | 数据沿固定路径流转 |
| 成功标准明确 | 数据完整性 + Schema 校验 |
| 资源可预测 | CPU/内存需求相对稳定 |
1.2 ML 管道的独特挑战
机器学习管道面临着截然不同的工程挑战:
| 特征 | 表现 | 对管道设计的影响 |
|---|---|---|
| 实验性 | 需要尝试多组超参数 | 需要动态并行能力 |
| 非确定性 | 同一超参可能产生不同结果 | 需要指标追踪和比较 |
| 多路径分支 | 不同实验产生不同模型 | 需要动态汇聚和选择 |
| 成功标准模糊 | "够好"取决于业务阈值 | 需要质量门控和人工审核 |
| 资源异构 | GPU/TPU 高成本资源 | 需要精细的资源生命周期管理 |
| 阶段耦合松散 | 训练/评估/部署可独立迭代 | 需要阶段化组织 |
1.3 ML Pipeline 的典型阶段
┌─────────────────────────────────────────────────────────────────────┐ │ ML Pipeline 生命周期 │ ├─────────────────────────────────────────────────────────────────────┤ │ │ │ ┌──────────┐ ┌──────────┐ ┌────────────┐ ┌──────────────┐ │ │ │ 数据准备 │──▶│ 特征工程 │──▶│ 模型训练 │──▶│ 模型评估 │ │ │ │ │ │ │ │ (多组实验) │ │ (指标比较) │ │ │ └──────────┘ └──────────┘ └────────────┘ └──────────────┘ │ │ │ │ │ │ │ N组并行 │ 选最优 │ │ ▼ ▼ │ │ ┌──────────────┐ ┌──────────────┐ │ │ │ 超参搜索空间 │ │ 模型注册 │ │ │ │ (Grid/Random) │ │ (Model Reg.) │ │ │ └──────────────┘ └──────────────┘ │ │ │ │ │ ▼ │ │ ┌──────────────┐ │ │ │ 模型部署 │ │ │ │ (Serving) │ │ │ └──────────────┘ │ │ │ ├─────────────────────────────────────────────────────────────────────┤ │ 横切关注点:GPU 资源管理 │ 实验追踪 │ 数据版本控制 │ 模型血缘 │ └─────────────────────────────────────────────────────────────────────┘1.4 Airflow 组件与 ML 阶段的映射
| ML 阶段 | Airflow 组件 | 选择理由 |
|---|---|---|
| 数据准备 | @task+ Asset | 利用 Asset 事件驱动重训练 |
| 超参搜索 | Dynamic Task Mapping (.expand()) | 运行时动态创建 N 组并行实验 |
| 模型训练 | @task+ GPU Pool | 绑定 GPU 资源池,精细控制并发 |
| 模型评估 | @task+ XCom | 通过 XCom 传递指标并比较 |
| 模型选择 | Reduce 任务 | 汇总所有实验结果,选择最优 |
| 模型部署 | @task+ 质量门控 | 只有超越基线的模型才能部署 |
| 资源管理 | Setup/Teardown | 确保 GPU 资源的可靠释放 |
| 阶段组织 | TaskGroup | 训练/评估/部署分组,保持清晰 |
2. Dynamic Task Mapping:超参搜索的并行引擎
2.1 核心原理
Dynamic Task Mapping 允许在 DAG运行时(而非解析时)动态确定任务实例数量。这是通过MappedOperator实现的——一种特殊的算子代理对象,在 Scheduler 调度时根据上游 XCom 产出的列表长度,动态创建对应数量的 TaskInstance。
从源码层面看,MappedOperator定义在 mappedoperator.py:
classMappedOperator:"""Object representing a mapped operator in a Dag."""operator_class:type[BaseOperator]expand_input:ExpandInput# 存储映射配置partial_kwargs:dict[str,Any]# 非映射参数(固定值)_needs_expansion:bool=True# 标记需要展开def__repr__(self):returnf"<Mapped({self.task_type}):{self.task_id}>"其运行时展开的元数据存储在数据库的task_map表中(taskmap.py):
classTaskMap(TaskInstanceDependencies):"""Model to track dynamic task-mapping information."""__tablename__="task_map"dag_id:Mapped[str]task_id:Mapped[str]run_id:Mapped[str]map_index:Mapped[int]# 每个映射实例的索引length:Mapped[int]# 总映射数量keys:Mapped[list|None]# 可选的命名映射键当上游任务产出一个列表(如超参组合列表),TaskMap记录该列表的长度,Scheduler 随后创建对应数量的 TaskInstance,每个实例接收列表中的一个元素。
2.2 两种映射模式
Airflow 支持两种输入映射方式,定义在 expandinput.py:
模式1:DictOfListsExpandInput —— 参数维度展开
# 对单个参数的值列表进行映射@taskdeftrain_model(learning_rate:float):...# 生成3个 TaskInstance,分别接收 0.01, 0.001, 0.0001train_model.expand(learning_rate=[0.01,0.001,0.0001])模式2:ListOfDictsExpandInput —— 完整参数组展开
# 对完整参数字典列表进行映射@taskdeftrain_model(config:dict):...# 每个字典是一组完整的实验配置train_model.expand_kwargs([{"config":{"lr":0.01,"batch_size":32}},{"config":{"lr":0.001,"batch_size":64}},{"config":{"lr":0.0001,"batch_size":128}},])2.3 实战:超参网格搜索
""" ml_pipeline/dags/hyperparameter_search.py 使用 Dynamic Task Mapping 实现超参网格搜索 """from__future__importannotationsfromdatetimeimportdatetimefromairflow.sdkimportDAG,taskwithDAG(dag_id="ml_hyperparameter_search",schedule=None,start_date=datetime(2024,1,1),catchup=False,tags=["ml","training"],):@taskdefgenerate_search_space():""" 生成超参搜索空间:Grid Search 策略 返回所有超参组合列表(运行时确定映射数量) """importitertools param_grid={"learning_rate":[0.01,0.001,0.0001],"batch_size":[32,64,128],"hidden_dim":[128,256],"dropout":[0.1,0.3,0.5],}# 笛卡尔积生成所有组合keys=param_grid.keys()values=param_grid.values()combinations=[dict(zip(keys,combo))forcomboinitertools.product(*values)]print(f"Generated{len(combinations)}hyperparameter combinations")returncombinations# 返回列表 → Scheduler 创建对应数量的 TaskInstance@taskdeftrain_single_experiment(params:dict):""" 单次训练实验:每个 TaskInstance 处理一组超参 通过 Dynamic Task Mapping,N 组超参自动创建 N 个并行实例 """importrandomimporttime# 模拟训练过程print(f"Training with params:{params}")time.sleep(2)# 模拟训练耗时# 模拟产生训练指标(实际场景中这里是真实的模型训练)metrics={"params":params,"accuracy":random.uniform(0.75,0.95),"f1_score":random.uniform(0.70,0.92),"loss":random.uniform(0.1,0.5),"training_time_seconds":random.uniform(60,300),}print(f"Experiment result: accuracy={metrics['accuracy']:.4f}")returnmetrics@taskdefselect_best_model(all_results:list[dict]):""" 汇总所有实验结果,选择最优模型 这是 Map-Reduce 模式中的 Reduce 步骤 """# 按 accuracy 排序选择最优sorted_results=sorted(all_results,key=lambdax:x["accuracy"],reverse=True)best=sorted_results[0]print(f"Best model: accuracy={best['accuracy']:.4f}, params={best['params']}")print(f"Total experiments:{len(all_results)}")# 返回 Top-3 结果供后续评估return{"best_params":best["params"],"best_accuracy":best["accuracy"],"top_3":sorted_results[:3],"total_experiments":len(all_results),}# 编排:生成搜索空间 → 并行训练 → 汇总选择search_space=generate_search_space()# ⭐ 关键:.expand() 将列表中的每个元素映射为独立的 TaskInstanceexperiment_results=train_single_experiment.expand(params=search_space)# 汇聚所有并行实验的结果select_best_model(experiment_results)在这个 DAG 中:
generate_search_space返回一个包含 54 组超参的列表(3×3×2×3)train_single_experiment.expand(params=search_space)在运行时创建 54 个 TaskInstance- 所有实验并行执行(受限于 Pool 和 Executor 容量)
select_best_model等待所有实验完成后汇总结果
2.4 二阶映射:链式动态任务
从官方示例 example_dynamic_task_mapping.py 可以看到,Dynamic Task Mapping 支持链式调用——一个 Mapped Task 的输出可以作为下一个 Mapped Task 的输入:
@taskdefget_nums():return[1,2,3]@taskdeftimes_2(num):returnnum*2@taskdefadd_10(num):returnnum+10_get_nums=get_nums()_times_2=times_2.expand(num=_get_nums)# 3个实例add_10.expand(num=_times_2)# 同样3个实例这种模式在 ML 场景中非常有用——例如先并行训练多个模型,再并行评估每个模型的多个指标。
3. TaskGroup:组织 ML 工作流的阶段结构
3.1 TaskGroup 的设计哲学
ML Pipeline 通常包含多个清晰的阶段:数据准备 → 特征工程 → 训练 → 评估 → 部署。TaskGroup提供了逻辑分组能力,让复杂的 DAG 在 UI 中保持可读性。
TaskGroup 定义在 taskgroup.py,它是一个容器节点,管理子任务的集合并处理组级别的依赖关系:
classTaskGroup:_group_id:str|None# 组标识符group_display_name:str|None# UI 显示名prefix_group_id:bool# 是否给子任务 ID 加组前缀children:dict[str,DAGNode]# 子节点(Task 或嵌套 TaskGroup)upstream_group_ids:set[str]# 上游组依赖downstream_group_ids:set[str]# 下游组依赖使用@task_group装饰器(task_group.py)可以用函数式风格定义 TaskGroup,并且支持.expand()进行动态映射。
3.2 TaskGroup 与 Dynamic Task Mapping 结合
从示例 example_dynamic_task_mapping.py 可以看到,TaskGroup 也支持.expand():
@task_groupdefop(num):@taskdefadd_1(num):returnnum+1@taskdefmul_2(num):returnnum*2returnmul_2(add_1(num))# 整个 TaskGroup 展开3次——每次包含 add_1 和 mul_2 两个任务op.expand(num=[1,2,3])这意味着我们可以将整个训练流程(数据加载→训练→验证)封装为一个 TaskGroup,然后对不同超参组合并行展开。
3.3 实战:阶段化 ML 管道
""" ml_pipeline/dags/staged_ml_pipeline.py 使用 TaskGroup 组织训练/评估/部署三阶段 """from__future__importannotationsfromdatetimeimportdatetimefromairflow.sdkimportDAG,task,task_groupwithDAG(dag_id="ml_staged_pipeline",schedule=None,start_date=datetime(2024,1,1),catchup=False,tags=["ml","staged"],):# ========================================# 阶段1:数据准备(TaskGroup)# ========================================@task_groupd