-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #24 from suraj-nair-tri/robomimic-eval
Robomimic Policy Eval Support
- Loading branch information
Showing
8 changed files
with
524 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
import json | ||
import os | ||
import numpy as np | ||
import torch | ||
from collections import OrderedDict | ||
from copy import deepcopy | ||
|
||
from droid.controllers.oculus_controller import VRPolicy | ||
from droid.evaluation.policy_wrapper import PolicyWrapperRobomimic | ||
from droid.robot_env import RobotEnv | ||
from droid.user_interface.data_collector import DataCollecter | ||
from droid.user_interface.gui import RobotGUI | ||
|
||
import robomimic.utils.file_utils as FileUtils | ||
import robomimic.utils.torch_utils as TorchUtils | ||
import robomimic.utils.tensor_utils as TensorUtils | ||
|
||
import cv2 | ||
|
||
def eval_launcher(variant, run_id, exp_id): | ||
# Get Directory # | ||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
|
||
# Prepare Log Directory # | ||
variant["exp_name"] = os.path.join(variant["exp_name"], "run{0}/id{1}/".format(run_id, exp_id)) | ||
log_dir = os.path.join(dir_path, "../../evaluation_logs", variant["exp_name"]) | ||
|
||
# Set Random Seeds # | ||
torch.manual_seed(variant["seed"]) | ||
np.random.seed(variant["seed"]) | ||
|
||
# Set Compute Mode # | ||
use_gpu = variant.get("use_gpu", False) | ||
torch.device("cuda:0" if use_gpu else "cpu") | ||
|
||
ckpt_path = variant["ckpt_path"] | ||
|
||
device = TorchUtils.get_torch_device(try_to_use_cuda=True) | ||
ckpt_dict = FileUtils.maybe_dict_from_checkpoint(ckpt_path=ckpt_path) | ||
config = json.loads(ckpt_dict["config"]) | ||
|
||
### infer image size ### | ||
for obs_key in ckpt_dict["shape_metadata"]["all_shapes"].keys(): | ||
if 'camera/image' in obs_key: | ||
imsize = max(ckpt_dict["shape_metadata"]["all_shapes"][obs_key]) | ||
break | ||
|
||
ckpt_dict["config"] = json.dumps(config) | ||
policy, _ = FileUtils.policy_from_checkpoint(ckpt_dict=ckpt_dict, device=device, verbose=True) | ||
policy.goal_mode = config["train"]["goal_mode"] | ||
policy.eval_mode = True | ||
|
||
# determine the action space (relative or absolute) | ||
action_keys = config["train"]["action_keys"] | ||
if "action/rel_pos" in action_keys: | ||
action_space = "cartesian_velocity" | ||
for k in action_keys: | ||
assert not k.startswith("action/abs_") | ||
elif "action/abs_pos" in action_keys: | ||
action_space = "cartesian_position" | ||
for k in action_keys: | ||
assert not k.startswith("action/rel_") | ||
else: | ||
raise ValueError | ||
|
||
# determine the action space for the gripper | ||
if "action/gripper_velocity" in action_keys: | ||
gripper_action_space = "velocity" | ||
elif "action/gripper_position" in action_keys: | ||
gripper_action_space = "position" | ||
else: | ||
raise ValueError | ||
|
||
# determine the action space (relative or absolute) | ||
action_keys = config["train"]["action_keys"] | ||
if "action/rel_pos" in action_keys: | ||
action_space = "cartesian_velocity" | ||
for k in action_keys: | ||
assert not k.startswith("action/abs_") | ||
elif "action/abs_pos" in action_keys: | ||
action_space = "cartesian_position" | ||
for k in action_keys: | ||
assert not k.startswith("action/rel_") | ||
else: | ||
raise ValueError | ||
|
||
# determine the action space for the gripper | ||
if "action/gripper_velocity" in action_keys: | ||
gripper_action_space = "velocity" | ||
elif "action/gripper_position" in action_keys: | ||
gripper_action_space = "position" | ||
else: | ||
raise ValueError | ||
|
||
# Prepare Policy Wrapper # | ||
data_processing_kwargs = dict( | ||
timestep_filtering_kwargs=dict( | ||
action_space=action_space, | ||
gripper_action_space=gripper_action_space, | ||
robot_state_keys=["cartesian_position", "gripper_position", "joint_positions"], | ||
# camera_extrinsics=[], | ||
), | ||
image_transform_kwargs=dict( | ||
remove_alpha=True, | ||
bgr_to_rgb=True, | ||
to_tensor=True, | ||
augment=False, | ||
), | ||
) | ||
timestep_filtering_kwargs = data_processing_kwargs.get("timestep_filtering_kwargs", {}) | ||
image_transform_kwargs = data_processing_kwargs.get("image_transform_kwargs", {}) | ||
|
||
policy_data_processing_kwargs = {} | ||
policy_timestep_filtering_kwargs = policy_data_processing_kwargs.get("timestep_filtering_kwargs", {}) | ||
policy_image_transform_kwargs = policy_data_processing_kwargs.get("image_transform_kwargs", {}) | ||
|
||
policy_timestep_filtering_kwargs.update(timestep_filtering_kwargs) | ||
policy_image_transform_kwargs.update(image_transform_kwargs) | ||
|
||
fs = config["train"]["frame_stack"] | ||
|
||
wrapped_policy = PolicyWrapperRobomimic( | ||
policy=policy, | ||
timestep_filtering_kwargs=policy_timestep_filtering_kwargs, | ||
image_transform_kwargs=policy_image_transform_kwargs, | ||
frame_stack=fs, | ||
eval_mode=True, | ||
) | ||
|
||
camera_kwargs = dict( | ||
hand_camera=dict(image=True, concatenate_images=False, resolution=(imsize, imsize), resize_func="cv2"), | ||
varied_camera=dict(image=True, concatenate_images=False, resolution=(imsize, imsize), resize_func="cv2"), | ||
) | ||
|
||
policy_camera_kwargs = {} | ||
policy_camera_kwargs.update(camera_kwargs) | ||
|
||
env = RobotEnv( | ||
action_space=policy_timestep_filtering_kwargs["action_space"], | ||
gripper_action_space=policy_timestep_filtering_kwargs["gripper_action_space"], | ||
camera_kwargs=policy_camera_kwargs | ||
) | ||
controller = VRPolicy() | ||
|
||
# Launch GUI # | ||
data_collector = DataCollecter( | ||
env=env, | ||
controller=controller, | ||
policy=wrapped_policy, | ||
save_traj_dir=log_dir, | ||
save_data=variant.get("save_data", True), | ||
) | ||
RobotGUI(robot=data_collector) | ||
|
||
|
||
def get_goal_im(variant, run_id, exp_id): | ||
# Get Directory # | ||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
|
||
# Prepare Log Directory # | ||
variant["exp_name"] = os.path.join(variant["exp_name"], "run{0}/id{1}/".format(run_id, exp_id)) | ||
log_dir = os.path.join(dir_path, "../../evaluation_logs", variant["exp_name"]) | ||
|
||
# Set Random Seeds # | ||
torch.manual_seed(variant["seed"]) | ||
np.random.seed(variant["seed"]) | ||
|
||
# Set Compute Mode # | ||
use_gpu = variant.get("use_gpu", False) | ||
torch.device("cuda:0" if use_gpu else "cpu") | ||
|
||
ckpt_path = variant["ckpt_path"] | ||
|
||
device = TorchUtils.get_torch_device(try_to_use_cuda=True) | ||
ckpt_dict = FileUtils.maybe_dict_from_checkpoint(ckpt_path=ckpt_path) | ||
config = json.loads(ckpt_dict["config"]) | ||
|
||
### infer image size ### | ||
imsize = max(ckpt_dict["shape_metadata"]["all_shapes"]["camera/image/hand_camera_left_image"]) | ||
|
||
ckpt_dict["config"] = json.dumps(config) | ||
policy, _ = FileUtils.policy_from_checkpoint(ckpt_dict=ckpt_dict, device=device, verbose=True) | ||
|
||
# determine the action space (relative or absolute) | ||
action_keys = config["train"]["action_keys"] | ||
if "action/rel_pos" in action_keys: | ||
action_space = "cartesian_velocity" | ||
for k in action_keys: | ||
assert not k.startswith("action/abs_") | ||
elif "action/abs_pos" in action_keys: | ||
action_space = "cartesian_position" | ||
for k in action_keys: | ||
assert not k.startswith("action/rel_") | ||
else: | ||
raise ValueError | ||
|
||
# determine the action space for the gripper | ||
if "action/gripper_velocity" in action_keys: | ||
gripper_action_space = "velocity" | ||
elif "action/gripper_position" in action_keys: | ||
gripper_action_space = "position" | ||
else: | ||
raise ValueError | ||
|
||
# Prepare Policy Wrapper # | ||
data_processing_kwargs = dict( | ||
timestep_filtering_kwargs=dict( | ||
action_space=action_space, | ||
gripper_action_space=gripper_action_space, | ||
robot_state_keys=["cartesian_position", "gripper_position", "joint_positions"], | ||
# camera_extrinsics=[], | ||
), | ||
image_transform_kwargs=dict( | ||
remove_alpha=True, | ||
bgr_to_rgb=True, | ||
to_tensor=True, | ||
augment=False, | ||
), | ||
) | ||
timestep_filtering_kwargs = data_processing_kwargs.get("timestep_filtering_kwargs", {}) | ||
image_transform_kwargs = data_processing_kwargs.get("image_transform_kwargs", {}) | ||
|
||
policy_data_processing_kwargs = {} | ||
policy_timestep_filtering_kwargs = policy_data_processing_kwargs.get("timestep_filtering_kwargs", {}) | ||
policy_image_transform_kwargs = policy_data_processing_kwargs.get("image_transform_kwargs", {}) | ||
|
||
policy_timestep_filtering_kwargs.update(timestep_filtering_kwargs) | ||
policy_image_transform_kwargs.update(image_transform_kwargs) | ||
|
||
wrapped_policy = PolicyWrapperRobomimic( | ||
policy=policy, | ||
timestep_filtering_kwargs=policy_timestep_filtering_kwargs, | ||
image_transform_kwargs=policy_image_transform_kwargs, | ||
frame_stack=config["train"]["frame_stack"], | ||
eval_mode=True, | ||
) | ||
|
||
camera_kwargs = dict( | ||
hand_camera=dict(image=True, concatenate_images=False, resolution=(imsize, imsize), resize_func="cv2"), | ||
varied_camera=dict(image=True, concatenate_images=False, resolution=(imsize, imsize), resize_func="cv2"), | ||
) | ||
|
||
policy_camera_kwargs = {} | ||
policy_camera_kwargs.update(camera_kwargs) | ||
|
||
env = RobotEnv( | ||
action_space=policy_timestep_filtering_kwargs["action_space"], | ||
gripper_action_space=policy_timestep_filtering_kwargs["gripper_action_space"], | ||
camera_kwargs=policy_camera_kwargs, | ||
do_reset=False | ||
) | ||
|
||
ims = env.read_cameras()[0]["image"] | ||
if not os.path.exists('eval_params'): | ||
os.makedirs('eval_params') | ||
for k in ims.keys(): | ||
image = ims[k] | ||
cv2.imwrite(f'eval_params/{k}.png', image[:, :, :3]) | ||
return ims |
Oops, something went wrong.