diff --git a/Cargo.toml b/Cargo.toml index f642e964..9c3acabd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,4 +3,6 @@ members = [ "full-moon", "full-moon-derive", + "full-moon-lua-types", + "full-moon-lua-types-derive", ] \ No newline at end of file diff --git a/full-moon-lua-types-derive/Cargo.toml b/full-moon-lua-types-derive/Cargo.toml new file mode 100644 index 00000000..a468da28 --- /dev/null +++ b/full-moon-lua-types-derive/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "full_moon_lua_types_derive" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0.43" +quote = "1.0.21" +syn = { version = "1.0.99" } diff --git a/full-moon-lua-types-derive/src/lib.rs b/full-moon-lua-types-derive/src/lib.rs new file mode 100644 index 00000000..d68e7bc1 --- /dev/null +++ b/full-moon-lua-types-derive/src/lib.rs @@ -0,0 +1,10 @@ +use proc_macro::TokenStream; + +extern crate proc_macro; + +mod lua_user_data; + +#[proc_macro_derive(LuaUserData, attributes(lua))] +pub fn derive_lua_user_data(input: TokenStream) -> TokenStream { + lua_user_data::derive(input) +} diff --git a/full-moon-lua-types-derive/src/lua_user_data.rs b/full-moon-lua-types-derive/src/lua_user_data.rs new file mode 100644 index 00000000..4408c276 --- /dev/null +++ b/full-moon-lua-types-derive/src/lua_user_data.rs @@ -0,0 +1,418 @@ +use proc_macro::TokenStream; +use quote::{format_ident, quote}; + +fn match_all proc_macro2::TokenStream>( + input_ident: &syn::Ident, + variants: &syn::punctuated::Punctuated, + case: F, +) -> proc_macro2::TokenStream { + let mut cases = Vec::new(); + + for variant in variants { + let ident = &variant.ident; + let result = case(ident); + + match variant.fields { + syn::Fields::Named(_) => { + cases.push(quote! { + #input_ident::#ident { .. } => { #result } + }); + } + + syn::Fields::Unnamed(_) => { + cases.push(quote! { + #input_ident::#ident(..) => { #result } + }); + } + + syn::Fields::Unit => { + cases.push(quote! { + #input_ident::#ident => { #result } + }); + } + } + } + + quote! { + match this { + #(#cases,)* + } + } +} + +pub fn derive(input: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(input as syn::DeriveInput); + + let input_enum = match &input.data { + syn::Data::Enum(input_enum) => input_enum, + _ => panic!("can only derive for enums"), + }; + + let input_ident = &input.ident; + + let match_kind = match_all(input_ident, &input_enum.variants, |variant| { + quote! { + stringify!(#variant) + } + }); + + let match_to_string = match_all(input_ident, &input_enum.variants, |variant| { + quote! { + format!("{}::{}", stringify!(#input_ident), stringify!(#variant)) + } + }); + + // TODO: Error for invalid names + let match_match = { + let mut cases = Vec::with_capacity(input_enum.variants.len()); + + for variant in &input_enum.variants { + let variant_ident = &variant.ident; + + match &variant.fields { + syn::Fields::Named(fields) => { + let fields = fields + .named + .iter() + .map(|field| &field.ident) + .collect::>(); + + cases.push(quote! { + #input_ident::#variant_ident { #(#fields),* } => { + let mut table = lua.create_table()?; + + #( + table.set(stringify!(#fields), #fields.prepare_for_lua(lua)?)?; + )* + + (stringify!(#variant_ident), table.to_lua_multi(lua)?) + } + }); + } + + syn::Fields::Unnamed(fields) => { + let fields = fields + .unnamed + .iter() + .enumerate() + .map(|(index, _)| format_ident!("_{index}")) + .collect::>(); + + cases.push(quote! { + #input_ident::#variant_ident(#(#fields),*) => { + let mut fields = Vec::new(); + + #( + fields.push(#fields.prepare_for_lua(lua)?); + )* + + (stringify!(#variant_ident), mlua::MultiValue::from_vec(fields)) + } + }); + } + + syn::Fields::Unit => { + cases.push(quote! { + #input_ident::#variant_ident => { + (stringify!(#variant_ident), ().to_lua_multi(lua)?) + } + }); + } + } + } + + quote! { + use mlua::ToLuaMulti; + + let (function_name, args) = match this { + #(#cases,)* + }; + + match table_or_name { + crate::either::Either::A(table) => { + match table.get::<_, Option>(function_name)? { + Some(mlua::Value::Function(function)) => { + function.call::<_, mlua::MultiValue>(args) + } + + Some(other) => { + Err(mlua::Error::external(format!( + "expected function for {}, got {}", + function_name, + other.type_name(), + ))) + } + + None => { + mlua::Value::Nil.to_lua_multi(lua) + } + } + } + + crate::either::Either::B(name) => { + if name == function_name { + Ok(args) + } else { + mlua::Value::Nil.to_lua_multi(lua) + } + } + } + } + }; + + let match_create_ast_node = { + let mut cases = Vec::with_capacity(input_enum.variants.len()); + + for variant in &input_enum.variants { + let variant_ident = &variant.ident; + + if variant_ident == "NonStandard" { + let fields = match &variant.fields { + syn::Fields::Named(_) => quote! { { .. } }, + syn::Fields::Unnamed(_) => quote! { ( .. ) }, + syn::Fields::Unit => quote! {}, + }; + + cases.push(quote! { + #input_ident::#variant_ident #fields => { + None + } + }); + + continue; + } + + match &variant.fields { + syn::Fields::Named(field) => { + let fields = field + .named + .iter() + .filter_map(|field| field.ident.as_ref()) + .map(quote::ToTokens::to_token_stream) + .collect::>(); + + let mut added_fields: Vec = Vec::new(); + + for attr in &variant.attrs { + if !attr.path.is_ident("lua") { + continue; + } + + let name_value = attr + .parse_args::() + .expect("expected name value for lua attribute"); + + if !name_value.path.is_ident("add_field") { + continue; + } + + added_fields.push( + syn::parse_str(&match name_value.lit { + syn::Lit::Str(lit_str) => lit_str.value(), + _ => panic!("expected string for add_field"), + }) + .unwrap(), + ); + } + + cases.push(quote! { + #input_ident::#variant_ident { #(#fields),* } => { + Some(full_moon::ast::#input_ident::#variant_ident { + #(#fields: #fields.create_ast_node()?,)* + #(#added_fields,)* + }) + } + }); + } + + syn::Fields::Unnamed(field) => { + let fields = field + .unnamed + .iter() + .enumerate() + .map(|(index, _)| format_ident!("_{index}")) + .collect::>(); + + let body = match variant + .attrs + .iter() + .filter_map(|attr| { + if !attr.path.is_ident("lua") { + return None; + } + + let name_value = attr + .parse_args::() + .expect("expected name value for #[lua(create_ast_node)]"); + + if name_value.path.is_ident("create_ast_node") { + Some(match name_value.lit { + syn::Lit::Str(lit_str) => lit_str.value(), + _ => panic!("expected string for #[lua(create_ast_node)]"), + }) + } else { + None + } + }) + .next() + { + Some(create_ast_node_attr) => { + syn::parse_str(&create_ast_node_attr).expect("expected valid rust code") + } + + None => { + quote! { + full_moon::ast::#input_ident::#variant_ident( + #(#fields.create_ast_node()?),* + ) + } + } + }; + + cases.push(quote! { + #input_ident::#variant_ident(#(#fields),*) => { + #body.into() + } + }); + } + + syn::Fields::Unit => { + cases.push(quote! { + #input_ident::#variant_ident => { + Some(full_moon::ast::#input_ident::#variant_ident) + } + }); + } + } + } + + quote! { + match self { + #(#cases,)* + } + } + }; + + // TODO: This is copy-paste from match_match + let match_expect = { + let mut cases = Vec::with_capacity(input_enum.variants.len()); + + for variant in &input_enum.variants { + let variant_ident = &variant.ident; + + match &variant.fields { + syn::Fields::Named(fields) => { + let fields = fields + .named + .iter() + .map(|field| &field.ident) + .collect::>(); + + cases.push(quote! { + #input_ident::#variant_ident { #(#fields),* } => { + let mut table = lua.create_table()?; + + #( + table.set(stringify!(#fields), #fields.prepare_for_lua(lua)?)?; + )* + + (stringify!(#variant_ident), table.to_lua_multi(lua)?) + } + }); + } + + syn::Fields::Unnamed(fields) => { + let fields = fields + .unnamed + .iter() + .enumerate() + .map(|(index, _)| format_ident!("_{index}")) + .collect::>(); + + cases.push(quote! { + #input_ident::#variant_ident(#(#fields),*) => { + let mut fields = Vec::new(); + + #( + fields.push(#fields.prepare_for_lua(lua)?); + )* + + (stringify!(#variant_ident), mlua::MultiValue::from_vec(fields)) + } + }); + } + + syn::Fields::Unit => { + cases.push(quote! { + #input_ident::#variant_ident => { + (stringify!(#variant_ident), ().to_lua_multi(lua)?) + } + }); + } + } + } + + quote! { + use mlua::ToLuaMulti; + + let (function_name, args) = match this { + #(#cases,)* + }; + + if function_name == variant { + Ok(args) + } else { + Err(mlua::Error::external(format!( + "expected {}, got {}", + variant, + function_name, + ))) + } + } + }; + + quote! { + impl mlua::UserData for #input_ident { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("kind", |_, this| { + Ok(#match_kind) + }); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + crate::mlua_util::add_core_metamethods_no_tostring(stringify!(#input_ident), methods); + crate::mlua_util::add_create_ast_node_methods(methods); + crate::mlua_util::add_print(methods); + crate::mlua_util::add_visit(methods); + + methods.add_meta_method(mlua::MetaMethod::ToString, |_, this, _: ()| { + Ok(#match_to_string) + }); + + methods.add_method("expect", |lua, this, variant: String| { + #match_expect + }); + + methods.add_method("match", |lua, this, value: mlua::Value| { + let table_or_name = crate::either::take_either::( + lua, + value, + "table of variants to callbacks", + "variant name", + )?; + + #match_match + }); + } + } + + impl crate::ast_traits::CreateAstNode for #input_ident { + type Node = full_moon::ast::#input_ident; + + fn create_ast_node(&self) -> Option { + #match_create_ast_node + } + } + } + .into() +} diff --git a/full-moon-lua-types/Cargo.toml b/full-moon-lua-types/Cargo.toml new file mode 100644 index 00000000..966a89d0 --- /dev/null +++ b/full-moon-lua-types/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "full-moon-lua-types" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +full_moon = { version = "0.16.2", path = "../full-moon" } +full_moon_lua_types_derive = { path = "../full-moon-lua-types-derive" } +mlua = { version = "0.8.3", features = ["luau", "send"] } +paste = "1.0.9" + +[features] +luau = ["full_moon/roblox"] diff --git a/full-moon-lua-types/src/ast_traits.rs b/full-moon-lua-types/src/ast_traits.rs new file mode 100644 index 00000000..28aa5195 --- /dev/null +++ b/full-moon-lua-types/src/ast_traits.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; + +use mlua::ToLua; + +use crate::mlua_util::ArcLocked; + +pub trait CreateAstNode { + type Node; + + fn create_ast_node(&self) -> Option; +} + +impl CreateAstNode for Box { + type Node = Box; + + fn create_ast_node(&self) -> Option { + Some(Box::new((**self).create_ast_node()?)) + } +} + +impl CreateAstNode for Option { + type Node = T::Node; + + fn create_ast_node(&self) -> Option { + self.as_ref().and_then(|value| value.create_ast_node()) + } +} + +impl CreateAstNode for ArcLocked { + type Node = T::Node; + + fn create_ast_node(&self) -> Option { + Arc::clone(self).read().unwrap().create_ast_node() + } +} + +impl CreateAstNode for (T, U) { + type Node = (T::Node, U::Node); + + fn create_ast_node(&self) -> Option { + Some((self.0.create_ast_node()?, self.1.create_ast_node()?)) + } +} + +pub trait AstToLua { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result>; +} diff --git a/full-moon-lua-types/src/core.rs b/full-moon-lua-types/src/core.rs new file mode 100644 index 00000000..eb6c0e05 --- /dev/null +++ b/full-moon-lua-types/src/core.rs @@ -0,0 +1,1869 @@ +use std::sync::{Arc, RwLock}; + +use full_moon::{ast, node::Node, tokenizer}; +use full_moon_lua_types_derive::LuaUserData; +use mlua::{Table, ToLua, UserData}; + +use crate::{ + ast_traits::CreateAstNode, + mlua_util::{ + add_core_meta_methods, add_create_ast_node_methods, add_newindex_block, add_print, + add_visit, ArcLocked, + }, + prepare_for_lua::PrepareForLua, + shared::*, + AstToLua, +}; + +fn l(t: T) -> ArcLocked { + Arc::new(RwLock::new(t)) +} + +pub struct Ast { + nodes: ArcLocked, + eof: ArcLocked, +} + +impl From<&ast::Ast> for Ast { + fn from(ast: &ast::Ast) -> Self { + Ast { + nodes: l(Block::new(ast.nodes())), + eof: l(TokenReference::new(ast.eof())), + } + } +} + +impl UserData for Ast { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("eof", |_, Ast { eof, .. }| Ok(eof.clone())); + fields.add_field_method_get("nodes", |_, Ast { nodes, .. }| Ok(nodes.clone())); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("Ast", methods); + + crate::visitor::add_visit_with_visitor(methods, |ast, visitor| { + use full_moon::visitors::Visitor; + visitor.visit_ast(&ast); + }); + } +} + +impl CreateAstNode for Ast { + type Node = ast::Ast; + + fn create_ast_node(&self) -> Option { + Some( + ast::Ast::from_tokens(vec![tokenizer::Token::new(tokenizer::TokenType::Eof)]) + .unwrap() + .with_eof(self.eof.create_ast_node()?) + .with_nodes(self.nodes.create_ast_node()?), + ) + } +} + +impl AstToLua for ast::Ast { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + Ast::from(self).to_lua(lua) + } +} + +pub struct Assignment { + var_list: ArcLocked>, + equal_token: ArcLocked, + expr_list: ArcLocked>, +} + +impl Assignment { + pub fn new(assignment: &ast::Assignment) -> Self { + Assignment { + var_list: l(Punctuated::map_from_punctuated( + assignment.variables(), + Var::new, + )), + equal_token: l(TokenReference::new(assignment.equal_token())), + expr_list: l(Punctuated::map_from_punctuated( + assignment.expressions(), + Expression::new, + )), + } + } +} + +impl UserData for Assignment { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("variables", |_, Assignment { var_list, .. }| { + Ok(var_list.clone()) + }); + + fields.add_field_method_get("equal_token", |_, Assignment { equal_token, .. }| { + Ok(equal_token.clone()) + }); + + fields.add_field_method_get("expressions", |_, Assignment { expr_list, .. }| { + Ok(expr_list.clone()) + }); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("Assignment", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for Assignment { + type Node = ast::Assignment; + + fn create_ast_node(&self) -> Option { + Some( + ast::Assignment::new( + self.var_list.create_ast_node()?, + self.expr_list.create_ast_node()?, + ) + .with_equal_token(self.equal_token.create_ast_node()?), + ) + } +} + +impl AstToLua for ast::Assignment { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + Assignment::new(self).to_lua(lua) + } +} + +#[derive(Clone, LuaUserData)] +pub enum BinOp { + And(ArcLocked), + Caret(ArcLocked), + GreaterThan(ArcLocked), + GreaterThanEqual(ArcLocked), + LessThan(ArcLocked), + LessThanEqual(ArcLocked), + Minus(ArcLocked), + Or(ArcLocked), + Percent(ArcLocked), + Plus(ArcLocked), + Slash(ArcLocked), + Star(ArcLocked), + TildeEqual(ArcLocked), + TwoDots(ArcLocked), + TwoEqual(ArcLocked), +} + +impl BinOp { + pub fn new(bin_op: &ast::BinOp) -> Self { + match bin_op { + ast::BinOp::And(token) => BinOp::And(l(TokenReference::new(&token))), + ast::BinOp::Caret(token) => BinOp::Caret(l(TokenReference::new(&token))), + ast::BinOp::GreaterThan(token) => BinOp::GreaterThan(l(TokenReference::new(&token))), + ast::BinOp::GreaterThanEqual(token) => { + BinOp::GreaterThanEqual(l(TokenReference::new(&token))) + } + ast::BinOp::LessThan(token) => BinOp::LessThan(l(TokenReference::new(&token))), + ast::BinOp::LessThanEqual(token) => { + BinOp::LessThanEqual(l(TokenReference::new(&token))) + } + ast::BinOp::Minus(token) => BinOp::Minus(l(TokenReference::new(&token))), + ast::BinOp::Or(token) => BinOp::Or(l(TokenReference::new(&token))), + ast::BinOp::Percent(token) => BinOp::Percent(l(TokenReference::new(&token))), + ast::BinOp::Plus(token) => BinOp::Plus(l(TokenReference::new(&token))), + ast::BinOp::Slash(token) => BinOp::Slash(l(TokenReference::new(&token))), + ast::BinOp::Star(token) => BinOp::Star(l(TokenReference::new(&token))), + ast::BinOp::TildeEqual(token) => BinOp::TildeEqual(l(TokenReference::new(&token))), + ast::BinOp::TwoDots(token) => BinOp::TwoDots(l(TokenReference::new(&token))), + ast::BinOp::TwoEqual(token) => BinOp::TwoEqual(l(TokenReference::new(&token))), + other => panic!("unimplemented BinOp: {other:?}"), + } + } +} + +impl AstToLua for ast::BinOp { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + BinOp::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct Block { + stmts: Vec<(ArcLocked, Option>)>, + last_stmt: Option<(ArcLocked, Option>)>, +} + +impl Block { + pub fn new(block: &ast::Block) -> Self { + Block { + stmts: block + .stmts_with_semicolon() + .map(|(stmt, token)| { + ( + l(Stmt::new(stmt)), + token.as_ref().map(TokenReference::new).map(l), + ) + }) + .collect(), + + last_stmt: block + .last_stmt_with_semicolon() + .as_ref() + .map(|(last_stmt, token)| { + ( + l(LastStmt::new(last_stmt)), + token.as_ref().map(TokenReference::new).map(l), + ) + }), + } + } +} + +impl UserData for Block { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("stmts", |_, Block { stmts, .. }| { + Ok(stmts + .iter() + .map(|(stmt, _)| stmt.clone()) + .collect::>()) + }); + + fields.add_field_method_get("last_stmt", |_, Block { last_stmt, .. }| { + Ok(last_stmt.as_ref().map(|(last_stmt, _)| last_stmt.clone())) + }); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("Block", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for Block { + type Node = ast::Block; + + fn create_ast_node(&self) -> Option { + Some( + ast::Block::new() + .with_stmts( + self.stmts + .iter() + .map(|(stmt, token)| { + Some((stmt.create_ast_node()?, token.create_ast_node())) + }) + .collect::>>()?, + ) + .with_last_stmt(self.last_stmt.as_ref().and_then(|(last_stmt, token)| { + Some((last_stmt.create_ast_node()?, token.create_ast_node())) + })), + ) + } +} + +impl AstToLua for ast::Block { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + Block::new(self).to_lua(lua) + } +} + +#[derive(Clone, LuaUserData)] +pub enum Call { + AnonymousCall(ArcLocked), + MethodCall(ArcLocked), +} + +impl Call { + pub fn new(call: &ast::Call) -> Self { + match call { + ast::Call::AnonymousCall(function_args) => { + Call::AnonymousCall(l(FunctionArgs::new(function_args))) + } + ast::Call::MethodCall(method_call) => Call::MethodCall(l(MethodCall::new(method_call))), + other => panic!("unimplemented Call: {other:?}"), + } + } +} + +impl AstToLua for ast::Call { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + Call::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct Do { + do_token: ArcLocked, + block: ArcLocked, + end_token: ArcLocked, +} + +impl Do { + pub fn new(do_: &ast::Do) -> Self { + Do { + do_token: l(TokenReference::new(do_.do_token())), + block: l(Block::new(do_.block())), + end_token: l(TokenReference::new(do_.end_token())), + } + } +} + +impl UserData for Do { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("do_token", |_, Do { do_token, .. }| Ok(do_token.clone())); + + fields.add_field_method_get("block", |_, Do { block, .. }| Ok(block.clone())); + + fields.add_field_method_get("end_token", |_, Do { end_token, .. }| Ok(end_token.clone())); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("Do", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for Do { + type Node = ast::Do; + + fn create_ast_node(&self) -> Option { + Some( + ast::Do::new() + .with_block(self.block.create_ast_node()?) + .with_do_token(self.do_token.create_ast_node()?) + .with_end_token(self.end_token.create_ast_node()?), + ) + } +} + +impl AstToLua for ast::Do { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + Do::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct ElseIf { + else_if_token: ArcLocked, + condition: ArcLocked, + then_token: ArcLocked, + block: ArcLocked, +} + +impl ElseIf { + pub fn new(else_if: &ast::ElseIf) -> Self { + ElseIf { + else_if_token: l(TokenReference::new(else_if.else_if_token())), + condition: l(Expression::new(else_if.condition())), + then_token: l(TokenReference::new(else_if.then_token())), + block: l(Block::new(else_if.block())), + } + } +} + +impl UserData for ElseIf { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("else_if_token", |_, ElseIf { else_if_token, .. }| { + Ok(else_if_token.clone()) + }); + + fields.add_field_method_get("condition", |_, ElseIf { condition, .. }| { + Ok(condition.clone()) + }); + + fields.add_field_method_get("then_token", |_, ElseIf { then_token, .. }| { + Ok(then_token.clone()) + }); + + fields.add_field_method_get("block", |_, ElseIf { block, .. }| Ok(block.clone())); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("ElseIf", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for ElseIf { + type Node = ast::ElseIf; + + fn create_ast_node(&self) -> Option { + Some( + ast::ElseIf::new(self.condition.create_ast_node()?) + .with_block(self.block.create_ast_node()?) + .with_else_if_token(self.else_if_token.create_ast_node()?) + .with_then_token(self.then_token.create_ast_node()?), + ) + } +} + +impl AstToLua for ast::ElseIf { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + ElseIf::new(self).to_lua(lua) + } +} + +#[derive(Clone, LuaUserData)] +pub enum Expression { + BinaryOperator { + lhs: Box>, + binop: ArcLocked, + rhs: Box>, + }, + + Parentheses { + contained: ArcLocked, + expression: Box>, + }, + + UnaryOperator { + unop: ArcLocked, + expression: Box>, + }, + + #[cfg_attr(feature = "luau", lua(add_field = "type_assertion: None"))] + Value { value: Box> }, +} + +impl Expression { + pub fn new(expression: &ast::Expression) -> Self { + match expression { + ast::Expression::BinaryOperator { lhs, binop, rhs } => Expression::BinaryOperator { + lhs: Box::new(l(Expression::new(lhs))), + binop: l(BinOp::new(binop)), + rhs: Box::new(l(Expression::new(rhs))), + }, + + ast::Expression::Parentheses { + contained, + expression, + } => Expression::Parentheses { + contained: l(ContainedSpan::new(contained)), + expression: Box::new(l(Expression::new(expression))), + }, + + ast::Expression::UnaryOperator { unop, expression } => Expression::UnaryOperator { + unop: l(UnOp::new(unop)), + expression: Box::new(l(Expression::new(expression))), + }, + + ast::Expression::Value { value, .. } => Expression::Value { + value: Box::new(l(Value::new(value))), + }, + + other => panic!("unimplemented Expression: {other:?}"), + } + } +} + +impl AstToLua for ast::Expression { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + Expression::new(self).to_lua(lua) + } +} + +#[derive(Clone, LuaUserData)] +pub enum FunctionArgs { + Parentheses { + parentheses: ArcLocked, + arguments: ArcLocked>, + }, + + String(ArcLocked), + + TableConstructor(ArcLocked), +} + +impl FunctionArgs { + pub fn new(function_args: &ast::FunctionArgs) -> Self { + match function_args { + ast::FunctionArgs::Parentheses { + parentheses, + arguments, + } => FunctionArgs::Parentheses { + parentheses: l(ContainedSpan::new(parentheses)), + arguments: l(Punctuated::map_from_punctuated(arguments, Expression::new)), + }, + + ast::FunctionArgs::String(token) => FunctionArgs::String(l(TokenReference::new(token))), + + ast::FunctionArgs::TableConstructor(table_constructor) => { + FunctionArgs::TableConstructor(l(TableConstructor::new(table_constructor))) + } + + other => panic!("unimplemented FunctionArgs: {other:?}"), + } + } +} + +impl AstToLua for ast::FunctionArgs { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + FunctionArgs::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct FunctionBody { + parameters_parentheses: ArcLocked, + parameters: ArcLocked>, + block: ArcLocked, + end_token: ArcLocked, +} + +impl FunctionBody { + pub fn new(function_body: &ast::FunctionBody) -> Self { + FunctionBody { + parameters_parentheses: l(ContainedSpan::new(function_body.parameters_parentheses())), + + parameters: l(Punctuated::map_from_punctuated( + function_body.parameters(), + Parameter::new, + )), + + block: l(Block::new(function_body.block())), + end_token: l(TokenReference::new(function_body.end_token())), + } + } +} + +impl UserData for FunctionBody { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get( + "parameters_parentheses", + |_, + FunctionBody { + parameters_parentheses, + .. + }| { Ok(parameters_parentheses.clone()) }, + ); + + fields.add_field_method_get("parameters", |_, FunctionBody { parameters, .. }| { + Ok(parameters.clone()) + }); + + fields.add_field_method_get("block", |_, FunctionBody { block, .. }| Ok(block.clone())); + + fields.add_field_method_get("end_token", |_, FunctionBody { end_token, .. }| { + Ok(end_token.clone()) + }); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("FunctionBody", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for FunctionBody { + type Node = ast::FunctionBody; + + fn create_ast_node(&self) -> Option { + Some( + ast::FunctionBody::new() + .with_block(self.block.create_ast_node()?) + .with_end_token(self.end_token.create_ast_node()?) + .with_parameters(self.parameters.create_ast_node()?) + .with_parameters_parentheses(self.parameters_parentheses.create_ast_node()?), + ) + } +} + +impl AstToLua for ast::FunctionBody { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + FunctionBody::new(self).to_lua(lua) + } +} + +#[derive(Clone, LuaUserData)] +pub enum LastStmt { + Break(ArcLocked), + #[cfg(feature = "luau")] + Continue(ArcLocked), + Return(ArcLocked), +} + +impl LastStmt { + pub fn new(last_stmt: &ast::LastStmt) -> Self { + match last_stmt { + ast::LastStmt::Break(break_token) => { + LastStmt::Break(l(TokenReference::new(break_token))) + } + + #[cfg(feature = "luau")] + ast::LastStmt::Continue(continue_token) => { + LastStmt::Continue(l(TokenReference::new(continue_token))) + } + + ast::LastStmt::Return(return_token) => LastStmt::Return(l(Return::new(return_token))), + + _ => unimplemented!("unexpected LastStmt variant: {last_stmt:#?}"), + } + } +} + +impl AstToLua for ast::LastStmt { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + LastStmt::new(self).to_lua(lua) + } +} + +#[derive(Clone, LuaUserData)] +pub enum Field { + ExpressionKey { + brackets: ArcLocked, + key: ArcLocked, + equal: ArcLocked, + value: ArcLocked, + }, + + NameKey { + key: ArcLocked, + equal: ArcLocked, + value: ArcLocked, + }, + + NoKey(ArcLocked), +} + +impl Field { + pub fn new(field: &ast::Field) -> Self { + match field { + ast::Field::ExpressionKey { + brackets, + key, + equal, + value, + } => Field::ExpressionKey { + brackets: l(ContainedSpan::new(brackets)), + key: l(Expression::new(key)), + equal: l(TokenReference::new(equal)), + value: l(Expression::new(value)), + }, + + ast::Field::NameKey { key, equal, value } => Field::NameKey { + key: l(TokenReference::new(key)), + equal: l(TokenReference::new(equal)), + value: l(Expression::new(value)), + }, + + ast::Field::NoKey(expression) => Field::NoKey(l(Expression::new(expression))), + + other => panic!("unimplemented Field: {other:?}"), + } + } +} + +impl AstToLua for ast::Field { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + Field::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct FunctionCall { + prefix: ArcLocked, + suffixes: Vec>, +} + +impl FunctionCall { + pub fn new(function_call: &ast::FunctionCall) -> Self { + FunctionCall { + prefix: l(Prefix::new(function_call.prefix())), + suffixes: function_call.suffixes().map(Suffix::new).map(l).collect(), + } + } +} + +impl UserData for FunctionCall { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get( + "prefix", + |_, FunctionCall { prefix, .. }| Ok(prefix.clone()), + ); + + fields.add_field_method_get("suffixes", |_, FunctionCall { suffixes, .. }| { + Ok(suffixes.clone()) + }); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("FunctionCall", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for FunctionCall { + type Node = ast::FunctionCall; + + fn create_ast_node(&self) -> Option { + Some( + ast::FunctionCall::new(self.prefix.create_ast_node()?).with_suffixes( + self.suffixes + .iter() + .map(|suffix| suffix.read().unwrap().create_ast_node()) + .collect::>>()?, + ), + ) + } +} + +impl AstToLua for ast::FunctionCall { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + FunctionCall::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct FunctionDeclaration { + function_token: ArcLocked, + name: ArcLocked, + body: ArcLocked, +} + +impl FunctionDeclaration { + pub fn new(function_declaration: &ast::FunctionDeclaration) -> Self { + FunctionDeclaration { + function_token: l(TokenReference::new(function_declaration.function_token())), + name: l(FunctionName::new(function_declaration.name())), + body: l(FunctionBody::new(function_declaration.body())), + } + } +} + +impl UserData for FunctionDeclaration { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get( + "function_token", + |_, FunctionDeclaration { function_token, .. }| Ok(function_token.clone()), + ); + + fields.add_field_method_get("name", |_, FunctionDeclaration { name, .. }| { + Ok(name.clone()) + }); + + fields.add_field_method_get("body", |_, FunctionDeclaration { body, .. }| { + Ok(body.clone()) + }); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("FunctionDeclaration", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for FunctionDeclaration { + type Node = ast::FunctionDeclaration; + + fn create_ast_node(&self) -> Option { + Some( + ast::FunctionDeclaration::new(self.name.create_ast_node()?) + .with_body(self.body.create_ast_node()?) + .with_function_token(self.function_token.create_ast_node()?), + ) + } +} + +impl AstToLua for ast::FunctionDeclaration { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + FunctionDeclaration::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct FunctionName { + names: ArcLocked>, + colon_name: Option>, +} + +impl FunctionName { + pub fn new(function_name: &ast::FunctionName) -> Self { + FunctionName { + names: l(Punctuated::map_from_punctuated( + function_name.names(), + TokenReference::new, + )), + + colon_name: match (function_name.method_colon(), function_name.method_name()) { + (Some(colon), Some(name)) => { + Some(l((TokenReference::new(colon), TokenReference::new(name)))) + } + + _ => None, + }, + } + } +} + +impl UserData for FunctionName { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("names", |_, FunctionName { names, .. }| Ok(names.clone())); + + fields.add_field_method_get("method_colon", |_, FunctionName { colon_name, .. }| { + Ok(colon_name + .as_ref() + .map(|lock| Arc::clone(lock).read().unwrap().0.clone())) + }); + + fields.add_field_method_get("method_name", |_, FunctionName { colon_name, .. }| { + Ok(colon_name + .as_ref() + .map(|lock| Arc::clone(lock).read().unwrap().1.clone())) + }); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("FunctionName", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for FunctionName { + type Node = ast::FunctionName; + + fn create_ast_node(&self) -> Option { + Some( + ast::FunctionName::new(self.names.create_ast_node()?).with_method( + self.colon_name.as_ref().and_then(|colon_name_arc| { + let colon_name_arc = Arc::clone(colon_name_arc); + let lock = colon_name_arc.read().unwrap(); + Some((lock.0.create_ast_node()?, lock.1.create_ast_node()?)) + }), + ), + ) + } +} + +impl AstToLua for ast::FunctionName { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + FunctionName::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct GenericFor { + for_token: ArcLocked, + names: ArcLocked>, + in_token: ArcLocked, + expr_list: ArcLocked>, + do_token: ArcLocked, + block: ArcLocked, + end_token: ArcLocked, +} + +impl GenericFor { + pub fn new(generic_for: &ast::GenericFor) -> Self { + GenericFor { + for_token: l(TokenReference::new(generic_for.for_token())), + names: l(Punctuated::map_from_punctuated( + generic_for.names(), + TokenReference::new, + )), + in_token: l(TokenReference::new(generic_for.in_token())), + expr_list: l(Punctuated::map_from_punctuated( + generic_for.expressions(), + Expression::new, + )), + do_token: l(TokenReference::new(generic_for.do_token())), + block: l(Block::new(generic_for.block())), + end_token: l(TokenReference::new(generic_for.end_token())), + } + } +} + +impl UserData for GenericFor { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("for_token", |_, GenericFor { for_token, .. }| { + Ok(for_token.clone()) + }); + + fields.add_field_method_get("names", |_, GenericFor { names, .. }| Ok(names.clone())); + + fields.add_field_method_get("in_token", |_, GenericFor { in_token, .. }| { + Ok(in_token.clone()) + }); + + fields.add_field_method_get("expressions", |_, GenericFor { expr_list, .. }| { + Ok(expr_list.clone()) + }); + + fields.add_field_method_get("do_token", |_, GenericFor { do_token, .. }| { + Ok(do_token.clone()) + }); + + fields.add_field_method_get("block", |_, GenericFor { block, .. }| Ok(block.clone())); + + fields.add_field_method_get("end_token", |_, GenericFor { end_token, .. }| { + Ok(end_token.clone()) + }); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("GenericFor", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for GenericFor { + type Node = ast::GenericFor; + + fn create_ast_node(&self) -> Option { + Some( + ast::GenericFor::new( + self.names.create_ast_node()?, + self.expr_list.create_ast_node()?, + ) + .with_for_token(self.for_token.create_ast_node()?) + .with_in_token(self.in_token.create_ast_node()?) + .with_do_token(self.do_token.create_ast_node()?) + .with_block(self.block.create_ast_node()?) + .with_end_token(self.end_token.create_ast_node()?), + ) + } +} + +impl AstToLua for ast::GenericFor { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + GenericFor::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct If { + if_token: ArcLocked, + condition: ArcLocked, + then_token: ArcLocked, + block: ArcLocked, + else_if: Option>>, + else_token: Option>, + else_block: Option>, + end_token: ArcLocked, +} + +impl If { + pub fn new(if_node: &ast::If) -> Self { + If { + if_token: l(TokenReference::new(if_node.if_token())), + condition: l(Expression::new(if_node.condition())), + then_token: l(TokenReference::new(if_node.then_token())), + block: l(Block::new(if_node.block())), + else_if: if_node + .else_if() + .map(|else_if| else_if.iter().map(ElseIf::new).map(l).collect()), + else_token: if_node.else_token().map(TokenReference::new).map(l), + else_block: if_node.else_block().map(Block::new).map(l), + end_token: l(TokenReference::new(if_node.end_token())), + } + } +} + +impl UserData for If { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("if_token", |_, If { if_token, .. }| Ok(if_token.clone())); + + fields.add_field_method_get("condition", |_, If { condition, .. }| Ok(condition.clone())); + + fields.add_field_method_get("then_token", |_, If { then_token, .. }| { + Ok(then_token.clone()) + }); + + fields.add_field_method_get("block", |_, If { block, .. }| Ok(block.clone())); + + fields.add_field_method_get("else_if", |_, If { else_if, .. }| { + Ok(else_if + .as_ref() + .map(|else_if| else_if.iter().map(Arc::clone).collect::>())) + }); + + fields.add_field_method_get("else_token", |_, If { else_token, .. }| { + Ok(else_token.clone()) + }); + + fields.add_field_method_get("else_block", |_, If { else_block, .. }| { + Ok(else_block.clone()) + }); + + fields.add_field_method_get("end_token", |_, If { end_token, .. }| Ok(end_token.clone())); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("If", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for If { + type Node = ast::If; + + fn create_ast_node(&self) -> Option { + Some( + ast::If::new(self.condition.create_ast_node()?) + .with_if_token(self.if_token.create_ast_node()?) + .with_then_token(self.then_token.create_ast_node()?) + .with_block(self.block.create_ast_node()?) + .with_else_if(self.else_if.as_ref().and_then(|else_if| { + else_if + .iter() + .map(|else_if| else_if.read().unwrap().create_ast_node()) + .collect::>>() + })) + .with_else_token(self.else_token.create_ast_node()) + .with_else(self.else_block.create_ast_node()) + .with_end_token(self.end_token.create_ast_node()?), + ) + } +} + +impl AstToLua for ast::If { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + If::new(self).to_lua(lua) + } +} + +#[derive(Clone, LuaUserData)] +pub enum Index { + Brackets { + brackets: ArcLocked, + expression: ArcLocked, + }, + + Dot { + dot: ArcLocked, + name: ArcLocked, + }, +} + +impl Index { + pub fn new(index: &ast::Index) -> Self { + match index { + ast::Index::Brackets { + brackets, + expression, + } => Index::Brackets { + brackets: l(ContainedSpan::new(brackets)), + expression: l(Expression::new(expression)), + }, + + ast::Index::Dot { dot, name } => Index::Dot { + dot: l(TokenReference::new(dot)), + name: l(TokenReference::new(name)), + }, + + other => panic!("unimplemented Index: {other:?}"), + } + } +} + +impl AstToLua for ast::Index { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + Index::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct LocalAssignment { + local_token: ArcLocked, + name_list: ArcLocked>, + equal_token: Option>, + expr_list: ArcLocked>, +} + +impl LocalAssignment { + pub fn new(local_assignment: &ast::LocalAssignment) -> Self { + LocalAssignment { + local_token: l(TokenReference::new(local_assignment.local_token())), + name_list: l(Punctuated::map_from_punctuated( + local_assignment.names(), + TokenReference::new, + )), + equal_token: local_assignment + .equal_token() + .map(TokenReference::new) + .map(l), + expr_list: l(Punctuated::map_from_punctuated( + local_assignment.expressions(), + Expression::new, + )), + } + } +} + +impl UserData for LocalAssignment { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("local_token", |_, LocalAssignment { local_token, .. }| { + Ok(local_token.clone()) + }); + + fields.add_field_method_get("names", |_, LocalAssignment { name_list, .. }| { + Ok(name_list.clone()) + }); + + fields.add_field_method_get("equal_token", |_, LocalAssignment { equal_token, .. }| { + Ok(equal_token.clone()) + }); + + fields.add_field_method_get("expressions", |_, LocalAssignment { expr_list, .. }| { + Ok(expr_list.clone()) + }); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("LocalAssignment", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for LocalAssignment { + type Node = ast::LocalAssignment; + + fn create_ast_node(&self) -> Option { + Some( + ast::LocalAssignment::new(self.name_list.create_ast_node()?) + .with_expressions(self.expr_list.create_ast_node()?) + .with_local_token(self.local_token.create_ast_node()?) + .with_equal_token(self.equal_token.create_ast_node()), + ) + } +} + +impl AstToLua for ast::LocalAssignment { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + LocalAssignment::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct LocalFunction { + local_token: ArcLocked, + function_token: ArcLocked, + name: ArcLocked, + body: ArcLocked, +} + +impl LocalFunction { + pub fn new(local_function: &ast::LocalFunction) -> Self { + LocalFunction { + local_token: l(TokenReference::new(local_function.local_token())), + function_token: l(TokenReference::new(local_function.function_token())), + name: l(TokenReference::new(local_function.name())), + body: l(FunctionBody::new(local_function.body())), + } + } +} + +impl UserData for LocalFunction { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("local_token", |_, LocalFunction { local_token, .. }| { + Ok(local_token.clone()) + }); + + fields.add_field_method_get( + "function_token", + |_, LocalFunction { function_token, .. }| Ok(function_token.clone()), + ); + + fields.add_field_method_get("name", |_, LocalFunction { name, .. }| Ok(name.clone())); + + fields.add_field_method_get("body", |_, LocalFunction { body, .. }| Ok(body.clone())); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("LocalFunction", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for LocalFunction { + type Node = ast::LocalFunction; + + fn create_ast_node(&self) -> Option { + Some( + ast::LocalFunction::new(self.name.create_ast_node()?) + .with_body(self.body.create_ast_node()?) + .with_local_token(self.local_token.create_ast_node()?) + .with_function_token(self.function_token.create_ast_node()?), + ) + } +} + +impl AstToLua for ast::LocalFunction { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + LocalFunction::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct MethodCall { + colon_token: ArcLocked, + name: ArcLocked, + args: ArcLocked, +} + +impl MethodCall { + pub fn new(method_call: &ast::MethodCall) -> Self { + MethodCall { + colon_token: l(TokenReference::new(method_call.colon_token())), + name: l(TokenReference::new(method_call.name())), + args: l(FunctionArgs::new(method_call.args())), + } + } +} + +impl UserData for MethodCall { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("colon_token", |_, MethodCall { colon_token, .. }| { + Ok(colon_token.clone()) + }); + + fields.add_field_method_get("name", |_, MethodCall { name, .. }| Ok(name.clone())); + + fields.add_field_method_get("args", |_, MethodCall { args, .. }| Ok(args.clone())); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("MethodCall", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for MethodCall { + type Node = ast::MethodCall; + + fn create_ast_node(&self) -> Option { + Some( + ast::MethodCall::new(self.name.create_ast_node()?, self.args.create_ast_node()?) + .with_colon_token(self.colon_token.create_ast_node()?), + ) + } +} + +impl AstToLua for ast::MethodCall { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + MethodCall::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct NumericFor { + for_token: ArcLocked, + index_variable: ArcLocked, + equal_token: ArcLocked, + start: ArcLocked, + start_end_comma: ArcLocked, + end: ArcLocked, + end_step_comma: Option>, + step: Option>, + do_token: ArcLocked, + block: ArcLocked, + end_token: ArcLocked, +} + +impl NumericFor { + pub fn new(numeric_for: &ast::NumericFor) -> Self { + NumericFor { + for_token: l(TokenReference::new(numeric_for.for_token())), + index_variable: l(TokenReference::new(numeric_for.index_variable())), + equal_token: l(TokenReference::new(numeric_for.equal_token())), + start: l(Expression::new(numeric_for.start())), + start_end_comma: l(TokenReference::new(numeric_for.start_end_comma())), + end: l(Expression::new(numeric_for.end())), + end_step_comma: numeric_for.end_step_comma().map(TokenReference::new).map(l), + step: numeric_for.step().map(Expression::new).map(l), + do_token: l(TokenReference::new(numeric_for.do_token())), + block: l(Block::new(numeric_for.block())), + end_token: l(TokenReference::new(numeric_for.end_token())), + } + } +} + +impl UserData for NumericFor { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("for_token", |_, NumericFor { for_token, .. }| { + Ok(for_token.clone()) + }); + + fields.add_field_method_get("index_variable", |_, NumericFor { index_variable, .. }| { + Ok(index_variable.clone()) + }); + + fields.add_field_method_get("equal_token", |_, NumericFor { equal_token, .. }| { + Ok(equal_token.clone()) + }); + + fields.add_field_method_get("start", |_, NumericFor { start, .. }| Ok(start.clone())); + + fields.add_field_method_get( + "start_end_comma", + |_, + NumericFor { + start_end_comma, .. + }| Ok(start_end_comma.clone()), + ); + + fields.add_field_method_get("end", |_, NumericFor { end, .. }| Ok(end.clone())); + + fields.add_field_method_get("end_step_comma", |_, NumericFor { end_step_comma, .. }| { + Ok(end_step_comma.clone()) + }); + + fields.add_field_method_get("step", |_, NumericFor { step, .. }| Ok(step.clone())); + + fields.add_field_method_get("do_token", |_, NumericFor { do_token, .. }| { + Ok(do_token.clone()) + }); + + fields.add_field_method_get("block", |_, NumericFor { block, .. }| Ok(block.clone())); + + fields.add_field_method_get("end_token", |_, NumericFor { end_token, .. }| { + Ok(end_token.clone()) + }); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("NumericFor", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for NumericFor { + type Node = ast::NumericFor; + + fn create_ast_node(&self) -> Option { + Some( + ast::NumericFor::new( + self.index_variable.create_ast_node()?, + self.start.create_ast_node()?, + self.end.create_ast_node()?, + ) + .with_step(self.step.create_ast_node()) + .with_block(self.block.create_ast_node()?) + .with_end_token(self.end_token.create_ast_node()?) + .with_start_end_comma(self.start_end_comma.create_ast_node()?) + .with_end_step_comma(self.end_step_comma.create_ast_node()) + .with_for_token(self.for_token.create_ast_node()?) + .with_equal_token(self.equal_token.create_ast_node()?) + .with_do_token(self.do_token.create_ast_node()?), + ) + } +} + +impl AstToLua for ast::NumericFor { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + NumericFor::new(self).to_lua(lua) + } +} + +#[derive(Clone, LuaUserData)] +pub enum Parameter { + Ellipse(ArcLocked), + Name(ArcLocked), +} + +impl Parameter { + pub fn new(parameter: &ast::Parameter) -> Self { + match parameter { + ast::Parameter::Ellipse(ellipse_token) => { + Parameter::Ellipse(l(TokenReference::new(ellipse_token))) + } + + ast::Parameter::Name(name_token) => Parameter::Name(l(TokenReference::new(name_token))), + + _ => unimplemented!("unexpected Parameter variant: {parameter:#?}"), + } + } +} + +impl AstToLua for ast::Parameter { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + Parameter::new(self).to_lua(lua) + } +} + +#[derive(Clone, LuaUserData)] +pub enum Prefix { + Expression(ArcLocked), + Name(ArcLocked), +} + +impl Prefix { + pub fn new(prefix: &ast::Prefix) -> Self { + match prefix { + ast::Prefix::Expression(expr) => Prefix::Expression(l(Expression::new(expr))), + ast::Prefix::Name(name) => Prefix::Name(l(TokenReference::new(name))), + other => unimplemented!("unexpected Prefix variant: {other:?}"), + } + } +} + +impl AstToLua for ast::Prefix { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + Prefix::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct Return { + token: ArcLocked, + returns: ArcLocked>, +} + +impl Return { + pub fn new(return_token: &ast::Return) -> Self { + Return { + token: l(TokenReference::new(return_token.token())), + returns: l(Punctuated::map_from_punctuated( + return_token.returns(), + Expression::new, + )), + } + } +} + +impl UserData for Return { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("token", |_, Return { token, .. }| Ok(token.clone())); + + fields.add_field_method_get("returns", |_, Return { returns, .. }| Ok(returns.clone())); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("Return", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for Return { + type Node = ast::Return; + + fn create_ast_node(&self) -> Option { + Some( + ast::Return::new() + .with_token(self.token.create_ast_node()?) + .with_returns(self.returns.create_ast_node()?), + ) + } +} + +impl AstToLua for ast::Return { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + Return::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct Repeat { + repeat_token: ArcLocked, + block: ArcLocked, + until_token: ArcLocked, + until: ArcLocked, +} + +impl Repeat { + pub fn new(repeat: &ast::Repeat) -> Self { + Repeat { + repeat_token: l(TokenReference::new(repeat.repeat_token())), + block: l(Block::new(repeat.block())), + until_token: l(TokenReference::new(repeat.until_token())), + until: l(Expression::new(repeat.until())), + } + } +} + +impl UserData for Repeat { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("repeat_token", |_, Repeat { repeat_token, .. }| { + Ok(repeat_token.clone()) + }); + + fields.add_field_method_get("block", |_, Repeat { block, .. }| Ok(block.clone())); + + fields.add_field_method_get("until_token", |_, Repeat { until_token, .. }| { + Ok(until_token.clone()) + }); + + fields.add_field_method_get("until", |_, Repeat { until, .. }| Ok(until.clone())); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("Repeat", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for Repeat { + type Node = ast::Repeat; + + fn create_ast_node(&self) -> Option { + Some( + ast::Repeat::new(self.until.create_ast_node()?) + .with_repeat_token(self.repeat_token.create_ast_node()?) + .with_block(self.block.create_ast_node()?) + .with_until_token(self.until_token.create_ast_node()?), + ) + } +} + +impl AstToLua for ast::Repeat { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + Repeat::new(self).to_lua(lua) + } +} + +#[derive(Clone, LuaUserData)] +pub enum Stmt { + Assignment(ArcLocked), + Do(ArcLocked), + FunctionCall(ArcLocked), + FunctionDeclaration(ArcLocked), + GenericFor(ArcLocked), + If(ArcLocked), + LocalAssignment(ArcLocked), + LocalFunction(ArcLocked), + NumericFor(ArcLocked), + Repeat(ArcLocked), + While(ArcLocked), + + NonStandard(Vec>), +} + +impl Stmt { + pub fn new(stmt: &ast::Stmt) -> Self { + match stmt { + ast::Stmt::Assignment(assignment) => Stmt::Assignment(l(Assignment::new(assignment))), + ast::Stmt::Do(do_token) => Stmt::Do(l(Do::new(do_token))), + ast::Stmt::FunctionCall(function_call) => { + Stmt::FunctionCall(l(FunctionCall::new(function_call))) + } + ast::Stmt::FunctionDeclaration(function_declaration) => { + Stmt::FunctionDeclaration(l(FunctionDeclaration::new(function_declaration))) + } + ast::Stmt::GenericFor(generic_for) => Stmt::GenericFor(l(GenericFor::new(generic_for))), + ast::Stmt::If(if_token) => Stmt::If(l(If::new(if_token))), + ast::Stmt::LocalAssignment(local_assignment) => { + Stmt::LocalAssignment(l(LocalAssignment::new(local_assignment))) + } + ast::Stmt::LocalFunction(local_function) => { + Stmt::LocalFunction(l(LocalFunction::new(local_function))) + } + ast::Stmt::NumericFor(numeric_for) => Stmt::NumericFor(l(NumericFor::new(numeric_for))), + ast::Stmt::Repeat(repeat_token) => Stmt::Repeat(l(Repeat::new(repeat_token))), + ast::Stmt::While(while_token) => Stmt::While(l(While::new(while_token))), + + // TODO: Support everything, then make this `unimplemented!` + _ => Stmt::NonStandard(stmt.tokens().map(TokenReference::new).map(l).collect()), + } + } +} + +impl AstToLua for ast::Stmt { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + Stmt::new(self).to_lua(lua) + } +} + +#[derive(Clone, LuaUserData)] +pub enum Suffix { + Call(ArcLocked), + Index(ArcLocked), +} + +impl Suffix { + pub fn new(suffix: &ast::Suffix) -> Self { + match suffix { + ast::Suffix::Call(call) => Suffix::Call(l(Call::new(call))), + ast::Suffix::Index(index) => Suffix::Index(l(Index::new(index))), + other => unimplemented!("unexpected Suffix variant: {other:#?}"), + } + } +} + +impl AstToLua for ast::Suffix { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + Suffix::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct TableConstructor { + braces: ArcLocked, + fields: ArcLocked>, +} + +impl TableConstructor { + pub fn new(table_constructor: &ast::TableConstructor) -> Self { + TableConstructor { + braces: l(ContainedSpan::new(table_constructor.braces())), + fields: l(Punctuated::map_from_punctuated( + table_constructor.fields(), + Field::new, + )), + } + } +} + +impl UserData for TableConstructor { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("braces", |_, TableConstructor { braces, .. }| { + Ok(braces.clone()) + }); + + fields.add_field_method_get("fields", |_, TableConstructor { fields, .. }| { + Ok(fields.clone()) + }); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("TableConstructor", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for TableConstructor { + type Node = ast::TableConstructor; + + fn create_ast_node(&self) -> Option { + Some( + ast::TableConstructor::new() + .with_braces(self.braces.create_ast_node()?) + .with_fields(self.fields.create_ast_node()?), + ) + } +} + +impl AstToLua for ast::TableConstructor { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + TableConstructor::new(self).to_lua(lua) + } +} + +#[derive(Clone, LuaUserData)] +pub enum UnOp { + Minus(ArcLocked), + Not(ArcLocked), + Hash(ArcLocked), +} + +impl UnOp { + pub fn new(unop: &ast::UnOp) -> Self { + match unop { + ast::UnOp::Minus(token) => UnOp::Minus(l(TokenReference::new(token))), + ast::UnOp::Not(token) => UnOp::Not(l(TokenReference::new(token))), + ast::UnOp::Hash(token) => UnOp::Hash(l(TokenReference::new(token))), + other => panic!("unimplemented UnOp: {other:?}"), + } + } +} + +impl AstToLua for ast::UnOp { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + UnOp::new(self).to_lua(lua) + } +} + +#[derive(Clone, LuaUserData)] +pub enum Value { + #[lua( + create_ast_node = "ast::Value::Function((_0.create_ast_node()?, _1.create_ast_node()?))" + )] + Function(ArcLocked, ArcLocked), + FunctionCall(ArcLocked), + TableConstructor(ArcLocked), + Number(ArcLocked), + ParenthesesExpression(ArcLocked), + String(ArcLocked), + Symbol(ArcLocked), + Var(ArcLocked), + + NonStandard(Vec>), +} + +impl Value { + pub fn new(value: &ast::Value) -> Self { + match value { + ast::Value::Function((token_reference, function_body)) => Value::Function( + l(TokenReference::new(token_reference)), + l(FunctionBody::new(function_body)), + ), + + ast::Value::FunctionCall(function_call) => { + Value::FunctionCall(l(FunctionCall::new(function_call))) + } + + ast::Value::TableConstructor(table_constructor) => { + Value::TableConstructor(l(TableConstructor::new(table_constructor))) + } + + ast::Value::Number(number) => Value::Number(l(TokenReference::new(number))), + ast::Value::ParenthesesExpression(expression) => { + Value::ParenthesesExpression(l(Expression::new(expression))) + } + ast::Value::String(string) => Value::String(l(TokenReference::new(string))), + ast::Value::Symbol(symbol) => Value::Symbol(l(TokenReference::new(symbol))), + ast::Value::Var(var) => Value::Var(l(Var::new(var))), + + // TODO: implement everything, then `unimplemented!` + other => Value::NonStandard(other.tokens().map(TokenReference::new).map(l).collect()), + } + } +} + +impl AstToLua for ast::Value { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + Value::new(self).to_lua(lua) + } +} + +#[derive(Clone, LuaUserData)] +pub enum Var { + Expression(ArcLocked), + Name(ArcLocked), +} + +impl Var { + pub fn new(var: &ast::Var) -> Self { + match var { + ast::Var::Expression(expression) => Var::Expression(l(VarExpression::new(expression))), + ast::Var::Name(name_token) => Var::Name(l(TokenReference::new(name_token))), + other => unimplemented!("unexpected Var variant: {var:#?}"), + } + } +} + +impl AstToLua for ast::Var { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + Var::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct VarExpression { + prefix: ArcLocked, + suffixes: Vec>, +} + +impl VarExpression { + pub fn new(var_expression: &ast::VarExpression) -> Self { + VarExpression { + prefix: l(Prefix::new(var_expression.prefix())), + suffixes: var_expression.suffixes().map(Suffix::new).map(l).collect(), + } + } +} + +impl UserData for VarExpression { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("prefix", |_, VarExpression { prefix, .. }| { + Ok(prefix.clone()) + }); + + fields.add_field_method_get("suffixes", |_, VarExpression { suffixes, .. }| { + Ok(suffixes.clone()) + }); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("VarExpression", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for VarExpression { + type Node = ast::VarExpression; + + fn create_ast_node(&self) -> Option { + Some( + ast::VarExpression::new(self.prefix.create_ast_node()?).with_suffixes( + self.suffixes + .iter() + .map(|suffix| suffix.create_ast_node()) + .collect::>>()?, + ), + ) + } +} + +impl AstToLua for ast::VarExpression { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + VarExpression::new(self).to_lua(lua) + } +} + +#[derive(Clone)] +pub struct While { + while_token: ArcLocked, + condition: ArcLocked, + do_token: ArcLocked, + block: ArcLocked, + end_token: ArcLocked, +} + +impl While { + pub fn new(while_token: &ast::While) -> Self { + While { + while_token: l(TokenReference::new(while_token.while_token())), + condition: l(Expression::new(while_token.condition())), + do_token: l(TokenReference::new(while_token.do_token())), + block: l(Block::new(while_token.block())), + end_token: l(TokenReference::new(while_token.end_token())), + } + } +} + +impl UserData for While { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("while_token", |_, While { while_token, .. }| { + Ok(while_token.clone()) + }); + + fields.add_field_method_get("condition", |_, While { condition, .. }| { + Ok(condition.clone()) + }); + + fields.add_field_method_get("do_token", |_, While { do_token, .. }| Ok(do_token.clone())); + + fields.add_field_method_get("block", |_, While { block, .. }| Ok(block.clone())); + + fields.add_field_method_get("end_token", |_, While { end_token, .. }| { + Ok(end_token.clone()) + }); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("While", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for While { + type Node = ast::While; + + fn create_ast_node(&self) -> Option { + Some( + ast::While::new(self.condition.create_ast_node()?) + .with_block(self.block.create_ast_node()?) + .with_do_token(self.do_token.create_ast_node()?) + .with_end_token(self.end_token.create_ast_node()?) + .with_while_token(self.while_token.create_ast_node()?), + ) + } +} + +impl AstToLua for ast::While { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + While::new(self).to_lua(lua) + } +} diff --git a/full-moon-lua-types/src/either.rs b/full-moon-lua-types/src/either.rs new file mode 100644 index 00000000..32e9e6c3 --- /dev/null +++ b/full-moon-lua-types/src/either.rs @@ -0,0 +1,30 @@ +use mlua::FromLua; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Either { + A(A), + B(B), +} + +pub fn take_either<'lua, A: FromLua<'lua>, B: FromLua<'lua>>( + lua: &'lua mlua::Lua, + lua_value: mlua::Value<'lua>, + a_detail: &str, + b_detail: &str, +) -> mlua::Result> { + let type_name = lua_value.type_name(); + + // Values are cheap to clone, they're mostly just references + if let Ok(a) = A::from_lua(lua_value.clone(), lua) { + return Ok(Either::A(a)); + } + + if let Ok(b) = B::from_lua(lua_value, lua) { + return Ok(Either::B(b)); + } + + Err(mlua::Error::external(format!( + "expected either {a_detail} or {b_detail}, received {}", + type_name, + ))) +} diff --git a/full-moon-lua-types/src/extract_node.rs b/full-moon-lua-types/src/extract_node.rs new file mode 100644 index 00000000..f834a890 --- /dev/null +++ b/full-moon-lua-types/src/extract_node.rs @@ -0,0 +1,82 @@ +use crate::{core, mlua_util::ArcLocked, shared, CreateAstNode}; + +use full_moon::{ast, node::Node, tokenizer, visitors::Visit}; +use mlua::FromLua; + +macro_rules! all_nodes { + (pub enum AnyNode { + $( + $name:ident: $lua_type:ty => $ast_type:ty, + )+ + }) => { + pub enum AnyNode { + $( + $name($ast_type), + )+ + } + + impl<'lua> FromLua<'lua> for AnyNode { + fn from_lua(value: mlua::Value, _: &mlua::Lua) -> mlua::Result { + if let mlua::Value::UserData(user_data) = &value { + $( + if let Ok(lua_node) = user_data.borrow::<$lua_type>() { + return Ok(AnyNode::$name(match lua_node.create_ast_node() { + Some(ast_node) => ast_node, + None => return Err(mlua::Error::external(format!("could not convert {} to an AST node", stringify!($name)))), + })); + } + + if let Ok(lua_node) = user_data.borrow::>() { + return Ok(AnyNode::$name(match lua_node.create_ast_node() { + Some(ast_node) => ast_node, + None => return Err(mlua::Error::external(format!("could not convert {} to an AST node", stringify!($name)))), + })); + } + )+ + } + + Err(mlua::Error::external(format!("Expected a node, received {}", value.type_name()))) + } + } + }; +} + +all_nodes!(pub enum AnyNode { + Ast: core::Ast => ast::Ast, + Assignment: core::Assignment => ast::Assignment, + BinOp: core::BinOp => ast::BinOp, + Block: core::Block => ast::Block, + Call: core::Call => ast::Call, + Do: core::Do => ast::Do, + ElseIf: core::ElseIf => ast::ElseIf, + Expression: core::Expression => ast::Expression, + Field: core::Field => ast::Field, + FunctionArgs: core::FunctionArgs => ast::FunctionArgs, + FunctionBody: core::FunctionBody => ast::FunctionBody, + FunctionCall: core::FunctionCall => ast::FunctionCall, + FunctionDeclaration: core::FunctionDeclaration => ast::FunctionDeclaration, + FunctionName: core::FunctionName => ast::FunctionName, + GenericFor: core::GenericFor => ast::GenericFor, + If: core::If => ast::If, + Index: core::Index => ast::Index, + LastStmt: core::LastStmt => ast::LastStmt, + LocalAssignment: core::LocalAssignment => ast::LocalAssignment, + LocalFunction: core::LocalFunction => ast::LocalFunction, + MethodCall: core::MethodCall => ast::MethodCall, + NumericFor: core::NumericFor => ast::NumericFor, + Parameter: core::Parameter => ast::Parameter, + Prefix: core::Prefix => ast::Prefix, + Repeat: core::Repeat => ast::Repeat, + Return: core::Return => ast::Return, + Stmt: core::Stmt => ast::Stmt, + Suffix: core::Suffix => ast::Suffix, + TableConstructor: core::TableConstructor => ast::TableConstructor, + UnOp: core::UnOp => ast::UnOp, + Value: core::Value => ast::Value, + Var: core::Var => ast::Var, + VarExpression: core::VarExpression => ast::VarExpression, + While: core::While => ast::While, + + ContainedSpan: shared::ContainedSpan => ast::span::ContainedSpan, + TokenReference: shared::TokenReference => tokenizer::TokenReference, +}); diff --git a/full-moon-lua-types/src/lib.rs b/full-moon-lua-types/src/lib.rs new file mode 100644 index 00000000..6920d065 --- /dev/null +++ b/full-moon-lua-types/src/lib.rs @@ -0,0 +1,17 @@ +// fix this later? +#![allow(clippy::large_enum_variant)] + +mod ast_traits; +mod core; +mod either; +mod extract_node; +mod lua; +mod mlua_util; +mod prepare_for_lua; +mod shared; +mod visitor; + +pub use crate::core::Ast; +pub use ast_traits::*; +pub use extract_node::AnyNode; +pub use lua::*; diff --git a/full-moon-lua-types/src/lua.rs b/full-moon-lua-types/src/lua.rs new file mode 100644 index 00000000..4ad65367 --- /dev/null +++ b/full-moon-lua-types/src/lua.rs @@ -0,0 +1,32 @@ +use crate::core; + +pub fn create_lua() -> mlua::Result { + let lua = mlua::Lua::new(); + + assign_globals(&lua)?; + + Ok(lua) +} + +pub fn full_moon_table(lua: &mlua::Lua) -> mlua::Result { + let full_moon = lua.create_table()?; + + full_moon.set( + "parse", + lua.create_function(|_, code: String| { + let ast = full_moon::parse(&code).expect("NYI: Error on failure"); + + Ok(core::Ast::from(&ast)) + })?, + )?; + + Ok(full_moon) +} + +fn assign_globals(lua: &mlua::Lua) -> mlua::Result<()> { + let globals = lua.globals(); + + globals.set("full_moon", full_moon_table(lua)?)?; + + Ok(()) +} diff --git a/full-moon-lua-types/src/mlua_util.rs b/full-moon-lua-types/src/mlua_util.rs new file mode 100644 index 00000000..a1146eda --- /dev/null +++ b/full-moon-lua-types/src/mlua_util.rs @@ -0,0 +1,84 @@ +use std::sync::{Arc, RwLock}; + +use full_moon::node::Node; +use mlua::{ToLua, ToLuaMulti, UserData}; + +use crate::ast_traits::CreateAstNode; + +pub use crate::visitor::add_visit; + +pub type ArcLocked = Arc>; + +pub fn add_core_meta_methods<'lua, T: UserData>( + name: &'static str, + methods: &mut impl mlua::UserDataMethods<'lua, T>, +) { + add_to_string_display(name, methods); + add_newindex_block(name, methods); +} + +pub fn add_core_metamethods_no_tostring<'lua, T: UserData>( + name: &'static str, + methods: &mut impl mlua::UserDataMethods<'lua, T>, +) { + add_newindex_block(name, methods); +} + +pub fn add_create_ast_node_methods<'lua, T, N>(methods: &mut impl mlua::UserDataMethods<'lua, T>) +where + T: UserData + CreateAstNode, + N: Node, +{ + add_range(methods); +} + +pub fn add_range<'lua, T, N>(methods: &mut impl mlua::UserDataMethods<'lua, T>) +where + T: UserData + CreateAstNode, + N: Node, +{ + methods.add_method("range", |lua, this, ()| { + let node = this.create_ast_node().unwrap(); + + match node.range() { + Some((start, end)) => (start.bytes(), end.bytes()).to_lua_multi(lua), + None => mlua::Value::Nil.to_lua_multi(lua), + } + }); +} + +pub fn add_print<'lua, T, N>(methods: &mut impl mlua::UserDataMethods<'lua, T>) +where + T: UserData + CreateAstNode, + N: std::fmt::Display, +{ + methods.add_method("print", |lua, this, ()| match this.create_ast_node() { + Some(node) => node.to_string().to_lua(lua), + None => Ok(mlua::Value::Nil), + }); +} + +pub fn add_to_string_display<'lua, T: UserData>( + name: &'static str, + methods: &mut impl mlua::UserDataMethods<'lua, T>, +) { + methods.add_meta_method(mlua::MetaMethod::ToString, move |_, this, _: ()| { + Ok(format!("{name}({:x})", this as *const _ as usize)) + }); +} + +pub fn add_newindex_block<'lua, T: UserData>( + name: &'static str, + methods: &mut impl mlua::UserDataMethods<'lua, T>, +) { + methods.add_meta_method( + mlua::MetaMethod::NewIndex, + move |_, _, (_, _): (String, mlua::Value)| -> mlua::Result<()> { + // TODO: Detect if withKey exists, and suggest that + + Err(mlua::Error::RuntimeError(format!( + "can't mutate {name} directly", + ))) + }, + ); +} diff --git a/full-moon-lua-types/src/prepare_for_lua.rs b/full-moon-lua-types/src/prepare_for_lua.rs new file mode 100644 index 00000000..77ec4bad --- /dev/null +++ b/full-moon-lua-types/src/prepare_for_lua.rs @@ -0,0 +1,31 @@ +use std::sync::Arc; + +use mlua::{ToLua, UserData}; + +use crate::mlua_util::ArcLocked; + +pub trait PrepareForLua { + fn prepare_for_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result>; +} + +impl PrepareForLua for ArcLocked { + fn prepare_for_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + Arc::clone(self).to_lua(lua) + } +} + +impl PrepareForLua for Vec> { + fn prepare_for_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + self.iter() + .map(Arc::clone) + .map(|x| x.to_lua(lua)) + .collect::>>()? + .to_lua(lua) + } +} + +impl PrepareForLua for Box { + fn prepare_for_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + (**self).clone().to_lua(lua) + } +} diff --git a/full-moon-lua-types/src/shared.rs b/full-moon-lua-types/src/shared.rs new file mode 100644 index 00000000..3aab9ac4 --- /dev/null +++ b/full-moon-lua-types/src/shared.rs @@ -0,0 +1,235 @@ +use std::iter::FromIterator; + +use full_moon::{ + ast::{self, punctuated::Pair}, + node::Node, + tokenizer, +}; + +use mlua::{MetaMethod, ToLua, UserData}; + +use crate::{ + ast_traits::CreateAstNode, + mlua_util::{add_core_meta_methods, add_create_ast_node_methods, add_print}, + AstToLua, +}; + +#[derive(Clone)] +pub struct ContainedSpan { + start: TokenReference, + end: TokenReference, +} + +impl ContainedSpan { + pub fn new(contained_span: &ast::span::ContainedSpan) -> Self { + let (start, end) = contained_span.tokens(); + + ContainedSpan { + start: TokenReference::new(start), + end: TokenReference::new(end), + } + } +} + +impl UserData for ContainedSpan { + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("ContainedSpan", methods); + add_create_ast_node_methods(methods); + } +} + +impl CreateAstNode for ContainedSpan { + type Node = ast::span::ContainedSpan; + + fn create_ast_node(&self) -> Option { + Some(ast::span::ContainedSpan::new( + self.start.create_ast_node()?, + self.end.create_ast_node()?, + )) + } +} + +impl AstToLua for ast::span::ContainedSpan { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + ContainedSpan::new(self).to_lua(lua) + } +} + +#[derive(Clone, Copy)] +pub struct Position(tokenizer::Position); + +#[derive(Clone)] +pub struct Punctuated(ast::punctuated::Punctuated); + +impl Punctuated { + pub fn map_from_punctuated T>( + punctuated: &ast::punctuated::Punctuated, + mut map: F, + ) -> Self { + Punctuated(ast::punctuated::Punctuated::from_iter( + punctuated.pairs().map(|pair| match pair { + Pair::Punctuated(value, punctuation) => { + Pair::Punctuated(map(value), punctuation.clone()) + } + + Pair::End(value) => Pair::End(map(value)), + }), + )) + } +} + +impl Punctuated { + fn values<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + let table = lua.create_table()?; + + for (i, item) in self.0.iter().enumerate() { + table.set(i + 1, item.clone().to_lua(lua)?)?; + } + + Ok(table) + } +} + +impl FromIterator> for Punctuated { + fn from_iter>>(iter: I) -> Self { + Punctuated(ast::punctuated::Punctuated::from_iter(iter)) + } +} + +impl + Send + Sync + 'static> UserData + for Punctuated +{ + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(_fields: &mut F) {} + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_meta_method(MetaMethod::Iter, |lua, this, _: ()| { + Ok(( + lua.globals().get::<_, mlua::Function>("next")?, + this.values(lua)?, + )) + }); + + methods.add_meta_method(MetaMethod::Len, |_, Punctuated(punctuated), _: ()| { + Ok(punctuated.len()) + }); + + methods.add_method("values", |lua, this, _: ()| { + this.values(lua).map_err(mlua::Error::external) + }); + + add_create_ast_node_methods(methods); + } +} + +impl CreateAstNode for Punctuated { + type Node = ast::punctuated::Punctuated; + + fn create_ast_node(&self) -> Option { + Some(ast::punctuated::Punctuated::from_iter( + self.0 + .pairs() + .map(|pair| { + Some(match pair { + Pair::Punctuated(value, punctuation) => { + Pair::Punctuated(value.create_ast_node()?, punctuation.clone()) + } + + Pair::End(value) => Pair::End(value.create_ast_node()?), + }) + }) + .collect::>>()?, + )) + } +} + +#[derive(Clone)] +pub struct TokenType(tokenizer::TokenType); + +#[derive(Clone)] +pub struct Token { + start_position: Position, + end_position: Position, + token_type: TokenType, +} + +impl From<&tokenizer::Token> for Token { + fn from(token: &tokenizer::Token) -> Self { + Token { + start_position: Position(token.start_position()), + end_position: Position(token.end_position()), + token_type: TokenType(token.token_type().clone()), + } + } +} + +impl CreateAstNode for Token { + type Node = tokenizer::Token; + + fn create_ast_node(&self) -> Option { + Some( + tokenizer::Token::new(self.token_type.0.clone()) + .with_start_position(self.start_position.0) + .with_end_position(self.end_position.0), + ) + } +} + +impl UserData for Token { + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("Token", methods); + add_print(methods); + } +} + +#[derive(Clone)] +pub struct TokenReference { + leading_trivia: Vec, + token: Token, + trailing_trivia: Vec, +} + +impl TokenReference { + pub fn new(token_reference: &tokenizer::TokenReference) -> Self { + TokenReference { + leading_trivia: token_reference.leading_trivia().map(Token::from).collect(), + token: Token::from(token_reference.token()), + trailing_trivia: token_reference.trailing_trivia().map(Token::from).collect(), + } + } +} + +impl UserData for TokenReference { + fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("token", |lua, this: &Self| this.token.clone().to_lua(lua)); + } + + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + add_core_meta_methods("TokenReference", methods); + add_create_ast_node_methods(methods); + add_print(methods); + } +} + +impl CreateAstNode for TokenReference { + type Node = tokenizer::TokenReference; + + fn create_ast_node(&self) -> Option { + Some(tokenizer::TokenReference::new( + self.leading_trivia + .iter() + .map(Token::create_ast_node) + .collect::>>()?, + self.token.create_ast_node()?, + self.trailing_trivia + .iter() + .map(Token::create_ast_node) + .collect::>>()?, + )) + } +} + +impl AstToLua for tokenizer::TokenReference { + fn ast_to_lua<'lua>(&self, lua: &'lua mlua::Lua) -> mlua::Result> { + TokenReference::new(self).to_lua(lua) + } +} diff --git a/full-moon-lua-types/src/visitor.rs b/full-moon-lua-types/src/visitor.rs new file mode 100644 index 00000000..66ed1fd7 --- /dev/null +++ b/full-moon-lua-types/src/visitor.rs @@ -0,0 +1,242 @@ +use crate::{ast_traits::CreateAstNode, core, shared}; +use full_moon::{ + ast, + node::Node, + tokenizer, + visitors::{Visit, Visitor}, +}; +use mlua::UserData; +use paste::paste; + +macro_rules! create_visitor { + ( + ast: { + $( + ($name:ident: $type:ty, $converter:expr), + )+ + }, + + tokens: { + $( + $token_name:ident, + )+ + } + ) => { + paste! { + #[derive(Clone, Debug, Default)] + pub struct VisitorTable<'lua> { + $( + $name: Option>, + [< $name _end >]: Option>, + )+ + + $( + $token_name: Option>, + )+ + } + + impl<'lua> mlua::FromLua<'lua> for VisitorTable<'lua> { + fn from_lua(value: mlua::Value<'lua>, _: &'lua mlua::Lua) -> mlua::Result { + let mut visitor_table = VisitorTable::default(); + + let table = match value { + mlua::Value::Table(table) => table, + _ => return Err(mlua::Error::external(format!(":visit() expects a table, received {}", value.type_name()))), + }; + + for pair in table.pairs::>() { + let (key, value) = pair?; + + // TODO: When validating names, have a list of names for 5.2/5.3/Luau only that won't error when the feature is not enabled + $( + let pascal_cased_name = pascal_case_name(stringify!($name)); + + if key == pascal_cased_name { + visitor_table.$name = Some(value); + continue; + } else if key == format!("{pascal_cased_name}End") { + visitor_table.[< $name _end >] = Some(value); + continue; + } + )+ + + $( + let pascal_cased_name = pascal_case_name(stringify!($token_name)); + + if key == pascal_cased_name { + visitor_table.$token_name = Some(value); + continue; + } + )+ + + return Err(mlua::Error::external(format!(":visit() received an unknown key {}", key.to_string_lossy()))); + } + + Ok(visitor_table) + } + } + + pub struct LuaVisitor<'lua> { + existing_error: Option, + visitor_table: VisitorTable<'lua>, + } + + impl<'lua> LuaVisitor<'lua> { + fn ok(self) -> mlua::Result<()> { + match self.existing_error { + Some(error) => Err(error), + None => Ok(()), + } + } + } + + impl<'lua> Visitor for LuaVisitor<'lua> { + $( + fn $name(&mut self, node: &$type) { + if self.existing_error.is_some() { + return; + } + + if let Some(function) = &self.visitor_table.$name { + if let Err(error) = function.call::<_, ()>($converter(node)) { + self.existing_error = Some(error); + } + } + } + + fn [<$name _end>](&mut self, node: &$type) { + if self.existing_error.is_some() { + return; + } + + if let Some(function) = &self.visitor_table.[< $name _end >] { + if let Err(error) = function.call::<_, ()>($converter(node)) { + self.existing_error = Some(error); + } + } + } + )+ + + $( + fn $token_name(&mut self, token: &tokenizer::Token) { + if self.existing_error.is_some() { + return; + } + + if let Some(function) = &self.visitor_table.$token_name { + if let Err(error) = function.call::<_, ()>(shared::Token::from(token)) { + self.existing_error = Some(error); + } + } + } + )+ + } + } + }; +} + +create_visitor!(ast: { + (visit_anonymous_call: ast::FunctionArgs, core::FunctionArgs::new), + (visit_assignment: ast::Assignment, core::Assignment::new), + (visit_block: ast::Block, core::Block::new), + (visit_call: ast::Call, core::Call::new), + (visit_contained_span: ast::span::ContainedSpan, shared::ContainedSpan::new), + (visit_do: ast::Do, core::Do::new), + (visit_else_if: ast::ElseIf, core::ElseIf::new), + (visit_eof: tokenizer::TokenReference, shared::TokenReference::new), + (visit_expression: ast::Expression, core::Expression::new), + (visit_field: ast::Field, core::Field::new), + (visit_function_args: ast::FunctionArgs, core::FunctionArgs::new), + (visit_function_body: ast::FunctionBody, core::FunctionBody::new), + (visit_function_call: ast::FunctionCall, core::FunctionCall::new), + (visit_function_declaration: ast::FunctionDeclaration, core::FunctionDeclaration::new), + (visit_function_name: ast::FunctionName, core::FunctionName::new), + (visit_generic_for: ast::GenericFor, core::GenericFor::new), + (visit_if: ast::If, core::If::new), + (visit_index: ast::Index, core::Index::new), + (visit_local_assignment: ast::LocalAssignment, core::LocalAssignment::new), + (visit_local_function: ast::LocalFunction, core::LocalFunction::new), + (visit_last_stmt: ast::LastStmt, core::LastStmt::new), + (visit_method_call: ast::MethodCall, core::MethodCall::new), + (visit_numeric_for: ast::NumericFor, core::NumericFor::new), + (visit_parameter: ast::Parameter, core::Parameter::new), + (visit_prefix: ast::Prefix, core::Prefix::new), + (visit_return: ast::Return, core::Return::new), + (visit_repeat: ast::Repeat, core::Repeat::new), + (visit_stmt: ast::Stmt, core::Stmt::new), + (visit_suffix: ast::Suffix, core::Suffix::new), + (visit_table_constructor: ast::TableConstructor, core::TableConstructor::new), + (visit_token_reference: tokenizer::TokenReference, shared::TokenReference::new), + (visit_un_op: ast::UnOp, core::UnOp::new), + (visit_value: ast::Value, core::Value::new), + (visit_var: ast::Var, core::Var::new), + (visit_var_expression: ast::VarExpression, core::VarExpression::new), + (visit_while: ast::While, core::While::new), +}, tokens: { + visit_identifier, + visit_multi_line_comment, + visit_number, + visit_single_line_comment, + visit_string_literal, + visit_symbol, + visit_token, + visit_whitespace, +}); + +fn pascal_case_name(name: &str) -> String { + let mut pascal_case_name = String::new(); + + let mut should_capitalize = true; + for character in name.chars().skip("visit_".len()) { + if should_capitalize { + pascal_case_name.push(character.to_ascii_uppercase()); + should_capitalize = false; + } else if character == '_' { + should_capitalize = true; + } else { + pascal_case_name.push(character); + } + } + + pascal_case_name +} + +pub fn add_visit<'lua, T, N>(methods: &mut impl mlua::UserDataMethods<'lua, T>) +where + T: UserData + CreateAstNode, + N: Visit, +{ + methods.add_method_mut("visit", |_, this, visitor: VisitorTable| { + let mut visitor = LuaVisitor { + existing_error: None, + visitor_table: visitor, + }; + + if let Some(ast_node) = this.create_ast_node() { + ast_node.visit(&mut visitor); + } + + visitor.ok() + }); +} + +pub fn add_visit_with_visitor<'lua, T, N, F>( + methods: &mut impl mlua::UserDataMethods<'lua, T>, + mut callback: F, +) where + T: UserData + Send + Sync + CreateAstNode, + F: 'static + Send + FnMut(N, &mut LuaVisitor<'lua>), +{ + methods.add_method_mut("visit", move |_, this, visitor: VisitorTable| { + let mut visitor = LuaVisitor { + existing_error: None, + visitor_table: visitor, + }; + + if let Some(ast_node) = this.create_ast_node() { + callback(ast_node, &mut visitor); + } + + visitor.ok() + }); +} diff --git a/full-moon-lua-types/tests/lua/core.lua b/full-moon-lua-types/tests/lua/core.lua new file mode 100644 index 00000000..00dbf6fb --- /dev/null +++ b/full-moon-lua-types/tests/lua/core.lua @@ -0,0 +1,94 @@ +local function assertEq(x, y) + if x == y then + return + end + + error(("%s ~= %s"):format(tostring(x), tostring(y))) +end + +local ast = full_moon.parse("x, y = 1, 2") +assert(not pcall(function() + ast.nodes = {} +end), "expected ast.nodes to be read-only") + +local stmt = ast.nodes.stmts[1] + +assertEq(tostring(stmt), "Stmt::Assignment") +assertEq(stmt.kind, "Assignment") + +assertEq( + stmt:match({ + Assignment = function(assignment) + local saysAssignment = tostring(assignment):match("^[A-Za-z]+") + + return assert(saysAssignment, "tostring(assignment) didn't match (" .. tostring(assignment) .. ")") + end, + }), + + "Assignment" +) + +local assignments = {} +local assignmentEnds = {} + +ast:visit({ + Assignment = function(assignment) + table.insert(assignments, assignment:print()) + end, + + AssignmentEnd = function(assignmentEnd) + table.insert(assignmentEnds, assignmentEnd:print()) + end, +}) + +assert(not pcall(function() + ast:visit({ + Funky = function() end, + }) +end), "expected :visit to not allow invalid names") + +assert(not pcall(function() + ast:visit({ + Assignment = 3, + }) +end), "expected :visit to not allow invalid values") + +assertEq(#assignments, 1) +assertEq(assignments[1], "x, y = 1, 2") + +assertEq(#assignmentEnds, 1) +assertEq(assignmentEnds[1], "x, y = 1, 2") + +-- Test non-AST visiting +local numbers = {} + +stmt:visit({ + Number = function(token) + table.insert(numbers, token:print()) + end, +}) + +assertEq(#numbers, 2) +assertEq(numbers[1], "1") +assertEq(numbers[2], "2") + +assert(not pcall(function() + return stmt:expect("While") +end), "stmt:expect should have thrown") + +local assignment = stmt:expect("Assignment") +assertEq(#assignment.variables, 2) +assertEq(#assignment.variables:values(), 2) + +assertEq(tostring(stmt:match("Assignment")), tostring(assignment)) +assertEq(stmt:match("While"), nil) + +local iters = {} + +for i, v in assignment.variables do + iters[i] = v:print() +end + +assertEq(#iters, 2) +assertEq(iters[1], "x") +assertEq(iters[2], "y ") diff --git a/full-moon-lua-types/tests/lua_tests.rs b/full-moon-lua-types/tests/lua_tests.rs new file mode 100644 index 00000000..a6835858 --- /dev/null +++ b/full-moon-lua-types/tests/lua_tests.rs @@ -0,0 +1,12 @@ +fn test_lua_code(code: &str) { + let lua = full_moon_lua_types::create_lua().expect("can't create lua"); + + if let Err(error) = lua.load(code).exec() { + panic!("lua error:\n{error}"); + } +} + +#[test] +fn core() { + test_lua_code(include_str!("lua/core.lua")); +} diff --git a/full-moon/src/tokenizer.rs b/full-moon/src/tokenizer.rs index 184ee6f2..f2541fbb 100644 --- a/full-moon/src/tokenizer.rs +++ b/full-moon/src/tokenizer.rs @@ -595,6 +595,16 @@ impl Token { pub fn token_kind(&self) -> TokenKind { self.token_type().kind() } + + pub fn with_start_position(mut self, start_position: Position) -> Self { + self.start_position = start_position; + self + } + + pub fn with_end_position(mut self, end_position: Position) -> Self { + self.end_position = end_position; + self + } } impl fmt::Display for Token {