AReaL 中的 Off-policyness 与 Partial Rollouts 实现详解 #
概述 #
AReaL 框架通过异步推理和训练分离的架构,实现了高效的分布式 RL 训练。在这个过程中,核心挑战是如何处理:
- Off-policyness(策略过时性):推理引擎使用的模型版本可能落后于训练引擎的版本
- Partial Rollouts(部分轨迹):单个生成序列可能跨越多个模型版本
本文将深入分析 AReaL 如何通过 StalenessManager、版本跟踪机制和 WorkflowExecutor 来管理这些问题。
1. Off-policyness 控制机制 #
1.1 配置参数 #
AReaL 通过 max_head_offpolicyness 参数控制允许的最大版本差:
rollout:
max_head_offpolicyness: 4 # 允许推理版本最多落后训练版本 4 步
关键配置说明:
0:同步 RL(推理和训练完全同步,用于调试)2-8:典型异步范围(根据模型大小和更新频率调整)- 更高值:提高吞吐量,但可能降低训练稳定性
1.2 StalenessManager 核心实现 #
StalenessManager 是控制 off-policyness 的核心组件,位于 areal/core/staleness_manager.py。
1.2.1 容量计算公式 #
def get_capacity(self) -> int:
"""计算可用的新 rollout 槽位数量"""
with self.lock:
current_version = self.version_provider.get_version() # 获取当前训练版本
# 并发限制容量
max_concurrent_rollouts = max(1, self.max_concurrent_rollouts)
concurrency_capacity = max_concurrent_rollouts - self.rollout_stat.running
# 过时性限制容量
ofp = self.max_staleness # max_head_offpolicyness
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
consumer_bs = max(1, self.consumer_batch_size)
staleness_capacity = (ofp + current_version + 1) * consumer_bs - sample_cnt
# 返回两者的最小值
capacity = min(concurrency_capacity, staleness_capacity)
return capacity
容量计算逻辑:
-
并发容量:
max_concurrent_rollouts - running_rollouts- 限制同时运行的 rollout 数量
-
过时性容量:
(max_staleness + current_version + 1) × batch_size - current_samples- 确保样本被消费时不会超过最大允许的过时性
current_version + 1:预测下一个训练步骤的版本- 乘以
batch_size:每个版本可以有一个完整批次的样本
-
最终容量:取两者最小值,确保同时满足两个约束
1.2.2 Rollout 生命周期跟踪 #
class StalenessManager:
def on_rollout_enqueued(self):
"""Rollout 任务入队"""
with self.lock:
self.rollout_stat.enqueued += 1
def on_rollout_submitted(self):
"""Rollout 任务提交执行"""
with self.lock:
self.rollout_stat.enqueued -= 1
self.rollout_stat.running += 1
def on_rollout_accepted(self):
"""Rollout 完成并被接受"""
with self.lock:
self.rollout_stat.running -= 1
self.rollout_stat.accepted += 1
def on_rollout_rejected(self):
"""Rollout 完成但被拒绝"""
with self.lock:
self.rollout_stat.running -= 1
self.rollout_stat.rejected += 1
状态转换图:
stateDiagram-v2
[*] --> Enqueued: 任务提交
Enqueued --> Running: 开始执行
Running --> Accepted: 完成并通过筛选
Running --> Rejected: 完成但被拒绝/失败
Accepted --> [*]
Rejected --> [*]
2. Partial Rollouts(部分轨迹)实现 #
2.1 版本跟踪机制 #
AReaL 通过在生成过程中记录每个 token 的模型版本来支持 partial rollouts。
2.1.1 ModelResponse 结构 #
@dataclass
class ModelResponse:
input_tokens: list[int] = field(default_factory=list)
output_tokens: list[int] = field(default_factory=list)
output_logprobs: list[float] = field(default_factory=list)
output_versions: list[int] = field(default_factory=list) # 每个 token 的版本号
stop_reason: Literal["length", "stop", "tool_calls", "abort"] = "stop"
关键字段:
output_versions:与output_tokens长度相同的列表,记录每个 token 生成时的模型版本
2.1.2 增量生成与版本记录 #
在 RemoteInfEngine.agenerate 中实现增量生成:
async def agenerate(self, req: ModelRequest) -> ModelResponse:
"""
处理可能被权重更新中断的生成请求。
如果生成过程中发生权重更新,会继续从中断处生成,
并记录每个阶段的模型版本。
"""
output_tokens = []
output_logprobs = []
output_versions = []
stop_reason = None
max_new_tokens = req.gconfig.max_new_tokens
while stop_reason != "stop" and len(output_tokens) < max_new_tokens:
# 如果之前被中断,等待一段时间避免资源竞争
if stop_reason is not None:
await asyncio.sleep(0.5)
# 发送 HTTP 请求到推理服务器
result = await arequest_with_retry(
addr=server_addr,
endpoint="/generate",
payload=payload,
method="POST"
)
# 累积输出
new_tokens = result["output_ids"]
output_tokens.extend(new_tokens)
output_logprobs.extend(result["logprobs"])
# 记录当前版本号(所有新生成的 token 使用相同版本)
current_version = self.get_version()
output_versions.extend([current_version] * len(new_tokens))
# 更新请求以继续生成
payload["input_ids"] += new_tokens
payload["sample_params"]["max_new_tokens"] -= len(new_tokens)
stop_reason = result.get("stop_reason")
return ModelResponse(
input_tokens=req.input_ids,
output_tokens=output_tokens,
output_logprobs=output_logprobs,
output_versions=output_versions, # 版本信息
stop_reason=stop_reason
)
2.2 Workflow 中的版本处理 #
以 MultiTurnWorkflow 为例,展示如何在多轮对话中累积版本信息:
class MultiTurnWorkflow(RolloutWorkflow):
async def arun_episode(self, engine: InferenceEngine, data: dict[str, Any]):
seq, logprobs, loss_mask, versions = [], [], [], []
for t in range(self.max_turns):
# 发送生成请求
req = ModelRequest(...)
resp = await engine.agenerate(req)
# 累积结果,注意版本信息的处理
input_len = len(resp.input_tokens) - len(seq)
# 拼接序列
seq += resp.input_tokens[-input_len:] + resp.output_tokens
logprobs += [0.0] * input_len + resp.output_logprobs
loss_mask += [0] * input_len + [1] * resp.output_len
# 拼接版本信息:prompt 部分标记为 -1,completion 部分使用实际版本
versions += [-1] * input_len + resp.output_versions
# 检查是否需要继续
if reward > 0:
break
# 返回包含版本信息的轨迹
return {
"input_ids": torch.tensor(seq, dtype=torch.int32).unsqueeze(0),
"logprobs": torch.tensor(logprobs, dtype=torch.float32).unsqueeze(0),
"loss_mask": torch.tensor(loss_mask, dtype=torch.int32).unsqueeze(0),
"versions": torch.tensor(versions, dtype=torch.int32).unsqueeze(0), # 版本张量
"rewards": torch.tensor(reward, dtype=torch.float32).unsqueeze(0),
"attention_mask": torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
}
版本标记约定:
-1:prompt 部分(不需要计算损失)>= 0:completion 部分的实际模型版本号
3. 异步任务调度与分发 #
3.1 BatchTaskDispatcher 架构 #
BatchTaskDispatcher 使用生产者-消费者模式管理异步任务:
class BatchTaskDispatcher[TInput: WithTaskID, TResult]:
"""
通用的异步任务分发器,支持过时性控制。
架构:
- 生产者线程:根据过时性容量从 _pending_inputs 提交任务到 AsyncTaskRunner
- 消费者线程:从 AsyncTaskRunner 收集结果到 _pending_results
- 主线程:submit_task_input() 入队,wait_results() 等待结果
"""
def __init__(
self,
max_queue_size: int,
task_factory: Callable[[TInput], Callable[[], Awaitable[TResult | None]]],
staleness_manager: StalenessManager,
enable_tracing: bool = False,
):
self.runner = AsyncTaskRunner(max_queue_size=max_queue_size)
self.staleness_manager = staleness_manager
# 无界队列用于生产者-消费者模式
self._pending_inputs: deque[TInput] = deque()
self._pending_results: dict[int, TimedResult[TResult]] = {}
# 线程同步
self._input_lock = threading.Lock()
self._input_cv = threading.Condition(self._input_lock)
self._result_lock = threading.Lock()
self._result_cv = threading.Condition(self._result_lock)
3.2 生产者线程:容量感知的任务提交 #
def _commit_loop(self) -> None:
"""生产者线程 - 根据容量持续提交任务"""
while not self._shutdown_event.is_set():
try:
# 获取下一个待提交任务
task_input = self._get_next_task_for_submission()
if task_input is None:
continue
# 提交到 AsyncTaskRunner
task_fn = self.task_factory(task_input)
try:
self.runner.submit(task_fn, task_id=task_input.task_id)
self.staleness_manager.on_rollout_submitted()
except TaskQueueFullError:
# 队列满,重新入队并等待容量
with self._input_cv:
self._pending_inputs.appendleft(task_input)
self._input_cv.wait_for(
lambda: self._shutdown_event.is_set() or self._has_runner_capacity()
)
continue
except Exception as e:
self.logger.error("Producer thread failed", exc_info=True)
self._set_thread_exception(e)
break
def _get_next_task_for_submission(self) -> TInput | None:
"""获取下一个可提交的任务,考虑容量限制"""
with self._input_cv:
while not self._shutdown_event.is_set():
# 检查是否有容量且有待处理任务
if (
not self.runner.paused.is_set()
and self.staleness_manager.get_capacity() > 0 # 过时性容量检查
and self._pending_inputs
):
return self._pending_inputs.popleft()
self._input_cv.wait()
return None
3.3 消费者线程:结果收集 #
def _fetch_loop(self) -> None:
"""消费者线程 - 持续从 runner 收集结果"""
while not self._shutdown_event.is_set():
try:
# 轮询可用结果(非阻塞)
output_queue_size = self.runner.get_output_queue_size()
count = max(1, min(output_queue_size, _MAX_FETCH_BATCH_SIZE))
try:
# 使用短超时以保持响应性
results = self.runner.wait(count=count, timeout=0.05, with_timing=True)
except TimeoutError:
continue
# 将结果放入待处理队列
with self._result_cv:
for result in results:
self._pending_results[result.task_id] = result
self._result_cv.notify_all()
# 新容量可用,唤醒生产者
with self._input_cv:
self._input_cv.notify()
except Exception as e:
self.logger.error("Consumer thread failed", exc_info=True)
self._set_thread_exception(e)
break
4. 完整流程图 #
4.1 Off-policyness 控制流程 #
sequenceDiagram
participant Main as 主线程
participant Dispatcher as BatchTaskDispatcher
participant Producer as 生产者线程
participant SM as StalenessManager
participant Runner as AsyncTaskRunner
participant Consumer as 消费者线程
participant Workflow as RolloutWorkflow
Main->>Dispatcher: submit_task_input(task)
Dispatcher->>SM: on_rollout_enqueued()
Note over SM: enqueued++
Dispatcher->>Dispatcher: _pending_inputs.append(task)
Producer->>SM: get_capacity()
SM-->>Producer: capacity = min(concurrency_cap, staleness_cap)
alt capacity > 0
Producer->>Dispatcher: _pending_inputs.popleft()
Producer->>Runner: submit(task_fn)
Producer->>SM: on_rollout_submitted()
Note over SM: enqueued--, running++
Runner->>Workflow: arun_episode(engine, data)
alt 生成中权重更新
Workflow->>Workflow: 记录版本 v1
Note over Workflow: 权重更新中断
Workflow->>Workflow: 等待并继续
Workflow->>Workflow: 记录版本 v2
end
Workflow-->>Runner: trajectory (with versions)
Runner-->>Consumer: result ready
Consumer->>Dispatcher: _pending_results[task_id] = result
Consumer->>SM: on_rollout_accepted()
Note over SM: running--, accepted++
else capacity <= 0
Producer->>Producer: wait for capacity
end
Main->>Dispatcher: wait_results(count)
Dispatcher->>Main: [trajectory1, trajectory2, ...]
4.2 Partial Rollouts 处理流程 #
graph TB
Start([开始生成请求]) --> InitVersion[记录初始版本 v0]
InitVersion --> GenLoop{继续生成?}
GenLoop -->|是| SendReq[发送生成请求到推理服务器]
SendReq --> CheckStop{检查 stop_reason}
CheckStop -->|中断: weight_update| Wait[等待 0.5s
避免资源竞争]
Wait --> UpdateVersion[更新版本到 v1]
UpdateVersion --> RecordV1[为新 tokens 记录版本 v1]
RecordV1 --> UpdatePayload[更新 payload
继续生成]
UpdatePayload --> GenLoop
CheckStop -->|正常| RecordV0[为新 tokens 记录版本 v0]
RecordV0 --> CheckComplete{达到长度限制?}
CheckComplete -->|否| UpdatePayload
CheckComplete -->|是| BuildResp[构建 ModelResponse]
GenLoop -->|否: stop 或达到长度| BuildResp
BuildResp --> ReturnResp[返回包含:
- output_tokens
- output_logprobs
- output_versions]
ReturnResp --> End([结束])
style Wait fill:#ffe6e6
style UpdateVersion fill:#fff4e6
style RecordV1 fill:#fff4e6
style RecordV0 fill:#e6f7ff
4.3 版本信息在轨迹中的流动 #
graph LR
A[推理服务器
版本 v3] -->|agenerate| B[ModelResponse
output_versions: 3,3,3,...]
B -->|workflow.arun_episode| C[Trajectory
versions: -1,-1,3,3,3,...]
D[推理服务器
版本 v4] -->|agenerate
中途更新| E[ModelResponse
output_versions: 4,4,5,5,...]
E -->|workflow.arun_episode| F[Trajectory
versions: -1,-1,4,4,5,5,...]
C --> G[Training Batch
批次聚合]
F --> G
G -->|compute_advantages| H[PPOActor
考虑版本差异]
H -->|ppo_update| I[权重更新
新版本 v5]
style A fill:#e1f5ff
style D fill:#e1f5ff
style I fill:#ffe6e6
5. Decoupled PPO 目标函数 #
当启用 off-policy 训练时,AReaL 使用 Decoupled PPO 损失函数来处理版本差异。
5.1 配置 #
actor:
use_decoupled_loss: true # 启用 Decoupled PPO
recompute_logprobs: true # 训练时重新计算 logprobs
prox_logp_method: "recompute" # 近端策略的计算方法
5.2 三种 Log-Probability #
在 Decoupled PPO 中,需要区分三种对数概率:
- π_behave(行为策略):rollout 时记录的 logprobs(从推理服务器)
- π_prox(近端策略):训练开始时的策略(作为参考点)
- π_θ(当前策略):训练过程中更新的策略
class PPOActor:
def _log_configuration(self):
if not self.config.use_decoupled_loss:
logger.info("Mode: Standard PPO (on-policy)")
if self.config.recompute_logprob:
logger.info(" old_logp (π_old): RECOMPUTED from current policy")
else:
logger.info(" old_logp (π_old): FROM INFERENCE (cached during rollout)")
else:
logger.info("Mode: Decoupled PPO (off-policy)")
logger.info(" log_p_behave (π_behave): FROM INFERENCE (behavior policy)")
method_descriptions = {
"recompute": "RECOMPUTED via forward pass (standard decoupled PPO)",
"loglinear": "LOG-LINEAR APPROXIMATION (no forward pass)",
"metrics": "RECOMPUTED + APPROXIMATION METRICS (for evaluation)",
}
logger.info(f" Proximal policy (π_prox): {method_descriptions[config.prox_logp_method]}")
logger.info(" log_p_theta (π_θ): TRAINING FORWARD PASS (current policy)")
5.3 Decoupled 损失计算 #
def ppo_actor_loss_fn(...):
if use_decoupled_loss:
# 计算行为策略的重要性权重
log_ratio_behav = logp_theta - log_p_behave
ratio_behav = torch.exp(log_ratio_behav)
# 计算近端策略的比率(用于 PPO clip)
log_ratio_prox = logp_theta - prox_logp
ratio_prox = torch.exp(log_ratio_prox)
# Decoupled PPO 目标
surr1 = ratio_behav * advantages
surr2 = torch.clamp(ratio_prox, 1.0 - eps_clip, 1.0 + eps_clip) * advantages
policy_loss = -torch.min(surr1, surr2)
else:
# 标准 PPO
log_ratio = logp_theta - old_logp
ratio = torch.exp(log_ratio)
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip) * advantages
policy_loss = -torch.min(surr1, surr2)
6. 系统整体流程图 #
有语法错误
graph TB
subgraph "1. 初始化阶段"
Init[初始化推理引擎和训练引擎]
Init --> LaunchServer[启动推理服务器
SGLang/vLLM]
LaunchServer --> CreateSM[创建 StalenessManager
max_staleness=4]
CreateSM --> CreateDispatcher[创建 BatchTaskDispatcher]
CreateDispatcher --> StartThreads[启动生产者/消费者线程]
end
subgraph "2. 训练循环"
TrainStep[训练步骤 step]
TrainStep --> PrepareBatch[actor.prepare_batch
收集 rollout 数据]
PrepareBatch --> CheckCapacity{检查容量}
CheckCapacity -->|capacity > 0| SubmitTask[提交 rollout 任务]
CheckCapacity -->|capacity <= 0| WaitCapacity[等待容量释放]
WaitCapacity --> CheckCapacity
SubmitTask --> RunWorkflow[执行 workflow.arun_episode]
RunWorkflow --> GenTokens[生成 tokens
记录版本信息]
GenTokens --> CheckUpdate{权重更新中断?}
CheckUpdate -->|是| WaitUpdate[等待更新完成
继续生成]
WaitUpdate --> GenTokens
CheckUpdate -->|否| CompleteGen[完成生成]
CompleteGen --> BuildTraj[构建轨迹
包含版本张量]
BuildTraj --> AcceptReject{通过筛选?}
AcceptReject -->|是| CollectBatch[收集到批次]
AcceptReject -->|否| DiscardTraj[丢弃轨迹]
CollectBatch --> BatchReady{批次就绪?}
BatchReady -->|否| CheckCapacity
BatchReady -->|是| ComputeAdv[计算 Advantages]
ComputeAdv --> PPOUpdate[PPO 更新
Decoupled Loss]
PPOUpdate --> UpdateLR[更新学习率]
UpdateLR --> SyncWeights[同步权重到推理服务器]
SyncWeights --> IncVersion[增加版本号
version++]
IncVersion --> NextStep[下一步]
NextStep --> TrainStep
end
subgraph "3. 并发控制"
SM[StalenessManager]
SM --> CalcCap[计算容量
min(concurrency, staleness)]
CalcCap --> TrackState[跟踪 rollout 状态
enqueued/running/accepted]
end
StartThreads --> TrainStep
PrepareBatch -.调用.-> SM
IncVersion -.更新.-> SM
style CheckUpdate fill:#ffe6e6
style WaitUpdate fill:#ffe6e6
style GenTokens fill:#e6f7ff
style SyncWeights fill:#fff4e6
style IncVersion fill:#fff4e6
cursor 修复的
graph TB
subgraph "1. 初始化阶段"
Init["初始化推理引擎和训练引擎"]
Init --> LaunchServer["启动推理服务器
SGLang/vLLM"]
LaunchServer --> CreateSM["创建 StalenessManager
max_staleness=4"]
CreateSM --> CreateDispatcher["创建 BatchTaskDispatcher"]
CreateDispatcher --> StartThreads["启动生产者/消费者线程"]
end
subgraph "2. 训练循环"
TrainStep["训练步骤 step"]
TrainStep --> PrepareBatch["actor.prepare_batch
收集 rollout 数据"]
PrepareBatch --> CheckCapacity{"检查容量"}
CheckCapacity -->|"capacity > 0"| SubmitTask["提交 rollout 任务"]
CheckCapacity -->|"capacity <= 0"| WaitCapacity["等待容量释放"]
WaitCapacity --> CheckCapacity
SubmitTask --> RunWorkflow["执行 workflow.arun_episode"]
RunWorkflow --> GenTokens["生成 tokens
记录版本信息"]
GenTokens --> CheckUpdate{"权重更新中断?"}
CheckUpdate -->|"是"| WaitUpdate["等待更新完成
继续生成"]
WaitUpdate --> GenTokens
CheckUpdate -->|"否"| CompleteGen["完成生成"]
CompleteGen --> BuildTraj["构建轨迹
包含版本张量"]
BuildTraj --> AcceptReject{"通过筛选?"}
AcceptReject -->|"是"| CollectBatch["收集到批次"]
AcceptReject -->|"否"| DiscardTraj["丢弃轨迹"]
CollectBatch --> BatchReady{"批次就绪?"}
BatchReady -->|"否"| CheckCapacity
BatchReady -->|"是"| ComputeAdv["计算 Advantages"]
ComputeAdv --> PPOUpdate["PPO 更新
Decoupled Loss"]
PPOUpdate --> UpdateLR["更新学习率"]
UpdateLR --> SyncWeights["同步权重到推理服务器"]
SyncWeights --> IncVersion["增加版本号
version++"]
IncVersion --> NextStep["下一步"]
NextStep --> TrainStep
end
subgraph "3. 并发控制"
SM["StalenessManager"]
SM --> CalcCap["计算容量
min(concurrency, staleness)"]
CalcCap --> TrackState["跟踪 rollout 状态
enqueued/running/accepted"]
end
StartThreads --> TrainStep
PrepareBatch -.调用.-> SM
IncVersion -.更新.-> SM
style CheckUpdate fill:#ffe6e6
style WaitUpdate fill:#ffe6e6
style GenTokens fill:#e6f7ff
style SyncWeights fill:#fff4e6
style IncVersion fill:#fff4e6
7. 关键设计要点总结 #
7.1 Off-policyness 控制 #
-
双重容量限制:
- 并发限制:控制同时运行的 rollout 数量
- 过时性限制:确保样本不会过时到超过可接受范围
-
版本感知调度:
- 生产者线程在提交任务前检查
staleness_manager.get_capacity() - 容量计算考虑当前版本和最大允许的版本差
- 生产者线程在提交任务前检查
-
状态跟踪:
- 使用线程安全的计数器跟踪 rollout 生命周期
- 实时更新
enqueued、running、accepted、rejected状态
7.2 Partial Rollouts 支持 #
-
增量生成:
agenerate方法支持中断和恢复- 每次生成一段 tokens 后检查是否需要继续
-
版本记录:
output_versions列表与output_tokens一一对应- 记录每个 token 生成时的模型版本
-
版本传播:
- 从
ModelResponse→Workflow→Trajectory→Training Batch - 训练时可以利用版本信息调整损失计算
- 从
7.3 异步架构优势 #
-
高吞吐量:
- 推理和训练并行执行
- 典型情况下比同步 RL 快 2 倍
-
GPU 利用率高:
- 推理 GPU 持续生成数据
- 训练 GPU 持续更新模型
-
灵活的容错:
- 支持中断和恢复
- 拒绝的 rollout 不会阻塞训练
8. 配置建议 #
8.1 调试场景 #
rollout:
max_head_offpolicyness: 0 # 完全同步,便于调试
max_concurrent_rollouts: 1
actor:
use_decoupled_loss: false
recompute_logprob: true
8.2 高吞吐量场景 #
rollout:
max_head_offpolicyness: 6 # 允许更多异步
max_concurrent_rollouts: 128
consumer_batch_size: 32
actor:
use_decoupled_loss: true
recompute_logprob: true
prox_logp_method: "loglinear" # 使用近似避免额外前向传播
8.3 稳定性优先场景 #
rollout:
max_head_offpolicyness: 2 # 较小的异步度
max_concurrent_rollouts: 32
actor:
use_decoupled_loss: true
recompute_logprob: true
prox_logp_method: "recompute" # 精确计算近端策略
9. 参考资料 #
-
代码位置:
areal/core/staleness_manager.py: Off-policyness 控制areal/core/workflow_executor.py: 任务调度和执行areal/core/remote_inf_engine.py: 推理引擎和增量生成areal/engine/ppo/actor.py: Decoupled PPO 实现
-
文档:
附录:代码示例 #
A.1 完整的训练循环示例 #
# 初始化
rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize(train_data_parallel_size=parallel_strategy.dp_size)
actor = FSDPPPOActor(config=config.actor)
actor.create_process_group(parallel_strategy=parallel_strategy)
actor.initialize(None, ft_spec)
actor.connect_engine(rollout, weight_update_meta)
# 训练循环
for global_step in range(max_steps):
# 1. 收集 rollout 数据(异步)
batch = actor.prepare_batch(
train_dataloader,
group_size=config.gconfig.n_samples,
workflow=workflow,
should_accept_fn=lambda sample: True,
)
# batch 包含 'versions' 字段,形状 [batch_size, seq_len]
# 2. 重新计算 logprobs(如果需要)
if config.actor.recompute_logprob:
logp = actor.compute_logp(batch)
batch["prox_logp"] = logp
# 3. 计算参考 logprobs(用于 KL 惩罚)
if ref is not None:
batch["ref_logp"] = ref.compute_logp(batch)
# 4. 计算 Advantages
actor.compute_advantages(batch)
# 5. PPO 更新
actor.ppo_update(batch)
actor.step_lr_scheduler()
# 6. 同步权重到推理服务器
rollout.pause() # 暂停 rollout 避免竞争
actor.update_weights(weight_update_meta)
actor.set_version(global_step + 1)
rollout.set_version(global_step + 1)
rollout.resume() # 恢复 rollout
# 7. 日志和评估
if global_step % eval_interval == 0:
evaluator.evaluate(actor, valid_dataloader, global_step)
saver.save(actor, global_step)
A.2 自定义 Workflow 示例 #
class CustomWorkflow(RolloutWorkflow):
async def arun_episode(self, engine: InferenceEngine, data: dict[str, Any]):
# 准备输入
input_ids = self.tokenizer.apply_chat_template(
data["messages"], tokenize=True, add_generation_prompt=True
)
# 生成响应(支持 partial rollouts)
req = ModelRequest(
rid=uuid.uuid4().hex,
input_ids=input_ids,
gconfig=self.gconfig,
tokenizer=self.tokenizer,
)
resp = await engine.agenerate(req)
# 计算奖励
prompt_str = self.tokenizer.decode(resp.input_tokens)
completion_str = self.tokenizer.decode(resp.output_tokens)
reward = await self.async_reward_fn(prompt_str, completion_str, **data)
# 构建轨迹(包含版本信息)
seq = resp.input_tokens + resp.output_tokens
logprobs = [0.0] * resp.input_len + resp.output_logprobs
loss_mask = [0] * resp.input_len + [1] * resp.output_len
versions = [-1] * resp.input_len + resp.output_versions # 关键:传播版本信息
return {
"input_ids": torch.tensor(seq, dtype=torch.int32).unsqueeze(0),
"logprobs": torch.tensor(logprobs, dtype=torch.float32).unsqueeze(0),
"loss_mask": torch.tensor(loss_mask, dtype=torch.int32).unsqueeze(0),
"versions": torch.tensor(versions, dtype=torch.int32).unsqueeze(0),
"rewards": torch.tensor(reward, dtype=torch.float32).unsqueeze(0),
"attention_mask": torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
}
参考 #
AREAL v0.5.3