Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model validation against pytorch? #1

Open
relf opened this issue Nov 13, 2024 · 9 comments
Open

Model validation against pytorch? #1

relf opened this issue Nov 13, 2024 · 9 comments

Comments

@relf
Copy link

relf commented Nov 13, 2024

First, thank you for sharing your work. It will definitely help me to better understand burn framework usage and how a pytorch model can be translated.

I would like to know if you have validated your model against pytorch outputs and how does it compare? Otherwise have you got relevant results/performances using it?

@phenylshima
Copy link
Owner

phenylshima commented Nov 14, 2024

Thank you for your interst.

Actually, although I have observed that efficientnet + unet++ architecture generates some seemingly meaningful result (not uniformally gray or something) on a small dataset, I have not measured its performance nor tested it on a larger (real-life) dataset.

As for the difference with pytorch, if you use pytorch as backend and the model architecture is the same, I expect that there will be no significant difference. However, though I implemented the models so that its architecture will be the same as their pytorch counterpart, it is possible that there are bugs that I am not aware of (as I have not fully tested them).

Note: While testing the model on a small dataset, I found that efficientnet's batchnorm parameters are critical for the loss to drop, and changed the default value. If you see any issue on a larger dataset, this might be the problem.

@relf
Copy link
Author

relf commented Nov 18, 2024

Thank you for your reply. I am currently trying to use your model, but I stumble on dimension errors. Would you share a small example of a working configuration? Thank you in advance for your help.

@phenylshima
Copy link
Owner

phenylshima commented Nov 19, 2024

I have the code but I need to clean its nasty data loader before releasing it...
I plan to fix it some time, though.

However, I can provide some of its relevant parts.

  • The shape of input to segmentation_models_burn::Model<B, EfficientNet<B>, UnetPlusPlusDecoder<B>> is Tensor<B, 4>.
  • The output needs to be put into sigmoid before using it.

Here is the relevant parts.

The model definition
EfficientNetConfigPreset::new(EfficientNetKind::EfficientnetB0)
    .with_weight(Some(EfficientNetWeightKind::Normal)),
UnetPlusPlusDecoderConfig::new(vec![], vec![256, 128, 64, 32, 16]),
Model wrapper
use burn::{
    module::AutodiffModule,
    prelude::*,
    tensor::{activation::sigmoid, backend::AutodiffBackend},
    train::{
        metric::{Adaptor, LossInput},
        TrainOutput, TrainStep, ValidStep,
    },
};
use segmentation_models_burn::{
    activation::Activation,
    decoder::unetplusplus::{UnetPlusPlusDecoder, UnetPlusPlusDecoderConfig},
    encoder::efficientnet::{EfficientNet, EfficientNetConfigPreset},
    segmentation_head::SegmentationHeadConfig,
    ModelInit,
};

use crate::{data::SegmentationBatch, loss::DiceBCELossConfig};

#[derive(Debug, Module)]
pub struct Model<B: Backend> {
    inner: segmentation_models_burn::Model<B, EfficientNet<B>, UnetPlusPlusDecoder<B>>,
}

impl<B: Backend> Model<B> {
    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
        self.inner.forward(x)
    }

    pub fn forward_infer(&self, images: Tensor<B, 4>) -> Tensor<B, 4> {
        sigmoid(self.forward(images))
    }

    pub fn forward_loss(
        &self,
        images: Tensor<B, 4>,
        targets: Tensor<B, 4, Int>,
    ) -> SegmentationOutput<B> {
        let output = self.forward(images);
        let loss = DiceBCELossConfig::new().init(&output.device()).forward(
            output.clone(),
            targets.clone(),
            1.0,
        );

        SegmentationOutput {
            loss,
            output,
            targets,
        }
    }
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ModelConfig {
    inner:
        segmentation_models_burn::ModelConfig<EfficientNetConfigPreset, UnetPlusPlusDecoderConfig>,
}

impl ModelConfig {
    pub fn new<B: Backend>(
        encoder: EfficientNetConfigPreset,
        decoder: UnetPlusPlusDecoderConfig,
    ) -> ModelConfig {
        ModelConfig {
            inner: segmentation_models_burn::ModelConfig::new::<B>(
                encoder,
                decoder,
                SegmentationHeadConfig::new().with_activation(Activation::Identity),
            ),
        }
    }

    pub fn init<B: Backend>(&self, classes: usize, device: &B::Device) -> Model<B> {
        Model {
            inner: self.inner.init(classes, device),
        }
    }
}

impl<B: AutodiffBackend> TrainStep<SegmentationBatch<B>, SegmentationOutput<B>> for Model<B>
where
    Self: AutodiffModule<B>,
{
    fn step(&self, batch: SegmentationBatch<B>) -> TrainOutput<SegmentationOutput<B>> {
        let item = self.forward_loss(batch.images, batch.targets);

        TrainOutput::new(self, item.loss.backward(), item)
    }
}

