@@ -517,10 +517,7 @@ def _log_weight(
517
517
self .actor_network
518
518
) if self .functional else contextlib .nullcontext ():
519
519
dist = self .actor_network .get_dist (tensordict )
520
- if isinstance (dist , CompositeDistribution ):
521
- is_composite = True
522
- else :
523
- is_composite = False
520
+ is_composite = isinstance (dist , CompositeDistribution )
524
521
525
522
# current log_prob of actions
526
523
if is_composite :
@@ -537,6 +534,32 @@ def _log_weight(
537
534
prev_log_prob = _maybe_get_or_select (
538
535
tensordict , self .tensor_keys .sample_log_prob
539
536
)
537
+ # TODO:
538
+ # # current log_prob of actions
539
+ # action = _maybe_get_or_select(tensordict, self.tensor_keys.action)
540
+ #
541
+ # is_composite = None
542
+ # if all(key in tensordict for key in self.actor_network.dist_params_keys):
543
+ # prev_dist = self.actor_network.build_dist_from_params(tensordict.detach())
544
+ # kwargs, is_composite = _get_composite_kwargs(prev_dist)
545
+ # if is_composite:
546
+ # prev_log_prob = prev_dist.log_prob(tensordict, **kwargs)
547
+ # else:
548
+ # prev_log_prob = prev_dist.log_prob(action, **kwargs)
549
+ # print('prev_log_prob', prev_log_prob)
550
+ # else:
551
+ # try:
552
+ # prev_log_prob = _maybe_get_or_select(
553
+ # tensordict, self.tensor_keys.sample_log_prob
554
+ # )
555
+ # except KeyError as err:
556
+ # raise _make_lp_get_error(self.tensor_keys, tensordict, err)
557
+
558
+ with self .actor_network_params .to_module (
559
+ self .actor_network
560
+ ) if self .functional else contextlib .nullcontext ():
561
+ current_dist = self .actor_network .get_dist (tensordict )
562
+
540
563
541
564
if prev_log_prob .requires_grad :
542
565
raise RuntimeError (
@@ -558,20 +581,27 @@ def _log_weight(
558
581
"the beginning of your script to get a proper composite log-prob." ,
559
582
category = UserWarning ,
560
583
)
561
- if (
562
- is_composite
563
- and not is_tensor_collection (prev_log_prob )
564
- and is_tensor_collection (log_prob )
565
- ):
566
- log_prob = _sum_td_features (log_prob )
567
- log_prob .view_as (prev_log_prob )
584
+ # TODO:
585
+ # if isinstance(action, torch.Tensor):
586
+ # log_prob = current_dist.log_prob(action)
587
+ # else:
588
+ # if is_composite is None:
589
+ # kwargs, is_composite = _get_composite_kwargs(current_dist)
590
+ # log_prob: TensorDictBase = current_dist.log_prob(tensordict, **kwargs)
591
+ if (
592
+ is_composite
593
+ and not is_tensor_collection (prev_log_prob )
594
+ and is_tensor_collection (log_prob )
595
+ ):
596
+ log_prob = _sum_td_features (log_prob )
597
+ log_prob .view_as (prev_log_prob )
568
598
569
599
log_weight = (log_prob - prev_log_prob ).unsqueeze (- 1 )
570
600
kl_approx = (prev_log_prob - log_prob ).unsqueeze (- 1 )
571
601
if is_tensor_collection (kl_approx ):
572
602
kl_approx = _sum_td_features (kl_approx )
573
603
574
- return log_weight , dist , kl_approx
604
+ return log_weight , current_dist , kl_approx
575
605
576
606
def loss_critic (self , tensordict : TensorDictBase ) -> torch .Tensor :
577
607
"""Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``."""
@@ -647,6 +677,9 @@ def _cached_critic_network_params_detached(self):
647
677
@dispatch
648
678
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
649
679
tensordict = tensordict .clone (False )
680
+
681
+ log_weight , dist , kl_approx = self ._log_weight (tensordict )
682
+
650
683
advantage = tensordict .get (self .tensor_keys .advantage , None )
651
684
if advantage is None :
652
685
self .value_estimator (
@@ -660,7 +693,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
660
693
scale = advantage .std ().clamp_min (1e-6 )
661
694
advantage = (advantage - loc ) / scale
662
695
663
- log_weight , dist , kl_approx = self ._log_weight (tensordict )
664
696
if is_tensor_collection (log_weight ):
665
697
log_weight = _sum_td_features (log_weight )
666
698
log_weight = log_weight .view (advantage .shape )
0 commit comments