Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support splitting a single item into multiple files #270

Merged
merged 10 commits into from
Oct 14, 2024
78 changes: 73 additions & 5 deletions pilota-build/src/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
};

use ahash::AHashMap;
use dashmap::DashMap;
use dashmap::{mapref::one::RefMut, DashMap};
use faststr::FastStr;
use itertools::Itertools;
use normpath::PathExt;
Expand All @@ -25,6 +25,7 @@ use crate::{
context::{tls::CUR_ITEM, Mode},
rir,
},
rir::{Item, NodeKind},
symbol::{DefId, EnumRepr, FileId},
Context, Symbol,
};
Expand Down Expand Up @@ -447,8 +448,12 @@ where
ws.write_crates()
}

pub fn write_items(&self, stream: &mut String, items: impl Iterator<Item = CodegenItem>)
where
pub fn write_items(
&self,
stream: &mut String,
items: impl Iterator<Item = CodegenItem>,
base_dir: &Path,
) where
B: Send,
{
let mods = items.into_group_map_by(|CodegenItem { def_id, .. }| {
Expand All @@ -473,8 +478,13 @@ where

let _enter = span.enter();
let mut dup = AHashMap::default();
for def_id in def_ids.iter() {
this.write_item(&mut stream, *def_id, &mut dup)

if this.split {
Self::write_split_mod(this, base_dir, p, def_ids, &mut stream, &mut dup);
} else {
for def_id in def_ids.iter() {
this.write_item(&mut stream, *def_id, &mut dup)
}
}
});

Expand Down Expand Up @@ -514,11 +524,69 @@ where
write_stream(&mut pkgs, stream, &pkg_node);
}

fn write_split_mod(
this: &mut Codegen<B>,
base_dir: &Path,
p: &Arc<[FastStr]>,
def_ids: &Vec<CodegenItem>,
stream: &mut RefMut<Arc<[FastStr]>, String>,
mut dup: &mut AHashMap<FastStr, Vec<DefId>>,
) {
let base_mod_name = p.iter().map(|s| s.to_string()).join("/");
let mod_file_name = format!("{}/mod.rs", base_mod_name);
let mut mod_stream = String::new();

for def_id in def_ids.iter() {
let mut item_stream = String::new();
let node = this.db.node(def_id.def_id).unwrap();
let name_prefix = match node.kind {
NodeKind::Item(ref item) => match item.as_ref() {
Item::Message(_) => "message",
Item::Enum(_) => "enum",
Item::Service(_) => "service",
Item::NewType(_) => "new_type",
Item::Const(_) => "const",
Item::Mod(_) => "mod",
},
NodeKind::Variant(_) => "variant",
NodeKind::Field(_) => "field",
NodeKind::Method(_) => "method",
NodeKind::Arg(_) => "arg",
};

let mod_dir = base_dir.join(base_mod_name.clone());

let file_name = format!("{}_{}.rs", name_prefix, node.name());
this.write_item(&mut item_stream, *def_id, &mut dup);

let full_path = mod_dir.join(file_name.clone());
std::fs::create_dir_all(mod_dir).unwrap();

let mut file =
std::io::BufWriter::new(std::fs::File::create(full_path.clone()).unwrap());
file.write_all(item_stream.as_bytes()).unwrap();
file.flush().unwrap();
fmt_file(full_path);

mod_stream.push_str(format!("include!(\"{}\");\n", file_name).as_str());
}

let mod_path = base_dir.join(&mod_file_name);
let mut mod_file = std::io::BufWriter::new(std::fs::File::create(&mod_path).unwrap());
mod_file.write_all(mod_stream.as_bytes()).unwrap();
mod_file.flush().unwrap();
fmt_file(&mod_path);

stream.push_str(format!("include!(\"{}\");\n", mod_file_name).as_str());
}

pub fn write_file(self, ns_name: Symbol, file_name: impl AsRef<Path>) {
let base_dir = file_name.as_ref().parent().unwrap();
let mut stream = String::default();
self.write_items(
&mut stream,
self.codegen_items.iter().map(|def_id| (*def_id).into()),
base_dir,
);

stream = format! {r#"pub mod {ns_name} {{
Expand Down
4 changes: 4 additions & 0 deletions pilota-build/src/codegen/workspace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ where
None
}
})
.sorted()
.join(",\n");

