use std::sync::{atomic::AtomicU32, Arc, LazyLock}; use thiserror::Error; use byteorder::{BE, LE}; use common::{ cryptography::{Ec2b, MhyXorpad}, util, }; use proto::{CmdID, NapMessage, PacketHead}; use tokio::{ io::AsyncWriteExt, net::TcpStream, sync::{Mutex, MutexGuard, OnceCell}, }; use crate::{ handlers::{self, PacketHandlingError}, ServerState, }; use super::packet::DecodeError; use super::NetPacket; static SECRET_KEY: LazyLock = LazyLock::new(|| { let ec2b = Ec2b::read(&mut util::open_secret_key().expect("Failed to open secret key file")) .expect("Failed to read Ec2b data"); MhyXorpad::new::(ec2b.derive_seed()) }); pub struct NetSession { id: u64, stream: Mutex, session_key: OnceCell, packet_id_counter: AtomicU32, state: AtomicNetSessionState, account_uid: OnceCell, player_uid: OnceCell, } #[atomic_enum::atomic_enum] #[derive(PartialEq, Eq, PartialOrd)] pub enum NetSessionState { StartEnterGameWorld, PlayerGetTokenCsReq, PlayerGetTokenScRsp, PlayerLoginCsReq, PlayerLoginScRsp, StartBasicsReq, EndBasicsReq, EnterWorldScRsp, } impl NetSessionState { pub fn is_command_allowed(&self, cmd_id: u16) -> bool { match cmd_id { proto::PlayerGetTokenCsReq::CMD_ID => *self == NetSessionState::StartEnterGameWorld, proto::PlayerLoginCsReq::CMD_ID => *self == NetSessionState::PlayerGetTokenScRsp, _ => *self >= NetSessionState::StartBasicsReq, } } pub fn is_auth(&self) -> bool { *self < NetSessionState::PlayerLoginScRsp } } #[derive(Error, Debug)] pub enum SessionError { #[error("NetPacket decode failed: {0}")] PacketDecode(#[from] DecodeError), #[error("failed to handle packet: {0}")] PacketHandling(#[from] PacketHandlingError), } impl NetSession { pub fn new(id: u64, stream: TcpStream) -> Self { Self { id, stream: Mutex::new(stream), session_key: OnceCell::new(), packet_id_counter: AtomicU32::new(0), state: AtomicNetSessionState::new(NetSessionState::StartEnterGameWorld), account_uid: OnceCell::new(), player_uid: OnceCell::new(), } } pub async fn run(&self, state: Arc) -> Result<(), SessionError> { let mut last_save_time = util::cur_timestamp(); let result = loop { let packet = match NetPacket::read(&mut *self.stream().await).await { Ok(packet) => packet, Err(DecodeError::IoError(_)) => break Ok(()), Err(err) => break Err(SessionError::PacketDecode(err)), }; match self.handle_packet(packet, &state).await { Ok(()) => (), Err(PacketHandlingError::Logout) => break Ok(()), Err(err) => break Err(SessionError::PacketHandling(err)), } if let Some(uid) = self.player_uid.get() { if (util::cur_timestamp() - last_save_time) >= state.config.player_save_period_seconds { state.player_mgr.save(*uid).await; last_save_time = util::cur_timestamp(); } } }; self.on_disconnect(&state).await; result } async fn handle_packet( &self, mut packet: NetPacket, state: &ServerState, ) -> Result<(), PacketHandlingError> { self.xor_payload(packet.cmd_id, &mut packet.body); let net_state = self.state.load(std::sync::atomic::Ordering::SeqCst); if !net_state.is_command_allowed(packet.cmd_id) { tracing::warn!( "received cmd_id ({}) is not allowed in current state ({:?})", packet.cmd_id, self.state.load(std::sync::atomic::Ordering::SeqCst) ); } else if net_state.is_auth() { if !handlers::handle_auth_request(self, &packet, state).await? { tracing::warn!( "[LOGIN] packet with cmd_id={} wasn't handled, body: {}", packet.cmd_id, hex::encode(&packet.body) ); } } else if !handlers::handle_request(self, &packet, state).await? { if !handlers::handle_notify(self, &packet, state).await? { tracing::warn!( "packet with cmd_id={} wasn't handled, body: {}", packet.cmd_id, hex::encode(&packet.body) ); } } Ok(()) } async fn on_disconnect(&self, state: &ServerState) { state.session_mgr.remove(self.id); if let Some(player_uid) = self.player_uid.get() { state.player_mgr.save_and_remove(*player_uid).await; } } pub async fn notify(&self, mut ntf: impl NapMessage) -> Result<(), std::io::Error> { ntf.xor_fields(); self.send(NetPacket { cmd_id: ntf.get_cmd_id(), head: PacketHead { packet_id: self.next_packet_id(), ..Default::default() }, body: ntf.encode_to_vec().into_boxed_slice(), }) .await } pub async fn send_rsp( &self, request_id: u32, mut rsp: impl NapMessage, ) -> Result<(), std::io::Error> { rsp.xor_fields(); self.send(NetPacket { cmd_id: rsp.get_cmd_id(), head: PacketHead { packet_id: self.next_packet_id(), request_id, ..Default::default() }, body: rsp.encode_to_vec().into_boxed_slice(), }) .await } async fn send(&self, mut packet: NetPacket) -> Result<(), std::io::Error> { self.xor_payload(packet.cmd_id, &mut packet.body); let buf = packet.encode(); self.stream().await.write_all(&buf).await } async fn stream(&self) -> MutexGuard<'_, TcpStream> { self.stream.lock().await } pub fn id(&self) -> u64 { self.id } pub fn set_session_key(&self, seed: u64) { let _ = self.session_key.set(MhyXorpad::new::(seed)); } pub fn account_uid(&self) -> Option<&String> { self.account_uid.get() } pub fn set_account_uid(&self, uid: String) -> bool { self.account_uid.set(uid).is_ok() } pub fn player_uid(&self) -> Option<&u32> { self.player_uid.get() } pub fn set_player_uid(&self, uid: u32) -> bool { self.player_uid.set(uid).is_ok() } fn xor_payload(&self, cmd_id: u16, buf: &mut [u8]) { let key = match self.session_key.get() { _ if cmd_id == proto::PlayerGetTokenScRsp::CMD_ID => &*SECRET_KEY, Some(key) => key, None => &*SECRET_KEY, }; key.xor(buf); } fn next_packet_id(&self) -> u32 { self.packet_id_counter .fetch_add(1, std::sync::atomic::Ordering::SeqCst) } pub fn set_state(&self, state: NetSessionState) { self.state.store(state, std::sync::atomic::Ordering::SeqCst); } }