impl<B: Backend> ValidStep<SegmentationBatch<B>, SegmentationOutput<B>> for Model<B> {
    fn step(&self, batch: SegmentationBatch<B>) -> SegmentationOutput<B> {
        self.forward_loss(batch.images, batch.targets)
    }
}

pub struct SegmentationOutput<B: Backend> {
    pub loss: Tensor<B, 1>,
    pub output: Tensor<B, 4>,
    pub targets: Tensor<B, 4, Int>,
}

impl<B: Backend> Adaptor<LossInput<B>> for SegmentationOutput<B> {
    fn adapt(&self) -> LossInput<B> {
        LossInput::new(self.loss.clone())
    }
}

Also, is it possible to show me the error? I might be able to find some part related to it.

@relf
Copy link
Author

relf commented Nov 19, 2024

Thanks for this useful information. With burn tch-cpu backend I get the following error:

thread 'main' panicked at /stck/rlafage/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tch-0.15.0/src/wrappers/tensor_generated.rs:6683:72:
called `Result::unwrap()` on an `Err` value: Torch("Given groups=1, weight of size [384, 1024, 3, 3], expected input[1, 544, 64, 64] to have 1024 channels, but got 544 channels instead
Exception raised from check_shape_forward at ../aten/src/ATen/native/Convolution.cpp:682 (most recent call first):frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x14a45d0ccd87 in /tmp_user/juno/rlafage/opt/miniconda3/envs/burn/lib/python3.10/site-packages/torch/lib/libc10.so)
...

I guess I have to preprocess my images (currently 1024x1024) to resize to 600x600 which seems to be the expected size for EfficientnetB7, the efficient net kind I wanted to use (btw smaller images will improve training time).

Can I specify an image size? I noticed an image_size field in EfficientNetGlobalConfig which seems to be set up automatically to 600x600 when using B7 but no way to set it to a given value directly, right?

@phenylshima
Copy link
Owner

I used 512x512 image size for EfficientNetB0... (I didn't know it had constraint in input image size...)

Anyway, it seems that the image size has to be 600x600 for B7 and it is one of the constants of B7.

I'm not sure if it solves the problem, but can you try using 600x600 input?
If it does not work, there might be some mistakes in round_filters and round_repeats implementation.

@relf
Copy link
Author

relf commented Nov 20, 2024

I resized my images 600x600 and using EfficientNetB7 I get the following:

> cargo run --release --features tch-cpu
[...]
thread 'main' panicked at /stck/rlafage/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tch-0.15.0/src/wrappers/tensor_generated.rs:6683:72:
called `Result::unwrap()` on an `Err` value: Torch("Given groups=1, weight of size [256, 1024, 3, 3], expected input[1, 544, 38, 38] to have 1024 channels, but got 544 channels instead
Exception raised from check_shape_forward at ../aten/src/ATen/native/Convolution.cpp:682 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x14597eaced87 in /tmp_user/juno/rlafage/opt/miniconda3/envs/burn/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x14597ea7f75f in /tmp_user/juno/rlafage/opt/miniconda3/envs/burn/lib/python3.10/site-packages/torch/lib/libc10.so)
[...]
frame #31: <unknown function> + 0x22e985 (0x55c5c2fc9985 in target/release/spot)\nframe #32: __libc_start_main + 0xe5 (0x14597e3cfd85 in /lib64/libc.so.6)\nframe #33: <unknown function> + 0xc278e (0x55c5c2e5d78e in target/release/spot)\n")
stack backtrace:
   0: rust_begin_unwind
   1: core::panicking::panic_fmt
   2: core::result::unwrap_failed
   3: burn_tch::ops::module::<impl burn_tensor::tensor::ops::modules::base::ModuleOps<burn_tch::backend::LibTorch<E,Q>> for burn_tch::backend::LibTorch<E,Q>>::conv2d
   4: burn_tensor::tensor::module::conv2d
   5: burn_core::nn::conv::conv2d::Conv2d<B>::forward
   6: segmentation_models_burn::decoder::unetplusplus::conv2drelu::Conv2dReLU<B>::forward
   7: segmentation_models_burn::decoder::unetplusplus::decoder::DecoderBlock<B>::forward
   8: spot::training::run
   9: spot::main

@relf
Copy link
Author

relf commented Nov 20, 2024

I tried with a 512x512-sized image and B0, it passes!
With 512x512 and B7, it fails in decoder Conv2dReLU like above with just a slightly difference:
Given groups=1, weight of size [256, 1024, 3, 3], expected input[1, 544, 32, 32] to have 1024 channels, but got 544 channels instead

Any idea?

@relf
Copy link
Author

relf commented Nov 25, 2024

B0 passes with 1024x1024 images though the loss does not drop (I will try to tweak the parameters as you suggest). Have you tested other efficient net config (B1, ..., B7) with your small dataset? I still get errors like above with these other configurations.

@phenylshima
Copy link
Owner

phenylshima commented Dec 1, 2024

@relf Sorry for the late reply. I think I fixed it in #3, so can you try to run the code using fix-efficientnets branch?

(Unrelated but since I changed my username, git remote url may also need to be changed)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants