Skip to content

Commit

Permalink
Merge pull request #168 from purdue-arc/abuynits-jcrm1-sim-refactoring
Browse files Browse the repository at this point in the history
Abuynits/jcrm1 sim refactoring
  • Loading branch information
jcrm1 authored Aug 26, 2023
2 parents 7a72132 + 494a002 commit 7cf219a
Show file tree
Hide file tree
Showing 18 changed files with 1,228 additions and 334 deletions.
7 changes: 7 additions & 0 deletions rktl_autonomy/YAMLEditor/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Java
bin/
.classpath
.settings

# macOS metadata file
.DS_Store
17 changes: 17 additions & 0 deletions rktl_autonomy/YAMLEditor/.project
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<projectDescription>
<name>YAMLEditor</name>
<comment></comment>
<projects>
</projects>
<buildSpec>
<buildCommand>
<name>org.eclipse.jdt.core.javabuilder</name>
<arguments>
</arguments>
</buildCommand>
</buildSpec>
<natures>
<nature>org.eclipse.jdt.core.javanature</nature>
</natures>
</projectDescription>
7 changes: 7 additions & 0 deletions rktl_autonomy/YAMLEditor/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# YAMLEditor

## Unpack into a table, edit, and repack two YAML arrays

Intended for use with reward values

Built jar in `jars/`
Binary file added rktl_autonomy/YAMLEditor/jars/YAMLEditor.jar
Binary file not shown.
177 changes: 177 additions & 0 deletions rktl_autonomy/YAMLEditor/src/com/purduearc/yamleditor/Main.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package com.purduearc.yamleditor;

import java.awt.Rectangle;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.util.Arrays;
import java.util.regex.Pattern;

import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JScrollPane;
import javax.swing.JTable;
import javax.swing.JTextField;
import javax.swing.event.TableModelEvent;
import javax.swing.event.TableModelListener;
import javax.swing.table.DefaultTableModel;
/**
* Unpack into a table, edit, and repack two YAML arrays
*
* @author Campbell McClendon
*
* @version 1.0
*/
public class Main {
private static final Rectangle scrollPaneBounds = new Rectangle(2, 44, 496, 426);
private static final String[] titles = new String[] {"ID", "WIN", "LOSE"};
private static String[][] data = null;
public static void main(String[] args) {
JFrame frame = new JFrame("YAML Editor");
frame.setResizable(false);
frame.setSize(500, 500);

JTextField winInputField = new JTextField("[1, 2, 3, 4, 5]", 50);
winInputField.setBounds(2, 2, 100, 20);

JTextField loseInputField = new JTextField("[6, 7, 8, 9, 10]");
loseInputField.setBounds(104, 2, 100, 20);

JLabel inputLabel = new JLabel("Input");
inputLabel.setBounds(106, 24, 100, 20);

data = new String[][] {{"0", "1", "2"}, {"1", "3", "4"}};
DefaultTableModel tableModel = new DefaultTableModel(data, titles);
JTable table = new JTable(tableModel);
table.setFillsViewportHeight(true);
tableModel.addTableModelListener(new TableModelListener() {
@Override
public void tableChanged(TableModelEvent e) {
if (table.isEditing()) {
data[table.getSelectedRow()][table.getSelectedColumn()] = (String) table.getValueAt(table.getSelectedRow(), table.getSelectedColumn());
}
}
});

JLabel errorLabel = new JLabel("");
errorLabel.setBounds(225, 2, 100, 20);

JScrollPane scrollPane = new JScrollPane(table);
scrollPane.setBounds(scrollPaneBounds);

JTextField winOutputField = new JTextField(50);
winOutputField.setBounds(296, 2, 100, 20);
winOutputField.setEditable(false);

JTextField loseOutputField = new JTextField(50);
loseOutputField.setBounds(398, 2, 100, 20);
loseOutputField.setEditable(false);

JLabel outputLabel = new JLabel("Output");
outputLabel.setBounds(298, 24, 100, 20);

JButton unpackButton = new JButton("Unpack Arrays");
unpackButton.setBounds(2, 24, 100, 20);
unpackButton.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
try {
String col1Text = clip(winInputField.getText().replaceAll("\\s+",""));
String[] col1ByID = col1Text.split(Pattern.quote(","));
if (col1ByID.length < 2) {
System.err.println("Malformed win input data");
errorLabel.setText("ERROR");
return;
}
String col2Text = clip(loseInputField.getText().replaceAll("\\s+",""));
String[] col2ByID = col2Text.split(Pattern.quote(","));
if (col2ByID.length < 2) {
System.err.println("Malformed lose input data");
errorLabel.setText("ERROR");
return;
}
if (col1ByID.length != col2ByID.length) {
System.err.println("Win and lose arrays must be of same length");
errorLabel.setText("ERROR");
return;
}
String[][] newData = new String[col1ByID.length][2];
for (int i = 0; i < col1ByID.length; i++) {
newData[i][0] = col1ByID[i];
newData[i][1] = col2ByID[i];
}
String[][] dataWithID = new String[newData.length][newData[0].length + 1];
for (int id = 0; id < dataWithID.length; id++) {
dataWithID[id][0] = "" + id;
for (int i = 1; i < dataWithID[id].length; i++) {
dataWithID[id][i] = newData[id][i - 1];
}
}
data = dataWithID;
tableModel.setDataVector(dataWithID, titles);
errorLabel.setText("");
} catch (Exception exception) {
errorLabel.setText("ERROR");
exception.printStackTrace();
}
}
});

