@@ -510,17 +510,28 @@ def _log_weight(
510
510
# current log_prob of actions
511
511
action = _maybe_get_or_select (tensordict , self .tensor_keys .action )
512
512
513
+ is_composite = None
514
+ if all (key in tensordict for key in self .actor_network .dist_params_keys ):
515
+ prev_dist = self .actor_network .build_dist_from_params (tensordict .detach ())
516
+ kwargs , is_composite = _get_composite_kwargs (prev_dist )
517
+ if is_composite :
518
+ prev_log_prob = prev_dist .log_prob (tensordict , ** kwargs )
519
+ else :
520
+ prev_log_prob = prev_dist .log_prob (action , ** kwargs )
521
+ print ('prev_log_prob' , prev_log_prob )
522
+ else :
523
+ try :
524
+ prev_log_prob = _maybe_get_or_select (
525
+ tensordict , self .tensor_keys .sample_log_prob
526
+ )
527
+ except KeyError as err :
528
+ raise _make_lp_get_error (self .tensor_keys , tensordict , err )
529
+
513
530
with self .actor_network_params .to_module (
514
531
self .actor_network
515
532
) if self .functional else contextlib .nullcontext ():
516
- dist = self .actor_network .get_dist (tensordict )
533
+ current_dist = self .actor_network .get_dist (tensordict )
517
534
518
- try :
519
- prev_log_prob = _maybe_get_or_select (
520
- tensordict , self .tensor_keys .sample_log_prob
521
- )
522
- except KeyError as err :
523
- raise _make_lp_get_error (self .tensor_keys , tensordict , err )
524
535
525
536
if prev_log_prob .requires_grad :
526
537
raise RuntimeError (
@@ -532,25 +543,11 @@ def _log_weight(
532
543
f"tensordict stored { self .tensor_keys .action } requires grad."
533
544
)
534
545
if isinstance (action , torch .Tensor ):
535
- log_prob = dist .log_prob (action )
546
+ log_prob = current_dist .log_prob (action )
536
547
else :
537
- if isinstance (dist , CompositeDistribution ):
538
- is_composite = True
539
- aggregate = dist .aggregate_probabilities
540
- if aggregate is None :
541
- aggregate = False
542
- include_sum = dist .include_sum
543
- if include_sum is None :
544
- include_sum = False
545
- kwargs = {
546
- "inplace" : False ,
547
- "aggregate_probabilities" : aggregate ,
548
- "include_sum" : include_sum ,
549
- }
550
- else :
551
- is_composite = False
552
- kwargs = {}
553
- log_prob : TensorDictBase = dist .log_prob (tensordict , ** kwargs )
548
+ if is_composite is None :
549
+ kwargs , is_composite = _get_composite_kwargs (current_dist )
550
+ log_prob : TensorDictBase = current_dist .log_prob (tensordict , ** kwargs )
554
551
if (
555
552
is_composite
556
553
and not is_tensor_collection (prev_log_prob )
@@ -564,7 +561,7 @@ def _log_weight(
564
561
if is_tensor_collection (kl_approx ):
565
562
kl_approx = _sum_td_features (kl_approx )
566
563
567
- return log_weight , dist , kl_approx
564
+ return log_weight , current_dist , kl_approx
568
565
569
566
def loss_critic (self , tensordict : TensorDictBase ) -> torch .Tensor :
570
567
"""Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``."""
@@ -640,6 +637,9 @@ def _cached_critic_network_params_detached(self):
640
637
@dispatch
641
638
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
642
639
tensordict = tensordict .clone (False )
640
+
641
+ log_weight , dist , kl_approx = self ._log_weight (tensordict )
642
+
643
643
advantage = tensordict .get (self .tensor_keys .advantage , None )
644
644
if advantage is None :
645
645
self .value_estimator (
@@ -653,7 +653,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
653
653
scale = advantage .std ().clamp_min (1e-6 )
654
654
advantage = (advantage - loc ) / scale
655
655
656
- log_weight , dist , kl_approx = self ._log_weight (tensordict )
657
656
if is_tensor_collection (log_weight ):
658
657
log_weight = _sum_td_features (log_weight )
659
658
log_weight = log_weight .view (advantage .shape )
@@ -1295,3 +1294,22 @@ def _make_lp_get_error(tensor_keys, log_prob, err):
1295
1294
return KeyError (result )
1296
1295
result += "This is usually due to a missing call to loss.set_keys(sample_log_prob=<list_of_log_prob_keys>)."
1297
1296
return KeyError (result )
1297
+
1298
+ def _get_composite_kwargs (current_dist ):
1299
+ if isinstance (current_dist , CompositeDistribution ):
1300
+ is_composite = True
1301
+ aggregate = current_dist .aggregate_probabilities
1302
+ if aggregate is None :
1303
+ aggregate = False
1304
+ include_sum = current_dist .include_sum
1305
+ if include_sum is None :
1306
+ include_sum = False
1307
+ kwargs = {
1308
+ "inplace" : False ,
1309
+ "aggregate_probabilities" : aggregate ,
1310
+ "include_sum" : include_sum ,
1311
+ }
1312
+ else :
1313
+ is_composite = False
1314
+ kwargs = {}
1315
+ return kwargs , is_composite
0 commit comments