// The `quote!` macro requires deep recursion. #![recursion_limit = "4096"] extern crate alloc; extern crate proc_macro; use anyhow::{Error, bail}; use itertools::Itertools; use proc_macro2::{Span, TokenStream}; use quote::{ToTokens, quote}; use syn::{ Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed, FieldsUnnamed, Ident, Index, Meta, MetaList, Variant, parse_macro_input, punctuated::Punctuated, }; mod field; use crate::field::Field; fn try_message(input: TokenStream) -> Result { let input: DeriveInput = syn::parse2(input)?; let ident = input.ident; syn::custom_keyword!(skip_debug); let skip_debug = input .attrs .into_iter() .any(|a| a.path().is_ident("prost") && a.parse_args::().is_ok()); let variant_data = match input.data { Data::Struct(variant_data) => variant_data, Data::Enum(..) => bail!("Message can not be derived for an enum"), Data::Union(..) => bail!("Message can not be derived for a union"), }; let generics = &input.generics; let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let (is_struct, fields) = match variant_data { DataStruct { fields: Fields::Named(FieldsNamed { named: fields, .. }), .. } => (true, fields.into_iter().collect()), DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }), .. } => (false, fields.into_iter().collect()), DataStruct { fields: Fields::Unit, .. } => (false, Vec::new()), }; let mut next_tag: u32 = 1; let mut fields = fields .into_iter() .enumerate() .flat_map(|(i, field)| { let field_ident = field.ident.map(|x| quote!(#x)).unwrap_or_else(|| { let index = Index { index: i as u32, span: Span::call_site(), }; quote!(#index) }); match Field::new(field.attrs, Some(next_tag)) { Ok(Some(field)) => { next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag); Some(Ok((field_ident, field))) } Ok(None) => None, Err(err) => Some(Err( err.context(format!("invalid message field {}.{}", ident, field_ident)) )), } }) .collect::, _>>()?; // We want Debug to be in declaration order let unsorted_fields = fields.clone(); // Sort the fields by tag number so that fields will be encoded in tag order. // TODO: This encodes oneof fields in the position of their lowest tag, // regardless of the currently occupied variant, is that consequential? // See: https://developers.google.com/protocol-buffers/docs/encoding#order fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap_or(0)); let fields = fields; if let Some(duplicate_tag) = fields .iter() .flat_map(|(_, field)| field.tags()) .duplicates() .next() { bail!( "message {} has multiple fields with tag {}", ident, duplicate_tag ) }; let encoded_len = fields .iter() .filter(|(_, field)| !matches!(field, Field::Ignored)) .map(|(field_ident, field)| field.encoded_len(quote!(self.#field_ident))); let encode = fields .iter() .filter(|(_, field)| !matches!(field, Field::Ignored)) .map(|(field_ident, field)| field.encode(quote!(self.#field_ident))); let merge = fields .iter() .filter(|(_, field)| !matches!(field, Field::Ignored)) .map(|(field_ident, field)| { let merge = field.merge(quote!(value)); let tags = field.tags().into_iter().map(|tag| quote!(#tag)); let tags = Itertools::intersperse(tags, quote!(|)); quote! { #(#tags)* => { let mut value = &mut self.#field_ident; #merge.map_err(|mut error| { error.push(STRUCT_NAME, stringify!(#field_ident)); error }) }, } }); let struct_name = if fields.is_empty() { quote!() } else { quote!( const STRUCT_NAME: &'static str = stringify!(#ident); ) }; let clear = fields .iter() .filter(|(_, field)| !matches!(field, Field::Ignored)) .map(|(field_ident, field)| field.clear(quote!(self.#field_ident))); let default = if is_struct { let default = fields.iter().map(|(field_ident, field)| { let value = field.default(); quote!(#field_ident: #value,) }); quote! {#ident { #(#default)* }} } else { let default = fields.iter().map(|(_, field)| { let value = field.default(); quote!(#value,) }); quote! {#ident ( #(#default)* )} }; let methods = fields .iter() .flat_map(|(field_ident, field)| field.methods(field_ident)) .collect::>(); let methods = if methods.is_empty() { quote!() } else { quote! { #[allow(dead_code)] impl #impl_generics #ident #ty_generics #where_clause { #(#methods)* } } }; let expanded = quote! { impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause { #[allow(unused_variables)] fn encode_raw(&self, buf: &mut impl ::prost::bytes::BufMut) { #(#encode)* } #[allow(unused_variables)] fn merge_field( &mut self, tag: u32, wire_type: ::prost::encoding::wire_type::WireType, buf: &mut impl ::prost::bytes::Buf, ctx: ::prost::encoding::DecodeContext, ) -> ::core::result::Result<(), ::prost::DecodeError> { #struct_name match tag { #(#merge)* _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx), } } #[inline] fn encoded_len(&self) -> usize { 0 #(+ #encoded_len)* } fn clear(&mut self) { #(#clear;)* } } impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause { fn default() -> Self { #default } } }; let expanded = if skip_debug { expanded } else { let debugs = unsorted_fields.iter().map(|(field_ident, field)| { let wrapper = field.debug(quote!(self.#field_ident)); let call = if is_struct { quote!(builder.field(stringify!(#field_ident), &wrapper)) } else { quote!(builder.field(&wrapper)) }; quote! { let builder = { let wrapper = #wrapper; #call }; } }); let debug_builder = if is_struct { quote!(f.debug_struct(stringify!(#ident))) } else { quote!(f.debug_tuple(stringify!(#ident))) }; quote! { #expanded impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { let mut builder = #debug_builder; #(#debugs;)* builder.finish() } } } }; let expanded = quote! { #expanded #methods }; Ok(expanded) } #[proc_macro_derive(Message, attributes(prost))] pub fn message(input: proc_macro::TokenStream) -> proc_macro::TokenStream { try_message(input.into()).unwrap().into() } fn try_enumeration(input: TokenStream) -> Result { let input: DeriveInput = syn::parse2(input)?; let ident = input.ident; let generics = &input.generics; let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let punctuated_variants = match input.data { Data::Enum(DataEnum { variants, .. }) => variants, Data::Struct(_) => bail!("Enumeration can not be derived for a struct"), Data::Union(..) => bail!("Enumeration can not be derived for a union"), }; // Map the variants into 'fields'. let mut variants: Vec<(Ident, Expr)> = Vec::new(); for Variant { ident, fields, discriminant, .. } in punctuated_variants { match fields { Fields::Unit => (), Fields::Named(_) | Fields::Unnamed(_) => { bail!("Enumeration variants may not have fields") } } match discriminant { Some((_, expr)) => variants.push((ident, expr)), None => bail!("Enumeration variants must have a discriminant"), } } if variants.is_empty() { panic!("Enumeration must have at least one variant"); } let default = variants[0].0.clone(); let is_valid = variants.iter().map(|(_, value)| quote!(#value => true)); let from = variants .iter() .map(|(variant, value)| quote!(#value => ::core::option::Option::Some(#ident::#variant))); let try_from = variants .iter() .map(|(variant, value)| quote!(#value => ::core::result::Result::Ok(#ident::#variant))); let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident); let from_i32_doc = format!( "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.", ident ); let expanded = quote! { impl #impl_generics #ident #ty_generics #where_clause { #[doc=#is_valid_doc] pub fn is_valid(value: i32) -> bool { match value { #(#is_valid,)* _ => false, } } #[deprecated = "Use the TryFrom implementation instead"] #[doc=#from_i32_doc] pub fn from_i32(value: i32) -> ::core::option::Option<#ident> { match value { #(#from,)* _ => ::core::option::Option::None, } } } impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause { fn default() -> #ident { #ident::#default } } impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause { fn from(value: #ident) -> i32 { value as i32 } } impl #impl_generics ::core::convert::TryFrom:: for #ident #ty_generics #where_clause { type Error = ::prost::UnknownEnumValue; fn try_from(value: i32) -> ::core::result::Result<#ident, ::prost::UnknownEnumValue> { match value { #(#try_from,)* _ => ::core::result::Result::Err(::prost::UnknownEnumValue(value)), } } } }; Ok(expanded) } #[proc_macro_derive(Enumeration, attributes(prost))] pub fn enumeration(input: proc_macro::TokenStream) -> proc_macro::TokenStream { try_enumeration(input.into()).unwrap().into() } fn try_oneof(input: TokenStream) -> Result { let input: DeriveInput = syn::parse2(input)?; let ident = input.ident; syn::custom_keyword!(skip_debug); let skip_debug = input .attrs .into_iter() .any(|a| a.path().is_ident("prost") && a.parse_args::().is_ok()); let variants = match input.data { Data::Enum(DataEnum { variants, .. }) => variants, Data::Struct(..) => bail!("Oneof can not be derived for a struct"), Data::Union(..) => bail!("Oneof can not be derived for a union"), }; let generics = &input.generics; let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); // Map the variants into 'fields'. let mut fields: Vec<(Ident, Field)> = Vec::new(); for Variant { attrs, ident: variant_ident, fields: variant_fields, .. } in variants { let variant_fields = match variant_fields { Fields::Unit => Punctuated::new(), Fields::Named(FieldsNamed { named: fields, .. }) | Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) => fields, }; if variant_fields.len() != 1 { bail!("Oneof enum variants must have a single field"); } match Field::new_oneof(attrs)? { Some(field) => fields.push((variant_ident, field)), None => bail!("invalid oneof variant: oneof variants may not be ignored"), } } // Oneof variants cannot be oneofs themselves, so it's impossible to have a field with multiple // tags. assert!(fields.iter().all(|(_, field)| field.tags().len() <= 1)); if let Some(duplicate_tag) = fields .iter() .flat_map(|(_, field)| field.tags()) .duplicates() .next() { bail!( "invalid oneof {}: multiple variants have tag {}", ident, duplicate_tag ); } let encode = fields .iter() .filter(|(_, field)| !matches!(field, Field::Ignored)) .map(|(variant_ident, field)| { let encode = field.encode(quote!(*value)); quote!(#ident::#variant_ident(ref value) => { #encode }) }); let merge = fields.iter() .filter(|(_, field)| !matches!(field, Field::Ignored)) .map(|(variant_ident, field)| { if !field.tags().is_empty() { let tag = field.tags()[0]; let merge = field.merge(quote!(value)); quote! { #tag => if let ::core::option::Option::Some(#ident::#variant_ident(value)) = field { #merge } else { let mut owned_value = ::core::default::Default::default(); let value = &mut owned_value; #merge.map(|_| *field = ::core::option::Option::Some(#ident::#variant_ident(owned_value))) } } } else { TokenStream::new() } }); let encoded_len = fields .iter() .filter(|(_, field)| !matches!(field, Field::Ignored)) .map(|(variant_ident, field)| { let encoded_len = field.encoded_len(quote!(*value)); quote!(#ident::#variant_ident(ref value) => #encoded_len) }); let expanded = quote! { impl #impl_generics #ident #ty_generics #where_clause { /// Encodes the message to a buffer. pub fn encode(&self, buf: &mut impl ::prost::bytes::BufMut) { match *self { #(#encode,)* } } /// Decodes an instance of the message from a buffer, and merges it into self. pub fn merge( field: &mut ::core::option::Option<#ident #ty_generics>, tag: u32, wire_type: ::prost::encoding::wire_type::WireType, buf: &mut impl ::prost::bytes::Buf, ctx: ::prost::encoding::DecodeContext, ) -> ::core::result::Result<(), ::prost::DecodeError> { match tag { #(#merge,)* _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag), } } /// Returns the encoded length of the message without a length delimiter. #[inline] pub fn encoded_len(&self) -> usize { match *self { #(#encoded_len,)* } } } }; let expanded = if skip_debug { expanded } else { let debug = fields .iter() .filter(|(_, field)| !matches!(field, Field::Ignored)) .map(|(variant_ident, field)| { let wrapper = field.debug(quote!(*value)); quote!(#ident::#variant_ident(ref value) => { let wrapper = #wrapper; f.debug_tuple(stringify!(#variant_ident)) .field(&wrapper) .finish() }) }); quote! { #expanded impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { match *self { #(#debug,)* } } } } }; Ok(expanded) } #[proc_macro_derive(Oneof, attributes(prost))] pub fn oneof(input: proc_macro::TokenStream) -> proc_macro::TokenStream { try_oneof(input.into()).unwrap().into() } #[proc_macro_derive(NetCmd, attributes(cmd_id))] pub fn derive_net_cmd(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); let struct_name = input.ident; let id = match input .attrs .iter() .find(|attr| attr.path().is_ident("cmd_id")) { Some(attr) => match attr.meta { Meta::List(MetaList { ref tokens, .. }) => tokens.into_token_stream(), _ => panic!("Invalid cmd_id attribute value"), }, _ => 0u16.into_token_stream(), }; quote! { impl crate::NetCmd for #struct_name { const CMD_ID: u16 = #id; } } .into() } #[proc_macro_derive(NetResponse)] pub fn derive_net_response(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); let struct_name = input.ident; quote! { impl crate::NetResponse for #struct_name { fn set_retcode(&mut self, retcode: i32) { self.retcode = retcode; } fn get_retcode(&self) -> i32 { self.retcode } } } .into() } #[cfg(test)] mod test { use crate::{try_message, try_oneof}; use quote::quote; #[test] fn test_rejects_colliding_message_fields() { let output = try_message(quote!( struct Invalid { #[prost(bool, tag = "1")] a: bool, #[prost(oneof = "super::Whatever", tags = "4, 5, 1")] b: Option, } )); assert_eq!( output .expect_err("did not reject colliding message fields") .to_string(), "message Invalid has multiple fields with tag 1" ); } #[test] fn test_rejects_colliding_oneof_variants() { let output = try_oneof(quote!( pub enum Invalid { #[prost(bool, tag = "1")] A(bool), #[prost(bool, tag = "3")] B(bool), #[prost(bool, tag = "1")] C(bool), } )); assert_eq!( output .expect_err("did not reject colliding oneof variants") .to_string(), "invalid oneof Invalid: multiple variants have tag 1" ); } #[test] fn test_rejects_multiple_tags_oneof_variant() { let output = try_oneof(quote!( enum What { #[prost(bool, tag = "1", tag = "2")] A(bool), } )); assert_eq!( output .expect_err("did not reject multiple tags on oneof variant") .to_string(), "duplicate tag attributes: 1 and 2" ); let output = try_oneof(quote!( enum What { #[prost(bool, tag = "3")] #[prost(tag = "4")] A(bool), } )); assert!(output.is_err()); assert_eq!( output .expect_err("did not reject multiple tags on oneof variant") .to_string(), "duplicate tag attributes: 3 and 4" ); let output = try_oneof(quote!( enum What { #[prost(bool, tags = "5,6")] A(bool), } )); assert!(output.is_err()); assert_eq!( output .expect_err("did not reject multiple tags on oneof variant") .to_string(), "unknown attribute(s): #[prost(tags = \"5,6\")]" ); } }