JButton repackButton = new JButton("Repack Arrays");
repackButton.setBounds(398, 24, 100, 20);
repackButton.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
try {
String[][] dataWithoutID = new String[data.length][];
for (int x = 0; x < data.length; x++) {
String[] line = new String[data[x].length - 1];
for (int y = 1; y < data[x].length; y++) {
line[y - 1] = data[x][y];
}
dataWithoutID[x] = line;
}
String[] winValues = new String[dataWithoutID.length];
String[] loseValues = new String[dataWithoutID.length];
for (int i = 0; i < dataWithoutID.length; i++) {
winValues[i] = dataWithoutID[i][0];
loseValues[i] = dataWithoutID[i][1];
}
winOutputField.setText(Arrays.deepToString(winValues).replaceAll("\\s+",""));
loseOutputField.setText(Arrays.deepToString(loseValues).replaceAll("\\s+",""));
} catch (Exception exception) {
errorLabel.setText("ERROR");
exception.printStackTrace();
}
}
});

frame.add(winInputField);
frame.add(loseInputField);
frame.add(inputLabel);
frame.add(unpackButton);
frame.add(errorLabel);
frame.add(winOutputField);
frame.add(loseOutputField);
frame.add(outputLabel);
frame.add(repackButton);
frame.add(scrollPane);

frame.setLayout(null);
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.setVisible(true);
}
/**
* Clips the first and last characters off of a string
*
* @param str input string
* @return str without first and last characters
*/
private static String clip(String str) {
if (str.charAt(0) == (char) '[') {
if (str.charAt(str.length() - 1) == (char) ']') return str.substring(1, str.length() - 1);
else return str.substring(1, str.length());
} else if (str.charAt(str.length() - 1) == (char) ']') return str.substring(0, str.length() - 1);
else return str;
}
}
4 changes: 2 additions & 2 deletions rktl_autonomy/config/rocket_league.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ reward:
# reward given when car changes velocity direction
direction_change: 0
# reward given when the episode ends with the car scoring the proper goal
win: 100
win: [100, 200, 300]
# reward given when the episode ends with the car scoring on the wrong goal
loss: 100
loss: [100, 100, 100]
# reward given each frame when the car is in reverse
# reverse: -25
# # reward given each frame when the car is near the walls
Expand Down
25 changes: 25 additions & 0 deletions rktl_autonomy/scripts/modular_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env python3

from train_rocket_league import train
import yaml
import sys
import os
from multiprocessing import Process

if __name__ == '__main__':
numEnvsAllowed = 24

if len(sys.argv) == 2:
numEnvsAllowed = int(sys.argv[1])

configFile = os.path.join(os.pardir, 'config', 'rocket_league.yaml')
print(os.path.abspath(configFile))

file = yaml.load(open(configFile), Loader=yaml.FullLoader)
numGroups = len(file["reward"]["win"])

for i in range(numGroups):
args = (int(numEnvsAllowed / numGroups), 100, 240000000, i)
p = Process(target=train, args=args)
p.start()
print(f'Starting thread {i}/{numGroups}')
21 changes: 12 additions & 9 deletions rktl_autonomy/scripts/train_rocket_league.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
from os.path import expanduser
import uuid

if __name__ == '__main__': # this is required due to forking processes

def train(n_envs=24, n_saves=100, n_steps=240000000, env_number=0):
run_id = str(uuid.uuid4()) # ALL running environments must share this
print(f"RUN ID: {run_id}")

# to pass launch args, add to env_kwargs: 'launch_args': ['render:=false', 'plot_log:=true']
env = make_vec_env(RocketLeagueInterface, env_kwargs={'run_id':run_id},
n_envs=24, vec_env_cls=SubprocVecEnv)
env = make_vec_env(RocketLeagueInterface, env_kwargs={'run_id': run_id}, wrapper_kwargs = {'env_number': env_number, 'run_id': run_id},
n_envs=n_envs, vec_env_cls=SubprocVecEnv)

model = PPO("MlpPolicy", env)

Expand All @@ -31,23 +31,26 @@
model.set_parameters(previous_weights)
except:
print('Failed to load from previous weights')

# log training progress as CSV
log_dir = expanduser(f'~/catkin_ws/data/rocket_league/{run_id}')
logger = configure(log_dir, ["stdout", "csv", "log"])
model.set_logger(logger)

