Skip to content

Commit 7c07539

Browse files
committed
[WIP] Compute lp during loss execution
ghstack-source-id: ba00253 Pull Request resolved: #2688
1 parent c2a622b commit 7c07539

File tree

2 files changed

+52
-22
lines changed

2 files changed

+52
-22
lines changed

test/test_cost.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -8176,18 +8176,19 @@ def _create_seq_mock_data_ppo(
81768176
obs = total_obs[:, :T]
81778177
next_obs = total_obs[:, 1:]
81788178
if atoms:
8179-
action = torch.randn(batch, T, atoms, action_dim, device=device).clamp(
8180-
-1, 1
8181-
)
8179+
action_shape = (batch, T, atoms, action_dim)
81828180
else:
8183-
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
8181+
action_shape = (batch, T, action_dim)
8182+
params_mean = torch.randn(action_shape, device=device) / 10
8183+
params_scale = torch.rand(action_shape, device=device) / 10
8184+
action = (params_mean + params_scale * torch.randn(action_shape, device=device)).clamp(
8185+
-1, 1
8186+
)
81848187
reward = torch.randn(batch, T, 1, device=device)
81858188
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
81868189
terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
81878190
mask = torch.ones(batch, T, dtype=torch.bool, device=device)
81888191
action = action.masked_fill_(~mask.unsqueeze(-1), 0.0)
8189-
params_mean = torch.randn_like(action) / 10
8190-
params_scale = torch.rand_like(action) / 10
81918192
loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0)
81928193
scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0)
81938194
if sample_log_prob_key is None:
@@ -8214,9 +8215,6 @@ def _create_seq_mock_data_ppo(
82148215
},
82158216
"collector": {"mask": mask},
82168217
action_key: action,
8217-
sample_log_prob_key: (
8218-
torch.randn_like(action[..., 1]) / 10
8219-
).masked_fill_(~mask, 0.0),
82208218
},
82218219
device=device,
82228220
names=[None, "time"],

torchrl/objectives/ppo.py

+45-13
Original file line numberDiff line numberDiff line change
@@ -517,10 +517,7 @@ def _log_weight(
517517
self.actor_network
518518
) if self.functional else contextlib.nullcontext():
519519
dist = self.actor_network.get_dist(tensordict)
520-
if isinstance(dist, CompositeDistribution):
521-
is_composite = True
522-
else:
523-
is_composite = False
520+
is_composite = isinstance(dist, CompositeDistribution)
524521

525522
# current log_prob of actions
526523
if is_composite:
@@ -537,6 +534,32 @@ def _log_weight(
537534
prev_log_prob = _maybe_get_or_select(
538535
tensordict, self.tensor_keys.sample_log_prob
539536
)
537+
# TODO:
538+
# # current log_prob of actions
539+
# action = _maybe_get_or_select(tensordict, self.tensor_keys.action)
540+
#
541+
# is_composite = None
542+
# if all(key in tensordict for key in self.actor_network.dist_params_keys):
543+
# prev_dist = self.actor_network.build_dist_from_params(tensordict.detach())
544+
# kwargs, is_composite = _get_composite_kwargs(prev_dist)
545+
# if is_composite:
546+
# prev_log_prob = prev_dist.log_prob(tensordict, **kwargs)
547+
# else:
548+
# prev_log_prob = prev_dist.log_prob(action, **kwargs)
549+
# print('prev_log_prob', prev_log_prob)
550+
# else:
551+
# try:
552+
# prev_log_prob = _maybe_get_or_select(
553+
# tensordict, self.tensor_keys.sample_log_prob
554+
# )
555+
# except KeyError as err:
556+
# raise _make_lp_get_error(self.tensor_keys, tensordict, err)
557+
558+
with self.actor_network_params.to_module(
559+
self.actor_network
560+
) if self.functional else contextlib.nullcontext():
561+
current_dist = self.actor_network.get_dist(tensordict)
562+
540563

541564
if prev_log_prob.requires_grad:
542565
raise RuntimeError(
@@ -558,20 +581,27 @@ def _log_weight(
558581
"the beginning of your script to get a proper composite log-prob.",
559582
category=UserWarning,
560583
)
561-
if (
562-
is_composite
563-
and not is_tensor_collection(prev_log_prob)
564-
and is_tensor_collection(log_prob)
565-
):
566-
log_prob = _sum_td_features(log_prob)
567-
log_prob.view_as(prev_log_prob)
584+
# TODO:
585+
# if isinstance(action, torch.Tensor):
586+
# log_prob = current_dist.log_prob(action)
587+
# else:
588+
# if is_composite is None:
589+
# kwargs, is_composite = _get_composite_kwargs(current_dist)
590+
# log_prob: TensorDictBase = current_dist.log_prob(tensordict, **kwargs)
591+
if (
592+
is_composite
593+
and not is_tensor_collection(prev_log_prob)
594+
and is_tensor_collection(log_prob)
595+
):
596+
log_prob = _sum_td_features(log_prob)
597+
log_prob.view_as(prev_log_prob)
568598

569599
log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
570600
kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
571601
if is_tensor_collection(kl_approx):
572602
kl_approx = _sum_td_features(kl_approx)
573603

574-
return log_weight, dist, kl_approx
604+
return log_weight, current_dist, kl_approx
575605

576606
def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
577607
"""Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``."""
@@ -647,6 +677,9 @@ def _cached_critic_network_params_detached(self):
647677
@dispatch
648678
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
649679
tensordict = tensordict.clone(False)
680+
681+
log_weight, dist, kl_approx = self._log_weight(tensordict)
682+
650683
advantage = tensordict.get(self.tensor_keys.advantage, None)
651684
if advantage is None:
652685
self.value_estimator(
@@ -660,7 +693,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
660693
scale = advantage.std().clamp_min(1e-6)
661694
advantage = (advantage - loc) / scale
662695

663-
log_weight, dist, kl_approx = self._log_weight(tensordict)
664696
if is_tensor_collection(log_weight):
665697
log_weight = _sum_td_features(log_weight)
666698
log_weight = log_weight.view(advantage.shape)

0 commit comments

Comments
 (0)