This repository has been archived on 2024-12-20. You can view files and clone it, but cannot push or open issues or pull requests.
YanagiZS/crates/yanagi-proto/build.rs
2024-11-25 02:17:19 +03:00

375 lines
13 KiB
Rust

use std::{
fs::{self, read_to_string, File},
io::{BufRead, BufReader},
path::Path,
};
use quote::quote;
use syn::{Field, Ident, Item, Type, TypePath};
pub fn main() {
let proto_file = "nap.proto";
if Path::new(&proto_file).exists() {
println!("cargo:rerun-if-changed={proto_file}");
let _ = fs::create_dir("out/");
prost_build::Config::new()
.out_dir("out/")
.type_attribute(".", "#[derive(yanagi_proto_derive::CmdID)]")
.type_attribute(".", "#[derive(yanagi_proto_derive::XorFields)]")
.compile_protos(&[proto_file], &["."])
.unwrap();
apply_xor(Path::new("out/_.rs")).unwrap();
}
impl_struct_conversions(
Path::new("out/_.rs"),
Path::new("../protocol/src/lib.rs"),
Path::new("out/proto_conversion.rs"),
)
.unwrap();
}
#[must_use]
fn field_is_optional(field: &Field) -> bool {
match &field.ty {
Type::Path(TypePath { path, .. }) => {
let last_segment = path.segments.last().unwrap();
last_segment.ident == "Option"
}
_ => panic!("Unsupported field type"),
}
}
#[must_use]
fn field_is_vector(field: &Field) -> bool {
match &field.ty {
Type::Path(TypePath { path, .. }) => {
let last_segment = path.segments.last().unwrap();
last_segment.ident == "Vec"
}
_ => panic!("Unsupported field type"),
}
}
#[must_use]
fn field_is_hash_map(field: &Field) -> bool {
match &field.ty {
Type::Path(TypePath { path, .. }) => {
let last_segment = path.segments.last().unwrap();
last_segment.ident == "HashMap"
}
_ => panic!("Unsupported field type"),
}
}
fn impl_struct_conversions(
proto_codegen_path: &Path,
qwer_structures_path: &Path,
output_path: &Path,
) -> std::io::Result<()> {
let proto_file = read_to_string(proto_codegen_path)?;
let qwer_file = read_to_string(qwer_structures_path)?;
let proto_file = syn::parse_file(&proto_file).unwrap();
let qwer_file = syn::parse_file(&qwer_file).unwrap();
let mut from_impls = quote! {};
let mut proto_to_qwer_matches = quote! {};
let mut qwer_to_proto_matches = quote! {};
let mut ptc_registers = quote! {};
let mut ptc_matches = quote! {};
for item in proto_file.items.iter() {
let Item::Struct(proto) = item else {
continue;
};
let proto_ident = &proto.ident;
let qwer_ident = match proto_ident.to_string() {
s if s.ends_with("ScRsp") => Ident::new(
&format!("Rpc{}Ret", &s.as_str()[..s.len() - 5]),
proto_ident.span(),
),
s if s.ends_with("CsReq") => Ident::new(
&format!("Rpc{}Arg", &s.as_str()[..s.len() - 5]),
proto_ident.span(),
),
s if s.ends_with("ScNotify") => Ident::new(
&format!("Ptc{}Arg", &s.as_str()[..s.len() - 8]),
proto_ident.span(),
),
s if s.ends_with("Notify") => Ident::new(
&format!("Ptc{}Arg", &s.as_str()[..s.len() - 6]),
proto_ident.span(),
),
_ => proto_ident.clone(),
};
let Some(Item::Struct(qwer)) = qwer_file.items.iter().find(|i| {
if let Item::Struct(s) = i {
s.ident == qwer_ident
} else {
false
}
}) else {
continue;
};
if let Some(req_base_name) = proto_ident.to_string().strip_suffix("CsReq") {
if proto.attrs.iter().any(|attr| {
attr.path()
.get_ident()
.map(|i| i == "cmdid")
.unwrap_or(false)
}) {
let rpc_ret_ident =
Ident::new(&format!("Rpc{req_base_name}Ret"), proto_ident.span());
let proto_rsp_name = format!("{req_base_name}ScRsp");
// Check if rsp with this name defined in proto
// if it doesn't, send null message.
let proto_rsp_ident = proto_file
.items
.iter()
.any(|i| {
if let Item::Struct(s) = i {
s.ident == proto_rsp_name
} else {
false
}
})
.then_some(Ident::new(&proto_rsp_name, proto_ident.span()));
match proto_rsp_ident {
Some(proto_rsp_ident) => {
proto_to_qwer_matches = quote! {
#proto_to_qwer_matches
#proto_ident::CMD_ID => {
let packet = NetPacket::<::yanagi_proto::#proto_ident>::decode($buf)?;
let rpc_arg = ::protocol::#qwer_ident::from(packet.body);
let rpc_ret: ::protocol::#rpc_ret_ident = $point.call_rpc($addr, rpc_arg, $middlewares, $timeout).await?;
let proto_rsp = ::yanagi_proto::#proto_rsp_ident::from(rpc_ret);
$session.send_rsp(packet.head.packet_id, proto_rsp);
},
}
}
None => {
proto_to_qwer_matches = quote! {
#proto_to_qwer_matches
#proto_ident::CMD_ID => {
let packet = NetPacket::<::yanagi_proto::#proto_ident>::decode($buf)?;
let rpc_arg = ::protocol::#qwer_ident::from(packet.body);
let rpc_ret: ::protocol::#rpc_ret_ident = $point.call_rpc($addr, rpc_arg, $middlewares, $timeout).await?;
$session.send_null_rsp(packet.head.packet_id);
},
}
}
}
}
} else if let Some(_ntf_base_name) = proto_ident.to_string().strip_suffix("ScNotify") {
ptc_registers = quote! {
#ptc_registers
$point.register_rpc_recv(
::protocol::#qwer_ident::PROTOCOL_ID,
move |ctx| async move {
let _ = $tx.get().unwrap().send(Input::Notify($conv, ctx)).await;
}
);
};
ptc_matches = quote! {
#ptc_matches
::protocol::#qwer_ident::PROTOCOL_ID => {
$session.notify(::yanagi_proto::#proto_ident::from(
$ctx.get_arg::<::protocol::#qwer_ident>().unwrap(),
));
},
};
}
if qwer
.attrs
.iter()
.any(|attr| attr.path().get_ident().map(|i| i == "id").unwrap_or(false))
{
qwer_to_proto_matches = quote! {
#qwer_to_proto_matches
::protocol::#qwer_ident::PROTOCOL_ID => $process_proto_message(#proto_ident::from(qwer)),
};
}
let mut proto_assignments = quote! {};
let mut qwer_assignments = quote! {};
proto
.fields
.iter()
.filter(|f| qwer.fields.iter().any(|ff| ff.ident == f.ident))
.map(|f| (f, f.ident.as_ref().unwrap()))
.for_each(|(f, ident)| {
let qfield = qwer.fields.iter().find(|ff| ff.ident == f.ident).unwrap();
if field_is_optional(&f) && field_is_optional(&qfield) {
qwer_assignments = quote! {
#qwer_assignments
#ident: value.#ident.map(|v| v.into()),
};
proto_assignments = quote! {
#proto_assignments
#ident: value.#ident.map(|v| v.into()),
};
} else if field_is_optional(&f) {
qwer_assignments = quote! {
#qwer_assignments
#ident: value.#ident.map(|v| v.into()).unwrap_or_default(),
};
proto_assignments = quote! {
#proto_assignments
#ident: Some(value.#ident.into()),
};
} else if field_is_vector(&f) {
qwer_assignments = quote! {
#qwer_assignments
#ident: value.#ident.into_iter().map(|v| v.into()).collect(),
};
proto_assignments = quote! {
#proto_assignments
#ident: value.#ident.into_iter().map(|v| v.into()).collect(),
};
} else if field_is_hash_map(&f) {
qwer_assignments = quote! {
#qwer_assignments
#ident: value.#ident.into_iter().map(|(k, v)| (k.into(), v.into())).collect(),
};
proto_assignments = quote! {
#proto_assignments
#ident: value.#ident.into_iter().map(|(k, v)| (k.into(), v.into())).collect(),
}
} else {
qwer_assignments = quote! {
#qwer_assignments
#ident: value.#ident.into(),
};
proto_assignments = quote! {
#proto_assignments
#ident: value.#ident.into(),
};
}
});
let from_qwer = quote! {
#[allow(unused)]
impl From<::protocol::#qwer_ident> for #proto_ident {
fn from(value: ::protocol::#qwer_ident) -> Self {
Self {
#proto_assignments
..Default::default()
}
}
}
};
let from_proto = quote! {
#[allow(unused)]
impl From<#proto_ident> for ::protocol::#qwer_ident {
fn from(value: #proto_ident) -> Self {
Self {
#qwer_assignments
..Default::default()
}
}
}
};
from_impls = quote! {
#from_impls
#from_qwer
#from_proto
};
}
let generated_code = quote! {
#from_impls
#[macro_export]
macro_rules! decode_and_forward_proto {
($cmd_id:expr, $buf:expr, $session:expr, $point:expr, $addr:expr, $middlewares:expr, $timeout:expr) => {
match $cmd_id {
#proto_to_qwer_matches
_ => ::tracing::warn!("unknown cmd_id: {}", $cmd_id),
}
}
}
#[macro_export]
macro_rules! impl_qwer_to_proto_match {
($process_proto_message:ident) => {
match qwer.get_protocol_id() {
#qwer_to_proto_matches
_ => (),
}
}
}
#[macro_export]
macro_rules! register_ptc_handlers {
($point:ident, $conv:ident, $tx:ident) => {
#ptc_registers
}
}
#[macro_export]
macro_rules! forward_as_notify {
($session:ident, $ctx:ident) => {
match $ctx.protocol_id {
#ptc_matches
_ => (),
}
}
}
};
let ast = syn::parse2(generated_code).unwrap();
fs::write(output_path, prettyplease::unparse(&ast))
}
fn apply_xor(path: &Path) -> std::io::Result<()> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let mut output = Vec::new();
let mut cmd_id_attr = None;
for line in reader.lines() {
let line = line?;
if line.contains("xor const: ") {
if !line.contains("xor const: 0") {
output.push(make_xor_attr(&line).unwrap());
}
} else if line.contains("CmdID: ") {
cmd_id_attr = Some(make_cmd_id_attr(&line).unwrap());
} else {
output.push(line);
if let Some(attr) = cmd_id_attr.take() {
output.push(attr);
}
}
}
fs::write(path, output.join("\n").as_bytes())?;
Ok(())
}
fn make_xor_attr(line: &str) -> Option<String> {
let xor_value = line.split("xor const: ").nth(1)?.parse::<u32>().ok()?;
Some(format!(" #[xor({xor_value})]"))
}
fn make_cmd_id_attr(line: &str) -> Option<String> {
let cmd_id = line.split("CmdID: ").nth(1)?.parse::<u16>().ok()?;
Some(format!("#[cmdid({cmd_id})]"))
}