Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save the model in tensorflow saved Model format not work . #15

Open
BernradMaillard opened this issue Oct 5, 2024 · 4 comments
Open

Comments

@BernradMaillard
Copy link

BernradMaillard commented Oct 5, 2024

  --training_record ./input/TFRecords/Cameroun/train/ \
  --valid_records ./input/TFRecords/Cameroun/valid/ \
  --logdir ./LOGDIR \
  --model meraner_unet \
  -lr 0.0007 \
  -bt 1 \
  -bv 1 \
  -e 2\
  --ckpt_dir ./input/meraner/ckpt/ \
  --out_savedmodel ./data/meranercamer/model/
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
2024-10-05 07:46:32 WARNING  There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
2024-10-05 07:46:32 INFO     Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
2024-10-05 07:46:32 INFO     No GPU found, using CPU
2024-10-05 07:46:32 INFO     Learning rate was scaled to 0.0007, effective batch size is 1 (1 workers)
2024-10-05 07:46:32 INFO     Searching TFRecords in ./input/TFRecords/Cameroun/train//*.records...
2024-10-05 07:46:32 INFO     Number of matching TFRecords: 25
2024-10-05 07:46:32 INFO     Reducing number of records to : 25
2024-10-05 07:46:33 INFO     Searching TFRecords in ./input/TFRecords/Cameroun/valid//*.records...
2024-10-05 07:46:33 INFO     Number of matching TFRecords: 1
2024-10-05 07:46:33 INFO     Reducing number of records to : 1
2024-10-05 07:46:33 INFO     Loading model "meraner_unet"
Model: "meraner_unet"
______________________________________________________________________________________________________________________________________________________
 Layer (type)                                    Output Shape                     Param #           Connected to

======================================================================================================================================================
 s1_t (InputLayer)                               [(None, None, None, 2)]          0                 []



 s2_t (InputLayer)                               [(None, None, None, 4)]          0                 []



 tf.cast (TFOpLambda)                            (None, None, None, 2)            0                 ['s1_t[0][0]']



 tf.cast_1 (TFOpLambda)                          (None, None, None, 4)            0                 ['s2_t[0][0]']



 tf.math.multiply (TFOpLambda)                   (None, None, None, 2)            0                 ['tf.cast[0][0]']



 tf.math.multiply_1 (TFOpLambda)                 (None, None, None, 4)            0                 ['tf.cast_1[0][0]']



 dem (InputLayer)                                [(None, None, None, 1)]          0                 []



 tf.concat (TFOpLambda)                          (None, None, None, 6)            0                 ['tf.math.multiply[0][0]',

                                                                                                     'tf.math.multiply_1[0][0]']



 tf.cast_2 (TFOpLambda)                          (None, None, None, 1)            0                 ['dem[0][0]']



 conv1_relu (Conv2D)                             (None, None, None, 64)           9664              ['tf.concat[0][0]']



 tf.math.multiply_2 (TFOpLambda)                 (None, None, None, 1)            0                 ['tf.cast_2[0][0]']



 conv2_bn_relu (Conv2D)                          (None, None, None, 128)          73856             ['conv1_relu[0][0]']



 conv1_dem_relu (Conv2D)                         (None, None, None, 64)           640               ['tf.math.multiply_2[0][0]']



 tf.concat_1 (TFOpLambda)                        (None, None, None, 192)          0                 ['conv2_bn_relu[0][0]',

                                                                                                     'conv1_dem_relu[0][0]']



 conv3_bn_relu (Conv2D)                          (None, None, None, 256)          442624            ['tf.concat_1[0][0]']



 conv4_bn_relu (Conv2D)                          (None, None, None, 512)          1180160           ['conv3_bn_relu[0][0]']



 conv5_bn_relu (Conv2D)                          (None, None, None, 512)          2359808           ['conv4_bn_relu[0][0]']



 conv6_bn_relu (Conv2D)                          (None, None, None, 512)          2359808           ['conv5_bn_relu[0][0]']



 deconv1_bn_relu (Conv2DTranspose)               (None, None, None, 512)          2359808           ['conv6_bn_relu[0][0]']



 tf.concat_2 (TFOpLambda)                        (None, None, None, 1024)         0                 ['conv5_bn_relu[0][0]',

                                                                                                     'deconv1_bn_relu[0][0]']



 deconv2_bn_relu (Conv2DTranspose)               (None, None, None, 512)          4719104           ['tf.concat_2[0][0]']



 tf.concat_3 (TFOpLambda)                        (None, None, None, 1024)         0                 ['conv4_bn_relu[0][0]',

                                                                                                     'deconv2_bn_relu[0][0]']



 deconv3_bn_relu (Conv2DTranspose)               (None, None, None, 256)          2359552           ['tf.concat_3[0][0]']



 tf.concat_4 (TFOpLambda)                        (None, None, None, 512)          0                 ['conv3_bn_relu[0][0]',

                                                                                                     'deconv3_bn_relu[0][0]']



 deconv4_bn_relu (Conv2DTranspose)               (None, None, None, 128)          589952            ['tf.concat_4[0][0]']



 tf.concat_5 (TFOpLambda)                        (None, None, None, 320)          0                 ['tf.concat_1[0][0]',

                                                                                                     'deconv4_bn_relu[0][0]']



 deconv5_bn_relu (Conv2DTranspose)               (None, None, None, 64)           184384            ['tf.concat_5[0][0]']



 tf.concat_6 (TFOpLambda)                        (None, None, None, 128)          0                 ['conv1_relu[0][0]',

                                                                                                     'deconv5_bn_relu[0][0]']



 s2_estim (Conv2D)                               (None, None, None, 4)            12804             ['tf.concat_6[0][0]']



 add (Add)                                       (None, None, None, 4)            0                 ['s2_estim[0][0]',

                                                                                                     'tf.math.multiply_1[0][0]']



 tf.math.multiply_5 (TFOpLambda)                 (None, None, None, 4)            0                 ['add[0][0]']



 tf.math.multiply_6 (TFOpLambda)                 (None, None, None, 4)            0                 ['add[0][0]']



 tf.math.multiply_3 (TFOpLambda)                 (None, None, None, 4)            0                 ['add[0][0]']



 tf.math.multiply_4 (TFOpLambda)                 (None, None, None, 4)            0                 ['add[0][0]']



 add_pad128 (Cropping2D)                         (None, None, None, 4)            0                 ['tf.math.multiply_5[0][0]']



 add_pad256 (Cropping2D)                         (None, None, None, 4)            0                 ['tf.math.multiply_6[0][0]']



 add_pad32 (Cropping2D)                          (None, None, None, 4)            0                 ['tf.math.multiply_3[0][0]']



 add_pad64 (Cropping2D)                          (None, None, None, 4)            0                 ['tf.math.multiply_4[0][0]']



