diff --git a/repak/src/data.rs b/repak/src/data.rs index 42095c0..c1390fa 100644 --- a/repak/src/data.rs +++ b/repak/src/data.rs @@ -1,3 +1,5 @@ +use std::io::Write; + use crate::{ entry::{Block, Entry}, Compression, Error, Hash, Version, VersionMajor, @@ -5,17 +7,21 @@ use crate::{ type Result = std::result::Result; -pub(crate) struct PartialEntry { +pub(crate) struct PartialEntry> { compression: Option, compressed_size: u64, uncompressed_size: u64, compression_block_size: u32, - pub(crate) blocks: Vec, + data: PartialEntryData, hash: Hash, } pub(crate) struct PartialBlock { uncompressed_size: usize, - pub(crate) data: Vec, + data: Vec, +} +pub(crate) enum PartialEntryData { + Slice(D), + Blocks(Vec), } #[cfg(feature = "compression")] @@ -55,7 +61,7 @@ fn get_compression_slot( } as u32) } -impl PartialEntry { +impl> PartialEntry { pub(crate) fn build_entry( &self, version: Version, @@ -70,25 +76,30 @@ impl PartialEntry { #[cfg(not(feature = "compression"))] let compression_slot = None; - let blocks = (!self.blocks.is_empty()).then(|| { - let entry_size = - Entry::get_serialized_size(version, compression_slot, self.blocks.len() as u32); + let blocks = match &self.data { + PartialEntryData::Slice(_) => None, + PartialEntryData::Blocks(blocks) => { + let entry_size = + Entry::get_serialized_size(version, compression_slot, blocks.len() as u32); - let mut offset = entry_size; - if version.version_major() < VersionMajor::RelativeChunkOffsets { - offset += file_offset; - }; + let mut offset = entry_size; + if version.version_major() < VersionMajor::RelativeChunkOffsets { + offset += file_offset; + }; - self.blocks - .iter() - .map(|block| { - let start = offset; - offset += block.data.len() as u64; - let end = offset; - Block { start, end } - }) - .collect() - }); + Some( + blocks + .iter() + .map(|block| { + let start = offset; + offset += block.data.len() as u64; + let end = offset; + Block { start, end } + }) + .collect(), + ) + } + }; Ok(Entry { offset: file_offset, @@ -102,22 +113,38 @@ impl PartialEntry { compression_block_size: self.compression_block_size, }) } + pub(crate) fn write_data(&self, stream: &mut S) -> Result<()> { + match &self.data { + PartialEntryData::Slice(data) => { + stream.write_all(data.as_ref())?; + } + PartialEntryData::Blocks(blocks) => { + for block in blocks { + stream.write_all(&block.data)?; + } + } + } + Ok(()) + } } -pub(crate) fn build_partial_entry( +pub(crate) fn build_partial_entry( allowed_compression: &[Compression], - data: &[u8], -) -> Result { + data: D, +) -> Result> +where + D: AsRef<[u8]>, +{ // TODO hash needs to be post-compression/encryption use sha1::{Digest, Sha1}; let mut hasher = Sha1::new(); // TODO possibly select best compression based on some criteria instead of picking first let compression = allowed_compression.first().cloned(); - let uncompressed_size = data.len() as u64; + let uncompressed_size = data.as_ref().len() as u64; let compression_block_size; - let (blocks, compressed_size) = match compression { + let (data, compressed_size) = match compression { #[cfg(not(feature = "compression"))] Some(_) => { unreachable!("should not be able to reach this point without compression feature") @@ -129,7 +156,7 @@ pub(crate) fn build_partial_entry( compression_block_size = 0x10000; let mut compressed_size = 0; let mut blocks = vec![]; - for chunk in data.chunks(compression_block_size as usize) { + for chunk in data.as_ref().chunks(compression_block_size as usize) { let data = compress(compression, chunk)?; compressed_size += data.len() as u64; hasher.update(&data); @@ -139,12 +166,12 @@ pub(crate) fn build_partial_entry( }) } - (blocks, compressed_size) + (PartialEntryData::Blocks(blocks), compressed_size) } None => { compression_block_size = 0; - hasher.update(data); - (vec![], uncompressed_size) + hasher.update(data.as_ref()); + (PartialEntryData::Slice(data), uncompressed_size) } }; @@ -153,7 +180,7 @@ pub(crate) fn build_partial_entry( compressed_size, uncompressed_size, compression_block_size, - blocks, + data, hash: Hash(hasher.finalize().into()), }) } diff --git a/repak/src/entry.rs b/repak/src/entry.rs index fd7d7dd..1a12469 100644 --- a/repak/src/entry.rs +++ b/repak/src/entry.rs @@ -109,13 +109,7 @@ impl Entry { let stream_position = writer.stream_position()?; let entry = partial_entry.build_entry(version, compression_slots, stream_position)?; entry.write(writer, version, crate::entry::EntryLocation::Data)?; - if partial_entry.blocks.is_empty() { - writer.write_all(data)?; - } else { - for block in partial_entry.blocks { - writer.write_all(&block.data)?; - } - } + partial_entry.write_data(writer)?; Ok(entry) } diff --git a/repak/src/pak.rs b/repak/src/pak.rs index e14fba2..030e8a9 100644 --- a/repak/src/pak.rs +++ b/repak/src/pak.rs @@ -7,7 +7,6 @@ use super::{Version, VersionMajor}; use byteorder::{ReadBytesExt, WriteBytesExt, LE}; use std::collections::BTreeMap; use std::io::{self, Read, Seek, Write}; -use std::sync::Arc; #[derive(Default, Clone, Copy)] pub(crate) struct Hash(pub(crate) [u8; 20]); @@ -88,10 +87,6 @@ pub struct PakWriter { allowed_compression: Vec, } -pub struct ParallelPakWriter { - tx: std::sync::mpsc::SyncSender<(String, bool, Arc>)>, -} - #[derive(Debug)] pub(crate) struct Pak { version: Version, @@ -147,8 +142,8 @@ impl Index { self.entries } - fn add_entry(&mut self, path: &str, entry: super::entry::Entry) { - self.entries.insert(path.to_string(), entry); + fn add_entry(&mut self, path: String, entry: super::entry::Entry) { + self.entries.insert(path, entry); } } @@ -280,7 +275,7 @@ impl PakWriter { data: impl AsRef<[u8]>, ) -> Result<(), super::Error> { self.pak.index.add_entry( - path, + path.to_string(), Entry::write_file( &mut self.writer, self.pak.version, @@ -297,75 +292,56 @@ impl PakWriter { Ok(()) } - pub fn parallel(&mut self, f: F) -> Result<&mut Self, E> + pub fn parallel<'scope, F, E>(&mut self, f: F) -> Result<&mut Self, E> where - F: Send + Sync + FnOnce(&mut ParallelPakWriter) -> Result<(), E>, + F: Send + Sync + FnOnce(&mut ParallelPakWriter<'scope>) -> Result<(), E>, E: From + Send, { - { - use pariter::IteratorExt as _; + use pariter::IteratorExt as _; + let allowed_compression = self.allowed_compression.as_slice(); + pariter::scope(|scope: &pariter::Scope<'_>| -> Result<(), E> { let (tx, rx) = std::sync::mpsc::sync_channel(0); - pariter::scope(|scope| -> Result<(), E> { - let handle = scope.spawn(|_| -> Result<(), E> { - f(&mut ParallelPakWriter { tx })?; + let handle = scope.spawn(|_| f(&mut ParallelPakWriter { tx })); + + let result = rx + .into_iter() + .parallel_map_scoped(scope, |(path, compress, data)| -> Result<_, Error> { + let compression = compress.then_some(allowed_compression).unwrap_or_default(); + let partial_entry = build_partial_entry(compression, data)?; + Ok((path, partial_entry)) + }) + .try_for_each(|message| -> Result<(), Error> { + let stream_position = self.writer.stream_position()?; + let (path, partial_entry) = message?; + + let entry = partial_entry.build_entry( + self.pak.version, + &mut self.pak.compression, + stream_position, + )?; + + entry.write( + &mut self.writer, + self.pak.version, + crate::entry::EntryLocation::Data, + )?; + + self.pak.index.add_entry(path, entry); + partial_entry.write_data(&mut self.writer)?; Ok(()) }); - let result = rx - .into_iter() - .parallel_map_scoped( - scope, - |(path, allow_compress, data): (String, bool, Arc>)| -> Result<_, Error> { - let allowed_compression = if allow_compress { - self.allowed_compression.as_slice() - } else { - &[] - }; - let partial_entry = build_partial_entry(allowed_compression, &data)?; - let data = partial_entry.blocks.is_empty().then(|| Arc::new(data)); - Ok((path, data, partial_entry)) - }, - ) - .try_for_each(|message| -> Result<(), Error> { - let stream_position = self.writer.stream_position()?; - let (path, data, partial_entry) = message?; - - let entry = partial_entry.build_entry( - self.pak.version, - &mut self.pak.compression, - stream_position, - )?; - - entry.write( - &mut self.writer, - self.pak.version, - crate::entry::EntryLocation::Data, - )?; - - self.pak.index.add_entry(&path, entry); - - if let Some(data) = data { - self.writer.write_all(&data)?; - } else { - for block in partial_entry.blocks { - self.writer.write_all(&block.data)?; - } - } - Ok(()) - }); - - if let Err(err) = handle.join().unwrap() { - Err(err) // prioritize error from user code - } else if let Err(err) = result { - Err(err.into()) // user code was successful, check pak writer error - } else { - Ok(()) // neither returned error so return success - } - }) - .unwrap()?; - } + if let Err(err) = handle.join().unwrap() { + Err(err) // prioritize error from user code + } else if let Err(err) = result { + Err(err.into()) // user code was successful, check pak writer error + } else { + Ok(()) // neither returned error so return success + } + }) + .unwrap()?; Ok(self) } @@ -375,13 +351,30 @@ impl PakWriter { } } -impl ParallelPakWriter { - pub fn write_file(&self, path: String, compress: bool, data: Vec) -> Result<(), Error> { - self.tx.send((path, compress, Arc::new(data))).unwrap(); +pub struct ParallelPakWriter<'scope> { + tx: std::sync::mpsc::SyncSender<(String, bool, Data<'scope>)>, +} +impl<'scope> ParallelPakWriter<'scope> { + pub fn write_file + Send + Sync + 'scope>( + &self, + path: String, + compress: bool, + data: D, + ) -> Result<(), Error> { + self.tx + .send((path, compress, Data(Box::new(data)))) + .unwrap(); Ok(()) } } +struct Data<'d>(Box + Send + Sync + 'd>); +impl AsRef<[u8]> for Data<'_> { + fn as_ref(&self) -> &[u8] { + self.0.as_ref().as_ref() + } +} + impl Pak { fn read( reader: &mut R, diff --git a/repak/tests/test.rs b/repak/tests/test.rs index e951c02..e817fdd 100644 --- a/repak/tests/test.rs +++ b/repak/tests/test.rs @@ -88,6 +88,33 @@ mod test { } } +#[test] +fn test_parallel_writer() -> Result<(), repak::Error> { + let mut cur = Cursor::new(vec![]); + let mut writer = repak::PakBuilder::new().writer( + &mut cur, + repak::Version::V11, + "../../../".to_string(), + Some(0x12345678), + ); + + let outside_scope1 = vec![1, 2, 3]; + let outside_scope2 = vec![4, 5, 6]; + + writer.parallel(|writer| -> Result<(), repak::Error> { + let inside_scope = vec![7, 8, 9]; + + writer.write_file("pass/takes/ownership".to_string(), true, outside_scope1)?; + writer.write_file("pass/outlives/scope".to_string(), true, &outside_scope2)?; + + writer.write_file("pass/takes/ownership".to_string(), true, inside_scope)?; + // writer.write_file("fail/doesnt/outlive/scope".to_string(), true, &inside_scope)?; + Ok(()) + })?; + + Ok(()) +} + static AES_KEY: &str = "lNJbw660IOC+kU7cnVQ1oeqrXyhk4J6UAZrCBbcnp94="; fn test_read(version: repak::Version, _file_name: &str, bytes: &[u8]) {