rust/hwphysics/src/data.rs
author unC0Rr
Thu, 16 May 2024 14:11:16 +0200
changeset 16014 e66feb918d71
parent 15945 343b8819b051
permissions -rw-r--r--
Improve AI aiming with minigun

use super::common::GearId;
use std::{
    any::TypeId,
    fmt::{Debug, Error, Formatter},
    marker::PhantomData,
    mem::{align_of, size_of, MaybeUninit},
    num::NonZeroU16,
    ptr::{copy_nonoverlapping, null_mut, NonNull},
    slice,
};

const MAX_TYPES: usize = 8;

pub trait TypeTuple: Sized {
    fn get_types(_types: &mut [TypeId; MAX_TYPES]) -> usize;
}

impl TypeTuple for () {
    fn get_types(_types: &mut [TypeId; MAX_TYPES]) -> usize {
        0
    }
}

impl<T: 'static> TypeTuple for &T {
    fn get_types(types: &mut [TypeId; MAX_TYPES]) -> usize {
        if MAX_TYPES > 0 {
            unsafe {
                *types.get_unchecked_mut(0) = TypeId::of::<T>();
            }
            1
        } else {
            0
        }
    }
}

pub trait TypeIter: TypeTuple {
    unsafe fn iter<F: FnMut(GearId, Self)>(slices: &[*mut u8], count: usize, f: F);
}

macro_rules! type_tuple_impl {
    ($($n: literal: $t: ident),+) => {
        impl<$($t: 'static),+> TypeTuple for ($(&$t),+,) {
            fn get_types(types: &mut [TypeId; MAX_TYPES]) -> usize {
                let mut count = 0;
                $({
                    if MAX_TYPES > $n {
                        unsafe {
                            *types.get_unchecked_mut($n) = TypeId::of::<$t>();
                        }
                        count = $n + 1;
                    }
                });+
                count
            }
        }

        impl<$($t: 'static),+> TypeIter for ($(&$t),+,) {
            unsafe fn iter<FI: FnMut(GearId, Self)>(slices: &[*mut u8], count: usize, mut f: FI) {
                for i in 0..count {
                    f(*(*slices.get_unchecked(0) as *const GearId).add(i),
                      ($(&*(*slices.get_unchecked($n + 1) as *mut $t).add(i)),+,));
                }
            }
        }

        impl<$($t: 'static),+> TypeTuple for ($(&mut $t),+,) {
            fn get_types(types: &mut [TypeId; MAX_TYPES]) -> usize {
                let mut count = 0;
                $({
                    if MAX_TYPES > $n {
                        unsafe {
                            *types.get_unchecked_mut($n) = TypeId::of::<$t>();
                        }
                        count = $n + 1;
                    }
                });+
                count
            }
        }

        impl<$($t: 'static),+> TypeIter for ($(&mut $t),+,) {
            unsafe fn iter<FI: FnMut(GearId, Self)>(slices: &[*mut u8], count: usize, mut f: FI) {
                for i in 0..count {
                    f(*(*slices.get_unchecked(0) as *const GearId).add(i),
                      ($(&mut *(*slices.get_unchecked($n + 1) as *mut $t).add(i)),+,));
                }
            }
        }
    }
}

type_tuple_impl!(0: A);
type_tuple_impl!(0: A, 1: B);
type_tuple_impl!(0: A, 1: B, 2: C);
type_tuple_impl!(0: A, 1: B, 2: C, 3: D);
type_tuple_impl!(0: A, 1: B, 2: C, 3: D, 4: E);
type_tuple_impl!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F);
type_tuple_impl!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G);
type_tuple_impl!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H);

const BLOCK_SIZE: usize = 32768;

struct DataBlock {
    max_elements: u16,
    elements_count: u16,
    data: Box<[u8; BLOCK_SIZE]>,
    component_blocks: [Option<NonNull<u8>>; 64],
    element_sizes: Box<[u16]>,
}

impl Unpin for DataBlock {}