======================================================================================================================================================
Total params: 16,652,164
Trainable params: 16,652,164
Non-trainable params: 0
______________________________________________________________________________________________________________________________________________________
Epoch 1/2
/opt/otbtf/lib/python3/dist-packages/keras/engine/functional.py:639: UserWarning: Input dict contained keys ['s1_ascending_t', 's1_timestamp_t', 's2_timestamp_t', 's2_timestamp_target'] which did not match any model input. They will be ignored by the model.
  inputs = self._flatten_to_reference_inputs(inputs)
2500/2500 - 2743s - loss: 0.0240 - add_loss: 0.0240 - add_pad128_loss: 0.0000e+00 - add_pad256_loss: 0.0000e+00 - add_pad32_loss: 0.0000e+00 - add_pad64_loss: 0.0000e+00 - add_PSNR: 27.8203 - val_loss: 0.0247 - val_add_loss: 0.0247 - val_add_pad128_loss: 0.0000e+00 - val_add_pad256_loss: 0.0000e+00 - val_add_pad32_loss: 0.0000e+00 - val_add_pad64_loss: 0.0000e+00 - val_add_PSNR: 27.3260 - 2743s/epoch - 1s/step
Epoch 2/2
2500/2500 - 2738s - loss: 0.0220 - add_loss: 0.0220 - add_pad128_loss: 0.0000e+00 - add_pad256_loss: 0.0000e+00 - add_pad32_loss: 0.0000e+00 - add_pad64_loss: 0.0000e+00 - add_PSNR: 28.8060 - val_loss: 0.0184 - val_add_loss: 0.0184 - val_add_pad128_loss: 0.0000e+00 - val_add_pad256_loss: 0.0000e+00 - val_add_pad32_loss: 0.0000e+00 - val_add_pad64_loss: 0.0000e+00 - val_add_PSNR: 30.6968 - 2738s/epoch - 1s/step
2024-10-05 09:17:57 INFO     Saving SavedModel in ./data/meranercamer/model/meraner_unet_bt1_bv1_lr0.0007_e2_cpu05-10-24-074632
2024-10-05 09:17:59 WARNING  Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 5 of 14). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: ./data/meranercamer/model/meraner_unet_bt1_bv1_lr0.0007_e2_cpu05-10-24-074632/assets
2024-10-05 09:18:02 INFO     Assets written to: ./data/meranercamer/model/meraner_unet_bt1_bv1_lr0.0007_e2_cpu05-10-24-074632/assets
otbuser@6dc23c62c23f:/home/data/decloud$ meraner_processor \
  --il_s1 /home/data/decloud/input/S1_PREPARE/T33PUK/s1a_33PUK_vvvh_ASC_161_20231029txxxxxx_from-10to3dB.tif \
  --in_s2 /home/data/decloud/input/S2_PREPARE/T33PUK/SENTINEL2A_20231104-093818-876_L2A_T33PUK_C_V3-1 \
  --in_dem /home/data/decloud/input/DEM_PREPARE/T33PUK.tif \
  --output lagdo.tif \
  --savedmodel /home/data/decloud/data/meranercamer/model/meraner_unet_bt1_bv1_lr0.0007_e2_cpu05-10-24-074632/ \
  --pad 256 \
  --ts 1024
