Skip to content

Commit c2a622b

Browse files
committed
[Refactor] Use default device instead of CPU in losses
ghstack-source-id: 6139ede Pull Request resolved: #2687
1 parent 6ca4e2c commit c2a622b

File tree

7 files changed

+8
-8
lines changed

7 files changed

+8
-8
lines changed

torchrl/objectives/cql.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def __init__(
323323
try:
324324
device = next(self.parameters()).device
325325
except AttributeError:
326-
device = torch.device("cpu")
326+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
327327
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
328328
if bool(min_alpha) ^ bool(max_alpha):
329329
min_alpha = min_alpha if min_alpha else 0.0

torchrl/objectives/crossq.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def __init__(
306306
try:
307307
device = next(self.parameters()).device
308308
except AttributeError:
309-
device = torch.device("cpu")
309+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
310310
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
311311
if bool(min_alpha) ^ bool(max_alpha):
312312
min_alpha = min_alpha if min_alpha else 0.0

torchrl/objectives/decision_transformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(
103103
try:
104104
device = next(self.parameters()).device
105105
except AttributeError:
106-
device = torch.device("cpu")
106+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
107107

108108
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
109109
if bool(min_alpha) ^ bool(max_alpha):

torchrl/objectives/deprecated.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def __init__(
203203
try:
204204
device = next(self.parameters()).device
205205
except AttributeError:
206-
device = torch.device("cpu")
206+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
207207

208208
self.register_buffer("alpha_init", torch.as_tensor(alpha_init, device=device))
209209
self.register_buffer(

torchrl/objectives/ppo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def __init__(
387387
try:
388388
device = next(self.parameters()).device
389389
except (AttributeError, StopIteration):
390-
device = torch.device("cpu")
390+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
391391

392392
self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
393393
if critic_coef is not None:

torchrl/objectives/redq.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def __init__(
319319
try:
320320
device = next(self.parameters()).device
321321
except AttributeError:
322-
device = torch.device("cpu")
322+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
323323

324324
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
325325
self.register_buffer(

torchrl/objectives/sac.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def __init__(
393393
try:
394394
device = next(self.parameters()).device
395395
except AttributeError:
396-
device = torch.device("cpu")
396+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
397397
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
398398
if bool(min_alpha) ^ bool(max_alpha):
399399
min_alpha = min_alpha if min_alpha else 0.0
@@ -1119,7 +1119,7 @@ def __init__(
11191119
try:
11201120
device = next(self.parameters()).device
11211121
except AttributeError:
1122-
device = torch.device("cpu")
1122+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
11231123

11241124
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
11251125
if bool(min_alpha) ^ bool(max_alpha):

0 commit comments

Comments
 (0)