impl Debug for DataBlock {
    fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
        write!(
            f,
            "Block ({}/{}) {{\n",
            self.elements_count, self.max_elements
        )?;
        write!(f, "\tIDs: [")?;
        let id_slice = unsafe {
            slice::from_raw_parts(
                self.data.as_ptr() as *const GearId,
                self.elements_count as usize,
            )
        };
        for gear_id in id_slice {
            write!(f, "{}, ", gear_id)?;
        }
        write!(f, "]\n")?;
        for type_index in 0..self.element_sizes.len() {
            if let Some(ptr) = self.component_blocks[type_index] {
                write!(f, "\tC{}: [", type_index)?;
                let slice = unsafe {
                    slice::from_raw_parts(
                        ptr.as_ptr(),
                        (self.elements_count * self.element_sizes[type_index]) as usize,
                    )
                };
                for byte in slice {
                    write!(f, "{}, ", byte)?;
                }
                write!(f, "]\n")?;
            }
        }
        write!(f, "}}\n")
    }
}

impl DataBlock {
    fn new(mask: u64, element_sizes: &[u16], element_alignments: &[u8]) -> Self {
        let total_padding: usize = element_alignments.iter().map(|x| *x as usize).sum();
        let total_size: u16 = element_sizes
            .iter()
            .enumerate()
            .filter(|(i, _)| mask & (1 << *i as u64) != 0)
            .map(|(_, size)| *size)
            .sum();
        let max_elements =
            ((BLOCK_SIZE - total_padding) / (total_size as usize + size_of::<GearId>())) as u16;

        //ensure the block memory is aligned to GearId
        let tmp_data: Box<[GearId; BLOCK_SIZE / size_of::<GearId>()]> =
            Box::new(unsafe { MaybeUninit::uninit().assume_init() });
        let mut data: Box<[u8; BLOCK_SIZE]> =
            unsafe { Box::from_raw(Box::into_raw(tmp_data) as *mut [u8; BLOCK_SIZE]) };

        let mut blocks = [None; 64];
        let mut address = unsafe {
            data.as_mut_ptr()
                .add(size_of::<GearId>() * max_elements as usize)
        };

        let mut mask_bits = mask;
        while mask_bits != 0 {
            let i = mask_bits.trailing_zeros() as usize;

            unsafe {
                address = address.add(address.align_offset(element_alignments[i] as usize));
                blocks[i] = Some(NonNull::new_unchecked(address));
                address = address.add(element_sizes[i] as usize * max_elements as usize)
            };

            mask_bits &= mask_bits - 1;
        }

        Self {
            elements_count: 0,
            max_elements,
            data,
            component_blocks: blocks,
            element_sizes: Box::from(element_sizes),
        }
    }

    #[inline]
    fn gear_ids(&self) -> &[GearId] {
        unsafe {
            slice::from_raw_parts(
                self.data.as_ptr() as *const GearId,
                self.max_elements as usize,
            )
        }
    }

    #[inline]
    fn gear_ids_mut(&mut self) -> &mut [GearId] {
        unsafe {
            slice::from_raw_parts_mut(
                self.data.as_mut_ptr() as *mut GearId,
                self.max_elements as usize,
            )
        }
    }

    #[inline]
    fn is_full(&self) -> bool {
        self.elements_count == self.max_elements
    }
}

#[derive(Clone, Copy, Debug, Default)]
struct LookupEntry {
    index: Option<NonZeroU16>,
    block_index: u16,
}

impl LookupEntry {
    fn new(block_index: u16, index: u16) -> Self {
        Self {
            index: unsafe { Some(NonZeroU16::new_unchecked(index + 1)) },
            block_index,
        }
    }
}

#[derive(Copy, Clone, Eq, PartialEq, Debug)]
struct BlockMask {
    type_mask: u64,
    tag_mask: u64,
}

impl BlockMask {
    #[inline]
    fn new(type_mask: u64, tag_mask: u64) -> Self {
        Self {
            type_mask,
            tag_mask,
        }
    }

    #[inline]
    fn with_type(&self, type_bit: u64) -> Self {
        Self::new(self.type_mask | type_bit, self.tag_mask)
    }

    #[inline]
    fn without_type(&self, type_bit: u64) -> Self {
        Self::new(self.type_mask & !type_bit, self.tag_mask)
    }

    #[inline]
    fn with_tag(&self, tag_bit: u64) -> Self {
        Self::new(self.type_mask, self.tag_mask | tag_bit)
    }
}