2024-10-05 10:26:13 (INFO) [pyOTB] Successfully loaded 126 OTB applications
2024-10-05 10:26:13 INFO     Init. S2_TILED product
2024-10-05 10:26:13 INFO     Init. S1_TILED product
2024-10-05 10:26:13 INFO     10m spacing bands: /home/data/decloud/input/S1_PREPARE/T33PUK/s1a_33PUK_vvvh_ASC_161_20231029txxxxxx_from-10to3dB.tif
2024-10-05 10:26:13 INFO     Date: 2023-10-29 11:11:11
2024-10-05 10:26:14 (INFO): Loading metadata from official product
2024-10-05 10:26:14 (INFO) Mosaic: Temporary files prefix is:
2024-10-05 10:26:14 (INFO) Mosaic: No feathering
2024-10-05 10:26:14 INFO     Setup inference pipeline
2024-10-05 10:26:14 INFO     Input sources: {'s1_t': <pyotb.apps.Mosaic object at 0x7fbb9e3ba4d0>, 's2_t': '/home/data/decloud/input/S2_PREPARE/T33PUK/SENTINEL2A_20231104-093818-876_L2A_T33PUK_C_V3-1/SENTINEL2A_20231104-093818-876_L2A_T33PUK_C_V3-1_FRE_10m.tif', 'dem': '/home/data/decloud/input/DEM_PREPARE/T33PUK.tif'}
2024-10-05 10:26:14 INFO     Input pad: 256
2024-10-05 10:26:14 INFO     Input tile size: 1024
2024-10-05 10:26:14 INFO     SavedModel directory: /home/data/decloud/data/meranercamer/model/meraner_unet_bt1_bv1_lr0.0007_e2_cpu05-10-24-074632/
2024-10-05 10:26:14 INFO     Output tensor name: s2_estim
2024-10-05 10:26:14 INFO     Receptive field: 1536, Expression field: 1024
2024-10-05 10:26:14 INFO     Preparing source 1 for placeholder s1_t
2024-10-05 10:26:14 INFO     Preparing source 2 for placeholder s2_t
2024-10-05 10:26:14 INFO     Preparing source 3 for placeholder dem
2024-10-05 10:26:15 (INFO) TensorflowModelServe: Source info :
2024-10-05 10:26:15 (INFO) TensorflowModelServe: Receptive field  : [1536, 1536]
2024-10-05 10:26:15 (INFO) TensorflowModelServe: Placeholder name : s1_t
2024-10-05 10:26:15 (INFO): Loading metadata from official product
2024-10-05 10:26:15 (INFO) TensorflowModelServe: Source info :
2024-10-05 10:26:15 (INFO) TensorflowModelServe: Receptive field  : [1536, 1536]
2024-10-05 10:26:15 (INFO) TensorflowModelServe: Placeholder name : s2_t
2024-10-05 10:26:15 (INFO): Loading metadata from official product
2024-10-05 10:26:15 (INFO) TensorflowModelServe: Source info :
2024-10-05 10:26:15 (INFO) TensorflowModelServe: Receptive field  : [768, 768]
2024-10-05 10:26:15 (INFO) TensorflowModelServe: Placeholder name : dem
2024-10-05 10:26:15 (INFO) TensorflowModelServe: Output spacing ratio: 1
2024-10-05 10:26:15 (INFO) TensorflowModelServe: The TensorFlow model is used in fully convolutional mode
2024-10-05 10:26:15 (INFO) TensorflowModelServe: Setting background value to 0
2024-10-05 10:26:15 (INFO) TensorflowModelServe: Output field of expression: [1024, 1024]
2024-10-05 10:26:15 (INFO) TensorflowModelServe: Tiling disabled
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/pyotb/core.py", line 643, in execute
    self.app.Execute()
  File "/opt/otbtf/lib/otb/python/otbApplication.py", line 2445, in Execute
    return _otbApplication.Application_Execute(self)
