Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 14, 2024
2 parents 957ec6e + 7d7cd95 commit 21467c9
Showing 1 changed file with 21 additions and 22 deletions.
43 changes: 21 additions & 22 deletions torchrl/modules/models/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dataclasses

import importlib
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any

Expand Down Expand Up @@ -92,9 +93,6 @@ def __init__(
config: dict | DTConfig = None,
device: torch.device | None = None,
):
if device is not None:
with torch.device(device):
return self.__init__(state_dim, action_dim, config)

if not _has_transformers:
raise ImportError(
Expand All @@ -117,28 +115,29 @@ def __init__(

super(DecisionTransformer, self).__init__()

gpt_config = transformers.GPT2Config(
n_embd=config["n_embd"],
n_layer=config["n_layer"],
n_head=config["n_head"],
n_inner=config["n_inner"],
activation_function=config["activation"],
n_positions=config["n_positions"],
resid_pdrop=config["resid_pdrop"],
attn_pdrop=config["attn_pdrop"],
vocab_size=1,
)
self.state_dim = state_dim
self.action_dim = action_dim
self.hidden_size = config["n_embd"]
with torch.device(device) if device is not None else nullcontext():
gpt_config = transformers.GPT2Config(
n_embd=config["n_embd"],
n_layer=config["n_layer"],
n_head=config["n_head"],
n_inner=config["n_inner"],
activation_function=config["activation"],
n_positions=config["n_positions"],
resid_pdrop=config["resid_pdrop"],
attn_pdrop=config["attn_pdrop"],
vocab_size=1,
)
self.state_dim = state_dim
self.action_dim = action_dim
self.hidden_size = config["n_embd"]

self.transformer = GPT2Model(config=gpt_config)
self.transformer = GPT2Model(config=gpt_config)

self.embed_return = torch.nn.Linear(1, self.hidden_size)
self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size)
self.embed_action = torch.nn.Linear(self.action_dim, self.hidden_size)
self.embed_return = torch.nn.Linear(1, self.hidden_size)
self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size)
self.embed_action = torch.nn.Linear(self.action_dim, self.hidden_size)

self.embed_ln = nn.LayerNorm(self.hidden_size)
self.embed_ln = nn.LayerNorm(self.hidden_size)

def forward(
self,
Expand Down

0 comments on commit 21467c9

Please sign in to comment.