Skip to content

Commit

Permalink
Reformat .cc and .h file (#1191)
Browse files Browse the repository at this point in the history
Change-Id: Id8c933aacdfda9620f05868954047bbf93cee660
  • Loading branch information
frankfliu authored Aug 25, 2021
1 parent 1764533 commit 752de2c
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,33 +52,33 @@ JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_loadExtraDir(
}

JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_analysisConfigEnableMKLDNN(
JNIEnv* env, jobject jthis, jlong jhandle) {
auto* config_ptr = reinterpret_cast<paddle::AnalysisConfig*>(jhandle);
config_ptr->EnableMKLDNN();
JNIEnv* env, jobject jthis, jlong jhandle) {
auto* config_ptr = reinterpret_cast<paddle::AnalysisConfig*>(jhandle);
config_ptr->EnableMKLDNN();
}

JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_analysisConfigDisableGLog(
JNIEnv* env, jobject jthis, jlong jhandle) {
auto* config_ptr = reinterpret_cast<paddle::AnalysisConfig*>(jhandle);
config_ptr->DisableGlogInfo();
JNIEnv* env, jobject jthis, jlong jhandle) {
auto* config_ptr = reinterpret_cast<paddle::AnalysisConfig*>(jhandle);
config_ptr->DisableGlogInfo();
}

JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_analysisConfigCMLNumThreads(
JNIEnv* env, jobject jthis, jlong jhandle, jint num) {
auto* config_ptr = reinterpret_cast<paddle::AnalysisConfig*>(jhandle);
config_ptr->SetCpuMathLibraryNumThreads(num);
JNIEnv* env, jobject jthis, jlong jhandle, jint num) {
auto* config_ptr = reinterpret_cast<paddle::AnalysisConfig*>(jhandle);
config_ptr->SetCpuMathLibraryNumThreads(num);
}

JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_analysisConfigSwitchIrOptim(
JNIEnv* env, jobject jthis, jlong jhandle, jboolean condition) {
auto* config_ptr = reinterpret_cast<paddle::AnalysisConfig*>(jhandle);
config_ptr->SwitchIrOptim(condition);
JNIEnv* env, jobject jthis, jlong jhandle, jboolean condition) {
auto* config_ptr = reinterpret_cast<paddle::AnalysisConfig*>(jhandle);
config_ptr->SwitchIrOptim(condition);
}

JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_analysisConfigRemovePass(
JNIEnv* env, jobject jthis, jlong jhandle, jstring jpass) {
auto* config_ptr = reinterpret_cast<paddle::AnalysisConfig*>(jhandle);
config_ptr->pass_builder()->DeletePass(djl::utils::jni::GetStringFromJString(env, jpass));
JNIEnv* env, jobject jthis, jlong jhandle, jstring jpass) {
auto* config_ptr = reinterpret_cast<paddle::AnalysisConfig*>(jhandle);
config_ptr->pass_builder()->DeletePass(djl::utils::jni::GetStringFromJString(env, jpass));
}

JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_deleteAnalysisConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ JNIEXPORT jstring JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_getTensorNa
}

JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_setTensorLoD(
JNIEnv* env, jobject jthis, jlong jhandle, jobjectArray j2dlongarray) {
auto tensor_ptr = reinterpret_cast<paddle::PaddleTensor*>(jhandle);
tensor_ptr->lod = djl::utils::jni::Get2DVecFrom2DLongArray(env, j2dlongarray);
JNIEnv* env, jobject jthis, jlong jhandle, jobjectArray j2dlongarray) {
auto tensor_ptr = reinterpret_cast<paddle::PaddleTensor*>(jhandle);
tensor_ptr->lod = djl::utils::jni::Get2DVecFrom2DLongArray(env, j2dlongarray);
}