pub struct GearDataManager {
    types: Vec<TypeId>,
    tags: Vec<TypeId>,
    blocks: Vec<DataBlock>,
    block_masks: Vec<BlockMask>,
    element_sizes: Box<[u16; 64]>,
    element_alignments: Box<[u8; 64]>,
    lookup: Box<[LookupEntry]>,
}

impl GearDataManager {
    pub fn new() -> Self {
        Self {
            types: Vec::with_capacity(64),
            tags: Vec::with_capacity(64),
            blocks: vec![],
            block_masks: vec![],
            element_sizes: Box::new([0; 64]),
            element_alignments: Box::new([0; 64]),
            lookup: vec![LookupEntry::default(); u16::MAX as usize].into_boxed_slice(),
        }
    }

    #[inline]
    fn get_type_index<T: 'static>(&self) -> Option<usize> {
        let type_id = TypeId::of::<T>();
        self.types.iter().position(|id| *id == type_id)
    }

    #[inline]
    fn get_tag_index<T: 'static>(&self) -> Option<usize> {
        let type_id = TypeId::of::<T>();
        self.tags.iter().position(|id| *id == type_id)
    }

    fn move_between_blocks(
        &mut self,
        src_block_index: u16,
        src_index: u16,
        dest_block_index: u16,
    ) -> u16 {
        debug_assert!(src_block_index != dest_block_index);
        let src_mask = self.block_masks[src_block_index as usize];
        let dest_mask = self.block_masks[dest_block_index as usize];
        debug_assert!(src_mask.type_mask & dest_mask.type_mask != 0);

        let src_block = &self.blocks[src_block_index as usize];
        let dest_block = &self.blocks[dest_block_index as usize];
        debug_assert!(src_index < src_block.elements_count);
        debug_assert!(!dest_block.is_full());

        let dest_index = dest_block.elements_count;

        let mut type_mask = src_mask.type_mask;
        while type_mask != 0 {
            let i = type_mask.trailing_zeros() as usize;

            let size = self.element_sizes[i];
            let src_ptr = src_block.component_blocks[i].unwrap().as_ptr();
            if let Some(dest_ptr) = dest_block.component_blocks[i] {
                let dest_ptr = dest_ptr.as_ptr();
                unsafe {
                    copy_nonoverlapping(
                        src_ptr.add((src_index * size) as usize),
                        dest_ptr.add((dest_index * size) as usize),
                        size as usize,
                    );
                }
            }
            unsafe {
                if src_index < src_block.elements_count - 1 {
                    copy_nonoverlapping(
                        src_ptr.add((size * (src_block.elements_count - 1)) as usize),
                        src_ptr.add((size * src_index) as usize),
                        size as usize,
                    );
                }
            }

            type_mask &= type_mask - 1;
        }

        let src_block = &mut self.blocks[src_block_index as usize];
        let gear_id = src_block.gear_ids()[src_index as usize];

        if src_index + 1 < src_block.elements_count {
            let relocated_index = src_block.elements_count as usize - 1;
            let gear_ids = src_block.gear_ids_mut();
            let relocated_id = gear_ids[relocated_index];

            gear_ids[src_index as usize] = relocated_id;
            self.lookup[relocated_id.get() as usize - 1] =
                LookupEntry::new(src_block_index, src_index);
        }
        src_block.elements_count -= 1;

        let dest_block = &mut self.blocks[dest_block_index as usize];
        let dest_index = dest_block.elements_count;

        dest_block.gear_ids_mut()[dest_index as usize] = gear_id;
        self.lookup[gear_id.get() as usize - 1] = LookupEntry::new(dest_block_index, dest_index);
        dest_block.elements_count += 1;
        dest_block.elements_count - 1
    }

    fn add_to_block<T: Clone>(&mut self, gear_id: GearId, block_index: u16, value: &T) {
        debug_assert!(
            self.block_masks[block_index as usize]
                .type_mask
                .count_ones()
                == 1
        );

        let block = &mut self.blocks[block_index as usize];
        debug_assert!(block.elements_count < block.max_elements);

        unsafe {
            *(block.component_blocks[0].unwrap().as_ptr() as *mut T)
                .add(block.elements_count as usize) = value.clone();
        };

        let index = block.elements_count;
        self.lookup[gear_id.get() as usize - 1] = LookupEntry::new(block_index, index);
        block.gear_ids_mut()[index as usize] = gear_id;
        block.elements_count += 1;
    }

    fn remove_from_block(&mut self, block_index: u16, index: u16) {
        let block = &mut self.blocks[block_index as usize];
        debug_assert!(index < block.elements_count);

        for (i, size) in self.element_sizes.iter().cloned().enumerate() {
            if index < block.elements_count - 1 {
                if let Some(ptr) = block.component_blocks[i] {
                    unsafe {
                        copy_nonoverlapping(
                            ptr.as_ptr()
                                .add((size * (block.elements_count - 1)) as usize),
                            ptr.as_ptr().add((size * index) as usize),
                            size as usize,
                        );
                    }
                }
            }
        }

        self.lookup[block.gear_ids()[index as usize].get() as usize - 1] = LookupEntry::default();
        if index < block.elements_count - 1 {
            let relocated_index = block.elements_count as usize - 1;
            let gear_ids = block.gear_ids_mut();

            gear_ids[index as usize] = gear_ids[relocated_index];
            self.lookup[gear_ids[relocated_index].get() as usize - 1] =
                LookupEntry::new(block_index, index);
        }
        block.elements_count -= 1;
    }

    fn write_component<T: Clone>(
        &mut self,
        block_index: u16,
        index: u16,
        type_index: usize,
        value: &T,
    ) {
        debug_assert!(type_index < self.types.len());
        let block = &mut self.blocks[block_index as usize];
        debug_assert!(index < block.elements_count);

        unsafe {
            *(block.component_blocks[type_index].unwrap().as_ptr() as *mut T).add(index as usize) =
                value.clone();
        };
    }

    #[inline]
    fn ensure_block(&mut self, mask: BlockMask) -> u16 {
        if let Some(index) = self
            .block_masks
            .iter()
            .enumerate()
            .position(|(i, m)| *m == mask && !self.blocks[i].is_full())
        {
            index as u16
        } else {
            self.blocks.push(DataBlock::new(
                mask.type_mask,
                &self.element_sizes[0..self.types.len()],
                &self.element_alignments[0..self.types.len()],
            ));
            self.block_masks.push(mask);
            (self.blocks.len() - 1) as u16
        }
    }

    pub fn add<T: Clone + 'static>(&mut self, gear_id: GearId, value: &T) {
        if let Some(type_index) = self.get_type_index::<T>() {
            let type_bit = 1 << type_index as u64;
            let entry = self.lookup[gear_id.get() as usize - 1];

            if let Some(index) = entry.index {
                let mask = self.block_masks[entry.block_index as usize];
                let new_mask = mask.with_type(type_bit);

                if new_mask != mask {
                    let dest_block_index = self.ensure_block(new_mask);
                    let dest_index = self.move_between_blocks(
                        entry.block_index,
                        index.get() - 1,
                        dest_block_index,
                    );
                    self.write_component(dest_block_index, dest_index, type_index, value);
                }
            } else {
                let dest_block_index = self.ensure_block(BlockMask::new(type_bit, 0));
                self.add_to_block(gear_id, dest_block_index, value);
            }
        } else {
            panic!("Unregistered type")
        }
    }

    pub fn add_tag<T: 'static>(&mut self, gear_id: GearId) {
        if let Some(tag_index) = self.get_tag_index::<T>() {
            let tag_bit = 1 << tag_index as u64;
            let entry = self.lookup[gear_id.get() as usize - 1];

            if let Some(index) = entry.index {
                let mask = self.block_masks[entry.block_index as usize];
                let new_mask = mask.with_tag(tag_bit);

                if new_mask != mask {
                    let dest_block_index = self.ensure_block(new_mask);
                    self.move_between_blocks(entry.block_index, index.get() - 1, dest_block_index);
                }
            } else {
                panic!("Cannot tag a gear with no data")
            }
        } else {
            panic!("Unregistered tag")
        }
    }

    pub fn remove<T: 'static>(&mut self, gear_id: GearId) {
        if let Some(type_index) = self.get_type_index::<T>() {
            let type_bit = 1 << type_index as u64;
            let entry = self.lookup[gear_id.get() as usize - 1];

            if let Some(index) = entry.index {
                let mask = self.block_masks[entry.block_index as usize];
                let new_mask = mask.without_type(type_bit);

                if new_mask != mask {
                    if new_mask.type_mask == 0 {
                        self.remove_from_block(entry.block_index, index.get() - 1);
                    } else {
                        let dest_block_index = self.ensure_block(new_mask);
                        self.move_between_blocks(
                            entry.block_index,
                            index.get() - 1,
                            dest_block_index,
                        );
                    }
                }
            }
        } else {
            panic!("Unregistered type")
        }
    }

    pub fn remove_all(&mut self, gear_id: GearId) {
        let entry = self.lookup[gear_id.get() as usize - 1];
        if let Some(index) = entry.index {
            self.remove_from_block(entry.block_index, index.get() - 1);
        }
    }

    pub fn register<T: 'static>(&mut self) {
        debug_assert!(!std::mem::needs_drop::<T>());
        debug_assert!(size_of::<T>() <= u16::MAX as usize);

        let id = TypeId::of::<T>();
        if size_of::<T>() == 0 {
            if !self.tags.contains(&id) {
                debug_assert!(self.tags.len() <= 64);
                self.tags.push(id)
            }
        } else {
            if !self.types.contains(&id) {
                debug_assert!(self.types.len() <= 64);
                self.element_sizes[self.types.len()] = size_of::<T>() as u16;
                self.element_alignments[self.types.len()] = align_of::<T>() as u8;
                self.types.push(id);
            }
        }
    }

    fn run_impl<T: TypeIter + 'static, F: FnMut(GearId, T)>(
        &mut self,
        type_selector: u64,
        included_tags: u64,
        type_indices: &[i8],
        mut f: F,
    ) {
        let mut slices = [null_mut(); MAX_TYPES + 1];

        for (block_index, mask) in self.block_masks.iter().enumerate() {
            if mask.type_mask & type_selector == type_selector
                && mask.tag_mask & included_tags == included_tags
            {
                let block = &mut self.blocks[block_index];
                slices[0] = block.data.as_mut_ptr();

                for (arg_index, type_index) in type_indices.iter().cloned().enumerate() {
                    slices[arg_index as usize + 1] = block.component_blocks[type_index as usize]
                        .unwrap()
                        .as_ptr()
                }

                unsafe {
                    T::iter(
                        &slices[0..=type_indices.len()],
                        block.elements_count as usize,
                        |id, x| f(id, x),
                    );
                }
            }
        }
    }

    pub fn get<T: 'static>(&self, gear_id: GearId) -> Option<&T> {
        let entry = self.lookup[gear_id.get() as usize - 1];
        match (entry.index, self.get_type_index::<T>()) {
            (Some(index), Some(type_index)) => {
                let block = &self.blocks[entry.block_index as usize];
                block.component_blocks[type_index].map(|ptr| unsafe {
                    &*(ptr.as_ptr() as *const T).add(index.get() as usize - 1)
                })
            }
            _ => None,
        }
    }

    pub fn iter<T: TypeIter + 'static>(&mut self) -> DataIterator<T> {
        let mut arg_types: [TypeId; MAX_TYPES] = unsafe { MaybeUninit::uninit().assume_init() };
        let types_count = T::get_types(&mut arg_types);
        let mut type_indices = [-1; MAX_TYPES];
        let mut selector = 0u64;

        for (arg_index, type_id) in arg_types[0..types_count].iter().enumerate() {
            match self.types.iter().position(|t| t == type_id) {
                Some(i) if selector & (1 << i as u64) != 0 => panic!("Duplicate type"),
                Some(i) => {
                    type_indices[arg_index] = i as i8;
                    selector |= 1 << i as u64;
                }
                None => panic!("Unregistered type"),
            }
        }
        DataIterator::new(self, selector, type_indices)
    }
}

