385 lines
12 KiB
Rust
385 lines
12 KiB
Rust
use std::{
|
|
collections::HashSet,
|
|
fs::{self, File},
|
|
io::{BufRead, BufReader},
|
|
path::Path,
|
|
};
|
|
|
|
use quote::{quote, ToTokens};
|
|
use syn::{Field, GenericArgument, Item, PathArguments, Type, TypePath};
|
|
|
|
fn main() {
|
|
let proto_file = "nap.proto";
|
|
if Path::new(&proto_file).exists() {
|
|
println!("cargo:rerun-if-changed={proto_file}");
|
|
let _ = fs::create_dir("out/");
|
|
|
|
prost_build::Config::new()
|
|
.out_dir("out/")
|
|
.type_attribute(".", "#[derive(trigger_protobuf_derive::CmdID)]")
|
|
.type_attribute(".", "#[derive(trigger_protobuf_derive::XorFields)]")
|
|
.compile_protos(&[proto_file], &["."])
|
|
.unwrap();
|
|
|
|
let pb_out = Path::new("out/_.rs");
|
|
remove_obfuscated_messages(pb_out, Path::new("nap.proto")).unwrap();
|
|
apply_message_attributes(pb_out).unwrap();
|
|
impl_struct_conversions(
|
|
pb_out,
|
|
Path::new("../trigger-protocol/src/lib.rs"),
|
|
Path::new("out/protocol_map.rs"),
|
|
)
|
|
.unwrap();
|
|
}
|
|
}
|
|
|
|
fn impl_struct_conversions(
|
|
protobuf_codegen_path: &Path,
|
|
protocol_structures_path: &Path,
|
|
output_path: &Path,
|
|
) -> std::io::Result<()> {
|
|
let pb_gen = syn::parse_file(&fs::read_to_string(protobuf_codegen_path)?).unwrap();
|
|
let custom_structs = syn::parse_file(&fs::read_to_string(protocol_structures_path)?).unwrap();
|
|
|
|
let mut from_impls = proc_macro2::TokenStream::new();
|
|
let mut pb_to_cp_unit_match_arms = proc_macro2::TokenStream::new();
|
|
let mut cp_unit_to_pb_match_arms = proc_macro2::TokenStream::new();
|
|
|
|
for item in pb_gen.items.iter() {
|
|
let Item::Struct(pb) = item else {
|
|
continue;
|
|
};
|
|
|
|
let ident = &pb.ident;
|
|
|
|
// if it has cmdid, then we want a PB <-> Common Protocol Unit conversion
|
|
if pb.attrs.iter().any(|attr| {
|
|
attr.meta
|
|
.path()
|
|
.get_ident()
|
|
.map(|i| i == "cmdid")
|
|
.unwrap_or(false)
|
|
// Should be also defined in list of custom Common Protocol structures
|
|
}) && custom_structs.items.iter().any(|st| {
|
|
if let Item::Struct(cst) = st {
|
|
&cst.ident == ident
|
|
} else {
|
|
false
|
|
}
|
|
}) {
|
|
pb_to_cp_unit_match_arms.extend(quote! {
|
|
#ident::CMD_ID => {
|
|
let mut pb_message = #ident::decode(pb)?;
|
|
pb_message.xor_fields();
|
|
|
|
let common_protocol_message = ::trigger_protocol::#ident::from(pb_message);
|
|
Ok(Some(common_protocol_message.into()))
|
|
}
|
|
});
|
|
|
|
cp_unit_to_pb_match_arms.extend(quote! {
|
|
::trigger_protocol::#ident::CMD_ID => {
|
|
let common_protocol_message = ::trigger_protocol::#ident::decode(&mut ::std::io::Cursor::new(&unit.blob))?;
|
|
let mut pb_message = #ident::from(common_protocol_message);
|
|
pb_message.xor_fields();
|
|
|
|
Ok(Some((#ident::CMD_ID, pb_message.encode_to_vec())))
|
|
}
|
|
});
|
|
}
|
|
|
|
let Some(Item::Struct(c_struct)) = custom_structs.items.iter().find(|i| {
|
|
if let Item::Struct(s) = i {
|
|
&s.ident == ident
|
|
} else {
|
|
false
|
|
}
|
|
}) else {
|
|
continue;
|
|
};
|
|
|
|
let mut assignments = proc_macro2::TokenStream::new();
|
|
|
|
pb
|
|
.fields
|
|
.iter()
|
|
.filter(|pb_field| c_struct.fields.iter().any(|c_field| c_field.ident == pb_field.ident))
|
|
.map(|f| (f, f.ident.as_ref().unwrap()))
|
|
.for_each(|(f, ident)| {
|
|
if field_is_optional(f) {
|
|
assignments.extend(quote! {
|
|
#ident: value.#ident.map(|v| v.into()),
|
|
});
|
|
}
|
|
else if field_is_vector(f) {
|
|
assignments.extend(quote! {
|
|
#ident: value.#ident.into_iter().map(|v| v.into()).collect(),
|
|
});
|
|
}
|
|
else if field_is_hash_map(f) {
|
|
assignments.extend(quote! {
|
|
#ident: value.#ident.into_iter().map(|(k, v)| (k.into(), v.into())).collect(),
|
|
});
|
|
}
|
|
else {
|
|
assignments.extend(quote! {
|
|
#ident: value.#ident.into(),
|
|
});
|
|
}
|
|
});
|
|
|
|
from_impls.extend(quote! {
|
|
#[allow(unused)]
|
|
impl From<#ident> for ::trigger_protocol::#ident {
|
|
fn from(value: #ident) -> Self {
|
|
Self {
|
|
#assignments
|
|
..Default::default()
|
|
}
|
|
}
|
|
}
|
|
|
|
#[allow(unused)]
|
|
impl From<::trigger_protocol::#ident> for #ident {
|
|
fn from(value: ::trigger_protocol::#ident) -> Self {
|
|
Self {
|
|
#assignments
|
|
..Default::default()
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
let generated_code = quote! {
|
|
pub fn pb_to_common_protocol_unit(pb_cmd_id: u16, pb: &[u8]) -> Result<Option<::trigger_protocol::util::ProtocolUnit>, crate::ProtobufDecodeError> {
|
|
match pb_cmd_id {
|
|
#pb_to_cp_unit_match_arms
|
|
_ => Ok(None),
|
|
}
|
|
}
|
|
|
|
pub fn common_protocol_unit_to_pb(unit: &::trigger_protocol::util::ProtocolUnit) -> ::std::io::Result<Option<(u16, Vec<u8>)>> {
|
|
use ::trigger_encoding::Decodeable;
|
|
use ::trigger_protocol::ClientCmdID;
|
|
|
|
match unit.cmd_id {
|
|
#cp_unit_to_pb_match_arms
|
|
_ => Ok(None),
|
|
}
|
|
}
|
|
|
|
#from_impls
|
|
};
|
|
|
|
std::fs::write(
|
|
output_path,
|
|
&prettyplease::unparse(&syn::parse2(generated_code).unwrap()),
|
|
)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn remove_obfuscated_messages(codegen_path: &Path, proto_path: &Path) -> std::io::Result<()> {
|
|
// Remove unused obfuscated messages to reduce compilation time
|
|
|
|
let mut obfuscated_names = HashSet::new();
|
|
|
|
let proto_file = File::open(proto_path)?;
|
|
let reader = BufReader::new(proto_file);
|
|
|
|
for line in reader.lines() {
|
|
let line = line?;
|
|
let name = if line.starts_with("message ") {
|
|
line.split(' ').nth(1).unwrap()
|
|
} else if line.contains('=') {
|
|
let split = line.trim().split('=').nth(0).unwrap().trim().split(' ');
|
|
split.last().unwrap()
|
|
} else {
|
|
continue;
|
|
};
|
|
|
|
// length=11 & uppercase => obfuscated name
|
|
if name.len() == 11 && !name.chars().fold(false, |b, ch| b || ch.is_lowercase()) {
|
|
obfuscated_names.insert(name.to_string());
|
|
}
|
|
}
|
|
|
|
let codegen = syn::parse_file(&std::fs::read_to_string(codegen_path).unwrap()).unwrap();
|
|
let mut optimized = proc_macro2::TokenStream::new();
|
|
|
|
for item in codegen.items.iter() {
|
|
match item {
|
|
syn::Item::Struct(st) => {
|
|
let name = &st.ident;
|
|
let attrs = &st.attrs;
|
|
if obfuscated_names.contains(&name.to_string().to_uppercase()) {
|
|
continue;
|
|
}
|
|
|
|
let mut struct_fields = proc_macro2::TokenStream::new();
|
|
for field in st.fields.iter() {
|
|
if obfuscated_names
|
|
.contains(&field.ident.as_ref().unwrap().to_string().to_uppercase())
|
|
{
|
|
continue;
|
|
}
|
|
|
|
let ft = extract_type(&field.ty);
|
|
if !obfuscated_names.contains(&ft.to_uppercase()) {
|
|
struct_fields.extend(quote::quote! {
|
|
#field,
|
|
});
|
|
}
|
|
}
|
|
|
|
optimized.extend(quote::quote! {
|
|
#(#attrs)*
|
|
pub struct #name {
|
|
#struct_fields
|
|
}
|
|
});
|
|
}
|
|
syn::Item::Mod(module) => {
|
|
if obfuscated_names.contains(&module.ident.to_string().to_uppercase()) {
|
|
continue;
|
|
}
|
|
|
|
let mod_name = &module.ident;
|
|
let mut mod_items = proc_macro2::TokenStream::new();
|
|
for item in module.content.as_ref().unwrap().1.iter() {
|
|
if let Item::Enum(en) = &item {
|
|
let attrs = &en.attrs;
|
|
let name = &en.ident;
|
|
let mut variants = Vec::new();
|
|
for variant in en.variants.iter() {
|
|
let field = variant.fields.iter().next().unwrap();
|
|
if !obfuscated_names.contains(&extract_type(&field.ty).to_uppercase()) {
|
|
variants.push(variant);
|
|
}
|
|
}
|
|
|
|
mod_items.extend(quote::quote! {
|
|
#(#attrs)*
|
|
pub enum #name {
|
|
#(#variants),*
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
optimized.extend(quote::quote! {
|
|
pub mod #mod_name {
|
|
#mod_items
|
|
}
|
|
});
|
|
}
|
|
_ => {
|
|
optimized.extend(item.to_token_stream());
|
|
}
|
|
}
|
|
}
|
|
|
|
std::fs::write(
|
|
codegen_path,
|
|
prettyplease::unparse(&syn::parse2(optimized).unwrap()),
|
|
)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn apply_message_attributes(path: &Path) -> std::io::Result<()> {
|
|
let file = File::open(path)?;
|
|
let reader = BufReader::new(file);
|
|
let mut output = Vec::new();
|
|
|
|
let mut cmd_id_attr = None;
|
|
for line in reader.lines() {
|
|
let line = line?;
|
|
if line.contains("xor const: ") {
|
|
if !line.contains("xor const: 0") {
|
|
output.push(make_xor_attr(&line).unwrap());
|
|
}
|
|
} else if line.contains("CmdID: ") {
|
|
cmd_id_attr = Some(make_cmd_id_attr(&line).unwrap());
|
|
} else {
|
|
output.push(line);
|
|
if let Some(attr) = cmd_id_attr.take() {
|
|
output.push(attr);
|
|
}
|
|
}
|
|
}
|
|
|
|
fs::write(path, output.join("\n").as_bytes())?;
|
|
Ok(())
|
|
}
|
|
|
|
#[must_use]
|
|
fn field_is_optional(field: &Field) -> bool {
|
|
match &field.ty {
|
|
Type::Path(TypePath { path, .. }) => {
|
|
let last_segment = path.segments.last().unwrap();
|
|
last_segment.ident == "Option"
|
|
}
|
|
_ => panic!("Unsupported field type"),
|
|
}
|
|
}
|
|
|
|
#[must_use]
|
|
fn field_is_vector(field: &Field) -> bool {
|
|
match &field.ty {
|
|
Type::Path(TypePath { path, .. }) => {
|
|
let last_segment = path.segments.last().unwrap();
|
|
last_segment.ident == "Vec"
|
|
}
|
|
_ => panic!("Unsupported field type"),
|
|
}
|
|
}
|
|
|
|
#[must_use]
|
|
fn field_is_hash_map(field: &Field) -> bool {
|
|
match &field.ty {
|
|
Type::Path(TypePath { path, .. }) => {
|
|
let last_segment = path.segments.last().unwrap();
|
|
last_segment.ident == "HashMap"
|
|
}
|
|
_ => panic!("Unsupported field type"),
|
|
}
|
|
}
|
|
|
|
fn make_xor_attr(line: &str) -> Option<String> {
|
|
let xor_value = line.split("xor const: ").nth(1)?.parse::<u32>().ok()?;
|
|
Some(format!(" #[xor({xor_value})]"))
|
|
}
|
|
|
|
fn make_cmd_id_attr(line: &str) -> Option<String> {
|
|
let cmd_id = line.split("CmdID: ").nth(1)?.parse::<u16>().ok()?;
|
|
Some(format!("#[cmdid({cmd_id})]"))
|
|
}
|
|
|
|
#[must_use]
|
|
fn get_type_name(path: &syn::Path) -> String {
|
|
path.segments.last().unwrap().ident.to_string()
|
|
}
|
|
|
|
#[must_use]
|
|
fn extract_type(ty: &Type) -> String {
|
|
match &ty {
|
|
Type::Path(TypePath { path, .. }) => {
|
|
let last_segment = path.segments.last().unwrap();
|
|
match &last_segment.arguments {
|
|
PathArguments::AngleBracketed(args) => {
|
|
if let Some(GenericArgument::Type(Type::Path(TypePath { path, .. }))) =
|
|
args.args.last()
|
|
{
|
|
get_type_name(path)
|
|
} else {
|
|
get_type_name(path)
|
|
}
|
|
}
|
|
_ => get_type_name(path),
|
|
}
|
|
}
|
|
_ => panic!("Unsupported field type"),
|
|
}
|
|
}
|