Skip to content

Commit ef90882

Browse files
committed
Update
[ghstack-poisoned]
2 parents 5a174ec + 29819f6 commit ef90882

File tree

12 files changed

+316
-111
lines changed

12 files changed

+316
-111
lines changed

test/test_actors.py

+51-2
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
import pytest
99
import torch
1010

11-
from tensordict import TensorDict
11+
from tensordict import NonTensorStack, TensorDict
1212
from tensordict.nn import CompositeDistribution, TensorDictModule
1313
from tensordict.nn.distributions import NormalParamExtractor
1414

1515
from torch import distributions as dist, nn
1616
from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
1717
from torchrl.data.llm.dataset import _has_transformers
18-
from torchrl.modules import MLP, SafeModule, TanhDelta, TanhNormal
18+
from torchrl.modules import from_hf_transformers, MLP, SafeModule, TanhDelta, TanhNormal
1919
from torchrl.modules.tensordict_module.actors import (
2020
_process_action_space_spec,
2121
ActorValueOperator,
@@ -907,6 +907,55 @@ def test_lmhead_actorvalueoperator(device):
907907
) == len(policy_params)
908908

909909

910+
@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies")
911+
class TestTransformerActor:
912+
@pytest.mark.parametrize(
913+
"from_text, generate, tokens, attention_mask",
914+
[
915+
(True, True, None, None),
916+
(True, False, None, None),
917+
(
918+
False,
919+
True,
920+
torch.randint(1024, (1, 10)),
921+
torch.ones(1, 10, dtype=torch.int64),
922+
),
923+
(False, True, torch.randint(1024, (1, 10)), None),
924+
],
925+
)
926+
def test_from_hf_transformers(self, from_text, generate, tokens, attention_mask):
927+
from torchrl.data.llm import LLMData
928+
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel
929+
930+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
931+
tokenizer.pad_token = tokenizer.eos_token
932+
model = GPT2LMHeadModel(GPT2Config())
933+
tokenizer.padding_side = "left"
934+
m = from_hf_transformers(
935+
model, tokenizer=tokenizer, from_text=from_text, generate=generate
936+
)
937+
if from_text:
938+
tdin = LLMData(text=NonTensorStack("a text"), batch_size=1)
939+
else:
940+
tdin = LLMData(tokens=tokens, attention_mask=attention_mask, batch_size=1)
941+
td = m(tdin)
942+
assert td is tdin
943+
assert isinstance(td, LLMData)
944+
if from_text and generate:
945+
assert td.text_response is not None
946+
else:
947+
assert td.text_response is None
948+
if attention_mask is not None or from_text:
949+
assert td.attention_mask is not None
950+
else:
951+
assert td.attention_mask is None
952+
if not generate:
953+
assert td.text_response is None
954+
assert td.tokens_response is None
955+
assert td.log_probs is not None
956+
assert td.logits is not None
957+
958+
910959
if __name__ == "__main__":
911960
args, unknown = argparse.ArgumentParser().parse_known_args()
912961
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/test_env.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -4699,13 +4699,13 @@ def test_llm_from_dataloader(
46994699
def policy(td):
47004700
if str2str:
47014701
if not td.shape:
4702-
td[LLMEnv._DEFAULT_ACTION_KEY] = "<nothing>"
4702+
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = "<nothing>"
47034703
else:
4704-
td[LLMEnv._DEFAULT_ACTION_KEY] = NonTensorStack(
4704+
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack(
47054705
*["<nothing>" for _ in range(td.shape[0])]
47064706
)
47074707
else:
4708-
td[LLMEnv._DEFAULT_ACTION_KEY] = torch.ones(
4708+
td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones(
47094709
td.shape + (1,), dtype=torch.int64
47104710
)
47114711
return td
@@ -4720,25 +4720,25 @@ def policy(td):
47204720
assert (
47214721
r[0, 0][LLMEnv._DEFAULT_STR_KEY]
47224722
== r[0, 1][LLMEnv._DEFAULT_STR_KEY][
4723-
: -len(r[0, 0][LLMEnv._DEFAULT_ACTION_KEY])
4723+
: -len(r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
47244724
]
47254725
)
47264726
assert (
47274727
r[0, 1][LLMEnv._DEFAULT_STR_KEY]
47284728
== r[0, 2][LLMEnv._DEFAULT_STR_KEY][
4729-
: -len(r[0, 1][LLMEnv._DEFAULT_ACTION_KEY])
4729+
: -len(r[0, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
47304730
]
47314731
)
47324732
assert (
47334733
r[-1, 0][LLMEnv._DEFAULT_STR_KEY]
47344734
== r[-1, 1][LLMEnv._DEFAULT_STR_KEY][
4735-
: -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_KEY])
4735+
: -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
47364736
]
47374737
)
47384738
assert (
47394739
r[-1, 1][LLMEnv._DEFAULT_STR_KEY]
47404740
== r[-1, 2][LLMEnv._DEFAULT_STR_KEY][
4741-
: -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_KEY])
4741+
: -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
47424742
]
47434743
)
47444744
else:
@@ -4815,13 +4815,13 @@ def test_llm_from_dataloader_repeats(
48154815
def policy(td):
48164816
if str2str:
48174817
if not td.shape:
4818-
td[LLMEnv._DEFAULT_ACTION_KEY] = "<nothing>"
4818+
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = "<nothing>"
48194819
else:
4820-
td[LLMEnv._DEFAULT_ACTION_KEY] = NonTensorStack(
4820+
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack(
48214821
*["<nothing>" for _ in range(td.shape[0])]
48224822
)
48234823
else:
4824-
td[LLMEnv._DEFAULT_ACTION_KEY] = torch.ones(
4824+
td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones(
48254825
td.shape + (1,), dtype=torch.int64
48264826
)
48274827
return td
@@ -4957,7 +4957,7 @@ def test_done_and_reward(
49574957
env.append_transform(StepCounter(max_steps=10))
49584958

49594959
def policy(td):
4960-
td[LLMEnv._DEFAULT_ACTION_KEY] = torch.ones(
4960+
td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones(
49614961
td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64
49624962
)
49634963
return td
@@ -4974,7 +4974,6 @@ def policy(td):
49744974
if assign_done:
49754975
assert "terminated" in r
49764976
assert "done" in r
4977-
print(r)
49784977

49794978

49804979
if __name__ == "__main__":

torchrl/data/llm/__init__.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,14 @@
1111
)
1212
from .prompt import PromptData, PromptTensorDictTokenizer
1313
from .reward import PairwiseDataset, RewardData
14-
from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel, LLMData, LLMOutput, LLMInput
14+
from .utils import (
15+
AdaptiveKLController,
16+
ConstantKLController,
17+
LLMData,
18+
LLMInput,
19+
LLMOutput,
20+
RolloutFromModel,
21+
)
1522

1623
__all__ = [
1724
"AdaptiveKLController",

torchrl/data/llm/utils.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -543,8 +543,10 @@ def step_scheduler(self):
543543
while len(self._kl_queue):
544544
self._kl_queue.remove(self._kl_queue[0])
545545

546+
546547
LLMInpOut = TypeVar("LLMInpOut")
547548

549+
548550
class LLMInput(TensorClass["nocast"]):
549551
"""Represents the input to a Large Language Model (LLM).
550552
@@ -557,11 +559,13 @@ class LLMInput(TensorClass["nocast"]):
557559
.. seealso:: :class:`~torchrl.data.LLMOutput` and :class:`~torchrl.data.LLMData`.
558560
559561
"""
562+
560563
tokens: torch.Tensor
561564
attention_mask: torch.Tensor | None = None
562565
token_list: list[int] | list[list[int]] | None = None
563566
text: str | list[str] | None = None
564567

568+
565569
class LLMOutput(TensorClass["nocast"]):
566570
"""Represents the output from a Large Language Model (LLM).
567571
@@ -581,6 +585,7 @@ class LLMOutput(TensorClass["nocast"]):
581585
.. seealso:: :class:`~torchrl.data.LLMInput` and :class:`~torchrl.data.LLMData`.
582586
583587
"""
588+
584589
tokens: torch.Tensor
585590
tokens_response: torch.Tensor | None = None
586591
token_list: list[int] | list[list[int]] | None = None
@@ -594,6 +599,7 @@ def from_vllm_output(cls: type[LLMInpOut], vllm_output) -> LLMInpOut:
594599
# placeholder
595600
raise NotImplementedError
596601

602+
597603
class LLMData(TensorClass["nocast"]):
598604
"""Represents the input or output of a Large Language Model (LLM).
599605
@@ -619,11 +625,13 @@ class LLMData(TensorClass["nocast"]):
619625
.. seealso:: :class:`~torchrl.data.LLMInput` and :class:`~torchrl.data.LLMOutput`.
620626
621627
"""
622-
tokens: torch.Tensor
628+
629+
tokens: torch.Tensor | None = None
623630
tokens_response: torch.Tensor | None = None
624631
attention_mask: torch.Tensor | None = None
625632
token_list: list[int] | list[list[int]] | None = None
626633
tokens_response_list: list[list[int]] | None = None
627634
logits: torch.Tensor | None = None
628635
log_probs: torch.Tensor | None = None
629636
text: str | list[str] | None = None
637+
text_response: torch.Tensor | None = None

torchrl/envs/custom/llm.py

+33-15
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ class LLMEnv(EnvBase):
4242
4343
Keyword Args:
4444
token_key (NestedKey, optional): The key in the tensordict where the tokens are stored (when `str2str=False`).
45-
Defaults to ``("tokens_in", "input_ids")``.
45+
Defaults to ``"tokens"``.
4646
str_key (NestedKey, optional): The key in the tensordict where the string input is stored (when `str2str=True`).
47-
Defaults to ``"test"``.
47+
Defaults to ``"text"``.
4848
attention_key (NestedKey, optional): The key in the tensordict where the attention mask is stored.
49-
Defaults to ``("tokens_in", "input_ids")``
49+
Defaults to ``"attention_mask"``.
5050
action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to
51-
``("tokens_out", "sequences")``.
51+
``tokens_response`` or ``"text_response"``.
5252
reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`.
5353
Defaults to ``"reward"``.
5454
str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``False``.
@@ -71,6 +71,8 @@ class LLMEnv(EnvBase):
7171
batch_size (int or torch.Size, optional): Batch size of the environment. If left empty, the environment
7272
is batchless (or batch-unlocked), meaning that it can accept tensordicts of any batch size.
7373
Defaults to ``None`` (batch-unlocked).
74+
as_llm_data (bool, optional): If ``True``, the data will be of type :class:`~torchrl.data.LLMData`.
75+
Defaults to ``False``.
7476
7577
.. seealso:: :class:`~torchrl.envs.DataLoadingPrimer` for examples.
7678
@@ -79,10 +81,11 @@ class LLMEnv(EnvBase):
7981
8082
"""
8183

82-
_DEFAULT_TOKEN_KEY = ("tokens_in", "input_ids")
84+
_DEFAULT_TOKEN_KEY = "tokens"
8385
_DEFAULT_STR_KEY = "text"
84-
_DEFAULT_ATTENTION_KEY = ("tokens_in", "attention_mask")
85-
_DEFAULT_ACTION_KEY = ("tokens_out", "sequences")
86+
_DEFAULT_ATTENTION_KEY = "attention_mask"
87+
_DEFAULT_ACTION_TOKENS_KEY = "tokens_response"
88+
_DEFAULT_ACTION_STR_KEY = "text_response"
8689

8790
def __init__(
8891
self,
@@ -100,15 +103,20 @@ def __init__(
100103
assign_done: bool = False,
101104
batch_size: int | torch.Size | None = None,
102105
has_attention: bool = True,
106+
as_llm_data: bool = False,
103107
) -> None:
108+
self.as_llm_data = as_llm_data
104109
if token_key is None:
105110
token_key = self._DEFAULT_TOKEN_KEY
106111
if str_key is None:
107112
str_key = self._DEFAULT_STR_KEY
108113
if attention_key is None:
109114
attention_key = self._DEFAULT_ATTENTION_KEY
110115
if action_key is None:
111-
action_key = self._DEFAULT_ACTION_KEY
116+
if str2str:
117+
action_key = self._DEFAULT_ACTION_STR_KEY
118+
else:
119+
action_key = self._DEFAULT_ACTION_TOKENS_KEY
112120
if batch_size is None:
113121
self._batch_locked = False
114122
batch_size = ()
@@ -206,7 +214,7 @@ def __init__(
206214
else:
207215
# Use single done
208216
self.full_done_spec_unbatched = Composite(
209-
tokens=Composite(
217+
tokens_data=Composite(
210218
done=Unbounded(shape=(-1,), dtype=torch.bool),
211219
terminated=Unbounded(shape=(-1,), dtype=torch.bool),
212220
),
@@ -228,6 +236,7 @@ def from_dataloader(
228236
device: torch.device | None = None,
229237
vocab_size: int | None = None,
230238
no_stack: bool = False,
239+
as_llm_data: bool = False,
231240
batch_size: int | torch.Size | None = None,
232241
has_attention: bool = True,
233242
assign_reward: bool = False,
@@ -288,6 +297,8 @@ def from_dataloader(
288297
repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in
289298
situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo
290299
samples (rather than an advantage module).
300+
as_llm_data (bool, optional): If ``True``, the data will be of type :class:`~torchrl.data.LLMData`.
301+
Defaults to ``False``.
291302
292303
Returns:
293304
LLMEnv: The created LLMEnv instance.
@@ -334,6 +345,7 @@ def from_dataloader(
334345
assign_done=assign_done,
335346
batch_size=batch_size,
336347
has_attention=has_attention,
348+
as_llm_data=as_llm_data,
337349
)
338350
return env.append_transform(primer)
339351

@@ -353,6 +365,8 @@ def _step(
353365
self._make_next_obs(tensordict, next_td)
354366
self._maybe_make_reward(tensordict, next_td)
355367
self._maybe_make_done(tensordict, next_td)
368+
if self.as_llm_data:
369+
raise NotImplementedError()
356370
return next_td
357371

358372
def _maybe_make_reward(
@@ -378,14 +392,14 @@ def _maybe_make_done(
378392
)
379393
else:
380394
done = torch.zeros_like(action, dtype=torch.bool)
381-
next_td.set(("tokens", "terminated"), done)
382-
next_td.set(("tokens", "done"), done.clone())
395+
next_td.set(("tokens_data", "terminated"), done)
396+
next_td.set(("tokens_data", "done"), done.clone())
383397
next_td.set(
384-
"terminated", next_td.get(("tokens", "done")).any(-1, keepdim=True)
398+
"terminated", next_td.get(("tokens_data", "done")).any(-1, keepdim=True)
385399
)
386400
next_td.set(
387401
"terminated",
388-
next_td.get(("tokens", "terminated")).any(-1, keepdim=True),
402+
next_td.get(("tokens_data", "terminated")).any(-1, keepdim=True),
389403
)
390404
return next_td
391405

@@ -400,7 +414,8 @@ def _make_next_obs(
400414
if self.has_attention:
401415
attention_mask = tensordict.get(self.attention_key)
402416
n = action.shape[-1] - attention_mask.shape[-1]
403-
if n:
417+
if n > 0:
418+
# It can happen that there's only one action (eg rand_action)
404419
attention_mask = torch.cat(
405420
[
406421
attention_mask,
@@ -471,7 +486,10 @@ def check_str():
471486
f"torchrl.envs.DataLoadingPrimer) is appended to the env transforms."
472487
)
473488
td_reset = tensordict.copy()
474-
return self._maybe_make_done(tensordict, td_reset)
489+
tensordict = self._maybe_make_done(tensordict, td_reset)
490+
if self.as_llm_data:
491+
raise NotImplementedError()
492+
return tensordict
475493

476494
def _set_seed(self, seed: int | None):
477495
return seed

torchrl/envs/transforms/llm.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,6 @@ def __init__(
417417
else:
418418
self.data_keys = list(primers.keys(True, True))
419419

420-
self._reset_key = "_reset"
421-
422420
super().__init__(
423421
primers=primers,
424422
default_value=self._load_from_dataloader,
@@ -427,6 +425,7 @@ def __init__(
427425
single_default_value=True,
428426
call_before_env_reset=True,
429427
)
428+
self._reset_key = "_reset"
430429

431430
@classmethod
432431
def _endless_iter(self, obj):

0 commit comments

Comments
 (0)