Do not require moving data into parallel writer

This commit is contained in:
Truman Kilen 2025-01-20 17:39:58 -06:00
parent bdeb0df8c7
commit 194e800270
4 changed files with 150 additions and 109 deletions

View file

@ -1,3 +1,5 @@
use std::io::Write;
use crate::{ use crate::{
entry::{Block, Entry}, entry::{Block, Entry},
Compression, Error, Hash, Version, VersionMajor, Compression, Error, Hash, Version, VersionMajor,
@ -5,17 +7,21 @@ use crate::{
type Result<T, E = Error> = std::result::Result<T, E>; type Result<T, E = Error> = std::result::Result<T, E>;
pub(crate) struct PartialEntry { pub(crate) struct PartialEntry<D: AsRef<[u8]>> {
compression: Option<Compression>, compression: Option<Compression>,
compressed_size: u64, compressed_size: u64,
uncompressed_size: u64, uncompressed_size: u64,
compression_block_size: u32, compression_block_size: u32,
pub(crate) blocks: Vec<PartialBlock>, data: PartialEntryData<D>,
hash: Hash, hash: Hash,
} }
pub(crate) struct PartialBlock { pub(crate) struct PartialBlock {
uncompressed_size: usize, uncompressed_size: usize,
pub(crate) data: Vec<u8>, data: Vec<u8>,
}
pub(crate) enum PartialEntryData<D> {
Slice(D),
Blocks(Vec<PartialBlock>),
} }
#[cfg(feature = "compression")] #[cfg(feature = "compression")]
@ -55,7 +61,7 @@ fn get_compression_slot(
} as u32) } as u32)
} }
impl PartialEntry { impl<D: AsRef<[u8]>> PartialEntry<D> {
pub(crate) fn build_entry( pub(crate) fn build_entry(
&self, &self,
version: Version, version: Version,
@ -70,25 +76,30 @@ impl PartialEntry {
#[cfg(not(feature = "compression"))] #[cfg(not(feature = "compression"))]
let compression_slot = None; let compression_slot = None;
let blocks = (!self.blocks.is_empty()).then(|| { let blocks = match &self.data {
let entry_size = PartialEntryData::Slice(_) => None,
Entry::get_serialized_size(version, compression_slot, self.blocks.len() as u32); PartialEntryData::Blocks(blocks) => {
let entry_size =
Entry::get_serialized_size(version, compression_slot, blocks.len() as u32);
let mut offset = entry_size; let mut offset = entry_size;
if version.version_major() < VersionMajor::RelativeChunkOffsets { if version.version_major() < VersionMajor::RelativeChunkOffsets {
offset += file_offset; offset += file_offset;
}; };
self.blocks Some(
.iter() blocks
.map(|block| { .iter()
let start = offset; .map(|block| {
offset += block.data.len() as u64; let start = offset;
let end = offset; offset += block.data.len() as u64;
Block { start, end } let end = offset;
}) Block { start, end }
.collect() })
}); .collect(),
)
}
};
Ok(Entry { Ok(Entry {
offset: file_offset, offset: file_offset,
@ -102,22 +113,38 @@ impl PartialEntry {
compression_block_size: self.compression_block_size, compression_block_size: self.compression_block_size,
}) })
} }
pub(crate) fn write_data<S: Write>(&self, stream: &mut S) -> Result<()> {
match &self.data {
PartialEntryData::Slice(data) => {
stream.write_all(data.as_ref())?;
}
PartialEntryData::Blocks(blocks) => {
for block in blocks {
stream.write_all(&block.data)?;
}
}
}
Ok(())
}
} }
pub(crate) fn build_partial_entry( pub(crate) fn build_partial_entry<D>(
allowed_compression: &[Compression], allowed_compression: &[Compression],
data: &[u8], data: D,
) -> Result<PartialEntry> { ) -> Result<PartialEntry<D>>
where
D: AsRef<[u8]>,
{
// TODO hash needs to be post-compression/encryption // TODO hash needs to be post-compression/encryption
use sha1::{Digest, Sha1}; use sha1::{Digest, Sha1};
let mut hasher = Sha1::new(); let mut hasher = Sha1::new();
// TODO possibly select best compression based on some criteria instead of picking first // TODO possibly select best compression based on some criteria instead of picking first
let compression = allowed_compression.first().cloned(); let compression = allowed_compression.first().cloned();
let uncompressed_size = data.len() as u64; let uncompressed_size = data.as_ref().len() as u64;
let compression_block_size; let compression_block_size;
let (blocks, compressed_size) = match compression { let (data, compressed_size) = match compression {
#[cfg(not(feature = "compression"))] #[cfg(not(feature = "compression"))]
Some(_) => { Some(_) => {
unreachable!("should not be able to reach this point without compression feature") unreachable!("should not be able to reach this point without compression feature")
@ -129,7 +156,7 @@ pub(crate) fn build_partial_entry(
compression_block_size = 0x10000; compression_block_size = 0x10000;
let mut compressed_size = 0; let mut compressed_size = 0;
let mut blocks = vec![]; let mut blocks = vec![];
for chunk in data.chunks(compression_block_size as usize) { for chunk in data.as_ref().chunks(compression_block_size as usize) {
let data = compress(compression, chunk)?; let data = compress(compression, chunk)?;
compressed_size += data.len() as u64; compressed_size += data.len() as u64;
hasher.update(&data); hasher.update(&data);
@ -139,12 +166,12 @@ pub(crate) fn build_partial_entry(
}) })
} }
(blocks, compressed_size) (PartialEntryData::Blocks(blocks), compressed_size)
} }
None => { None => {
compression_block_size = 0; compression_block_size = 0;
hasher.update(data); hasher.update(data.as_ref());
(vec![], uncompressed_size) (PartialEntryData::Slice(data), uncompressed_size)
} }
}; };
@ -153,7 +180,7 @@ pub(crate) fn build_partial_entry(
compressed_size, compressed_size,
uncompressed_size, uncompressed_size,
compression_block_size, compression_block_size,
blocks, data,
hash: Hash(hasher.finalize().into()), hash: Hash(hasher.finalize().into()),
}) })
} }

