Skip to content

Commit

Permalink
fix: local var set value in the internal scope (#1444)
Browse files Browse the repository at this point in the history
* fix: local var set value in the config internal scope

Signed-off-by: peefy <[email protected]>

* fix: config if variable scope set

Signed-off-by: peefy <[email protected]>

---------

Signed-off-by: peefy <[email protected]>
  • Loading branch information
Peefy authored Jun 27, 2024
1 parent 39cdc71 commit 59a5719
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 60 deletions.
61 changes: 35 additions & 26 deletions kclvm/compiler/src/codegen/llvm/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ pub struct Scope<'ctx> {
pub schema_scalar_idx: RefCell<usize>,
/// Scope normal variables
pub variables: RefCell<IndexMap<String, PointerValue<'ctx>>>,
/// Scope normal initialized variables
pub uninitialized: RefCell<IndexSet<String>>,
/// Scope closures referenced by internal scope.
pub closures: RefCell<IndexMap<String, PointerValue<'ctx>>>,
/// Potential arguments in the current scope, such as schema/lambda arguments.
Expand Down Expand Up @@ -1714,40 +1716,42 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
level
}

/// Append a variable or update the existed local variable.
pub fn add_or_update_local_variable(&self, name: &str, value: BasicValueEnum<'ctx>) {
/// Append a variable or update the existed closure variable within the current scope.
pub fn add_or_update_local_variable_within_scope(
&self,
name: &str,
value: Option<BasicValueEnum<'ctx>>,
) {
let current_pkgpath = self.current_pkgpath();
let mut pkg_scopes = self.pkg_scopes.borrow_mut();
let msg = format!("pkgpath {} is not found", current_pkgpath);
let scopes = pkg_scopes.get_mut(&current_pkgpath).expect(&msg);
let mut existed = false;
// Query the variable in all scopes.
for i in 0..scopes.len() {
let index = scopes.len() - i - 1;
let variables_mut = scopes[index].variables.borrow_mut();
let index = scopes.len() - 1;
if let Some(scope) = scopes.last_mut() {
let mut variables_mut = scope.variables.borrow_mut();
let mut uninitialized = scope.uninitialized.borrow_mut();
if value.is_none() {
uninitialized.insert(name.to_string());
} else {
uninitialized.remove(name);
}
match variables_mut.get(&name.to_string()) {
// If the local variable is found, store the new value for the variable.
// We cannot update rule/lambda/schema arguments because they are read-only.
Some(ptr)
if index > GLOBAL_LEVEL
&& !self.local_vars.borrow().contains(name)
&& !scopes[index].arguments.borrow().contains(name) =>
{
self.builder.build_store(*ptr, value);
existed = true;
Some(ptr) if index > GLOBAL_LEVEL => {
if let Some(value) = value {
self.builder.build_store(*ptr, value);
}
}
_ => {}
}
}
// If not found, alloc a new variable.
if !existed {
let ptr = self.builder.build_alloca(self.value_ptr_type(), name);
self.builder.build_store(ptr, value);
// Store the value for the variable and add the variable into the current scope.
if let Some(last) = scopes.last_mut() {
let mut variables = last.variables.borrow_mut();
variables.insert(name.to_string(), ptr);
}
_ => {
let ptr = self.builder.build_alloca(self.value_ptr_type(), name);
if let Some(value) = value {
self.builder.build_store(ptr, value);
}
// Store the value for the variable and add the variable into the current scope.
variables_mut.insert(name.to_string(), ptr);
}
};
}
}

Expand Down Expand Up @@ -1998,6 +2002,8 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
for i in 0..scopes_len {
let index = scopes_len - i - 1;
let variables = scopes[index].variables.borrow();
// Skip uninitialized pointer value, which may cause NPE.
let uninitialized = scopes[index].uninitialized.borrow();
if let Some(var) = variables.get(&name.to_string()) {
// Closure vars, 2 denotes the builtin scope and the global scope, here is a closure scope.
let value = if i >= 1 && i < scopes_len - 2 {
Expand Down Expand Up @@ -2062,6 +2068,9 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
],
)
} else {
if uninitialized.contains(name) {
continue;
}
self.builder.build_load(*var, name)
}
};
Expand Down
46 changes: 46 additions & 0 deletions kclvm/compiler/src/codegen/llvm/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,52 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
}
}

