Skip to content

Commit

Permalink
resnet_fpn_8_1
Browse files Browse the repository at this point in the history
  • Loading branch information
DaliCHEBBI committed Nov 26, 2024
1 parent 6e44d4d commit 4104afc
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 16 deletions.
12 changes: 12 additions & 0 deletions configs/model/resnet_fpn_mlp_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
defaults:
- default.yaml
- override criterion: BinaryEntropyLoss.yaml

lr: 0.001

neural_net_class_name: "ResNetFPN_8_1"
neural_net_hparams:
INITIAL_DIM: 128
BLOC_DIMS: [128,196,256]

mode: "feature+decision"
11 changes: 11 additions & 0 deletions configs/model/resnet_fpn_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
defaults:
- default.yaml

lr: 0.001

neural_net_class_name: "ResNetFPN_8_1"
neural_net_hparams:
INITIAL_DIM: 128
BLOC_DIMS: [128,196,256]

mode: "feature"
15 changes: 3 additions & 12 deletions simlearner3d/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from simlearner3d.models.generic_model import Model

from simlearner3d.models.modules.msaff import MSNet,MSNETInferenceGatedAttention
from simlearner3d.models.modules.resnet_fpn import ResNetFPN_8_1,ResNetFPN_8_1_Inference
from simlearner3d.models.modules.unet import UNet,UNetInference
from simlearner3d.models.modules.unetgatedattention import UNetGatedAttention, UNetInferenceGatedAttention
from simlearner3d.models.modules.decision_net import DecisionNetworkOnCube
Expand All @@ -20,8 +21,8 @@

NEURAL_NET_ARCHITECTURE_CONFIG_GROUP = "neural_net"

MODEL_ZOO = [MSNet,UNet,UNetGatedAttention]
MODEL_INFERENCE_ZOO=[MSNETInferenceGatedAttention,UNetInference,UNetInferenceGatedAttention]
MODEL_ZOO = [ResNetFPN_8_1,MSNet,UNet,UNetGatedAttention]
MODEL_INFERENCE_ZOO=[ResNetFPN_8_1_Inference,MSNETInferenceGatedAttention,UNetInference,UNetInferenceGatedAttention]


def get_inference_neural_net_class(class_training: nn.Module) -> nn.Module:
Expand Down Expand Up @@ -90,13 +91,3 @@ def extract(config: DictConfig):

print("Model Decision is saved as : ", out_decision_inference)











3 changes: 2 additions & 1 deletion simlearner3d/models/generic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pytorch_lightning import LightningModule
from torch import nn

from simlearner3d.models.modules.resnet_fpn import ResNetFPN_8_1, ResNetFPN_16_4
from simlearner3d.models.modules.msaff import MSNet
from simlearner3d.models.modules.unet import UNet
from simlearner3d.models.modules.unetgatedattention import UNetGatedAttention
Expand All @@ -11,7 +12,7 @@

log = utils.get_logger(__name__)

MODEL_ZOO = [MSNet,UNet,UNetGatedAttention]
MODEL_ZOO = [ResNetFPN_8_1,MSNet,UNet,UNetGatedAttention]


def get_neural_net_class(class_name: str) -> nn.Module:
Expand Down
6 changes: 3 additions & 3 deletions simlearner3d/qualify.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from simlearner3d.models.generic_model import Model

from simlearner3d.models.modules.msaff import MSNet,MSNETInferenceGatedAttention
from simlearner3d.models.modules.resnet_fpn import ResNetFPN_8_1,ResNetFPN_8_1_Inference
from simlearner3d.models.modules.unet import UNet,UNetInference
from simlearner3d.models.modules.unetgatedattention import UNetGatedAttention, UNetInferenceGatedAttention
from simlearner3d.models.modules.decision_net import DecisionNetworkOnCube
Expand All @@ -29,8 +30,8 @@

NEURAL_NET_ARCHITECTURE_CONFIG_GROUP = "neural_net"

MODEL_ZOO = [MSNet,UNet,UNetGatedAttention]
MODEL_INFERENCE_ZOO=[MSNETInferenceGatedAttention,UNetInference,UNetInferenceGatedAttention]
MODEL_ZOO = [ResNetFPN_8_1,MSNet,UNet,UNetGatedAttention]
MODEL_INFERENCE_ZOO=[ResNetFPN_8_1_Inference,MSNETInferenceGatedAttention,UNetInference,UNetInferenceGatedAttention]

DEFAULT_MODE="feature"

Expand Down Expand Up @@ -97,7 +98,6 @@ def PlotJointDistribution(Simsplus,
for j in range(200):
for i in range(j+1):
if (values[j,i]!='--'):
print(values)
SUM_GOOD+=values[j,i]
pourcent=str("%.2f" % SUM_GOOD)+" %"

Expand Down

0 comments on commit 4104afc

Please sign in to comment.