pub struct DataIterator<'a, T> {
    data: &'a mut GearDataManager,
    types: u64,
    type_indices: [i8; MAX_TYPES],
    tags: u64,
    phantom_types: PhantomData<T>,
}

impl<'a, T: TypeIter + 'static> DataIterator<'a, T> {
    fn new(
        data: &'a mut GearDataManager,
        types: u64,
        type_indices: [i8; MAX_TYPES],
    ) -> DataIterator<'a, T> {
        Self {
            data,
            types,
            type_indices,
            tags: 0,
            phantom_types: PhantomData,
        }
    }

    pub fn with_tags<U: TypeTuple + 'static>(self) -> Self {
        let mut tag_types: [TypeId; MAX_TYPES] = unsafe { MaybeUninit::uninit().assume_init() };
        let tags_count = U::get_types(&mut tag_types);
        let mut tags = 0;

        for (i, tag) in self.data.tags.iter().enumerate() {
            if tag_types[0..tags_count].contains(tag) {
                tags |= 1 << i as u64;
            }
        }
        Self { tags, ..self }
    }

    #[inline]
    pub fn run<F: FnMut(T)>(&mut self, mut f: F) {
        self.run_id(|_, x| f(x))
    }

    #[inline]
    pub fn run_id<F: FnMut(GearId, T)>(&mut self, f: F) {
        let types_count = self
            .type_indices
            .iter()
            .position(|i| *i == -1)
            .unwrap_or(self.type_indices.len());
        self.data
            .run_impl(self.types, self.tags, &self.type_indices[0..types_count], f);
    }
}