pub(crate) fn emit_config_if_entry_expr_vars(
&self,
config_if_entry_expr: &'ctx ast::ConfigIfEntryExpr,
) {
self.emit_config_entries_vars(&config_if_entry_expr.items);
if let Some(orelse) = &config_if_entry_expr.orelse {
// Config expr or config if entry expr.
if let ast::Expr::Config(config_expr) = &orelse.node {
self.emit_config_entries_vars(&config_expr.items);
} else if let ast::Expr::ConfigIfEntry(config_if_entry_expr) = &orelse.node {
self.emit_config_if_entry_expr_vars(config_if_entry_expr);
}
}
}

pub(crate) fn emit_config_entries_vars(&self, items: &'ctx [ast::NodeRef<ast::ConfigEntry>]) {
for item in items {
if let ast::Expr::ConfigIfEntry(config_if_entry_expr) = &item.node.value.node {
self.emit_config_if_entry_expr_vars(config_if_entry_expr);
}
if let Some(key) = &item.node.key {
let optional_name = match &key.node {
ast::Expr::Identifier(identifier) => Some(identifier.names[0].node.clone()),
ast::Expr::StringLit(string_lit) => Some(string_lit.value.clone()),
ast::Expr::Subscript(subscript) => {
let mut name = None;
if let ast::Expr::Identifier(identifier) = &subscript.value.node {
if let Some(index_node) = &subscript.index {
if let ast::Expr::NumberLit(number) = &index_node.node {
if let ast::NumberLitValue::Int(_) = number.value {
name = Some(identifier.names[0].node.clone())
}
}
}
}
name
}
_ => None,
};
if let Some(name) = &optional_name {
self.add_or_update_local_variable_within_scope(name, None);
}
}
}
}

/// Compile AST Modules, which requires traversing three times.
/// 1. scan all possible global variables and allocate undefined values to global pointers.
/// 2. build all user-defined schema/rule types.
Expand Down
18 changes: 13 additions & 5 deletions kclvm/compiler/src/codegen/llvm/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1979,6 +1979,7 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
let else_block = self.append_block("");
let end_block = self.append_block("");
let is_truth = self.value_is_truthy(cond);
self.emit_config_if_entry_expr_vars(config_if_entry_expr);
let tpe = self.value_ptr_type();
self.cond_br(is_truth, then_block, else_block);
self.builder.position_at_end(then_block);
Expand All @@ -1992,7 +1993,13 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
self.br(end_block);
self.builder.position_at_end(else_block);
let else_value = if let Some(orelse) = &config_if_entry_expr.orelse {
self.walk_expr(orelse).expect(kcl_error::COMPILE_ERROR_MSG)
// Config expr or config if entry expr.
if let ast::Expr::Config(config_expr) = &orelse.node {
self.walk_config_entries(&config_expr.items)
.expect(kcl_error::COMPILE_ERROR_MSG)
} else {
self.walk_expr(orelse).expect(kcl_error::COMPILE_ERROR_MSG)
}
} else {
self.none_value()
};
Expand Down Expand Up @@ -2076,7 +2083,10 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {

fn walk_config_expr(&self, config_expr: &'ctx ast::ConfigExpr) -> Self::Result {
check_backtrack_stop!(self);
self.walk_config_entries(&config_expr.items)
self.enter_scope();
let result = self.walk_config_entries(&config_expr.items);
self.leave_scope();
result
}

fn walk_check_expr(&self, check_expr: &'ctx ast::CheckExpr) -> Self::Result {
Expand Down Expand Up @@ -2830,7 +2840,6 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
items: &'ctx [NodeRef<ConfigEntry>],
) -> CompileResult<'ctx> {
let config_value = self.dict_value();
self.enter_scope();
for item in items {
let value = self.walk_expr(&item.node.value)?;
if let Some(key) = &item.node.key {
Expand Down Expand Up @@ -2869,7 +2878,7 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
if let Some(name) = &optional_name {
let value =
self.dict_get(config_value, self.native_global_string(name, "").into());
self.add_or_update_local_variable(name, value);
self.add_or_update_local_variable_within_scope(name, Some(value));
}
} else {
// If the key does not exist, execute the logic of unpacking expression `**expr` here.
Expand All @@ -2879,7 +2888,6 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
);
}
}
self.leave_scope();
Ok(config_value)
}
}
16 changes: 11 additions & 5 deletions kclvm/evaluator/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,12 @@ impl<'ctx> TypedResultWalker<'ctx> for Evaluator<'ctx> {
Ok(if is_truth {
self.walk_config_entries(&config_if_entry_expr.items)?
} else if let Some(orelse) = &config_if_entry_expr.orelse {
self.walk_expr(orelse)?
// Config expr or config if entry expr.
if let ast::Expr::Config(config_expr) = &orelse.node {
self.walk_config_entries(&config_expr.items)?
} else {
self.walk_expr(orelse)?
}
} else {
self.none_value()
})
Expand Down Expand Up @@ -927,7 +932,10 @@ impl<'ctx> TypedResultWalker<'ctx> for Evaluator<'ctx> {