JNIEXPORT jobjectArray JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_getTensorLoD(
JNIEnv* env, jobject jthis, jlong jhandle) {
auto tensor_ptr = reinterpret_cast<paddle::PaddleTensor*>(jhandle);
return djl::utils::jni::Get2DLongArrayFrom2DVec(env, tensor_ptr->lod);
JNIEnv* env, jobject jthis, jlong jhandle) {
auto tensor_ptr = reinterpret_cast<paddle::PaddleTensor*>(jhandle);
return djl::utils::jni::Get2DLongArrayFrom2DVec(env, tensor_ptr->lod);
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,30 @@
#ifndef DJL_PADDLE_DJL_PADDLE_JNI_UTILS_H
#define DJL_PADDLE_DJL_PADDLE_JNI_UTILS_H

#include <jni.h>
#include <paddle_api.h>

#include <jni.h>
#include <iostream>
#include <vector>
#include <numeric>
#include <vector>

namespace utils {

inline void GetZTensorFromTensor(paddle::ZeroCopyTensor* z_tensor, paddle::PaddleTensor* tensor) {
inline void GetZTensorFromTensor(paddle::ZeroCopyTensor *z_tensor, paddle::PaddleTensor *tensor) {
z_tensor->Reshape(tensor->shape);
z_tensor->SetLoD(tensor->lod);
switch (tensor->dtype) {
case paddle::PaddleDType::FLOAT32:
z_tensor->copy_from_cpu(static_cast<float*>(tensor->data.data()));
z_tensor->copy_from_cpu(static_cast<float *>(tensor->data.data()));
break;
case paddle::PaddleDType::INT32:
z_tensor->copy_from_cpu(static_cast<int32_t*>(tensor->data.data()));
z_tensor->copy_from_cpu(static_cast<int32_t *>(tensor->data.data()));
break;
case paddle::PaddleDType::INT64:
z_tensor->copy_from_cpu(static_cast<int64_t*>(tensor->data.data()));
z_tensor->copy_from_cpu(static_cast<int64_t *>(tensor->data.data()));
break;
case paddle::PaddleDType::UINT8:
z_tensor->copy_from_cpu(static_cast<uint8_t*>(tensor->data.data()));
z_tensor->copy_from_cpu(static_cast<uint8_t *>(tensor->data.data()));
break;
default:
// TODO improve the error handling
Expand Down Expand Up @@ -75,6 +75,6 @@ inline void GetTensorFromZTensor(paddle::ZeroCopyTensor *z_tensor, paddle::Paddl
}
}

}
} // namespace utils

#endif //DJL_PADDLE_DJL_PADDLE_JNI_UTILS_H
#endif // DJL_PADDLE_DJL_PADDLE_JNI_UTILS_H
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ extern jmethodID ERROR_METHOD;
// the highest version Android JNI version is 1.6
static jint JNI_VERSION = JNI_VERSION_1_6;

#endif //DJL_TORCH_AI_DJL_PYTORCH_JNI_CACHE_H
#endif // DJL_TORCH_AI_DJL_PYTORCH_JNI_CACHE_H
25 changes: 12 additions & 13 deletions pytorch/pytorch-native/src/main/native/djl_pytorch_jni_exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

#include "ai_djl_pytorch_jni_cache.h"

#define DJL_CHECK_WITH_MSG(cond, error_msg) \
if (!cond) { \
env->ThrowNew(ENGINE_EXCEPTION_CLASS, error_msg); \
#define DJL_CHECK_WITH_MSG(cond, error_msg) \
if (!cond) { \
env->ThrowNew(ENGINE_EXCEPTION_CLASS, error_msg); \
}

/*
Expand All @@ -27,16 +27,15 @@
* and finishes with API_END()
*/
#define API_BEGIN() try {

#define API_END() \
} \
catch (const c10::Error& e) { \
Log log(env); \
log.debug(e.what()); \
env->ThrowNew(ENGINE_EXCEPTION_CLASS, e.what_without_backtrace()); \
} \
catch (const std::exception& e_) { \
env->ThrowNew(ENGINE_EXCEPTION_CLASS, e_.what()); \
#define API_END() \
} \
catch (const c10::Error& e) { \
Log log(env); \
log.debug(e.what()); \
env->ThrowNew(ENGINE_EXCEPTION_CLASS, e.what_without_backtrace()); \
} \
catch (const std::exception& e_) { \
env->ThrowNew(ENGINE_EXCEPTION_CLASS, e_.what()); \
}

// TODO refactor all jni functions to c style function which mean
Expand Down
42 changes: 22 additions & 20 deletions pytorch/pytorch-native/src/main/native/djl_pytorch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@

#include <c10/util/typeid.h>
#include <c10/util/variant.h>
#include <djl/utils.h>
#include <jni.h>
#include <torch/csrc/api/include/torch/enum.h>
#include <torch/script.h>

#include <jni.h>
#include <iostream>

#include <djl/utils.h>

#include "djl_pytorch_jni_log.h"

// The file is utilities that are used for JNI
Expand All @@ -31,13 +30,9 @@ namespace utils {

#if !defined(__ANDROID__)
// for image interpolation
typedef torch::variant<
torch::enumtype::kNearest,
torch::enumtype::kLinear,
torch::enumtype::kBilinear,
torch::enumtype::kBicubic,
torch::enumtype::kTrilinear,
torch::enumtype::kArea> mode_t;
typedef torch::variant<torch::enumtype::kNearest, torch::enumtype::kLinear, torch::enumtype::kBilinear,
torch::enumtype::kBicubic, torch::enumtype::kTrilinear, torch::enumtype::kArea>
mode_t;
#endif

inline jint GetDTypeFromScalarType(const torch::ScalarType& type) {
Expand Down Expand Up @@ -101,27 +96,34 @@ inline torch::Device GetDeviceFromJDevice(JNIEnv* env, jintArray jdevice) {
#if !defined(__ANDROID__)
inline mode_t GetInterpolationMode(jint jmode) {
switch (jmode) {
case 0: return torch::kNearest;
case 1: return torch::kLinear;
case 2: return torch::kBilinear;
case 3: return torch::kBicubic;
case 4: return torch::kTrilinear;
case 5: return torch::kArea;
case 0:
return torch::kNearest;
case 1:
return torch::kLinear;
case 2:
return torch::kBilinear;
case 3:
return torch::kBicubic;
case 4:
return torch::kTrilinear;
case 5:
return torch::kArea;
default:
throw;
}
}
#endif

inline std::vector<torch::indexing::TensorIndex> CreateTensorIndex(JNIEnv* env, jlongArray jmin_indices, jlongArray jmax_indices, jlongArray jstep_indices) {
inline std::vector<torch::indexing::TensorIndex> CreateTensorIndex(
JNIEnv* env, jlongArray jmin_indices, jlongArray jmax_indices, jlongArray jstep_indices) {
const auto min_indices = djl::utils::jni::GetVecFromJLongArray(env, jmin_indices);
const auto max_indices = djl::utils::jni::GetVecFromJLongArray(env, jmax_indices);
const auto step_indices = djl::utils::jni::GetVecFromJLongArray(env, jstep_indices);
std::vector<torch::indexing::TensorIndex> indices;
indices.reserve(min_indices.size());
for (size_t i = 0; i < min_indices.size(); ++i) {
indices.emplace_back(
torch::indexing::TensorIndex(torch::indexing::Slice(min_indices[i], max_indices[i], step_indices[i])));
torch::indexing::TensorIndex(torch::indexing::Slice(min_indices[i], max_indices[i], step_indices[i])));
}
return indices;
}
Expand All @@ -131,8 +133,8 @@ inline torch::TensorOptions CreateTensorOptions(
// it gets the device and collect jdevice memory
const auto device = utils::GetDeviceFromJDevice(env, jdevice);
auto options = torch::TensorOptions()
// for tensor creation API, MKLDNN layout is not supported
// the workaround is to create with Strided then call to_mkldnn()
// for tensor creation API, MKLDNN layout is not supported
// the workaround is to create with Strided then call to_mkldnn()
.layout((jlayout != 1) ? torch::kStrided : torch::kSparse)
.memory_format(torch::MemoryFormat::Contiguous)
.device(device)
Expand Down
12 changes: 9 additions & 3 deletions tools/gradle/cpp-formatter.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class CppFormatterPlugin implements Plugin<Project> {
}

static def formatCpp(File f, File clang) {
if (!f.getName().endsWith(".cc") && !f.getName().endsWith(".h")) {
if (!f.getName().endsWith(".cc") && !f.getName().endsWith(".cpp") && !f.getName().endsWith(".h")) {
return
}
ProcessBuilder pb = new ProcessBuilder("${clang.absolutePath}",
Expand All @@ -34,7 +34,10 @@ class CppFormatterPlugin implements Plugin<Project> {
Project rootProject = project.getRootProject()
def clang = new File("${rootProject.projectDir}/.clang/clang-format")
checkClang(clang)
def files = project.fileTree("src").include("**/*.cc")
def files = project.fileTree("src")
.include("**/*.cc")
.include("**/*.cpp")
.include("**/*.h")
for (File f : files) {
if (!f.isFile()) {
continue
Expand All @@ -53,7 +56,10 @@ class CppFormatterPlugin implements Plugin<Project> {
Project rootProject = project.getRootProject()
def clang = new File("${rootProject.projectDir}/.clang/clang-format")
checkClang(clang)
def files = project.fileTree("src").include("**/*.cc")
def files = project.fileTree("src")
.include("**/*.cc")
.include("**/*.cpp")
.include("**/*.h")
for (File f : files) {
if (!f.isFile()) {
continue
Expand Down

0 comments on commit 752de2c

Please sign in to comment.