Skip to content

Commit 1777b9d

Browse files
committed
[WIP] Compute lp during loss execution
ghstack-source-id: b5b186b Pull Request resolved: #2688
1 parent a0a1f00 commit 1777b9d

File tree

2 files changed

+52
-36
lines changed

2 files changed

+52
-36
lines changed

test/test_cost.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -8158,18 +8158,19 @@ def _create_seq_mock_data_ppo(
81588158
obs = total_obs[:, :T]
81598159
next_obs = total_obs[:, 1:]
81608160
if atoms:
8161-
action = torch.randn(batch, T, atoms, action_dim, device=device).clamp(
8162-
-1, 1
8163-
)
8161+
action_shape = (batch, T, atoms, action_dim)
81648162
else:
8165-
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
8163+
action_shape = (batch, T, action_dim)
8164+
params_mean = torch.randn(action_shape, device=device) / 10
8165+
params_scale = torch.rand(action_shape, device=device) / 10
8166+
action = (params_mean + params_scale * torch.randn(action_shape, device=device)).clamp(
8167+
-1, 1
8168+
)
81668169
reward = torch.randn(batch, T, 1, device=device)
81678170
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
81688171
terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
81698172
mask = torch.ones(batch, T, dtype=torch.bool, device=device)
81708173
action = action.masked_fill_(~mask.unsqueeze(-1), 0.0)
8171-
params_mean = torch.randn_like(action) / 10
8172-
params_scale = torch.rand_like(action) / 10
81738174
loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0)
81748175
scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0)
81758176
td = TensorDict(
@@ -8184,9 +8185,6 @@ def _create_seq_mock_data_ppo(
81848185
},
81858186
"collector": {"mask": mask},
81868187
action_key: {"action1": action} if composite_action_dist else action,
8187-
sample_log_prob_key: (
8188-
torch.randn_like(action[..., 1]) / 10
8189-
).masked_fill_(~mask, 0.0),
81908188
},
81918189
device=device,
81928190
names=[None, "time"],

torchrl/objectives/ppo.py

+45-27
Original file line numberDiff line numberDiff line change
@@ -510,17 +510,28 @@ def _log_weight(
510510
# current log_prob of actions
511511
action = _maybe_get_or_select(tensordict, self.tensor_keys.action)
512512

513+
is_composite = None
514+
if all(key in tensordict for key in self.actor_network.dist_params_keys):
515+
prev_dist = self.actor_network.build_dist_from_params(tensordict.detach())
516+
kwargs, is_composite = _get_composite_kwargs(prev_dist)
517+
if is_composite:
518+
prev_log_prob = prev_dist.log_prob(tensordict, **kwargs)
519+
else:
520+
prev_log_prob = prev_dist.log_prob(action, **kwargs)
521+
print('prev_log_prob', prev_log_prob)
522+
else:
523+
try:
524+
prev_log_prob = _maybe_get_or_select(
525+
tensordict, self.tensor_keys.sample_log_prob
526+
)
527+
except KeyError as err:
528+
raise _make_lp_get_error(self.tensor_keys, tensordict, err)
529+
513530
with self.actor_network_params.to_module(
514531
self.actor_network
515532
) if self.functional else contextlib.nullcontext():
516-
dist = self.actor_network.get_dist(tensordict)
533+
current_dist = self.actor_network.get_dist(tensordict)
517534

518-
try:
519-
prev_log_prob = _maybe_get_or_select(
520-
tensordict, self.tensor_keys.sample_log_prob
521-
)
522-
except KeyError as err:
523-
raise _make_lp_get_error(self.tensor_keys, tensordict, err)
524535

525536
if prev_log_prob.requires_grad:
526537
raise RuntimeError(
@@ -532,25 +543,11 @@ def _log_weight(
532543
f"tensordict stored {self.tensor_keys.action} requires grad."
533544
)
534545
if isinstance(action, torch.Tensor):
535-
log_prob = dist.log_prob(action)
546+
log_prob = current_dist.log_prob(action)
536547
else:
537-
if isinstance(dist, CompositeDistribution):
538-
is_composite = True
539-
aggregate = dist.aggregate_probabilities
540-
if aggregate is None:
541-
aggregate = False
542-
include_sum = dist.include_sum
543-
if include_sum is None:
544-
include_sum = False
545-
kwargs = {
546-
"inplace": False,
547-
"aggregate_probabilities": aggregate,
548-
"include_sum": include_sum,
549-
}
550-
else:
551-
is_composite = False
552-
kwargs = {}
553-
log_prob: TensorDictBase = dist.log_prob(tensordict, **kwargs)
548+
if is_composite is None:
549+
kwargs, is_composite = _get_composite_kwargs(current_dist)
550+
log_prob: TensorDictBase = current_dist.log_prob(tensordict, **kwargs)
554551
if (
555552
is_composite
556553
and not is_tensor_collection(prev_log_prob)
@@ -564,7 +561,7 @@ def _log_weight(
564561
if is_tensor_collection(kl_approx):
565562
kl_approx = _sum_td_features(kl_approx)
566563

567-
return log_weight, dist, kl_approx
564+
return log_weight, current_dist, kl_approx
568565

569566
def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
570567
"""Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``."""
@@ -640,6 +637,9 @@ def _cached_critic_network_params_detached(self):
640637
@dispatch
641638
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
642639
tensordict = tensordict.clone(False)
640+
641+
log_weight, dist, kl_approx = self._log_weight(tensordict)
642+
643643
advantage = tensordict.get(self.tensor_keys.advantage, None)
644644
if advantage is None:
645645
self.value_estimator(
@@ -653,7 +653,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
653653
scale = advantage.std().clamp_min(1e-6)
654654
advantage = (advantage - loc) / scale
655655

656-
log_weight, dist, kl_approx = self._log_weight(tensordict)
657656
if is_tensor_collection(log_weight):
658657
log_weight = _sum_td_features(log_weight)
659658
log_weight = log_weight.view(advantage.shape)
@@ -1295,3 +1294,22 @@ def _make_lp_get_error(tensor_keys, log_prob, err):
12951294
return KeyError(result)
12961295
result += "This is usually due to a missing call to loss.set_keys(sample_log_prob=<list_of_log_prob_keys>)."
12971296
return KeyError(result)
1297+
1298+
def _get_composite_kwargs(current_dist):
1299+
if isinstance(current_dist, CompositeDistribution):
1300+
is_composite = True
1301+
aggregate = current_dist.aggregate_probabilities
1302+
if aggregate is None:
1303+
aggregate = False
1304+
include_sum = current_dist.include_sum
1305+
if include_sum is None:
1306+
include_sum = False
1307+
kwargs = {
1308+
"inplace": False,
1309+
"aggregate_probabilities": aggregate,
1310+
"include_sum": include_sum,
1311+
}
1312+
else:
1313+
is_composite = False
1314+
kwargs = {}
1315+
return kwargs, is_composite

0 commit comments

Comments
 (0)