@@ -53,13 +53,24 @@ def compute_loss(data1, _data2):
5353
5454def compute_reward (response_ids : torch .Tensor ) -> TensorDict :
5555 """Simulate a reward model that scores each token position in the response.
56+ Returns a TensorDict with a ``"rm_score"`` field whose shape matches
57+ ``response_ids`` (i.e. one scalar per response token).
58+ """
59+ time .sleep (1 )
60+ reward = torch .randn_like (response_ids , dtype = torch .float32 )
61+
62+ return TensorDict ({"rm_score" : reward }, batch_size = response_ids .size (0 ))
63+
64+
65+ def compute_advantage (rewards : torch .Tensor ) -> TensorDict :
66+ """Simulate the process of computing advantage.
5667
5768 Returns a TensorDict with an ``"advantage"`` field whose shape matches
58- ``response_ids `` (i.e. one scalar per response token ).
69+ ``rewards `` (i.e. one scalar per reward ).
5970 """
6071 time .sleep (1 )
61- advantage = torch .randn_like (response_ids , dtype = torch .float32 )
62- return TensorDict ({"advantage" : advantage }, batch_size = response_ids .size (0 ))
72+ advantage = torch .randn_like (rewards , dtype = torch .float32 )
73+ return TensorDict ({"advantage" : advantage }, batch_size = rewards .size (0 ))
6374
6475
6576class TrainingWorker :
@@ -89,7 +100,7 @@ def infer_batch(self, kv_meta: KVBatchMeta) -> KVBatchMeta:
89100 """Simulate forward-only inference"""
90101 # 1. Pull data from storage
91102 data = tq .kv_batch_get_by_meta (meta = kv_meta )
92- logger .info (f"compute_log_prob : got data { data } " )
103+ logger .info (f"infer_batch : got data { data } " )
93104
94105 # 2. Model forward
95106 output = compute_log_prob (data ["prompt_ids" ], data ["response_ids" ])
@@ -494,6 +505,13 @@ def fit(self):
494505 meta = tq .kv_batch_put (keys = meta .keys , partition_id = meta .partition_id , fields = reward_output )
495506 logger .info (f"demo reward KVBatchMeta: { meta } " )
496507
508+ # ========================= Compute advantage =========================
509+ meta .fields = ["response_ids" , "ref_log_prob" , "old_log_prob" , "rm_score" ]
510+ advantage_data = tq .kv_batch_get_by_meta (meta = meta )
511+ advantage_output = compute_advantage (advantage_data ["rm_score" ])
512+ meta = tq .kv_batch_put (keys = meta .keys , partition_id = meta .partition_id , fields = advantage_output )
513+ logger .info (f"demo advantage KVBatchMeta: { meta } " )
514+
497515 # ========================= Update actor =========================
498516 meta .fields = [
499517 "input_ids" ,
0 commit comments