#[inline]
fn walk_config_expr(&self, config_expr: &'ctx ast::ConfigExpr) -> Self::Result {
self.walk_config_entries(&config_expr.items)
self.enter_scope();
let result = self.walk_config_entries(&config_expr.items);
self.leave_scope();
result
}

fn walk_check_expr(&self, check_expr: &'ctx ast::CheckExpr) -> Self::Result {
Expand Down Expand Up @@ -1509,7 +1517,6 @@ impl<'ctx> Evaluator<'ctx> {

pub(crate) fn walk_config_entries(&self, items: &'ctx [NodeRef<ConfigEntry>]) -> EvalResult {
let mut config_value = self.dict_value();
self.enter_scope();
for item in items {
let value = self.walk_expr(&item.node.value)?;
if let Some(key) = &item.node.key {
Expand Down Expand Up @@ -1547,14 +1554,13 @@ impl<'ctx> Evaluator<'ctx> {
);
if let Some(name) = &optional_name {
let value = self.dict_get_value(&config_value, name);
self.add_or_update_local_variable(name, value);
self.add_or_update_local_variable_within_scope(name, value);
}
} else {
// If the key does not exist, execute the logic of unpacking expression `**expr` here.
config_value.dict_insert_unpack(&mut self.runtime_ctx.borrow_mut(), &value)
}
}
self.leave_scope();
Ok(config_value)
}
}
31 changes: 7 additions & 24 deletions kclvm/evaluator/src/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,34 +301,17 @@ impl<'ctx> Evaluator<'ctx> {
level
}

/// Append a variable or update the existed local variable.
pub fn add_or_update_local_variable(&self, name: &str, value: ValueRef) {
/// Append a variable or update the existed local variable within the current scope.
pub(crate) fn add_or_update_local_variable_within_scope(&self, name: &str, value: ValueRef) {
let current_pkgpath = self.current_pkgpath();
let is_local_var = self.is_local_var(name);
let pkg_scopes = &mut self.pkg_scopes.borrow_mut();
let msg = format!("pkgpath {} is not found", current_pkgpath);
let scopes = pkg_scopes.get_mut(&current_pkgpath).expect(&msg);
let mut existed = false;
// Query the variable in all scopes.
for i in 0..scopes.len() {
let index = scopes.len() - i - 1;
let is_argument = scopes[index].arguments.contains(name);
let variables_mut = &mut scopes[index].variables;
match variables_mut.get(&name.to_string()) {
// If the local variable is found, store the new value for the variable.
// We cannot update rule/lambda/schema arguments because they are read-only.
Some(_) if index > GLOBAL_LEVEL && !is_local_var && !is_argument => {
variables_mut.insert(name.to_string(), value.clone());
existed = true;
}
_ => {}
}
}
// If not found, alloc a new variable.
if !existed {
// Store the value for the variable and add the variable into the current scope.
if let Some(last) = scopes.last_mut() {
last.variables.insert(name.to_string(), value);
let index = scopes.len() - 1;
if let Some(scope) = scopes.last_mut() {
let variables_mut = &mut scope.variables;
if index > GLOBAL_LEVEL {
variables_mut.insert(name.to_string(), value.clone());
}
}
}
Expand Down
24 changes: 24 additions & 0 deletions test/grammar/datatype/dict/mutual_ref_15/main.k
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
render = lambda {
a = {
foo: "bar"
}
b = {
foo2: "bar2"
a: {
b: "c"
}
}
c = [a, b]
}

out = render()
a = {
foo: "bar"
}
b = {
foo2: "bar2"
a: {
b: "c"
}
}
c = [a, b]
16 changes: 16 additions & 0 deletions test/grammar/datatype/dict/mutual_ref_15/stdout.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
out:
- foo: bar
- foo2: bar2
a:
b: c
a:
foo: bar
b:
foo2: bar2
a:
b: c
c:
- foo: bar
- foo2: bar2
a:
b: c
9 changes: 9 additions & 0 deletions test/grammar/datatype/dict/mutual_ref_16/main.k
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
level0 = {
name = "apple"
level1_a = {
name = "orange"
}
level1_b = {
name = "pine" + name
}
}
6 changes: 6 additions & 0 deletions test/grammar/datatype/dict/mutual_ref_16/stdout.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
level0:
name: apple
level1_a:
name: orange
level1_b:
name: pineapple

0 comments on commit 59a5719

Please sign in to comment.