Skip to content

Commit

Permalink
Develop torch 1.13.1 (#5982)
Browse files Browse the repository at this point in the history
* Bumped PyTorch version to 1.13.1

* Added potential fixes to model overrider TBD at a later date.

* Updated changelog.
  • Loading branch information
miguelalonsojr authored Oct 5, 2023
1 parent 42ab41f commit 6427cdf
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
python-version: [3.10.x]
include:
- python-version: 3.10.x
pip_constraints: test_constraints_max_version.txt
pip_constraints: test_constraints_version.txt
steps:
- uses: actions/checkout@v2
- name: Set up Python
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
python-version: [3.10.12]
include:
- python-version: 3.10.12
pip_constraints: test_constraints_max_version.txt
pip_constraints: test_constraints_version.txt
steps:
- uses: actions/checkout@v2
- name: Set up Python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using UnityEngine;
using Unity.Sentis;
using System.IO;
using Unity.Sentis.ONNX;
using Unity.MLAgents;
using Unity.MLAgents.Policies;
#if UNITY_EDITOR
Expand Down Expand Up @@ -47,7 +46,6 @@ public class ModelOverrider : MonoBehaviour
// Cached loaded ModelAssets, with the behavior name as the key.
Dictionary<string, ModelAsset> m_CachedModels = new Dictionary<string, ModelAsset>();


// Max episodes to run. Only used if > 0
// Will default to 1 if override models are specified, otherwise 0.
int m_MaxEpisodes;
Expand Down Expand Up @@ -120,6 +118,7 @@ void GetAssetPathFromCommandLine()
{
return;
}

var maxEpisodes = 0;
var timeoutSeconds = 0;

Expand Down Expand Up @@ -148,6 +147,7 @@ void GetAssetPathFromCommandLine()
EditorApplication.isPlaying = false;
#endif
}

m_OverrideExtensions.Add(overrideExtension);
}
else if (args[i] == k_CommandLineQuitAfterEpisodesFlag && i < args.Length - 1)
Expand Down Expand Up @@ -276,11 +276,23 @@ public ModelAsset GetModelForBehaviorName(string behaviorName)
if (rawModel == null)
{
Debug.Log($"Couldn't load model file(s) for {behaviorName} in {m_BehaviorNameOverrideDirectory} (full path: {Path.GetFullPath(m_BehaviorNameOverrideDirectory)}");

// Cache the null so we don't repeatedly try to load a missing file
m_CachedModels[behaviorName] = null;
return null;
}

// TODO enable this when we have a decision on supporting loading/converting an ONNX model directly into a ModelAsset
// ModelAsset asset;
// if (isOnnx)
// {
// var modelName = Path.Combine(m_BehaviorNameOverrideDirectory, $"{behaviorName}.onnx");
// asset = LoadOnnxModel(modelName);
// }
// else
// {
// asset = LoadSentisModel(rawModel);
// }
// var asset = isOnnx ? LoadOnnxModel(rawModel) : LoadSentisModel(rawModel);
var asset = LoadSentisModel(rawModel);
asset.name = assetName;
Expand All @@ -296,6 +308,41 @@ ModelAsset LoadSentisModel(byte[] rawModel)
return asset;
}

// TODO enable this when we have a decision on supporting loading/converting an ONNX model directly into a ModelAsset
// ModelAsset LoadOnnxModel(string modelName)
// {
// Debug.Log($"Loading model for override: {modelName}");
// var converter = new ONNXModelConverter(true);
// var directoryName = Path.GetDirectoryName(modelName);
// var model = converter.Convert(modelName, directoryName);
// var asset = ScriptableObject.CreateInstance<ModelAsset>();
// var assetData = ScriptableObject.CreateInstance<ModelAssetData>();
// var descStream = new MemoryStream();
// ModelWriter.SaveModelDesc(descStream, model);
// assetData.value = descStream.ToArray();
// assetData.name = "Data";
// assetData.hideFlags = HideFlags.HideInHierarchy;
// descStream.Close();
// descStream.Dispose();
// asset.modelAssetData = assetData;
// var weightStreams = new List<MemoryStream>();
// ModelWriter.SaveModelWeights(weightStreams, model);
//
// asset.modelWeightsChunks = new ModelAssetWeightsData[weightStreams.Count];
// for (int i = 0; i < weightStreams.Count; i++)
// {
// var stream = weightStreams[i];
// asset.modelWeightsChunks[i] = ScriptableObject.CreateInstance<ModelAssetWeightsData>();
// asset.modelWeightsChunks[i].value = stream.ToArray();
// asset.modelWeightsChunks[i].name = "Data";
// asset.modelWeightsChunks[i].hideFlags = HideFlags.HideInHierarchy;
// stream.Close();
// stream.Dispose();
// }
//
// return asset;
// }

// TODO this should probably be deprecated since Sentis does not support direct conversion from byte arrays
// ModelAsset LoadOnnxModel(byte[] rawModel)
// {
Expand All @@ -317,7 +364,6 @@ ModelAsset LoadSentisModel(byte[] rawModel)
// return asset;
// }


/// <summary>
/// Load the ModelAsset file from the specified path, and give it to the attached agent.
/// </summary>
Expand Down Expand Up @@ -369,12 +415,12 @@ void OverrideModel()
{
Debug.LogWarning(overrideError);
}

Application.Quit(1);
#if UNITY_EDITOR
EditorApplication.isPlaying = false;
#endif
}

}
}
}
1 change: 1 addition & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to
## [Unreleased]
### Major Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
- Updated to PyTorch 1.13.1
- Deprecated support for Python 3.8.x and 3.9.x
- Upgraded ML-Agents to Sentis 1.2.0-exp.2 (#)
- The minimum supported Unity version was updated to 2022.3. (#)
Expand Down
3 changes: 2 additions & 1 deletion ml-agents/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def run(self):
"Pillow>=4.2.1",
"protobuf>=3.6,<3.20",
"pyyaml>=3.1.0",
"torch>=1.8.0,<=1.11.0",
"torch>=1.13.1",
"tensorboard>=2.14",
# adding six explicit dependency since tensorboard needs it but doesn't declare it as a dep
"six>=1.16",
Expand All @@ -72,6 +72,7 @@ def run(self):
"attrs>=19.3.0",
"huggingface_hub>=0.14",
'pypiwin32==223;platform_system=="Windows"',
"onnx==1.12.0",
],
python_requires=">=3.10.1,<=3.10.12",
entry_points={
Expand Down
2 changes: 0 additions & 2 deletions test_constraints_mid_version.txt

This file was deleted.

2 changes: 0 additions & 2 deletions test_constraints_min_version.txt

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# pip constraints to use the *highest* versions allowed in ml-agents/setup.py
# For projects with upper bounds, we should periodically update this list to the latest
torch==1.11.0
torch==1.13.1

0 comments on commit 6427cdf

Please sign in to comment.