RuntimeError: Exception thrown in otbApplication Application_Execute: /src/otb/otb/Modules/Remote/otbtf/include/otbTensorflowGraphOperations.cxx:178:
itk::ERROR: Tensor name "s2_estim_pad256" not found.
You can list all inputs/outputs of your SavedModel by running:
         `saved_model_cli show --dir your_model_dir --all`

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/bin/meraner_processor", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/decloud/production/meraner_processor.py", line 177, in main
    meraner_processor(params.il_s1, params.in_s2, params.savedmodel, params.in_dem,
  File "/usr/local/lib/python3.10/dist-packages/decloud/production/meraner_processor.py", line 121, in meraner_processor
    processor = inference(sources, sources_scales, pad=pad, ts=ts,
  File "/usr/local/lib/python3.10/dist-packages/decloud/production/inference.py", line 94, in inference
    infer = pyotb.TensorflowModelServe(parameters)
  File "/usr/local/lib/python3.10/dist-packages/pyotb/apps.py", line 122, in __init__
    super().__init__('TensorflowModelServe', *args, n_sources=n_sources, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pyotb/apps.py", line 111, in __init__
    super().__init__(app_name, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pyotb/core.py", line 594, in __init__
    self.execute()
  File "/usr/local/lib/python3.10/dist-packages/pyotb/core.py", line 645, in execute
    raise Exception(f'{self.name}: error during during app execution') from e
Exception: TensorflowModelServe: error during during app execution
@remicres
Copy link
Collaborator

remicres commented Oct 5, 2024

Hi @BernradMaillard ,
What does the following command show?

saved_model_cli show --dir /home/data/decloud/data/meranercamer/model/meraner_unet_bt1_bv1_lr0.0007_e2_cpu05-10-24-074632/ --all

@BernradMaillard
Copy link
Author

2024-10-05 23:06:49.927473: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['dem'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, -1, -1, 1)
        name: serving_default_dem:0
    inputs['s1_t'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, -1, -1, 2)
        name: serving_default_s1_t:0
    inputs['s2_t'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, -1, -1, 4)
        name: serving_default_s2_t:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['add'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, -1, -1, 4)
        name: StatefulPartitionedCall:0
    outputs['add_pad128'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, -1, -1, 4)
        name: StatefulPartitionedCall:1
    outputs['add_pad256'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, -1, -1, 4)
        name: StatefulPartitionedCall:2
    outputs['add_pad32'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, -1, -1, 4)
        name: StatefulPartitionedCall:3
    outputs['add_pad64'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, -1, -1, 4)
        name: StatefulPartitionedCall:4
  Method name is: tensorflow/serving/predict
The MetaGraph with tag set ['serve'] contains the following ops: {'Placeholder', 'RestoreV2', 'Identity', 'Select', 'Const', 'StatefulPartitionedCall', 'ReadVariableOp', 'AssignVariableOp', 'Mul', 'AddV2', 'StringJoin', 'SaveV2', 'VarHandleOp', 'MergeV2Checkpoints', 'StaticRegexFullMatch', 'StridedSlice', 'Conv2DBackpropInput', 'BiasAdd', 'DisableCopyOnRead', 'Pack', 'Relu', 'ConcatV2', 'Conv2D', 'ShardedFilename', 'NoOp', 'Shape'}

Concrete Functions:
  Function Name: '__call__'
    Option #1
      Callable with:
        Argument #1
          DType: dict
          Value: {'s2_t': TensorSpec(shape=(None, None, None, 4), dtype=tf.float32, name='inputs_s2_t'), 'dem': TensorSpec(shape=(None, None, None, 1), dtype=tf.float32, name='inputs_dem'), 's1_t': TensorSpec(shape=(None, None, None, 2), dtype=tf.float32, name='inputs_s1_t')}
        Argument #2
          DType: bool
          Value: True
        Argument #3
          DType: NoneType
          Value: None
    Option #2
      Callable with:
        Argument #1
          DType: dict
          Value: {'s2_t': TensorSpec(shape=(None, None, None, 4), dtype=tf.float32, name='s2_t'), 'dem': TensorSpec(shape=(None, None, None, 1), dtype=tf.float32, name='dem'), 's1_t': TensorSpec(shape=(None, None, None, 2), dtype=tf.float32, name='s1_t')}
        Argument #2
          DType: bool
          Value: True
        Argument #3
          DType: NoneType
          Value: None
    Option #3
      Callable with:
        Argument #1
          DType: dict
          Value: {'s2_t': TensorSpec(shape=(None, None, None, 4), dtype=tf.float32, name='s2_t'), 'dem': TensorSpec(shape=(None, None, None, 1), dtype=tf.float32, name='dem'), 's1_t': TensorSpec(shape=(None, None, None, 2), dtype=tf.float32, name='s1_t')}
        Argument #2
          DType: bool
          Value: False
        Argument #3
          DType: NoneType
          Value: None
    Option #4
      Callable with:
        Argument #1
          DType: dict
          Value: {'s2_t': TensorSpec(shape=(None, None, None, 4), dtype=tf.float32, name='inputs_s2_t'), 'dem': TensorSpec(shape=(None, None, None, 1), dtype=tf.float32, name='inputs_dem'), 's1_t': TensorSpec(shape=(None, None, None, 2), dtype=tf.float32, name='inputs_s1_t')}
        Argument #2
          DType: bool
          Value: False
        Argument #3
          DType: NoneType
          Value: None

  Function Name: '_default_save_signature'
    Option #1
      Callable with:
        Argument #1
          DType: dict
          Value: {'s2_t': TensorSpec(shape=(None, None, None, 4), dtype=tf.float32, name='s2_t'), 'dem': TensorSpec(shape=(None, None, None, 1), dtype=tf.float32, name='dem'), 's1_t': TensorSpec(shape=(None, None, None, 2), dtype=tf.float32, name='s1_t')}

  Function Name: 'call_and_return_all_conditional_losses'
    Option #1
      Callable with:
        Argument #1
          DType: dict
          Value: {'s2_t': TensorSpec(shape=(None, None, None, 4), dtype=tf.float32, name='inputs_s2_t'), 'dem': TensorSpec(shape=(None, None, None, 1), dtype=tf.float32, name='inputs_dem'), 's1_t': TensorSpec(shape=(None, None, None, 2), dtype=tf.float32, name='inputs_s1_t')}
        Argument #2
          DType: bool
          Value: True
        Argument #3
          DType: NoneType
          Value: None
    Option #2
      Callable with:
        Argument #1
          DType: dict
          Value: {'s2_t': TensorSpec(shape=(None, None, None, 4), dtype=tf.float32, name='s2_t'), 'dem': TensorSpec(shape=(None, None, None, 1), dtype=tf.float32, name='dem'), 's1_t': TensorSpec(shape=(None, None, None, 2), dtype=tf.float32, name='s1_t')}
        Argument #2
          DType: bool
          Value: False
        Argument #3
          DType: NoneType
          Value: None
    Option #3
      Callable with:
        Argument #1
          DType: dict
          Value: {'s2_t': TensorSpec(shape=(None, None, None, 4), dtype=tf.float32, name='s2_t'), 'dem': TensorSpec(shape=(None, None, None, 1), dtype=tf.float32, name='dem'), 's1_t': TensorSpec(shape=(None, None, None, 2), dtype=tf.float32, name='s1_t')}
        Argument #2
          DType: bool
          Value: True
        Argument #3
          DType: NoneType
          Value: None
    Option #4
      Callable with:
        Argument #1
          DType: dict
          Value: {'s2_t': TensorSpec(shape=(None, None, None, 4), dtype=tf.float32, name='inputs_s2_t'), 'dem': TensorSpec(shape=(None, None, None, 1), dtype=tf.float32, name='inputs_dem'), 's1_t': TensorSpec(shape=(None, None, None, 2), dtype=tf.float32, name='inputs_s1_t')}
        Argument #2
          DType: bool
          Value: False
        Argument #3
          DType: NoneType
          Value: None

@remicres
Copy link
Collaborator

remicres commented Oct 6, 2024

Hi @BernradMaillard it looks like a bug, I see why our CI doesn't catch it, it's because it relies only on pretrained models.

You must modify

s2_out = layers.Add()([net, normalized_inputs["s2_t"]])

from https://github.com/CNES/decloud/blob/master/decloud/models/meraner_unet.py#L95

into

s2_out = layers.Add(name="s2_estim")([net, normalized_inputs["s2_t"]])

Then re-train your model.
Or, if you want to use the old trained model, you can change the name of the output tensor in meraner_processor.py (add instead of s2_estim).

We will make a patch for this issue asap

@remicres
Copy link
Collaborator

remicres commented Oct 10, 2024

Hi @BernradMaillard ,

We found the issue in the code.
Is was not what I suggested in my last reply, in short, you have to remove the last Add() layer of the meraner_unet model (the other meraner_unet_20m works fine).

just replace

        net = conv_final(net)
        s2_out = layers.Add()([net, normalized_inputs["s2_t"]])

with

        s2_out = conv_final(net)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants