Skip to content

Commit 93cde03

Browse files
committed
feat(virtq): introduce pool alloc for managing allocation lifetime
1 parent e4fd644 commit 93cde03

2 files changed

Lines changed: 78 additions & 88 deletions

File tree

src/hyperlight_common/src/virtq/buffer.rs

Lines changed: 62 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -97,52 +97,26 @@ impl<T: BufferProvider> BufferProvider for Arc<T> {
9797

9898
/// The owner of a mapped buffer, ensuring its lifetime.
9999
///
100-
/// Holds a pool allocation and provides direct access to the underlying
100+
/// Holds a [`PoolAlloc`] and provides direct access to the underlying
101101
/// shared memory via [`MemOps::as_slice`]. Implements `AsRef<[u8]>` so it
102102
/// can be used with [`Bytes::from_owner`](bytes::Bytes::from_owner) for
103103
/// zero-copy `Bytes` backed by shared memory.
104104
///
105105
/// When dropped, the allocation is returned to the pool.
106106
#[derive(Debug)]
107107
pub struct BufferOwner<P: BufferProvider, M: MemOps> {
108-
pub(crate) pool: P,
108+
pub(crate) alloc: PoolAlloc<P>,
109109
pub(crate) mem: M,
110-
pub(crate) alloc: Allocation,
111110
pub(crate) written: usize,
112111
}
113112

114-
impl<P: BufferProvider, M: MemOps> Drop for BufferOwner<P, M> {
115-
fn drop(&mut self) {
116-
let _ = self.pool.dealloc(self.alloc);
117-
}
118-
}
119-
120-
impl<P: BufferProvider, M: MemOps> BufferOwner<P, M> {
121-
pub(crate) fn try_new(
122-
pool: P,
123-
mem: M,
124-
alloc: Allocation,
125-
written: usize,
126-
) -> Result<Self, M::Error> {
127-
// Pre check direct access before handing the owner to Bytes::from_owner
128-
let len = written.min(alloc.len);
129-
let _ = unsafe { mem.as_slice(alloc.addr, len) }?;
130-
131-
Ok(Self {
132-
pool,
133-
mem,
134-
alloc,
135-
written,
136-
})
137-
}
138-
}
139-
140113
impl<P: BufferProvider, M: MemOps> AsRef<[u8]> for BufferOwner<P, M> {
141114
fn as_ref(&self) -> &[u8] {
142-
let len = self.written.min(self.alloc.len);
115+
let alloc = self.alloc.allocation();
116+
let len = self.written.min(alloc.len);
143117
// Safety: BufferOwner keeps both the pool allocation and the M alive,
144118
// so the memory region is valid.
145-
match unsafe { self.mem.as_slice(self.alloc.addr, len) } {
119+
match unsafe { self.mem.as_slice(alloc.addr, len) } {
146120
Ok(slice) => slice,
147121
Err(_) => {
148122
debug_assert!(false, "BufferOwner direct slice failed");
@@ -152,41 +126,74 @@ impl<P: BufferProvider, M: MemOps> AsRef<[u8]> for BufferOwner<P, M> {
152126
}
153127
}
154128

155-
/// A guard that runs a cleanup function when dropped, unless dismissed.
156-
pub struct AllocGuard<F: FnOnce(Allocation)>(Option<(Allocation, F)>);
129+
/// Pool-owned allocation that is returned to the pool on drop.
130+
///
131+
/// Use [`into_raw`](Self::into_raw) to transfer ownership to a descriptor
132+
/// state that will deallocate the raw [`Allocation`] through another path.
133+
#[derive(Debug)]
134+
pub struct PoolAlloc<P: BufferProvider> {
135+
inner: Option<PoolAllocInner<P>>,
136+
}
137+
138+
#[derive(Debug)]
139+
struct PoolAllocInner<P: BufferProvider> {
140+
pool: P,
141+
alloc: Allocation,
142+
}
157143

158-
impl<F: FnOnce(Allocation)> AllocGuard<F> {
159-
pub fn new(alloc: Allocation, cleanup: F) -> Self {
160-
Self(Some((alloc, cleanup)))
144+
impl<P: BufferProvider> PoolAlloc<P> {
145+
/// Wrap an existing allocation with its owning pool.
146+
pub fn new(pool: P, alloc: Allocation) -> Self {
147+
Self {
148+
inner: Some(PoolAllocInner { pool, alloc }),
149+
}
150+
}
151+
152+
/// Allocate from `pool` and return an owning guard.
153+
pub fn allocate(pool: P, len: usize) -> Result<Self, AllocError> {
154+
let alloc = pool.alloc(len)?;
155+
Ok(Self::new(pool, alloc))
156+
}
157+
158+
/// The raw allocation currently owned by this guard.
159+
pub fn allocation(&self) -> Allocation {
160+
self.inner
161+
.as_ref()
162+
.map(|inner| inner.alloc)
163+
.unwrap_or_else(|| {
164+
unreachable!("PoolAlloc::allocation called after ownership transfer")
165+
})
161166
}
162167

163-
pub fn release(mut self) -> Allocation {
164-
// Safety: AllocGuard is always constructed with Some, and release is only called once
165-
self.0
168+
/// Release ownership and return the raw allocation.
169+
pub fn into_raw(mut self) -> Allocation {
170+
self.inner
166171
.take()
167-
.map(|(alloc, _)| alloc)
168-
.unwrap_or_else(|| unreachable!("AllocGuard::release called on dismissed guard"))
172+
.map(|inner| inner.alloc)
173+
.unwrap_or_else(|| unreachable!("PoolAlloc::into_raw called after ownership transfer"))
169174
}
170-
}
171175

172-
impl<F: FnOnce(Allocation)> core::ops::Deref for AllocGuard<F> {
173-
type Target = Allocation;
176+
pub(crate) fn into_buffer_owner<M: MemOps>(
177+
self,
178+
mem: M,
179+
written: usize,
180+
) -> Result<BufferOwner<P, M>, M::Error> {
181+
let alloc = self.allocation();
182+
let len = written.min(alloc.len);
183+
let _ = unsafe { mem.as_slice(alloc.addr, len) }?;
174184

175-
fn deref(&self) -> &Allocation {
176-
// Safety: AllocGuard is always constructed with Some, and the inner value is only
177-
// taken by release() or Drop.
178-
&self
179-
.0
180-
.as_ref()
181-
.unwrap_or_else(|| unreachable!("AllocGuard::deref called on dismissed guard"))
182-
.0
185+
Ok(BufferOwner {
186+
alloc: self,
187+
mem,
188+
written,
189+
})
183190
}
184191
}
185192

186-
impl<F: FnOnce(Allocation)> Drop for AllocGuard<F> {
193+
impl<P: BufferProvider> Drop for PoolAlloc<P> {
187194
fn drop(&mut self) {
188-
if let Some((alloc, cleanup)) = self.0.take() {
189-
cleanup(alloc)
195+
if let Some(PoolAllocInner { pool, alloc }) = self.inner.take() {
196+
let _ = pool.dealloc(alloc);
190197
}
191198
}
192199
}

src/hyperlight_common/src/virtq/producer.rs

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -253,30 +253,21 @@ where
253253
self.pool.dealloc(entry)?;
254254
}
255255

256-
let completion_guard = inf.completion().map(|buf| {
257-
let pool = self.pool.clone();
258-
AllocGuard::new(buf, move |a| {
259-
let _ = pool.dealloc(a);
260-
})
261-
});
256+
let completion_guard = inf
257+
.completion()
258+
.map(|buf| PoolAlloc::new(self.pool.clone(), buf));
262259

263260
// Read completion data
264261
let has_completion = completion_guard.is_some();
265262
let data = match completion_guard {
266-
Some(buf) => {
267-
if written > buf.len {
268-
return Err(VirtqError::InvalidState);
269-
}
270-
let owner = BufferOwner::try_new(
271-
self.pool.clone(),
272-
self.inner.mem().clone(),
273-
*buf,
274-
written,
275-
)
276-
.map_err(|_| VirtqError::MemoryReadError)?;
277-
let _ = buf.release();
278-
Bytes::from_owner(owner)
263+
Some(buf) if written > buf.allocation().len => {
264+
// This is a protocol violation
265+
return Err(VirtqError::InvalidState);
279266
}
267+
Some(buf) => Bytes::from_owner(
268+
buf.into_buffer_owner(self.inner.mem().clone(), written)
269+
.map_err(|_| VirtqError::MemoryReadError)?,
270+
),
280271
None => Bytes::new(),
281272
};
282273

@@ -642,16 +633,8 @@ impl<M: MemOps, P: BufferProvider + Clone> ChainBuilder<M, P> {
642633
}
643634
}
644635

645-
fn alloc(
646-
&self,
647-
size: usize,
648-
) -> Result<AllocGuard<impl FnOnce(Allocation) + use<M, P>>, VirtqError> {
649-
let alloc = self.pool.alloc(size)?;
650-
let pool = self.pool.clone();
651-
652-
Ok(AllocGuard::new(alloc, move |a| {
653-
let _ = pool.dealloc(a);
654-
}))
636+
fn alloc(&self, size: usize) -> Result<PoolAlloc<P>, VirtqError> {
637+
Ok(PoolAlloc::allocate(self.pool.clone(), size)?)
655638
}
656639

657640
/// Request an entry buffer of `cap` bytes.
@@ -688,14 +671,14 @@ impl<M: MemOps, P: BufferProvider + Clone> ChainBuilder<M, P> {
688671

689672
let inflight = match (entry_alloc, completion_alloc) {
690673
(Some(entry), Some(cqe)) => Inflight::ReadWrite {
691-
entry: entry.release(),
692-
completion: cqe.release(),
674+
entry: entry.into_raw(),
675+
completion: cqe.into_raw(),
693676
},
694677
(Some(entry), None) => Inflight::ReadOnly {
695-
entry: entry.release(),
678+
entry: entry.into_raw(),
696679
},
697680
(None, Some(cqe)) => Inflight::WriteOnly {
698-
completion: cqe.release(),
681+
completion: cqe.into_raw(),
699682
},
700683
(None, None) => unreachable!(),
701684
};

0 commit comments

Comments
 (0)