diff --git a/.githooks/pre-commit b/.githooks/pre-commit new file mode 100755 index 00000000..f3a5aa98 --- /dev/null +++ b/.githooks/pre-commit @@ -0,0 +1,69 @@ +#!/bin/bash +# +# A hook script to verify what is about to be committed. +# Called by "git commit" with no arguments. The hook should +# exit with non-zero status after issuing an appropriate message if +# it wants to stop the commit. + +# Fail immediately at first issue with the relevant exit status. +set -eo pipefail + +# =================================================================== + +if git rev-parse --verify HEAD >/dev/null 2>&1 +then + against=HEAD +else + # Initial commit: diff against an empty tree object + against=$(git hash-object -t tree /dev/null) +fi + +# =================================================================== + +# Check that ftorch.90 is not modified and staged alone. +git diff --cached --name-only | if grep --quiet "ftorch.f90"; then + git diff --cached --name-only | if ! grep --quiet "ftorch.fypp"; then + cat <<\EOF +Error: File ftorch.f90 has been modified and staged without ftorch.fypp being changed. +ftorch.90 should be generated from ftorch.fypp using fypp. +Please restore ftorch.f90 and make your modifications to ftorch.fypp instead. +EOF + exit 1 + fi +fi + +# Check to see if ftorch.fypp has been modified AND is staged. +git diff --cached --name-only | if grep --quiet "ftorch.fypp"; then + + # Check that ftorch.90 is also modified and staged. + git diff --cached --name-only | if ! grep --quiet "ftorch.f90"; then + cat <<\EOF +Error: File ftorch.fypp has been modified and staged, but ftorch.f90 has not. +ftorch.90 should be generated from ftorch.fypp and both committed together. +Please run fypp on ftorch.fypp to generate ftorch.f90 and commit together. +EOF + exit 1 + else + # Check fypp available, and raise error and exit if not. + if ! command -v fypp &> /dev/null; then + cat <<\EOF +echo "Error: Could not find fypp to run on ftorch.fypp. +Please install fypp using "pip install fypp" and then try committing again. +EOF + exit 1 + fi + + # If fypp is available and both .f90 and .fypp staged, check they match. + fypp src/ftorch.fypp src/ftorch.f90_tmp + if ! diff -q "src/ftorch.f90" "src/ftorch.f90_tmp" &> /dev/null; then + rm src/ftorch.f90_tmp + cat <<\EOF +Error: The code in ftorch.f90 does not match that expected from ftorch.fypp. +Please re-run fypp on ftorch.fypp to ensure consistency before committing. +EOF + exit 1 + else + rm src/ftorch.f90_tmp + fi + fi +fi diff --git a/.github/workflows/fypp.yml b/.github/workflows/fypp.yml new file mode 100644 index 00000000..88f5af8c --- /dev/null +++ b/.github/workflows/fypp.yml @@ -0,0 +1,27 @@ +name: fypp-checks + +on: + # run on every push + push: + +jobs: + various: + name: FYPP checks - runs check on fypp and f90 files + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: "3.11" + - run: pip install fypp + + - name: Check fypp matches f90 + run: | + fypp src/ftorch.fypp src/temp.f90_temp + if ! diff -q src/ftorch.f90 src/temp.f90_temp; then + echo "Error: The code in ftorch.f90 does not match that expected from ftorch.fypp." + echo "Please re-run fypp on ftorch.fypp to ensure consistency and re-commit." + exit 1 + else + exit 0 + fi diff --git a/README.md b/README.md index 7852c53c..5f6c9524 100644 --- a/README.md +++ b/README.md @@ -128,11 +128,10 @@ To use the trained Torch model from within Fortran we need to import the `ftorch A very simple example is given below. For more detailed documentation please consult the API documentation, source code, and examples. -This minimal snippet loads a saved Torch model, creates inputs consisting of two `10x10` matrices (one of ones, and one of zeros), and runs the model to infer output. +This minimal snippet loads a saved Torch model, creates an input consisting of a `10x10` matrix of ones, and runs the model to infer output. +This is illustrative only, and we recommend following the [examples](examples/) before writing your own code to explore more features. ```fortran -! Import any C bindings as required for this code -use, intrinsic :: iso_c_binding, only: c_int, c_int64_t, c_loc ! Import library for interfacing with PyTorch use ftorch @@ -141,34 +140,32 @@ implicit none ! Generate an object to hold the Torch model type(torch_module) :: model -! Set up types of input and output data and the interface with C -integer(c_int), parameter :: dims_input = 2 -integer(c_int64_t) :: shape_input(dims_input) -integer(c_int), parameter :: n_inputs = 2 +! Set up types of input and output data +integer, parameter :: n_inputs = 1 type(torch_tensor), dimension(n_inputs) :: model_input_arr -integer(c_int), parameter :: dims_output = 1 -integer(c_int64_t) :: shape_output(dims_output) type(torch_tensor) :: model_output -! Set up the model inputs as Fortran arrays -real, dimension(10,10), target :: input_1, input_2 +! Set up the model input and output as Fortran arrays +real, dimension(10,10), target :: input real, dimension(5), target :: output +! Set up number of dimensions of input tensor and axis order +integer, parameter :: in_dims = 2 +integer :: in_layout(in_dims) = [1,2] +integer, parameter :: out_dims = 1 +integer :: out_layout(out_dims) = [1] + ! Initialise the Torch model to be used model = torch_module_load("/path/to/saved/model.pt") -! Initialise the inputs as Fortran -input_1 = 0.0 -input_2 = 1.0 +! Initialise the inputs as Fortran array of ones +input = 1.0 ! Wrap Fortran data as no-copy Torch Tensors ! There may well be some reshaping required depending on the ! structure of the model which is not covered here (see examples) -shape_input = (/10, 10/) -shape_output = (/5/) -model_input_arr(1) = torch_tensor_from_blob(c_loc(input_1), dims_input, shape_input, torch_kFloat64, torch_kCPU) -model_input_arr(2) = torch_tensor_from_blob(c_loc(input_2), dims_input, shape_input, torch_kFloat64, torch_kCPU) -model_output = torch_tensor_from_blob(c_loc(output), dims_output, shape_output, torch_kFloat64, torch_kCPU) +model_input_arr(1) = torch_tensor_from_array(input, in_layout, torch_kCPU) +model_output = torch_tensor_from_array(output, out_layout, torch_kCPU) ! Run model and Infer ! Again, there may be some reshaping required depending on model design @@ -180,7 +177,6 @@ write(*,*) output ! Clean up call torch_module_delete(model) call torch_tensor_delete(model_input_arr(1)) -call torch_tensor_delete(model_input_arr(2)) call torch_tensor_delete(model_output) ``` diff --git a/examples/1_SimpleNet/simplenet_infer_fortran.f90 b/examples/1_SimpleNet/simplenet_infer_fortran.f90 index 2f1ff385..199b984c 100644 --- a/examples/1_SimpleNet/simplenet_infer_fortran.f90 +++ b/examples/1_SimpleNet/simplenet_infer_fortran.f90 @@ -1,28 +1,30 @@ program inference - ! Imports primitives used to interface with C - use, intrinsic :: iso_c_binding, only: c_int64_t, c_float, c_char, c_ptr, c_loc + ! Import precision info from iso + use, intrinsic :: iso_fortran_env, only : sp => real32 + ! Import our library for interfacing with PyTorch use ftorch implicit none - + + ! Set precision for reals + integer, parameter :: wp = sp + integer :: num_args, ix character(len=128), dimension(:), allocatable :: args - ! Set up types of input and output data and the interface with C + ! Set up Fortran data structures + real(wp), dimension(5), target :: in_data + real(wp), dimension(5), target :: out_data + integer, parameter :: n_inputs = 1 + integer :: tensor_layout(1) = [1] + + ! Set up Torch data structures type(torch_module) :: model type(torch_tensor), dimension(1) :: in_tensor type(torch_tensor) :: out_tensor - real(c_float), dimension(:), allocatable, target :: in_data - integer(c_int), parameter :: n_inputs = 1 - real(c_float), dimension(:), allocatable, target :: out_data - - integer(c_int), parameter :: tensor_dims = 1 - integer(c_int64_t) :: tensor_shape(tensor_dims) = [5] - integer(c_int) :: tensor_layout(tensor_dims) = [1] - ! Get TorchScript model file as a command line argument num_args = command_argument_count() allocate(args(num_args)) @@ -30,18 +32,14 @@ program inference call get_command_argument(ix,args(ix)) end do - ! Allocate one-dimensional input/output arrays, based on multiplication of all input/output dimension sizes - allocate(in_data(tensor_shape(1))) - allocate(out_data(tensor_shape(1))) - ! Initialise data in_data = [0.0, 1.0, 2.0, 3.0, 4.0] - ! Create input/output tensors from the above arrays - in_tensor(1) = torch_tensor_from_blob(c_loc(in_data), tensor_dims, tensor_shape, torch_kFloat32, torch_kCPU, tensor_layout) - out_tensor = torch_tensor_from_blob(c_loc(out_data), tensor_dims, tensor_shape, torch_kFloat32, torch_kCPU, tensor_layout) + ! Create Torch input/output tensors from the above arrays + in_tensor(1) = torch_tensor_from_array(in_data, tensor_layout, torch_kCPU) + out_tensor = torch_tensor_from_array(out_data, tensor_layout, torch_kCPU) - ! Load ML model (edit this line to use different models) + ! Load ML model model = torch_module_load(args(1)) ! Infer @@ -52,7 +50,5 @@ program inference call torch_module_delete(model) call torch_tensor_delete(in_tensor(1)) call torch_tensor_delete(out_tensor) - deallocate(in_data) - deallocate(out_data) end program inference diff --git a/examples/2_ResNet18/resnet_infer_fortran.f90 b/examples/2_ResNet18/resnet_infer_fortran.f90 index dfc012b1..1af256af 100644 --- a/examples/2_ResNet18/resnet_infer_fortran.f90 +++ b/examples/2_ResNet18/resnet_infer_fortran.f90 @@ -1,18 +1,12 @@ program inference - ! Imports primitives used to interface with C - use, intrinsic :: iso_c_binding, only: c_sp=>c_float, c_dp=>c_double, c_int64_t, c_loc - use, intrinsic :: iso_fortran_env, only : sp => real32, dp => real64 + use, intrinsic :: iso_fortran_env, only : sp => real32 ! Import our library for interfacing with PyTorch use :: ftorch implicit none - ! Define working precision for C primitives - ! Precision must match `wp` in resnet18.py and `wp_torch` in pt2ts.py - integer, parameter :: c_wp = c_sp integer, parameter :: wp = sp - integer, parameter :: torch_wp = torch_kFloat32 call main() @@ -25,21 +19,21 @@ subroutine main() integer :: num_args, ix character(len=128), dimension(:), allocatable :: args - ! Set up types of input and output data and the interface with C + ! Set up types of input and output data type(torch_module) :: model type(torch_tensor), dimension(1) :: in_tensor type(torch_tensor) :: out_tensor - real(c_wp), dimension(:,:,:,:), allocatable, target :: in_data - integer(c_int), parameter :: n_inputs = 1 - real(c_wp), dimension(:,:), allocatable, target :: out_data + real(wp), dimension(:,:,:,:), allocatable, target :: in_data + real(wp), dimension(:,:), allocatable, target :: out_data + integer, parameter :: n_inputs = 1 - integer(c_int), parameter :: in_dims = 4 - integer(c_int64_t) :: in_shape(in_dims) = [1, 3, 224, 224] - integer(c_int) :: in_layout(in_dims) = [1,2,3,4] - integer(c_int), parameter :: out_dims = 2 - integer(c_int64_t) :: out_shape(out_dims) = [1, 1000] - integer(c_int) :: out_layout(out_dims) = [1,2] + integer, parameter :: in_dims = 4 + integer :: in_shape(in_dims) = [1, 3, 224, 224] + integer :: in_layout(in_dims) = [1,2,3,4] + integer, parameter :: out_dims = 2 + integer :: out_shape(out_dims) = [1, 1000] + integer :: out_layout(out_dims) = [1,2] ! Binary file containing input tensor character(len=*), parameter :: filename = '../data/image_tensor.dat' @@ -72,8 +66,9 @@ subroutine main() call load_data(filename, tensor_length, in_data) ! Create input/output tensors from the above arrays - in_tensor(1) = torch_tensor_from_blob(c_loc(in_data), in_dims, in_shape, torch_wp, torch_kCPU, in_layout) - out_tensor = torch_tensor_from_blob(c_loc(out_data), out_dims, out_shape, torch_wp, torch_kCPU, out_layout) + in_tensor(1) = torch_tensor_from_array(in_data, in_layout, torch_kCPU) + + out_tensor = torch_tensor_from_array(out_data, out_layout, torch_kCPU) ! Load ML model (edit this line to use different models) model = torch_module_load(args(1)) @@ -113,9 +108,9 @@ subroutine load_data(filename, tensor_length, in_data) character(len=*), intent(in) :: filename integer, intent(in) :: tensor_length - real(c_wp), dimension(:,:,:,:), intent(out) :: in_data + real(wp), dimension(:,:,:,:), intent(out) :: in_data - real(c_wp) :: flat_data(tensor_length) + real(wp) :: flat_data(tensor_length) integer :: ios character(len=100) :: ioerrmsg @@ -166,7 +161,7 @@ subroutine calc_probs(out_data, probabilities) implicit none - real(c_wp), dimension(:,:), intent(in) :: out_data + real(wp), dimension(:,:), intent(in) :: out_data real(wp), dimension(:,:), intent(out) :: probabilities real(wp) :: prob_sum diff --git a/examples/tensor_tests/CMakeLists.txt b/examples/tensor_tests/CMakeLists.txt deleted file mode 100644 index 6571f1cb..00000000 --- a/examples/tensor_tests/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -cmake_minimum_required(VERSION 3.1 FATAL_ERROR) -#policy CMP0076 - target_sources source files are relative to file where target_sources is run -cmake_policy (SET CMP0076 NEW) - -set(PROJECT_NAME test_tensor) - -project(${PROJECT_NAME} LANGUAGES Fortran) - -# Build in Debug mode if not specified -if(NOT CMAKE_BUILD_TYPE) - set(CMAKE_BUILD_TYPE Debug CACHE STRING "" FORCE) -endif() - -find_package(FTorch) -message(STATUS "Building with Fortran PyTorch coupling") - -# Some tests for tensor generation. -add_executable(test_tensor test_tensor.f90) -target_link_libraries(test_tensor PRIVATE FTorch::ftorch) diff --git a/examples/tensor_tests/test_tensor.f90 b/examples/tensor_tests/test_tensor.f90 deleted file mode 100644 index d7bc1410..00000000 --- a/examples/tensor_tests/test_tensor.f90 +++ /dev/null @@ -1,69 +0,0 @@ -program test_tensor - use, intrinsic :: iso_c_binding, only: c_int64_t, c_float, c_char, c_ptr, c_loc - use ftorch - implicit none - - real(kind=8), dimension(:,:), allocatable, target :: uuu_flattened, vvv_flattened - real(kind=8), dimension(:,:), allocatable, target :: lat_reshaped, psfc_reshaped - real(kind=8), dimension(:,:), allocatable, target :: gwfcng_x_flattened, gwfcng_y_flattened - type(torch_tensor), target :: output_tensor - integer(c_int), parameter :: dims_1D = 2 - integer(c_int), parameter :: dims_2D = 2 - integer(c_int64_t) :: shape_2D_F(dims_2D), shape_2D_C(dims_2D) - integer(c_int64_t) :: shape_1D_F(dims_1D), shape_1D_C(dims_1D) - integer(c_int) :: layout_F(dims_1D), layout_C(dims_1D) - integer :: imax, jmax, kmax, i, j, k - - imax = 1 - jmax = 5 - kmax = 7 - - shape_2D_F = (/ kmax, imax*jmax /) - shape_1D_F = (/ 1, imax*jmax /) - shape_2D_C = (/ imax*jmax, kmax /) - shape_1D_C = (/ imax*jmax, 1 /) - - layout_F = (/ 1, 2 /) - layout_C = (/ 2, 1 /) - - allocate( lat_reshaped(imax*jmax, 1) ) - allocate( uuu_flattened(imax*jmax, kmax) ) - do i = 1, imax*jmax - lat_reshaped(i, 1) = i - do k = 1, kmax - uuu_flattened(i, k) = i + k*100 - end do - end do - - write(*,*) uuu_flattened - - output_tensor = torch_tensor_from_blob(c_loc(uuu_flattened), & - dims_2D, shape_2D_C, torch_kFloat64, torch_kCPU, layout_F) - - call torch_tensor_print(output_tensor) - - output_tensor = torch_tensor_from_blob(c_loc(uuu_flattened), & - dims_2D, shape_2D_F, torch_kFloat64, torch_kCPU, layout_C) - - call torch_tensor_print(output_tensor) - - shape_2D_F = shape(uuu_flattened) - output_tensor = torch_tensor_from_array_c_double(uuu_flattened, shape_2D_F, torch_kCPU) - - call torch_tensor_print(output_tensor) - - output_tensor = torch_tensor_from_array(uuu_flattened, shape_2D_F, torch_kCPU) - - call torch_tensor_print(output_tensor) - - ! output_tensor = torch_tensor_zeros( & - ! dims_2D, shape_2D_C, torch_kFloat64, torch_kCPU) - - ! call torch_tensor_print(output_tensor) - - ! output_tensor = torch_tensor_ones( & - ! dims_2D, shape_2D_C, torch_kFloat64, torch_kCPU) - - ! call torch_tensor_print(output_tensor) - -end program test_tensor diff --git a/src/ftorch.f90 b/src/ftorch.f90 index 3412320a..945673cf 100644 --- a/src/ftorch.f90 +++ b/src/ftorch.f90 @@ -1,305 +1,1007 @@ +!| Main module for FTorch containing types and procedures. +! Generated from `ftorch.fypp` using the [fypp Fortran preprocessor](https://fypp.readthedocs.io/en/stable/index.html). +! +! * License +! FTorch is released under an MIT license. +! See the [LICENSE](https://github.com/Cambridge-ICCS/FTorch/blob/main/LICENSE) +! file for details. + module ftorch - use, intrinsic :: iso_c_binding, only: c_int, c_int8_t, c_int16_t, c_int32_t, c_int64_t, c_int64_t, & - c_float, c_double, c_char, c_ptr, c_null_ptr - implicit none - - type torch_module - type(c_ptr) :: p = c_null_ptr - end type torch_module - - type torch_tensor - type(c_ptr) :: p = c_null_ptr - end type torch_tensor - - ! From c_torch.h (torch_data_t) - enum, bind(c) - enumerator :: torch_kUInt8 = 0 - enumerator :: torch_kInt8 = 1 - enumerator :: torch_kInt16 = 2 - enumerator :: torch_kInt32 = 3 - enumerator :: torch_kInt64 = 4 - enumerator :: torch_kFloat16 = 5 - enumerator :: torch_kFloat32 = 6 - enumerator :: torch_kFloat64 = 7 - end enum - - ! From c_torch.h (torch_device_t) - enum, bind(c) - enumerator :: torch_kCPU = 0 - enumerator :: torch_kCUDA = 1 - end enum - - ! Interface for calculating tensor from array for different possible input types - interface torch_tensor_from_array - module procedure torch_tensor_from_array_c_float - module procedure torch_tensor_from_array_c_double - ! module procedure torch_tensor_from_array_c_int8_t - ! module procedure torch_tensor_from_array_c_int16_t - ! module procedure torch_tensor_from_array_c_int32_t - ! module procedure torch_tensor_from_array_c_int64_t - end interface + use, intrinsic :: iso_c_binding, only: c_int, c_int8_t, c_int16_t, c_int32_t, c_int64_t, c_int64_t, & + c_float, c_double, c_char, c_ptr, c_null_ptr + use, intrinsic :: iso_fortran_env, only: int8, int16, int32, int64, real32, real64 + + implicit none + + !> Type for holding a torch neural net (nn.Module). + type torch_module + type(c_ptr) :: p = c_null_ptr !! pointer to the neural net module in memory + end type torch_module + + !> Type for holding a Torch tensor. + type torch_tensor + type(c_ptr) :: p = c_null_ptr !! pointer to the tensor in memory + end type torch_tensor + + !| Enumerator for Torch data types + ! From c_torch.h (torch_data_t) + ! Note that 0 `torch_kUInt8` and 5 `torch_kFloat16` are not sypported in Fortran + enum, bind(c) + enumerator :: torch_kUInt8 = 0 ! not supported in Fortran + enumerator :: torch_kInt8 = 1 + enumerator :: torch_kInt16 = 2 + enumerator :: torch_kInt32 = 3 + enumerator :: torch_kInt64 = 4 + enumerator :: torch_kFloat16 = 5 ! not supported in Fortran + enumerator :: torch_kFloat32 = 6 + enumerator :: torch_kFloat64 = 7 + end enum + + + !| Enumerator for Torch devices + ! From c_torch.h (torch_device_t) + enum, bind(c) + enumerator :: torch_kCPU = 0 + enumerator :: torch_kCUDA = 1 + end enum + + !> Interface for directing `torch_tensor_from_array` to possible input types and ranks + interface torch_tensor_from_array + module procedure torch_tensor_from_array_int8_1d + module procedure torch_tensor_from_array_int8_2d + module procedure torch_tensor_from_array_int8_3d + module procedure torch_tensor_from_array_int8_4d + module procedure torch_tensor_from_array_int16_1d + module procedure torch_tensor_from_array_int16_2d + module procedure torch_tensor_from_array_int16_3d + module procedure torch_tensor_from_array_int16_4d + module procedure torch_tensor_from_array_int32_1d + module procedure torch_tensor_from_array_int32_2d + module procedure torch_tensor_from_array_int32_3d + module procedure torch_tensor_from_array_int32_4d + module procedure torch_tensor_from_array_int64_1d + module procedure torch_tensor_from_array_int64_2d + module procedure torch_tensor_from_array_int64_3d + module procedure torch_tensor_from_array_int64_4d + module procedure torch_tensor_from_array_real32_1d + module procedure torch_tensor_from_array_real32_2d + module procedure torch_tensor_from_array_real32_3d + module procedure torch_tensor_from_array_real32_4d + module procedure torch_tensor_from_array_real64_1d + module procedure torch_tensor_from_array_real64_2d + module procedure torch_tensor_from_array_real64_3d + module procedure torch_tensor_from_array_real64_4d + end interface + + interface + function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor_p) & + bind(c, name = 'torch_from_blob') + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr + + ! Arguments + type(c_ptr), value, intent(in) :: data + integer(c_int), value, intent(in) :: ndims + integer(c_int64_t), intent(in) :: tensor_shape(*) + integer(c_int64_t), intent(in) :: strides(*) + integer(c_int), value, intent(in) :: dtype + integer(c_int), value, intent(in) :: device + type(c_ptr) :: tensor_p + end function torch_from_blob_c + end interface contains - ! Torch Tensor API - !> Exposes the given data as a tensor without taking ownership of the original data. - !> This routine will take an (i, j, k) array and return an (k, j, i) tensor. - function torch_tensor_from_blob(data, ndims, tensor_shape, dtype, device, layout) result(tensor) - use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr - type(c_ptr), intent(in) :: data !! Pointer to data - integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor - integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor - integer(c_int), intent(in) :: dtype !! Data type of the tensor - integer(c_int), intent(in) :: device !! Device on which the tensor will live on (torch_kCPU or torch_kCUDA) - integer(c_int), intent(in) :: layout(*) !! Layout for strides for accessing data - type(torch_tensor) :: tensor !! Returned tensor - - integer(c_int) :: i !! loop index - integer(c_int64_t) :: strides(ndims) !! Strides for accessing data - - interface - function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor) & - bind(c, name='torch_from_blob') - use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr - type(c_ptr), value, intent(in) :: data - integer(c_int), value, intent(in) :: ndims - integer(c_int64_t), intent(in) :: tensor_shape(*) - integer(c_int64_t), intent(in) :: strides(*) - integer(c_int), value, intent(in) :: dtype - integer(c_int), value, intent(in) :: device - type(c_ptr) :: tensor - end function torch_from_blob_c - end interface - - strides(layout(1)) = 1 - do i = 2, ndims - strides(layout(i)) = strides(layout(i-1)) * tensor_shape(layout(i-1)) - end do - tensor%p = torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) - end function torch_tensor_from_blob - - !> This routine will take an (i, j, k) array and return an (k, j, i) tensor - !> it is invoked from a set of interfaces `torch_tensor_from_array_dtype` - function t_t_from_array(data_arr, tensor_shape, dtype, device) result(tensor) - use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_double, c_loc - type(c_ptr), intent(in) :: data_arr !! Pointer to data - integer(c_int64_t), intent(in) :: tensor_shape(:) !! Shape of the tensor - integer(c_int), intent(in) :: dtype !! Data type of the tensor - integer(c_int), intent(in) :: device !! Device on which the tensor will live on (torch_kCPU or torch_kCUDA) - type(torch_tensor) :: tensor !! Returned tensor - - integer(c_int) :: i !! loop index - integer(c_int64_t), allocatable :: strides(:) !! Strides for accessing data - integer(c_int), allocatable :: layout(:) !! Layout for strides for accessing data - integer(c_int) :: ndims !! Number of dimensions of the tensor - - interface - function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor) & - bind(c, name='torch_from_blob') - use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr - type(c_ptr), value, intent(in) :: data - integer(c_int), value, intent(in) :: ndims - integer(c_int64_t), intent(in) :: tensor_shape(*) - integer(c_int64_t), intent(in) :: strides(*) - integer(c_int), value, intent(in) :: dtype - integer(c_int), value, intent(in) :: device - type(c_ptr) :: tensor - end function torch_from_blob_c - end interface - - ndims = size(tensor_shape) - - allocate(strides(ndims)) - allocate(layout(ndims)) - - ! Fortran Layout - do i=1, ndims - layout(i) = i - end do - - strides(layout(1)) = 1 - do i = 2, ndims - strides(layout(i)) = strides(layout(i-1)) * tensor_shape(layout(i-1)) - end do - - tensor%p = torch_from_blob_c(data_arr, ndims, tensor_shape, strides, dtype, device) - - deallocate(strides) - deallocate(layout) - - end function t_t_from_array - - !> Returns a tensor filled with the scalar value 1. - function torch_tensor_ones(ndims, tensor_shape, dtype, device) result(tensor) - use, intrinsic :: iso_c_binding, only : c_int, c_int64_t - integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor - integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor - integer(c_int), intent(in) :: dtype !! Data type of the tensor - integer(c_int), intent(in) :: device !! Device on which the tensor will live on (torch_kCPU or torch_kCUDA) - type(torch_tensor) :: tensor !! Returned tensor - - interface - function torch_ones_c(ndims, tensor_shape, dtype, device) result(tensor) & - bind(c, name='torch_ones') - use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr - integer(c_int), value, intent(in) :: ndims - integer(c_int64_t), intent(in) :: tensor_shape(*) - integer(c_int), value, intent(in) :: dtype - integer(c_int), value, intent(in) :: device - type(c_ptr) :: tensor - end function torch_ones_c - end interface - - tensor%p = torch_ones_c(ndims, tensor_shape, dtype, device) - end function torch_tensor_ones - - !> Returns a tensor filled with the scalar value 0. - function torch_tensor_zeros(ndims, tensor_shape, dtype, device) result(tensor) - use, intrinsic :: iso_c_binding, only : c_int, c_int64_t - integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor - integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor - integer(c_int), intent(in) :: dtype !! Data type of the tensor - integer(c_int), intent(in) :: device !! Device on which the tensor will live on (torch_kCPU or torch_kCUDA) - type(torch_tensor) :: tensor !! Returned tensor - - interface - function torch_zeros_c(ndims, tensor_shape, dtype, device) result(tensor) & - bind(c, name='torch_zeros') - use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr - integer(c_int), value, intent(in) :: ndims - integer(c_int64_t), intent(in) :: tensor_shape(*) - integer(c_int), value, intent(in) :: dtype - integer(c_int), value, intent(in) :: device - type(c_ptr) :: tensor - end function torch_zeros_c - end interface - - tensor%p = torch_zeros_c(ndims, tensor_shape, dtype, device) - end function torch_tensor_zeros - - !> Prints the contents of a tensor. - subroutine torch_tensor_print(tensor) - type(torch_tensor), intent(in) :: tensor !! Input tensor - - interface - subroutine torch_tensor_print_c(tensor) & - bind(c, name='torch_tensor_print') - use, intrinsic :: iso_c_binding, only : c_ptr - type(c_ptr), value, intent(in) :: tensor - end subroutine torch_tensor_print_c - end interface - - call torch_tensor_print_c(tensor%p) - end subroutine torch_tensor_print - - !> Deallocates a tensor. - subroutine torch_tensor_delete(tensor) - type(torch_tensor), intent(in) :: tensor !! Input tensor - - interface - subroutine torch_tensor_delete_c(tensor) & - bind(c, name='torch_tensor_delete') - use, intrinsic :: iso_c_binding, only : c_ptr - type(c_ptr), value, intent(in) :: tensor - end subroutine torch_tensor_delete_c - end interface - - call torch_tensor_delete_c(tensor%p) - end subroutine torch_tensor_delete - - ! Torch Module API - !> Loads a Torch Script module (pre-trained PyTorch model saved with Torch Script) - function torch_module_load(filename) result(module) - use, intrinsic :: iso_c_binding, only : c_null_char - character(*), intent(in) :: filename !! Filename of Torch Script module - type(torch_module) :: module !! Returned deserialized module - - interface - function torch_jit_load_c(filename) result(module) & - bind(c, name='torch_jit_load') - use, intrinsic :: iso_c_binding, only : c_char, c_ptr - character(c_char), intent(in) :: filename(*) - type(c_ptr) :: module - end function torch_jit_load_c - end interface - - ! Need to append c_null_char at end of filename - module%p = torch_jit_load_c(trim(adjustl(filename))//c_null_char) - end function torch_module_load - - !> Performs a forward pass of the module with the input tensors - subroutine torch_module_forward(module, input_tensors, n_inputs, output_tensor) - use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_loc - type(torch_module), intent(in) :: module !! Module - type(torch_tensor), intent(in), dimension(:) :: input_tensors !! Array of Input tensors - type(torch_tensor), intent(in) :: output_tensor !! Returned output tensors - integer(c_int) :: n_inputs - - integer :: i - type(c_ptr), dimension(n_inputs), target :: input_ptrs - - interface - subroutine torch_jit_module_forward_c(module, input_tensors, n_inputs, & - output_tensor) & - bind(c, name='torch_jit_module_forward') - use, intrinsic :: iso_c_binding, only : c_ptr, c_int - type(c_ptr), value, intent(in) :: module - type(c_ptr), value, intent(in) :: input_tensors - integer(c_int), value, intent(in) :: n_inputs - type(c_ptr), value, intent(in) :: output_tensor - end subroutine torch_jit_module_forward_c - end interface - - ! Assign array of pointers to the input tensors - do i = 1, n_inputs - input_ptrs(i) = input_tensors(i)%p - end do - - call torch_jit_module_forward_c(module%p, c_loc(input_ptrs), n_inputs, output_tensor%p) - end subroutine torch_module_forward - - !> Deallocates a Torch Script module - subroutine torch_module_delete(module) - type(torch_module), intent(in) :: module !! Module - - interface - subroutine torch_jit_module_delete_c(module) & - bind(c, name='torch_jit_module_delete') - use, intrinsic :: iso_c_binding, only : c_ptr - type(c_ptr), value, intent(in) :: module - end subroutine torch_jit_module_delete_c - end interface - - call torch_jit_module_delete_c(module%p) - end subroutine torch_module_delete - - ! Series of interface functions - function torch_tensor_from_array_c_double(data_arr, tensor_shape, device) result(tensor) - !function torch_tensor_from_array_c_double(data_arr, tensor_shape) result(tensor) - use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_double, c_loc - real(c_double), intent(in), target :: data_arr(*) !! Fortran array of data - ! real(c_double), intent(in), target :: data_arr(*) !! Fortran array of data - integer(c_int64_t), intent(in) :: tensor_shape(:) !! Shape of the tensor - integer(c_int), parameter :: dtype = torch_kFloat64 - integer(c_int), intent(in) :: device !! Device on which the tensor will live on (torch_kCPU or torch_kCUDA) - type(torch_tensor) :: tensor !! Returned tensor - - - tensor = t_t_from_array(c_loc(data_arr), tensor_shape, dtype, device) - - end function torch_tensor_from_array_c_double - - function torch_tensor_from_array_c_float(data_arr, tensor_shape, device) result(tensor) - use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc - real(c_float), intent(in), target :: data_arr(*) !! Fortran array of data - integer(c_int64_t), intent(in) :: tensor_shape(:) !! Shape of the tensor - integer(c_int), parameter :: dtype = torch_kFloat32 - integer(c_int), intent(in) :: device !! Device on which the tensor will live on (torch_kCPU or torch_kCUDA) - type(torch_tensor) :: tensor !! Returned tensor - - tensor = t_t_from_array(c_loc(data_arr), tensor_shape, dtype, device) - - end function torch_tensor_from_array_c_float + ! Torch Tensor API + !| Exposes the given data as a tensor without taking ownership of the original data. + ! This routine will take an (i, j, k) array and return an (k, j, i) tensor. + function torch_tensor_from_blob(data, ndims, tensor_shape, layout, dtype, device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr + type(c_ptr), intent(in) :: data !! Pointer to data + integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor + integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor + integer(c_int), intent(in) :: dtype !! Data type of the tensor + integer(c_int), intent(in) :: device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: layout(*) !! Layout for strides for accessing data + type(torch_tensor) :: tensor !! Returned tensor + + integer(c_int) :: i !! loop index + integer(c_int64_t) :: strides(ndims) !! Strides for accessing data + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) + end do + tensor%p = torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) + end function torch_tensor_from_blob + + !> Returns a tensor filled with the scalar value 1. + function torch_tensor_ones(ndims, tensor_shape, dtype, device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t + integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor + integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor + integer(c_int), intent(in) :: dtype !! Data type of the tensor + integer(c_int), intent(in) :: device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + type(torch_tensor) :: tensor !! Returned tensor + + interface + function torch_ones_c(ndims, tensor_shape, dtype, device) result(tensor) & + bind(c, name = 'torch_ones') + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr + integer(c_int), value, intent(in) :: ndims + integer(c_int64_t), intent(in) :: tensor_shape(*) + integer(c_int), value, intent(in) :: dtype + integer(c_int), value, intent(in) :: device + type(c_ptr) :: tensor + end function torch_ones_c + end interface + + tensor%p = torch_ones_c(ndims, tensor_shape, dtype, device) + end function torch_tensor_ones + + !> Returns a tensor filled with the scalar value 0. + function torch_tensor_zeros(ndims, tensor_shape, dtype, device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t + integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor + integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor + integer(c_int), intent(in) :: dtype !! Data type of the tensor + integer(c_int), intent(in) :: device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + type(torch_tensor) :: tensor !! Returned tensor + + interface + function torch_zeros_c(ndims, tensor_shape, dtype, device) result(tensor) & + bind(c, name = 'torch_zeros') + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr + integer(c_int), value, intent(in) :: ndims + integer(c_int64_t), intent(in) :: tensor_shape(*) + integer(c_int), value, intent(in) :: dtype + integer(c_int), value, intent(in) :: device + type(c_ptr) :: tensor + end function torch_zeros_c + end interface + + tensor%p = torch_zeros_c(ndims, tensor_shape, dtype, device) + end function torch_tensor_zeros + + !> Prints the contents of a tensor. + subroutine torch_tensor_print(tensor) + type(torch_tensor), intent(in) :: tensor !! Input tensor + + interface + subroutine torch_tensor_print_c(tensor) & + bind(c, name = 'torch_tensor_print') + use, intrinsic :: iso_c_binding, only : c_ptr + type(c_ptr), value, intent(in) :: tensor + end subroutine torch_tensor_print_c + end interface + + call torch_tensor_print_c(tensor%p) + end subroutine torch_tensor_print + + !> Deallocates a tensor. + subroutine torch_tensor_delete(tensor) + type(torch_tensor), intent(in) :: tensor !! Input tensor + + interface + subroutine torch_tensor_delete_c(tensor) & + bind(c, name = 'torch_tensor_delete') + use, intrinsic :: iso_c_binding, only : c_ptr + type(c_ptr), value, intent(in) :: tensor + end subroutine torch_tensor_delete_c + end interface + + call torch_tensor_delete_c(tensor%p) + end subroutine torch_tensor_delete + + ! Torch Module API + !> Loads a Torch Script module (pre-trained PyTorch model saved with Torch Script) + function torch_module_load(filename) result(module) + use, intrinsic :: iso_c_binding, only : c_null_char + character(*), intent(in) :: filename !! Filename of Torch Script module + type(torch_module) :: module !! Returned deserialized module + + interface + function torch_jit_load_c(filename) result(module) & + bind(c, name = 'torch_jit_load') + use, intrinsic :: iso_c_binding, only : c_char, c_ptr + character(c_char), intent(in) :: filename(*) + type(c_ptr) :: module + end function torch_jit_load_c + end interface + + ! Need to append c_null_char at end of filename + module%p = torch_jit_load_c(trim(adjustl(filename))//c_null_char) + end function torch_module_load + + !> Performs a forward pass of the module with the input tensors + subroutine torch_module_forward(module, input_tensors, n_inputs, output_tensor) + use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_loc + type(torch_module), intent(in) :: module !! Module + type(torch_tensor), intent(in), dimension(:) :: input_tensors !! Array of Input tensors + type(torch_tensor), intent(in) :: output_tensor !! Returned output tensors + integer(c_int) :: n_inputs + + integer :: i + type(c_ptr), dimension(n_inputs), target :: input_ptrs + + interface + subroutine torch_jit_module_forward_c(module, input_tensors, n_inputs, & + output_tensor) & + bind(c, name = 'torch_jit_module_forward') + use, intrinsic :: iso_c_binding, only : c_ptr, c_int + type(c_ptr), value, intent(in) :: module + type(c_ptr), value, intent(in) :: input_tensors + integer(c_int), value, intent(in) :: n_inputs + type(c_ptr), value, intent(in) :: output_tensor + end subroutine torch_jit_module_forward_c + end interface + + ! Assign array of pointers to the input tensors + do i = 1, n_inputs + input_ptrs(i) = input_tensors(i)%p + end do + + call torch_jit_module_forward_c(module%p, c_loc(input_ptrs), n_inputs, output_tensor%p) + end subroutine torch_module_forward + + !> Deallocates a Torch Script module + subroutine torch_module_delete(module) + type(torch_module), intent(in) :: module !! Module to deallocate + + interface + subroutine torch_jit_module_delete_c(module) & + bind(c, name = 'torch_jit_module_delete') + use, intrinsic :: iso_c_binding, only : c_ptr + type(c_ptr), value, intent(in) :: module + end subroutine torch_jit_module_delete_c + end interface + + call torch_jit_module_delete_c(module%p) + end subroutine torch_module_delete + + !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `int8` + function torch_tensor_from_array_int8_1d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : int8 + + ! inputs + integer(kind=int8), intent(in), target :: data_in(:) !! Input data that tensor will point at + integer, intent(in) :: layout(1) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type + integer(c_int64_t) :: strides(1) !! Strides for accessing data + integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_int8_1d + + !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `int8` + function torch_tensor_from_array_int8_2d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : int8 + + ! inputs + integer(kind=int8), intent(in), target :: data_in(:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(2) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type + integer(c_int64_t) :: strides(2) !! Strides for accessing data + integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_int8_2d + + !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `int8` + function torch_tensor_from_array_int8_3d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : int8 + + ! inputs + integer(kind=int8), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(3) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type + integer(c_int64_t) :: strides(3) !! Strides for accessing data + integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_int8_3d + + !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `int8` + function torch_tensor_from_array_int8_4d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : int8 + + ! inputs + integer(kind=int8), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(4) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type + integer(c_int64_t) :: strides(4) !! Strides for accessing data + integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_int8_4d + + !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `int16` + function torch_tensor_from_array_int16_1d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : int16 + + ! inputs + integer(kind=int16), intent(in), target :: data_in(:) !! Input data that tensor will point at + integer, intent(in) :: layout(1) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type + integer(c_int64_t) :: strides(1) !! Strides for accessing data + integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_int16_1d + + !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `int16` + function torch_tensor_from_array_int16_2d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : int16 + + ! inputs + integer(kind=int16), intent(in), target :: data_in(:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(2) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type + integer(c_int64_t) :: strides(2) !! Strides for accessing data + integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_int16_2d + + !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `int16` + function torch_tensor_from_array_int16_3d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : int16 + + ! inputs + integer(kind=int16), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(3) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type + integer(c_int64_t) :: strides(3) !! Strides for accessing data + integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_int16_3d + + !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `int16` + function torch_tensor_from_array_int16_4d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : int16 + + ! inputs + integer(kind=int16), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(4) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type + integer(c_int64_t) :: strides(4) !! Strides for accessing data + integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_int16_4d + + !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `int32` + function torch_tensor_from_array_int32_1d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : int32 + + ! inputs + integer(kind=int32), intent(in), target :: data_in(:) !! Input data that tensor will point at + integer, intent(in) :: layout(1) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type + integer(c_int64_t) :: strides(1) !! Strides for accessing data + integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_int32_1d + + !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `int32` + function torch_tensor_from_array_int32_2d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : int32 + + ! inputs + integer(kind=int32), intent(in), target :: data_in(:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(2) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type + integer(c_int64_t) :: strides(2) !! Strides for accessing data + integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_int32_2d + + !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `int32` + function torch_tensor_from_array_int32_3d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : int32 + + ! inputs + integer(kind=int32), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(3) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type + integer(c_int64_t) :: strides(3) !! Strides for accessing data + integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_int32_3d + + !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `int32` + function torch_tensor_from_array_int32_4d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : int32 + + ! inputs + integer(kind=int32), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(4) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type + integer(c_int64_t) :: strides(4) !! Strides for accessing data + integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_int32_4d + + !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `int64` + function torch_tensor_from_array_int64_1d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : int64 + + ! inputs + integer(kind=int64), intent(in), target :: data_in(:) !! Input data that tensor will point at + integer, intent(in) :: layout(1) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type + integer(c_int64_t) :: strides(1) !! Strides for accessing data + integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_int64_1d + + !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `int64` + function torch_tensor_from_array_int64_2d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : int64 + + ! inputs + integer(kind=int64), intent(in), target :: data_in(:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(2) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type + integer(c_int64_t) :: strides(2) !! Strides for accessing data + integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_int64_2d + + !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `int64` + function torch_tensor_from_array_int64_3d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : int64 + + ! inputs + integer(kind=int64), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(3) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type + integer(c_int64_t) :: strides(3) !! Strides for accessing data + integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_int64_3d + + !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `int64` + function torch_tensor_from_array_int64_4d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : int64 + + ! inputs + integer(kind=int64), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(4) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type + integer(c_int64_t) :: strides(4) !! Strides for accessing data + integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_int64_4d + + !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `real32` + function torch_tensor_from_array_real32_1d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : real32 + + ! inputs + real(kind=real32), intent(in), target :: data_in(:) !! Input data that tensor will point at + integer, intent(in) :: layout(1) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type + integer(c_int64_t) :: strides(1) !! Strides for accessing data + integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_real32_1d + + !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `real32` + function torch_tensor_from_array_real32_2d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : real32 + + ! inputs + real(kind=real32), intent(in), target :: data_in(:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(2) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type + integer(c_int64_t) :: strides(2) !! Strides for accessing data + integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_real32_2d + + !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `real32` + function torch_tensor_from_array_real32_3d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : real32 + + ! inputs + real(kind=real32), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(3) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type + integer(c_int64_t) :: strides(3) !! Strides for accessing data + integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_real32_3d + + !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `real32` + function torch_tensor_from_array_real32_4d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : real32 + + ! inputs + real(kind=real32), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(4) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type + integer(c_int64_t) :: strides(4) !! Strides for accessing data + integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_real32_4d + + !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `real64` + function torch_tensor_from_array_real64_1d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : real64 + + ! inputs + real(kind=real64), intent(in), target :: data_in(:) !! Input data that tensor will point at + integer, intent(in) :: layout(1) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type + integer(c_int64_t) :: strides(1) !! Strides for accessing data + integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_real64_1d + + !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `real64` + function torch_tensor_from_array_real64_2d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : real64 + + ! inputs + real(kind=real64), intent(in), target :: data_in(:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(2) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type + integer(c_int64_t) :: strides(2) !! Strides for accessing data + integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_real64_2d + + !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `real64` + function torch_tensor_from_array_real64_3d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : real64 + + ! inputs + real(kind=real64), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(3) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type + integer(c_int64_t) :: strides(3) !! Strides for accessing data + integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_real64_3d + + !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `real64` + function torch_tensor_from_array_real64_4d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : real64 + + ! inputs + real(kind=real64), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at + integer, intent(in) :: layout(4) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type + integer(c_int64_t) :: strides(4) !! Strides for accessing data + integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_real64_4d + end module ftorch diff --git a/src/ftorch.fypp b/src/ftorch.fypp new file mode 100644 index 00000000..16371e48 --- /dev/null +++ b/src/ftorch.fypp @@ -0,0 +1,296 @@ +#:def ranksuffix(RANK) +$:'' if RANK == 0 else '(' + ':' + ',:' * (RANK - 1) + ')' +#:enddef ranksuffix +#:set PRECISIONS = ['int8', 'int16', 'int32', 'int64', 'real32', 'real64'] +#:set C_PRECISIONS = ['c_int8_t', 'c_int16_t', 'c_int32_t', 'c_int64_t', 'c_float', 'c_double'] +#:set C_PRECISIONS = dict(zip(PRECISIONS, C_PRECISIONS)) +#:set ENUMS = dict(zip(PRECISIONS, ['torch_kInt8', 'torch_kInt16', 'torch_kInt32', 'torch_kInt64', 'torch_kFloat32', 'torch_kFloat64'])) +#:set RANKS = range(1, 5) +#:def enum_from_prec(PRECISION) +$:ENUMS[PRECISION] +#:enddef enum_from_prec +#:def c_prec(PRECISION) +$:C_PRECISIONS[PRECISION] +#:enddef c_prec +#:def f_type(PRECISION) +$:'integer' if PRECISION[:3] == 'int' else 'real' +#:enddef f_type +!| Main module for FTorch containing types and procedures. +! Generated from `ftorch.fypp` using the [fypp Fortran preprocessor](https://fypp.readthedocs.io/en/stable/index.html). +! +! * License +! FTorch is released under an MIT license. +! See the [LICENSE](https://github.com/Cambridge-ICCS/FTorch/blob/main/LICENSE) +! file for details. + +module ftorch + + use, intrinsic :: iso_c_binding, only: c_int, c_int8_t, c_int16_t, c_int32_t, c_int64_t, c_int64_t, & + c_float, c_double, c_char, c_ptr, c_null_ptr + use, intrinsic :: iso_fortran_env, only: int8, int16, int32, int64, real32, real64 + + implicit none + + !> Type for holding a torch neural net (nn.Module). + type torch_module + type(c_ptr) :: p = c_null_ptr !! pointer to the neural net module in memory + end type torch_module + + !> Type for holding a Torch tensor. + type torch_tensor + type(c_ptr) :: p = c_null_ptr !! pointer to the tensor in memory + end type torch_tensor + + !| Enumerator for Torch data types + ! From c_torch.h (torch_data_t) + ! Note that 0 `torch_kUInt8` and 5 `torch_kFloat16` are not sypported in Fortran + enum, bind(c) + enumerator :: torch_kUInt8 = 0 ! not supported in Fortran + enumerator :: torch_kInt8 = 1 + enumerator :: torch_kInt16 = 2 + enumerator :: torch_kInt32 = 3 + enumerator :: torch_kInt64 = 4 + enumerator :: torch_kFloat16 = 5 ! not supported in Fortran + enumerator :: torch_kFloat32 = 6 + enumerator :: torch_kFloat64 = 7 + end enum + + + !| Enumerator for Torch devices + ! From c_torch.h (torch_device_t) + enum, bind(c) + enumerator :: torch_kCPU = 0 + enumerator :: torch_kCUDA = 1 + end enum + + !> Interface for directing `torch_tensor_from_array` to possible input types and ranks + interface torch_tensor_from_array + #:for PREC in PRECISIONS + #:for RANK in RANKS + module procedure torch_tensor_from_array_${PREC}$_${RANK}$d + #:endfor + #:endfor + end interface + + interface + function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor_p) & + bind(c, name = 'torch_from_blob') + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr + + ! Arguments + type(c_ptr), value, intent(in) :: data + integer(c_int), value, intent(in) :: ndims + integer(c_int64_t), intent(in) :: tensor_shape(*) + integer(c_int64_t), intent(in) :: strides(*) + integer(c_int), value, intent(in) :: dtype + integer(c_int), value, intent(in) :: device + type(c_ptr) :: tensor_p + end function torch_from_blob_c + end interface + +contains + + ! Torch Tensor API + !| Exposes the given data as a tensor without taking ownership of the original data. + ! This routine will take an (i, j, k) array and return an (k, j, i) tensor. + function torch_tensor_from_blob(data, ndims, tensor_shape, layout, dtype, device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr + type(c_ptr), intent(in) :: data !! Pointer to data + integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor + integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor + integer(c_int), intent(in) :: dtype !! Data type of the tensor + integer(c_int), intent(in) :: device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: layout(*) !! Layout for strides for accessing data + type(torch_tensor) :: tensor !! Returned tensor + + integer(c_int) :: i !! loop index + integer(c_int64_t) :: strides(ndims) !! Strides for accessing data + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) + end do + tensor%p = torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) + end function torch_tensor_from_blob + + !> Returns a tensor filled with the scalar value 1. + function torch_tensor_ones(ndims, tensor_shape, dtype, device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t + integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor + integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor + integer(c_int), intent(in) :: dtype !! Data type of the tensor + integer(c_int), intent(in) :: device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + type(torch_tensor) :: tensor !! Returned tensor + + interface + function torch_ones_c(ndims, tensor_shape, dtype, device) result(tensor) & + bind(c, name = 'torch_ones') + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr + integer(c_int), value, intent(in) :: ndims + integer(c_int64_t), intent(in) :: tensor_shape(*) + integer(c_int), value, intent(in) :: dtype + integer(c_int), value, intent(in) :: device + type(c_ptr) :: tensor + end function torch_ones_c + end interface + + tensor%p = torch_ones_c(ndims, tensor_shape, dtype, device) + end function torch_tensor_ones + + !> Returns a tensor filled with the scalar value 0. + function torch_tensor_zeros(ndims, tensor_shape, dtype, device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t + integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor + integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor + integer(c_int), intent(in) :: dtype !! Data type of the tensor + integer(c_int), intent(in) :: device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + type(torch_tensor) :: tensor !! Returned tensor + + interface + function torch_zeros_c(ndims, tensor_shape, dtype, device) result(tensor) & + bind(c, name = 'torch_zeros') + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr + integer(c_int), value, intent(in) :: ndims + integer(c_int64_t), intent(in) :: tensor_shape(*) + integer(c_int), value, intent(in) :: dtype + integer(c_int), value, intent(in) :: device + type(c_ptr) :: tensor + end function torch_zeros_c + end interface + + tensor%p = torch_zeros_c(ndims, tensor_shape, dtype, device) + end function torch_tensor_zeros + + !> Prints the contents of a tensor. + subroutine torch_tensor_print(tensor) + type(torch_tensor), intent(in) :: tensor !! Input tensor + + interface + subroutine torch_tensor_print_c(tensor) & + bind(c, name = 'torch_tensor_print') + use, intrinsic :: iso_c_binding, only : c_ptr + type(c_ptr), value, intent(in) :: tensor + end subroutine torch_tensor_print_c + end interface + + call torch_tensor_print_c(tensor%p) + end subroutine torch_tensor_print + + !> Deallocates a tensor. + subroutine torch_tensor_delete(tensor) + type(torch_tensor), intent(in) :: tensor !! Input tensor + + interface + subroutine torch_tensor_delete_c(tensor) & + bind(c, name = 'torch_tensor_delete') + use, intrinsic :: iso_c_binding, only : c_ptr + type(c_ptr), value, intent(in) :: tensor + end subroutine torch_tensor_delete_c + end interface + + call torch_tensor_delete_c(tensor%p) + end subroutine torch_tensor_delete + + ! Torch Module API + !> Loads a Torch Script module (pre-trained PyTorch model saved with Torch Script) + function torch_module_load(filename) result(module) + use, intrinsic :: iso_c_binding, only : c_null_char + character(*), intent(in) :: filename !! Filename of Torch Script module + type(torch_module) :: module !! Returned deserialized module + + interface + function torch_jit_load_c(filename) result(module) & + bind(c, name = 'torch_jit_load') + use, intrinsic :: iso_c_binding, only : c_char, c_ptr + character(c_char), intent(in) :: filename(*) + type(c_ptr) :: module + end function torch_jit_load_c + end interface + + ! Need to append c_null_char at end of filename + module%p = torch_jit_load_c(trim(adjustl(filename))//c_null_char) + end function torch_module_load + + !> Performs a forward pass of the module with the input tensors + subroutine torch_module_forward(module, input_tensors, n_inputs, output_tensor) + use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_loc + type(torch_module), intent(in) :: module !! Module + type(torch_tensor), intent(in), dimension(:) :: input_tensors !! Array of Input tensors + type(torch_tensor), intent(in) :: output_tensor !! Returned output tensors + integer(c_int) :: n_inputs + + integer :: i + type(c_ptr), dimension(n_inputs), target :: input_ptrs + + interface + subroutine torch_jit_module_forward_c(module, input_tensors, n_inputs, & + output_tensor) & + bind(c, name = 'torch_jit_module_forward') + use, intrinsic :: iso_c_binding, only : c_ptr, c_int + type(c_ptr), value, intent(in) :: module + type(c_ptr), value, intent(in) :: input_tensors + integer(c_int), value, intent(in) :: n_inputs + type(c_ptr), value, intent(in) :: output_tensor + end subroutine torch_jit_module_forward_c + end interface + + ! Assign array of pointers to the input tensors + do i = 1, n_inputs + input_ptrs(i) = input_tensors(i)%p + end do + + call torch_jit_module_forward_c(module%p, c_loc(input_ptrs), n_inputs, output_tensor%p) + end subroutine torch_module_forward + + !> Deallocates a Torch Script module + subroutine torch_module_delete(module) + type(torch_module), intent(in) :: module !! Module to deallocate + + interface + subroutine torch_jit_module_delete_c(module) & + bind(c, name = 'torch_jit_module_delete') + use, intrinsic :: iso_c_binding, only : c_ptr + type(c_ptr), value, intent(in) :: module + end subroutine torch_jit_module_delete_c + end interface + + call torch_jit_module_delete_c(module%p) + end subroutine torch_module_delete + + #:for PREC in PRECISIONS + #:for RANK in RANKS + !> Return a Torch tensor pointing to data_in array of rank ${RANK}$ containing data of type `${PREC}$` + function torch_tensor_from_array_${PREC}$_${RANK}$d(data_in, layout, c_device) result(tensor) + use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc + use, intrinsic :: iso_fortran_env, only : ${PREC}$ + + ! inputs + ${f_type(PREC)}$(kind=${PREC}$), intent(in), target :: data_in${ranksuffix(RANK)}$ !! Input data that tensor will point at + integer, intent(in) :: layout(${RANK}$) !! Control order of indices + integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + + ! output tensory + type(torch_tensor) :: tensor !! Returned tensor + + ! local data + integer(c_int64_t) :: c_tensor_shape(${RANK}$) !! Shape of the tensor + integer(c_int), parameter :: c_dtype = ${enum_from_prec(PREC)}$ !! Data type + integer(c_int64_t) :: strides(${RANK}$) !! Strides for accessing data + integer(c_int), parameter :: ndims = ${RANK}$ !! Number of dimension of input data + integer :: i + + c_tensor_shape = shape(data_in) + + strides(layout(1)) = 1 + do i = 2, ndims + strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + end do + + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device) + + end function torch_tensor_from_array_${PREC}$_${RANK}$d + + #:endfor + #:endfor + +end module ftorch