#[cfg(test)]
mod test {
    use super::{super::common::GearId, GearDataManager};

    #[derive(Clone)]
    struct DatumA {
        value: u32,
    }

    #[derive(Clone)]
    struct DatumB {
        value: u32,
    }

    #[derive(Clone)]
    struct Tag;

    #[test]
    fn direct_access() {
        let mut manager = GearDataManager::new();
        manager.register::<DatumA>();
        for i in 1..=5 {
            manager.add(GearId::new(i as u16).unwrap(), &DatumA { value: i * i });
        }

        for i in 1..=5 {
            assert_eq!(
                manager
                    .get::<DatumA>(GearId::new(i as u16).unwrap())
                    .unwrap()
                    .value,
                i * i
            );
        }
    }

    #[test]
    fn single_component_iteration() {
        let mut manager = GearDataManager::new();
        manager.register::<DatumA>();

        for i in 1..=5 {
            manager.add(GearId::new(i as u16).unwrap(), &DatumA { value: i });
        }

        let mut sum = 0;
        manager.iter().run(|(d,): (&DatumA,)| sum += d.value);
        assert_eq!(sum, 15);

        manager.iter().run(|(d,): (&mut DatumA,)| d.value += 1);
        manager.iter().run(|(d,): (&DatumA,)| sum += d.value);
        assert_eq!(sum, 35);
    }