View file

@ -109,13 +109,7 @@ impl Entry {
let stream_position = writer.stream_position()?; let stream_position = writer.stream_position()?;
let entry = partial_entry.build_entry(version, compression_slots, stream_position)?; let entry = partial_entry.build_entry(version, compression_slots, stream_position)?;
entry.write(writer, version, crate::entry::EntryLocation::Data)?; entry.write(writer, version, crate::entry::EntryLocation::Data)?;
if partial_entry.blocks.is_empty() { partial_entry.write_data(writer)?;
writer.write_all(data)?;
} else {
for block in partial_entry.blocks {
writer.write_all(&block.data)?;
}
}
Ok(entry) Ok(entry)
} }

View file

@ -7,7 +7,6 @@ use super::{Version, VersionMajor};
use byteorder::{ReadBytesExt, WriteBytesExt, LE}; use byteorder::{ReadBytesExt, WriteBytesExt, LE};
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::io::{self, Read, Seek, Write}; use std::io::{self, Read, Seek, Write};
use std::sync::Arc;
#[derive(Default, Clone, Copy)] #[derive(Default, Clone, Copy)]
pub(crate) struct Hash(pub(crate) [u8; 20]); pub(crate) struct Hash(pub(crate) [u8; 20]);
@ -88,10 +87,6 @@ pub struct PakWriter<W: Write + Seek> {
allowed_compression: Vec<Compression>, allowed_compression: Vec<Compression>,
} }
pub struct ParallelPakWriter {
tx: std::sync::mpsc::SyncSender<(String, bool, Arc<Vec<u8>>)>,
}
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct Pak { pub(crate) struct Pak {
version: Version, version: Version,
@ -147,8 +142,8 @@ impl Index {
self.entries self.entries
} }
fn add_entry(&mut self, path: &str, entry: super::entry::Entry) { fn add_entry(&mut self, path: String, entry: super::entry::Entry) {
self.entries.insert(path.to_string(), entry); self.entries.insert(path, entry);
} }
} }
@ -280,7 +275,7 @@ impl<W: Write + Seek> PakWriter<W> {
data: impl AsRef<[u8]>, data: impl AsRef<[u8]>,
) -> Result<(), super::Error> { ) -> Result<(), super::Error> {
self.pak.index.add_entry( self.pak.index.add_entry(
path, path.to_string(),
Entry::write_file( Entry::write_file(
&mut self.writer, &mut self.writer,
self.pak.version, self.pak.version,
@ -297,75 +292,56 @@ impl<W: Write + Seek> PakWriter<W> {
Ok(()) Ok(())
} }
pub fn parallel<F, E>(&mut self, f: F) -> Result<&mut Self, E> pub fn parallel<'scope, F, E>(&mut self, f: F) -> Result<&mut Self, E>
where where
F: Send + Sync + FnOnce(&mut ParallelPakWriter) -> Result<(), E>, F: Send + Sync + FnOnce(&mut ParallelPakWriter<'scope>) -> Result<(), E>,
E: From<Error> + Send, E: From<Error> + Send,
{ {
{ use pariter::IteratorExt as _;
use pariter::IteratorExt as _;
let allowed_compression = self.allowed_compression.as_slice();
pariter::scope(|scope: &pariter::Scope<'_>| -> Result<(), E> {
let (tx, rx) = std::sync::mpsc::sync_channel(0); let (tx, rx) = std::sync::mpsc::sync_channel(0);
pariter::scope(|scope| -> Result<(), E> { let handle = scope.spawn(|_| f(&mut ParallelPakWriter { tx }));
let handle = scope.spawn(|_| -> Result<(), E> {
f(&mut ParallelPakWriter { tx })?; let result = rx
.into_iter()
.parallel_map_scoped(scope, |(path, compress, data)| -> Result<_, Error> {
let compression = compress.then_some(allowed_compression).unwrap_or_default();
let partial_entry = build_partial_entry(compression, data)?;
Ok((path, partial_entry))
})
.try_for_each(|message| -> Result<(), Error> {
let stream_position = self.writer.stream_position()?;
let (path, partial_entry) = message?;
let entry = partial_entry.build_entry(
self.pak.version,
&mut self.pak.compression,
stream_position,
)?;
entry.write(
&mut self.writer,
self.pak.version,
crate::entry::EntryLocation::Data,
)?;
self.pak.index.add_entry(path, entry);
partial_entry.write_data(&mut self.writer)?;
Ok(()) Ok(())
}); });
let result = rx if let Err(err) = handle.join().unwrap() {
.into_iter() Err(err) // prioritize error from user code
.parallel_map_scoped( } else if let Err(err) = result {
scope, Err(err.into()) // user code was successful, check pak writer error
|(path, allow_compress, data): (String, bool, Arc<Vec<u8>>)| -> Result<_, Error> { } else {
let allowed_compression = if allow_compress { Ok(()) // neither returned error so return success
self.allowed_compression.as_slice() }
} else { })
&[] .unwrap()?;
};
let partial_entry = build_partial_entry(allowed_compression, &data)?;
let data = partial_entry.blocks.is_empty().then(|| Arc::new(data));
Ok((path, data, partial_entry))
},
)
.try_for_each(|message| -> Result<(), Error> {
let stream_position = self.writer.stream_position()?;
let (path, data, partial_entry) = message?;
let entry = partial_entry.build_entry(
self.pak.version,
&mut self.pak.compression,
stream_position,
)?;
entry.write(
&mut self.writer,
self.pak.version,
crate::entry::EntryLocation::Data,
)?;
self.pak.index.add_entry(&path, entry);
if let Some(data) = data {
self.writer.write_all(&data)?;
} else {
for block in partial_entry.blocks {
self.writer.write_all(&block.data)?;
}
}
Ok(())
});
if let Err(err) = handle.join().unwrap() {
Err(err) // prioritize error from user code
} else if let Err(err) = result {
Err(err.into()) // user code was successful, check pak writer error
} else {
Ok(()) // neither returned error so return success
}
})
.unwrap()?;
}
Ok(self) Ok(self)
} }
@ -375,13 +351,30 @@ impl<W: Write + Seek> PakWriter<W> {
} }
} }
impl ParallelPakWriter { pub struct ParallelPakWriter<'scope> {
pub fn write_file(&self, path: String, compress: bool, data: Vec<u8>) -> Result<(), Error> { tx: std::sync::mpsc::SyncSender<(String, bool, Data<'scope>)>,
self.tx.send((path, compress, Arc::new(data))).unwrap(); }
impl<'scope> ParallelPakWriter<'scope> {
pub fn write_file<D: AsRef<[u8]> + Send + Sync + 'scope>(
&self,
path: String,
compress: bool,
data: D,
) -> Result<(), Error> {
self.tx
.send((path, compress, Data(Box::new(data))))
.unwrap();
Ok(()) Ok(())
} }
} }
struct Data<'d>(Box<dyn AsRef<[u8]> + Send + Sync + 'd>);
impl AsRef<[u8]> for Data<'_> {
fn as_ref(&self) -> &[u8] {
self.0.as_ref().as_ref()
}
}
impl Pak { impl Pak {
fn read<R: Read + Seek>( fn read<R: Read + Seek>(
reader: &mut R, reader: &mut R,

View file

@ -88,6 +88,33 @@ mod test {
} }
} }
#[test]
fn test_parallel_writer() -> Result<(), repak::Error> {
let mut cur = Cursor::new(vec![]);
let mut writer = repak::PakBuilder::new().writer(
&mut cur,
repak::Version::V11,
"../../../".to_string(),
Some(0x12345678),
);
let outside_scope1 = vec![1, 2, 3];
let outside_scope2 = vec![4, 5, 6];
writer.parallel(|writer| -> Result<(), repak::Error> {
let inside_scope = vec![7, 8, 9];
writer.write_file("pass/takes/ownership".to_string(), true, outside_scope1)?;
writer.write_file("pass/outlives/scope".to_string(), true, &outside_scope2)?;
writer.write_file("pass/takes/ownership".to_string(), true, inside_scope)?;
// writer.write_file("fail/doesnt/outlive/scope".to_string(), true, &inside_scope)?;
Ok(())
})?;
Ok(())
}
static AES_KEY: &str = "lNJbw660IOC+kU7cnVQ1oeqrXyhk4J6UAZrCBbcnp94="; static AES_KEY: &str = "lNJbw660IOC+kU7cnVQ1oeqrXyhk4J6UAZrCBbcnp94=";
fn test_read(version: repak::Version, _file_name: &str, bytes: &[u8]) { fn test_read(version: repak::Version, _file_name: &str, bytes: &[u8]) {