From a69cff58d7f14844288b743a51a23a644e1f6fac Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 17 Dec 2024 12:41:52 -0800 Subject: [PATCH] Fix two non-idiomatic uses of node_type --- src/Derivative.cpp | 6 ++---- src/Serialization.cpp | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/Derivative.cpp b/src/Derivative.cpp index 2520d27e290f..e64fb4ada94b 100644 --- a/src/Derivative.cpp +++ b/src/Derivative.cpp @@ -337,8 +337,7 @@ void ReverseAccumulationVisitor::propagate_adjoints( let_var_mapping.clear(); let_variables.clear(); for (const auto &expr : expr_list) { - if (expr.get()->node_type == IRNodeType::Let) { - const Let *op = expr.as(); + if (const Let *op = expr.as()) { // Assume Let variables are unique internal_assert(let_var_mapping.find(op->name) == let_var_mapping.end()); let_var_mapping[op->name] = op->value; @@ -660,8 +659,7 @@ void ReverseAccumulationVisitor::propagate_adjoints( let_var_mapping.clear(); let_variables.clear(); for (const auto &expr : expr_list) { - if (expr.get()->node_type == IRNodeType::Let) { - const Let *op = expr.as(); + if (const Let *op = expr.as()) { // Assume Let variables are unique internal_assert(let_var_mapping.find(op->name) == let_var_mapping.end()); let_var_mapping[op->name] = op->value; diff --git a/src/Serialization.cpp b/src/Serialization.cpp index 27bea2f5cfa3..15722d878974 100644 --- a/src/Serialization.cpp +++ b/src/Serialization.cpp @@ -410,7 +410,7 @@ std::pair> Serializer::serialize_stmt(FlatBufferBu if (!stmt.defined()) { return std::make_pair(Serialize::Stmt::UndefinedStmt, Serialize::CreateUndefinedStmt(builder).Union()); } - switch (stmt->node_type) { + switch (stmt.node_type()) { case IRNodeType::LetStmt: { const auto *const let_stmt = stmt.as(); const auto name_serialized = serialize_string(builder, let_stmt->name); @@ -681,7 +681,7 @@ std::pair> Serializer::serialize_expr(FlatBufferBu if (!expr.defined()) { return std::make_pair(Serialize::Expr::UndefinedExpr, Serialize::CreateUndefinedExpr(builder).Union()); } - switch (expr->node_type) { + switch (expr.node_type()) { case IRNodeType::IntImm: { const auto *const int_imm = expr.as(); const auto type_serialized = serialize_type(builder, int_imm->type);