# log model weights
freq = 20833 # save 20 times
# freq = steps / (n_saves * n_envs)
freq = n_steps / (n_saves * n_envs)
callback = CheckpointCallback(save_freq=freq, save_path=log_dir)

# run training
steps = 240000000 # 240M (10M sequential)
steps = n_steps
print(f"training on {steps} steps")
model.learn(total_timesteps=steps, callback=callback)

# save final weights
print("done training")
model.save(log_dir + "/final_weights")
env.close() # this must be done to clean up other processes
env.close() # this must be done to clean up other processes


if __name__ == '__main__':
train()
4 changes: 3 additions & 1 deletion rktl_autonomy/src/rktl_autonomy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
from .cartpole_direct_interface import CartpoleDirectInterface
from .snake_interface import SnakeInterface
from .rocket_league_interface import RocketLeagueInterface
from .env_counter import EnvCounter

__all__ = [
"ROSInterface",
"CartpoleInterface",
"CartpoleDirectInterface",
"SnakeInterface",
"RocketLeagueInterface"]
"RocketLeagueInterface",
"EnvCounter"]
9 changes: 9 additions & 0 deletions rktl_autonomy/src/rktl_autonomy/env_counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class EnvCounter:
def __init__(self):
self.counter = 0

def count_env(self):
self.counter += 1

def get_current_counter(self):
return self.counter
26 changes: 20 additions & 6 deletions rktl_autonomy/src/rktl_autonomy/rocket_league_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ class CarActions(IntEnum):

class RocketLeagueInterface(ROSInterface):
"""ROS interface for the Rocket League."""
def __init__(self, eval=False, launch_file=['rktl_autonomy', 'rocket_league_train.launch'], launch_args=[], run_id=None):
super().__init__(node_name='rocket_league_agent', eval=eval, launch_file=launch_file, launch_args=launch_args, run_id=run_id)

## Constants
def __init__(self, eval=False, launch_file=('rktl_autonomy', 'rocket_league_train.launch'), launch_args=[],
run_id=None, env_number=0):
super().__init__(node_name='rocket_league_agent', eval=eval, launch_file=launch_file, launch_args=launch_args,
run_id=run_id)
# Constants
self.env_number = env_number
# Actions
self._MIN_VELOCITY = -rospy.get_param('/cars/throttle/max_speed')
self._MAX_VELOCITY = rospy.get_param('/cars/throttle/max_speed')
Expand Down Expand Up @@ -81,8 +83,20 @@ def __init__(self, eval=False, launch_file=['rktl_autonomy', 'rocket_league_trai
self._BALL_DISTANCE_REWARD = rospy.get_param('~reward/ball_dist_sq', 0.0)
self._GOAL_DISTANCE_REWARD = rospy.get_param('~reward/goal_dist_sq', 0.0)
self._DIRECTION_CHANGE_REWARD = rospy.get_param('~reward/direction_change', 0.0)
self._WIN_REWARD = rospy.get_param('~reward/win', 100.0)
self._LOSS_REWARD = rospy.get_param('~reward/loss', 0.0)
if isinstance(rospy.get_param('~reward/win', [100.0]), int):
self._WIN_REWARD = rospy.get_param('~reward/win', [100.0])
else:
if len(rospy.get_param('~reward/win', [100.0])) >= self.env_number:
self._WIN_REWARD = rospy.get_param('~reward/win', [100.0])[0]
else:
self._WIN_REWARD = rospy.get_param('~reward/win', [100.0])[self.env_number]
if isinstance(rospy.get_param('~reward/loss', [100.0]), int):
self._LOSS_REWARD = rospy.get_param('~reward/loss', [100.0])
else:
if len(rospy.get_param('~reward/loss', [100.0])) >= self.env_number:
self._LOSS_REWARD = rospy.get_param('~reward/loss', [100.0])[0]
else:
self._LOSS_REWARD = rospy.get_param('~reward/loss', [100.0])[self.env_number]
self._REVERSE_REWARD = rospy.get_param('~reward/reverse', 0.0)
self._WALL_REWARD = rospy.get_param('~reward/walls/value', 0.0)
self._WALL_THRESHOLD = rospy.get_param('~reward/walls/threshold', 0.0)
Expand Down
8 changes: 8 additions & 0 deletions rktl_launch/config/global_params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
# field dimensions
field:
width: 3
# max: 3.5
# min: 2.5
length: 4.25
# max: 5
# min: 3.5
wall_thickness: 0.25
goal:
width: 1
Expand All @@ -15,8 +19,12 @@ ball:
# car dimensions and physical constants
cars:
length: 0.12 # front to rear wheel center to center, meters
# min: 0.10
# max: 0.30
steering:
max_throw: 0.1826 # center to side, radians
# min: 0.16
# max: 0.20
rate: 0.9128 # rad/s
throttle:
max_speed: 2.3 # m/s
Expand Down
Loading

0 comments on commit 7cf219a

Please sign in to comment.