-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_collection_ml1.py
325 lines (299 loc) · 13.7 KB
/
data_collection_ml1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
"""
Training behavior policies for FOCAL
"""
import click
import json
import os
import gym, gym.wrappers
from hydra.experimental import compose, initialize
import argparse
import multiprocessing as mp
from multiprocessing import Pool
from itertools import product
from rlkit.envs.wrappers import NormalizedBoxEnv
from rlkit.envs.metaworld_wrapper import MetaWorldWrapper
from rlkit.envs import ENVS
from configs.default import default_config
import metaworld,random
import numpy as np
import metaworld.policies as p
import copy
def deep_update_dict(fr, to):
''' update dict of dicts with new values '''
# assume dicts have same keys
for k, v in fr.items():
if type(v) is dict:
deep_update_dict(v, to[k])
else:
to[k] = v
return to
initialize(config_dir="rlkit/torch/sac/pytorch_sac/config/")
cfg = compose("train.yaml")
def experiment(variant, cfg=cfg, goal_idx=0, seed=0, eval=False):
os.makedirs('./data/'+variant['env_name']+'/goal_idx%d'%goal_idx,exist_ok=True)
ml1 = metaworld.MT1(variant['env_name'],seed=1337) # Construct the benchmark, sampling tasks
env = ml1.train_classes[variant['env_name']]() # Create an environment with task
# print(ml1.train_tasks)
env.train_tasks = ml1.train_tasks
task = random.choice(ml1.train_tasks)
task = ml1.train_tasks[goal_idx]
env.set_task(task)
# tasks = list(range(len(env.train_tasks)))
# env=gym.wrappers.TimeLimit(gym.wrappers.ClipAction(MetaWorldWrapper(env)),500)
#
# env.reset_task(goal_idx)
#['handle-pull-side-v2', 'handle-pull-v2', 'lever-pull-v2', 'peg-insert-side-v2', 'pick-place-wall-v2', 'pick-out-of-hole-v2', 'reach-v2', 'push-back-v2', 'push-v2', 'pick-place-v2', 'plate-slide-v2', 'plate-slide-side-v2', 'plate-slide-back-v2', 'plate-slide-back-side-v2', 'peg-unplug-side-v2', 'soccer-v2', 'stick-push-v2', 'stick-pull-v2', 'push-wall-v2', 'reach-wall-v2', 'shelf-place-v2', 'sweep-into-v2', 'sweep-v2', 'window-open-v2', 'window-close-v2']
# ppwall use other policy
# if variant['env_name']=='push-v2':
# policy = p.SawyerPushV2Policy
# elif variant['env_name']=='reach-v2':
# policy = p.SawyerReachV2Policy
# elif variant['env_name']=='pick-place-v2':
# policy = p.SawyerPickPlaceV2Policy
# elif variant['env_name']=='push-wall-v2':
# policy = p.SawyerPushWallV2Policy
# elif variant['env_name']=='pick-place-wall-v2':
# policy = p.SawyerPickPlaceV2Policy
# elif variant['env_name']=='window-open-v2':
# policy = p.SawyerWindowOpenV2Policy
# elif variant['env_name']=='drawer-close-v2':
# policy = p.SawyerDrawerCloseV2Policy
# elif variant['env_name']=='handle-pull-side-v2':
# policy = p.SawyerHandlePullSideV2Policy
# elif variant['env_name']=='handle-pull-v2':#
# policy = p.SawyerHandlePullV2Policy
# elif variant['env_name']=='lever-pull-v2':
# policy = p.SawyerLeverPullV2Policy
# elif variant['env_name']=='peg-insert-side-v2':
# policy = p.SawyerPegInsertionSideV2Policy
# elif variant['env_name']=='pick-place-wall-v2':
# policy = p.SawyerPickPlaceWallV2Policy
# elif variant['env_name']=='pick-out-of-hole-v2':
# policy = p.SawyerPickOutOfHoleV2Policy
# elif variant['env_name']=='push-back-v2':
# policy = p.SawyerPushBackV2Policy # bad data collection
# elif variant['env_name']=='plate-slide-v2':
# policy = p.SawyerPlateSlideV2Policy
# elif variant['env_name']=='plate-slide-side-v2':
# policy = p.SawyerPlateSlideSideV2Policy
# elif variant['env_name']=='plate-slide-back-v2':
# policy = p.SawyerPlateSlideBackV2Policy # bad data collection
# elif variant['env_name']=='plate-slide-back-side-v2':
# policy = p.SawyerPlateSlideBackSideV2Policy # bad data collection
# elif variant['env_name']=='peg-unplug-side-v2':
# policy = p.SawyerPegUnplugSideV2Policy# bad data collection
# elif variant['env_name']=='soccer-v2':
# policy = p.SawyerSoccerV2Policy
# elif variant['env_name']=='stick-push-v2':# bad data collection bad 2
# policy = p.SawyerStickPushV2Policy
# elif variant['env_name']=='stick-pull-v2':# bad data collection bad 2 button press 2000 box close 600 coffee push 29 disassemble 220
# policy = p.SawyerStickPullV2Policy
# elif variant['env_name']=='push-wall-v2':
# policy = p.SawyerPushWallV2Policy
# elif variant['env_name']=='reach-wall-v2':
# policy = p.SawyerReachWallV2Policy
# elif variant['env_name']=='shelf-place-v2':
# policy = p.SawyerShelfPlaceV2Policy
# elif variant['env_name']=='sweep-into-v2':# bad data collection
# policy = p.SawyerSweepIntoV2Policy
# elif variant['env_name']=='sweep-v2':
# policy = p.SawyerSweepV2Policy
# elif variant['env_name']=='window-close-v2':
# policy = p.SawyerWindowCloseV2Policy
# else:
# NotImplementedError
if variant['env_name']=='push-v2':
policy = p.SawyerPushV2Policy
elif variant['env_name']== 'basketball-v2':
policy = p.SawyerBasketballV2Policy
elif variant['env_name']=='push-wall-v2':
policy = p.SawyerPushWallV2Policy
elif variant['env_name']=='drawer-close-v2':
policy = p.SawyerDrawerCloseV2Policy
elif variant['env_name']=='handle-pull-side-v2':
policy = p.SawyerHandlePullSideV2Policy
elif variant['env_name']=='sweep-v2':
policy = p.SawyerSweepV2Policy
elif variant['env_name']=='coffee-push-v2':
policy = p.SawyerCoffeePushV2Policy
elif variant['env_name']=='handle-pull-v2':#
policy = p.SawyerHandlePullV2Policy
elif variant['env_name']=='lever-pull-v2':
policy = p.SawyerLeverPullV2Policy
elif variant['env_name']=='hammer-v2':
policy = p.SawyerHammerV2Policy
elif variant['env_name']=='assembly-v2':
policy = p.SawyerAssemblyV2Policy
elif variant['env_name']=='bin-picking-v2':
policy = p.SawyerBinPickingV2Policy
elif variant['env_name']=='box-close-v2':
policy = p.SawyerBoxCloseV2Policy
elif variant['env_name']=='button-press-topdown-v2':
policy = p.SawyerButtonPressTopdownV2Policy
elif variant['env_name']== 'button-press-topdown-wall-v2':
policy = p.SawyerButtonPressTopdownWallV2Policy
elif variant['env_name']=='button-press-v2':
policy = p.SawyerButtonPressV2Policy
elif variant['env_name']=='button-press-wall-v2':
policy = p.SawyerButtonPressWallV2Policy
elif variant['env_name']=='coffee-button-v2':
policy = p.SawyerCoffeeButtonV2Policy
elif variant['env_name']=='coffee-pull-v2':
policy = p.SawyerCoffeePullV2Policy
elif variant['env_name']=='coffeepush-v2':
policy = p.SawyerCoffeePushV2Policy
elif variant['env_name']=='dial-turn-v2':
policy = p.SawyerDialTurnV2Policy
elif variant['env_name']=='disassemble-v2':
policy = p.SawyerDisassembleV2Policy
elif variant['env_name']=='door-close-v2':
policy = p.SawyerDoorCloseV2Policy
elif variant['env_name']=='door-lock-v2':
policy = p.SawyerDoorLockV2Policy
elif variant['env_name']=='door-open-v2':
policy = p.SawyerDoorOpenV2Policy
elif variant['env_name']=='door-unlock-v2':
policy = p.SawyerDoorUnlockV2Policy
elif variant['env_name']=='hand-insert-v2':
policy = p.SawyerHandInsertV2Policy
elif variant['env_name']=='drawer-close-v2':
policy = p.SawyerDrawerCloseV2Policy
elif variant['env_name']=='drawer-open-v2':
policy = p.SawyerDrawerOpenV2Policy
elif variant['env_name']=='faucet-close-v2':
policy = p.SawyerFaucetCloseV2Policy
elif variant['env_name']=='faucet-open-v2':
policy = p.SawyerFaucetOpenV2Policy
elif variant['env_name']=='pick-place-wall-v2':
policy = p.SawyerPickPlaceV2Policy
elif variant['env_name']=='hammer-v2':
policy = p.SawyerHammerV2Policy
elif variant['env_name']=='handle-press-side-v2':
policy = p.SawyerHandlePressSideV2Policy
elif variant['env_name']=='handle-press-v2':
policy = p.SawyerHandlePressV2Policy
elif variant['env_name']=='handle-pull-side-v2':
policy = p.SawyerHandlePullSideV2Policy
elif variant['env_name']=='handle-pull-v2':
policy = p.SawyerHandlePullV2Policy
elif variant['env_name']=='peg-insert-side-v2':
policy = p.SawyerPegInsertionSideV2Policy
elif variant['env_name']=='pick-place-wall-v2':
policy = p.SawyerPickPlaceWallV2Policy
elif variant['env_name']=='pick-out-of-hole-v2':
policy = p.SawyerPickOutOfHoleV2Policy
elif variant['env_name']=='reach-v2':
policy = p.SawyerReachV2Policy
elif variant['env_name']=='pick-place-v2':
policy = p.SawyerPickPlaceV2Policy
elif variant['env_name']=='plate-slide-v2':
policy = p.SawyerPlateSlideV2Policy
elif variant['env_name']=='plate-slide-side-v2':
policy = p.SawyerPlateSlideSideV2Policy
elif variant['env_name']=='plate-slide-back-v2':
policy = p.SawyerPlateSlideBackV2Policy # bad data collection
elif variant['env_name']=='plate-slide-back-side-v2':
policy = p.SawyerPlateSlideBackSideV2Policy # bad data collection
elif variant['env_name']=='peg-unplug-side-v2':
policy = p.SawyerPegUnplugSideV2Policy# bad data collection
elif variant['env_name']=='soccer-v2':
policy = p.SawyerSoccerV2Policy
elif variant['env_name']=='stick-push-v2':# bad data collection bad 2
policy = p.SawyerStickPushV2Policy
elif variant['env_name']=='stick-pull-v2':# bad data collection bad 2 button press 2000 box close 600 coffee push 29 disassemble 220
policy = p.SawyerStickPullV2Policy
elif variant['env_name']=='push-wall-v2':
policy = p.SawyerPushWallV2Policy
elif variant['env_name']=='reach-wall-v2':
policy = p.SawyerReachWallV2Policy
elif variant['env_name']=='shelf-place-v2':
policy = p.SawyerShelfPlaceV2Policy
elif variant['env_name']=='sweep-into-v2':# bad data collection
policy = p.SawyerSweepIntoV2Policy
elif variant['env_name']=='window-close-v2':
policy = p.SawyerWindowCloseV2Policy
elif variant['env_name']=='window-open-v2':
policy = p.SawyerWindowOpenV2Policy
elif variant['env_name']=='push-back-v2':
policy = p.SawyerPushBackV2Policy # bad data collection
# elif variant['env_name']=='soccer-v2':
# policy = p.SawyerHammerV2Policy
else:
NotImplementedError
# from stable_baselines3 import SAC
# model = SAC("MlpPolicy", env, verbose=1, tensorboard_log="./sac_mt1/" + env_name)
# # model.learn(total_timesteps=1000000, log_interval=4)
# model.load(env_name + '_111')
success_cnt = 0
while success_cnt <45:
obs = env.reset()
done = False
episode_reward = 0
trj = []
step = 0
success = 0
while not done:
# tmp_obs = copy.deepcopy(obs)
# unscaled_action, _ = model.predict(obs, deterministic=False)
# caled_action = self.policy.scale_action(unscaled_action)
#
# action = np.clip(scaled_action, -1, 1)
#
action = policy.get_action(policy,obs)
noise = np.random.randn(action.shape[0]) *0.1
action = (action+noise).clip(-1,1)
new_obs, reward, done, info = env.step(action)
# env.render()
done = float(1) if step+1==500 else done
step +=1
store_obs = copy.deepcopy(obs)
store_new_obs = copy.deepcopy(new_obs)
store_obs[-3:] = 0
store_new_obs[-3:] = 0
trj.append([store_obs, action, reward, store_new_obs])
obs = new_obs
episode_reward += reward
success +=info['success']
if 1:
print(episode_reward,success,success_cnt)
np.save(os.path.join('./data/'+variant['env_name']+'/goal_idx%d'%goal_idx, f'trj_evalsample{success_cnt}_step{49500}.npy'), np.array(trj))
success_cnt+=1
else:
if np.random.rand()>0.9:
print(episode_reward, success, success_cnt)
np.save(os.path.join('./data/' + variant['env_name'] + '3/goal_idx%d' % goal_idx,
f'trj_evalsample{success_cnt}_step{49500}.npy'), np.array(trj))
success_cnt += 1
return
@click.command()
@click.argument("config", default="./configs/sparse-point-robot.json")
@click.option("--num_gpus", default=8)
@click.option("--docker", is_flag=True, default=False)
@click.option("--debug", is_flag=True, default=False)
@click.option("--eval", is_flag=True, default=False)
@click.option("--is_uniform", is_flag=True, default=False)
def main(config, num_gpus, docker, debug, eval, is_uniform, goal_idx=0, seed=0):
variant = default_config
cwd = os.getcwd()
files = os.listdir(cwd)
if config:
with open(os.path.join(config)) as f:
exp_params = json.load(f)
variant = deep_update_dict(exp_params, variant)
variant['util_params']['num_gpus'] = num_gpus
random_task_id = np.ndarray.tolist(np.random.permutation(variant['env_params']['n_tasks']))
cfg.is_uniform = is_uniform
print('cfg.is_uniform', cfg.is_uniform)
#cfg.gpu_id = gpu
#print('cfg.agent', cfg.agent)
print(list(range(variant['env_params']['n_tasks'])))
# multi-processing
p = mp.Pool(min(mp.cpu_count(), 50))
os.makedirs('./data/'+variant['env_name'],exist_ok=True)
if variant['env_params']['n_tasks'] > 1:
p.starmap(experiment, product([variant], [cfg], random_task_id))
else:
experiment(variant=variant, cfg=cfg, goal_idx=goal_idx)
if __name__ == '__main__':
#add a change
main()