trigger-rs/crates/trigger-protobuf/build.rs
2025-02-21 14:03:43 +03:00

385 lines
12 KiB
Rust

use std::{
collections::HashSet,
fs::{self, File},
io::{BufRead, BufReader},
path::Path,
};
use quote::{quote, ToTokens};
use syn::{Field, GenericArgument, Item, PathArguments, Type, TypePath};
fn main() {
let proto_file = "nap.proto";
if Path::new(&proto_file).exists() {
println!("cargo:rerun-if-changed={proto_file}");
let _ = fs::create_dir("out/");
prost_build::Config::new()
.out_dir("out/")
.type_attribute(".", "#[derive(trigger_protobuf_derive::CmdID)]")
.type_attribute(".", "#[derive(trigger_protobuf_derive::XorFields)]")
.compile_protos(&[proto_file], &["."])
.unwrap();
let pb_out = Path::new("out/_.rs");
remove_obfuscated_messages(pb_out, Path::new("nap.proto")).unwrap();
apply_message_attributes(pb_out).unwrap();
impl_struct_conversions(
pb_out,
Path::new("../trigger-protocol/src/lib.rs"),
Path::new("out/protocol_map.rs"),
)
.unwrap();
}
}
fn impl_struct_conversions(
protobuf_codegen_path: &Path,
protocol_structures_path: &Path,
output_path: &Path,
) -> std::io::Result<()> {
let pb_gen = syn::parse_file(&fs::read_to_string(protobuf_codegen_path)?).unwrap();
let custom_structs = syn::parse_file(&fs::read_to_string(protocol_structures_path)?).unwrap();
let mut from_impls = proc_macro2::TokenStream::new();
let mut pb_to_cp_unit_match_arms = proc_macro2::TokenStream::new();
let mut cp_unit_to_pb_match_arms = proc_macro2::TokenStream::new();
for item in pb_gen.items.iter() {
let Item::Struct(pb) = item else {
continue;
};
let ident = &pb.ident;
// if it has cmdid, then we want a PB <-> Common Protocol Unit conversion
if pb.attrs.iter().any(|attr| {
attr.meta
.path()
.get_ident()
.map(|i| i == "cmdid")
.unwrap_or(false)
// Should be also defined in list of custom Common Protocol structures
}) && custom_structs.items.iter().any(|st| {
if let Item::Struct(cst) = st {
&cst.ident == ident
} else {
false
}
}) {
pb_to_cp_unit_match_arms.extend(quote! {
#ident::CMD_ID => {
let mut pb_message = #ident::decode(pb)?;
pb_message.xor_fields();
let common_protocol_message = ::trigger_protocol::#ident::from(pb_message);
Ok(Some(common_protocol_message.into()))
}
});
cp_unit_to_pb_match_arms.extend(quote! {
::trigger_protocol::#ident::CMD_ID => {
let common_protocol_message = ::trigger_protocol::#ident::decode(&mut ::std::io::Cursor::new(&unit.blob))?;
let mut pb_message = #ident::from(common_protocol_message);
pb_message.xor_fields();
Ok(Some((#ident::CMD_ID, pb_message.encode_to_vec())))
}
});
}
let Some(Item::Struct(c_struct)) = custom_structs.items.iter().find(|i| {
if let Item::Struct(s) = i {
&s.ident == ident
} else {
false
}
}) else {
continue;
};
let mut assignments = proc_macro2::TokenStream::new();
pb
.fields
.iter()
.filter(|pb_field| c_struct.fields.iter().any(|c_field| c_field.ident == pb_field.ident))
.map(|f| (f, f.ident.as_ref().unwrap()))
.for_each(|(f, ident)| {
if field_is_optional(f) {
assignments.extend(quote! {
#ident: value.#ident.map(|v| v.into()),
});
}
else if field_is_vector(f) {
assignments.extend(quote! {
#ident: value.#ident.into_iter().map(|v| v.into()).collect(),
});
}
else if field_is_hash_map(f) {
assignments.extend(quote! {
#ident: value.#ident.into_iter().map(|(k, v)| (k.into(), v.into())).collect(),
});
}
else {
assignments.extend(quote! {
#ident: value.#ident.into(),
});
}
});
from_impls.extend(quote! {
#[allow(unused)]
impl From<#ident> for ::trigger_protocol::#ident {
fn from(value: #ident) -> Self {
Self {
#assignments
..Default::default()
}
}
}
#[allow(unused)]
impl From<::trigger_protocol::#ident> for #ident {
fn from(value: ::trigger_protocol::#ident) -> Self {
Self {
#assignments
..Default::default()
}
}
}
});
}
let generated_code = quote! {
pub fn pb_to_common_protocol_unit(pb_cmd_id: u16, pb: &[u8]) -> Result<Option<::trigger_protocol::util::ProtocolUnit>, crate::ProtobufDecodeError> {
match pb_cmd_id {
#pb_to_cp_unit_match_arms
_ => Ok(None),
}
}
pub fn common_protocol_unit_to_pb(unit: &::trigger_protocol::util::ProtocolUnit) -> ::std::io::Result<Option<(u16, Vec<u8>)>> {
use ::trigger_encoding::Decodeable;
use ::trigger_protocol::ClientCmdID;
match unit.cmd_id {
#cp_unit_to_pb_match_arms
_ => Ok(None),
}
}
#from_impls
};
std::fs::write(
output_path,
&prettyplease::unparse(&syn::parse2(generated_code).unwrap()),
)?;
Ok(())
}
fn remove_obfuscated_messages(codegen_path: &Path, proto_path: &Path) -> std::io::Result<()> {
// Remove unused obfuscated messages to reduce compilation time
let mut obfuscated_names = HashSet::new();
let proto_file = File::open(proto_path)?;
let reader = BufReader::new(proto_file);
for line in reader.lines() {
let line = line?;
let name = if line.starts_with("message ") {
line.split(' ').nth(1).unwrap()
} else if line.contains('=') {
let split = line.trim().split('=').nth(0).unwrap().trim().split(' ');
split.last().unwrap()
} else {
continue;
};
// length=11 & uppercase => obfuscated name
if name.len() == 11 && !name.chars().fold(false, |b, ch| b || ch.is_lowercase()) {
obfuscated_names.insert(name.to_string());
}
}
let codegen = syn::parse_file(&std::fs::read_to_string(codegen_path).unwrap()).unwrap();
let mut optimized = proc_macro2::TokenStream::new();
for item in codegen.items.iter() {
match item {
syn::Item::Struct(st) => {
let name = &st.ident;
let attrs = &st.attrs;
if obfuscated_names.contains(&name.to_string().to_uppercase()) {
continue;
}
let mut struct_fields = proc_macro2::TokenStream::new();
for field in st.fields.iter() {
if obfuscated_names
.contains(&field.ident.as_ref().unwrap().to_string().to_uppercase())
{
continue;
}
let ft = extract_type(&field.ty);
if !obfuscated_names.contains(&ft.to_uppercase()) {
struct_fields.extend(quote::quote! {
#field,
});
}
}
optimized.extend(quote::quote! {
#(#attrs)*
pub struct #name {
#struct_fields
}
});
}
syn::Item::Mod(module) => {
if obfuscated_names.contains(&module.ident.to_string().to_uppercase()) {
continue;
}
let mod_name = &module.ident;
let mut mod_items = proc_macro2::TokenStream::new();
for item in module.content.as_ref().unwrap().1.iter() {
if let Item::Enum(en) = &item {
let attrs = &en.attrs;
let name = &en.ident;
let mut variants = Vec::new();
for variant in en.variants.iter() {
let field = variant.fields.iter().next().unwrap();
if !obfuscated_names.contains(&extract_type(&field.ty).to_uppercase()) {
variants.push(variant);
}
}
mod_items.extend(quote::quote! {
#(#attrs)*
pub enum #name {
#(#variants),*
}
});
}
}
optimized.extend(quote::quote! {
pub mod #mod_name {
#mod_items
}
});
}
_ => {
optimized.extend(item.to_token_stream());
}
}
}
std::fs::write(
codegen_path,
prettyplease::unparse(&syn::parse2(optimized).unwrap()),
)?;
Ok(())
}
fn apply_message_attributes(path: &Path) -> std::io::Result<()> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let mut output = Vec::new();
let mut cmd_id_attr = None;
for line in reader.lines() {
let line = line?;
if line.contains("xor const: ") {
if !line.contains("xor const: 0") {
output.push(make_xor_attr(&line).unwrap());
}
} else if line.contains("CmdID: ") {
cmd_id_attr = Some(make_cmd_id_attr(&line).unwrap());
} else {
output.push(line);
if let Some(attr) = cmd_id_attr.take() {
output.push(attr);
}
}
}
fs::write(path, output.join("\n").as_bytes())?;
Ok(())
}
#[must_use]
fn field_is_optional(field: &Field) -> bool {
match &field.ty {
Type::Path(TypePath { path, .. }) => {
let last_segment = path.segments.last().unwrap();
last_segment.ident == "Option"
}
_ => panic!("Unsupported field type"),
}
}
#[must_use]
fn field_is_vector(field: &Field) -> bool {
match &field.ty {
Type::Path(TypePath { path, .. }) => {
let last_segment = path.segments.last().unwrap();
last_segment.ident == "Vec"
}
_ => panic!("Unsupported field type"),
}
}
#[must_use]
fn field_is_hash_map(field: &Field) -> bool {
match &field.ty {
Type::Path(TypePath { path, .. }) => {
let last_segment = path.segments.last().unwrap();
last_segment.ident == "HashMap"
}
_ => panic!("Unsupported field type"),
}
}
fn make_xor_attr(line: &str) -> Option<String> {
let xor_value = line.split("xor const: ").nth(1)?.parse::<u32>().ok()?;
Some(format!(" #[xor({xor_value})]"))
}
fn make_cmd_id_attr(line: &str) -> Option<String> {
let cmd_id = line.split("CmdID: ").nth(1)?.parse::<u16>().ok()?;
Some(format!("#[cmdid({cmd_id})]"))
}
#[must_use]
fn get_type_name(path: &syn::Path) -> String {
path.segments.last().unwrap().ident.to_string()
}
#[must_use]
fn extract_type(ty: &Type) -> String {
match &ty {
Type::Path(TypePath { path, .. }) => {
let last_segment = path.segments.last().unwrap();
match &last_segment.arguments {
PathArguments::AngleBracketed(args) => {
if let Some(GenericArgument::Type(Type::Path(TypePath { path, .. }))) =
args.args.last()
{
get_type_name(path)
} else {
get_type_name(path)
}
}
_ => get_type_name(path),
}
}
_ => panic!("Unsupported field type"),
}
}