    #[test]
    fn tagged_component_iteration() {
        let mut manager = GearDataManager::new();
        manager.register::<DatumA>();
        manager.register::<Tag>();

        for i in 1..=10 {
            let gear_id = GearId::new(i as u16).unwrap();
            manager.add(gear_id, &DatumA { value: i });
        }

        for i in (2..=10).step_by(2) {
            let gear_id = GearId::new(i as u16).unwrap();
            manager.add_tag::<Tag>(gear_id);
        }

        let mut sum = 0;
        manager.iter().run(|(d,): (&DatumA,)| sum += d.value);
        assert_eq!(sum, 55);

        let mut tag_sum = 0;
        manager
            .iter()
            .with_tags::<&Tag>()
            .run(|(d,): (&DatumA,)| tag_sum += d.value);
        assert_eq!(tag_sum, 30);
    }

    #[test]
    fn removal() {
        let mut manager = GearDataManager::new();
        manager.register::<DatumA>();
        manager.register::<DatumB>();

        for i in 1..=10 {
            let gear_id = GearId::new(i as u16).unwrap();
            manager.add(gear_id, &DatumA { value: i });
            manager.add(gear_id, &DatumB { value: i });
        }

        for i in (1..=10).step_by(2) {
            let gear_id = GearId::new(i as u16).unwrap();
            manager.remove::<DatumA>(gear_id);
        }

        let mut sum_a = 0;
        manager.iter().run(|(d,): (&DatumA,)| sum_a += d.value);
        assert_eq!(sum_a, 30);

        let mut sum_b = 0;
        manager.iter().run(|(d,): (&DatumB,)| sum_b += d.value);
        assert_eq!(sum_b, 55);
    }
}