let mut cargo_toml = toml::from_str::<toml::Value>(&unsafe {
Expand Down Expand Up @@ -184,6 +185,8 @@ where
Command::new("cargo")
.arg("init")
.arg("--lib")
.arg("--vcs")
.arg("none")
.current_dir(base_dir.as_ref())
.arg(&*info.name),
)?;
Expand Down Expand Up @@ -246,6 +249,7 @@ where
def_id,
kind: super::CodegenKind::RePub,
})),
base_dir.as_ref().join(&*info.name).join("src").as_path(),
);
if let Some(main_mod_path) = info.main_mod_path {
gen_rs_stream.push_str(&format!(
Expand Down
2 changes: 1 addition & 1 deletion pilota-build/src/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ pub fn fmt_file<P: AsRef<Path>>(file: P) {
Err(e) => eprintln!("{}", e),
Ok(output) => {
if !output.status.success() {
eprintln!("rustfmt failed to format {}", file.display());
std::io::stderr().write_all(&output.stderr).unwrap();
exit(output.status.code().unwrap_or(1))
}
}
}
Expand Down
13 changes: 13 additions & 0 deletions pilota-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ pub struct Builder<MkB, P> {
parser: P,
plugins: Vec<Box<dyn Plugin>>,
ignore_unused: bool,
split: bool,
touches: Vec<(std::path::PathBuf, Vec<String>)>,
change_case: bool,
keep_unknown_fields: Vec<std::path::PathBuf>,
Expand All @@ -103,6 +104,7 @@ impl Builder<MkThriftBackend, ThriftParser> {
dedups: Vec::default(),
special_namings: Vec::default(),
common_crate_name: "common".into(),
split: false,
}
}
}
Expand All @@ -124,6 +126,7 @@ impl Builder<MkProtobufBackend, ProtobufParser> {
dedups: Vec::default(),
special_namings: Vec::default(),
common_crate_name: "common".into(),
split: false,
}
}
}
Expand Down Expand Up @@ -152,6 +155,7 @@ impl<MkB, P> Builder<MkB, P> {
dedups: self.dedups,
special_namings: self.special_namings,
common_crate_name: self.common_crate_name,
split: self.split,
}
}

Expand All @@ -161,6 +165,11 @@ impl<MkB, P> Builder<MkB, P> {
self
}

pub fn split_generated_files(mut self, split: bool) -> Self {
self.split = split;
self
}

pub fn change_case(mut self, change_case: bool) -> Self {
self.change_case = change_case;
self
Expand Down Expand Up @@ -266,6 +275,7 @@ where
dedups: Vec<FastStr>,
special_namings: Vec<FastStr>,
common_crate_name: FastStr,
split: bool,
) -> Context {
let mut db = RootDatabase::default();
parser.inputs(services.iter().map(|s| &s.path));
Expand Down Expand Up @@ -341,6 +351,7 @@ where
dedups,
special_namings,
common_crate_name,
split,
)
}

Expand All @@ -359,6 +370,7 @@ where
self.dedups,
self.special_namings,
self.common_crate_name,
self.split,
);

cx.exec_plugin(BoxedPlugin);
Expand Down Expand Up @@ -441,6 +453,7 @@ where
self.dedups,
self.special_namings,
self.common_crate_name,
self.split,
);

std::thread::scope(|_scope| {
Expand Down
4 changes: 4 additions & 0 deletions pilota-build/src/middle/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ pub struct Context {
pub(crate) codegen_items: Arc<[DefId]>,
pub(crate) path_resolver: Arc<dyn PathResolver>,
pub(crate) mode: Arc<Mode>,
pub(crate) split: bool,
pub(crate) keep_unknown_fields: Arc<FxHashSet<DefId>>,
pub location_map: Arc<FxHashMap<DefId, DefLocation>>,
pub entry_map: Arc<HashMap<DefLocation, Vec<(DefId, DefLocation)>>>,
Expand All @@ -86,6 +87,7 @@ impl Clone for Context {
codegen_items: self.codegen_items.clone(),
path_resolver: self.path_resolver.clone(),
mode: self.mode.clone(),
split: self.split,
services: self.services.clone(),
keep_unknown_fields: self.keep_unknown_fields.clone(),
location_map: self.location_map.clone(),
Expand Down Expand Up @@ -327,6 +329,7 @@ impl ContextBuilder {
dedups: Vec<FastStr>,
special_namings: Vec<FastStr>,
common_crate_name: FastStr,
split: bool,
) -> Context {
SPECIAL_NAMINGS.get_or_init(|| special_namings);
let mut cx = Context {
Expand All @@ -341,6 +344,7 @@ impl ContextBuilder {
Mode::SingleFile { .. } => Arc::new(DefaultPathResolver),
},
mode: Arc::new(self.mode),
split,
keep_unknown_fields: Arc::new(self.keep_unknown_fields),
location_map: Arc::new(self.location_map),
entry_map: Arc::new(self.entry_map),
Expand Down
Loading
Loading