Skip to content

Commit 8bb35a3

Browse files
committed
Split standard_b64encode_impl
1 parent c9deee6 commit 8bb35a3

1 file changed

Lines changed: 28 additions & 15 deletions

File tree

Modules/_base64/src/lib.rs

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ struct BorrowedBuffer {
8383
}
8484

8585
impl BorrowedBuffer {
86-
unsafe fn from_object(obj: *mut PyObject) -> Result<Self, ()> {
86+
fn from_object(obj: &mut PyObject) -> Result<Self, ()> {
8787
let mut view = MaybeUninit::<Py_buffer>::uninit();
8888
if unsafe { PyObject_GetBuffer(obj, view.as_mut_ptr(), PYBUF_SIMPLE) } != 0 {
8989
return Err(());
@@ -110,6 +110,9 @@ impl Drop for BorrowedBuffer {
110110
}
111111
}
112112

113+
/// # Safety
114+
/// `module` must be a valid pointer of PyObject representing the module.
115+
/// `args` must be a valid pointer to an array of valid PyObject pointers with length `nargs`.
113116
#[unsafe(no_mangle)]
114117
pub unsafe extern "C" fn standard_b64encode(
115118
_module: *mut PyObject,
@@ -123,61 +126,71 @@ pub unsafe extern "C" fn standard_b64encode(
123126
c"standard_b64encode() takes exactly one argument".as_ptr(),
124127
);
125128
}
126-
return ptr::null_mut();
127129
}
128130

129-
let source = unsafe { *args };
130-
let buffer = match unsafe { BorrowedBuffer::from_object(source) } {
131+
let source = unsafe { &mut **args };
132+
133+
// Safe cast by Safety
134+
match standard_b64encode_impl(source) {
135+
Ok(result) => result,
136+
Err(_) => {
137+
ptr::null_mut()
138+
}
139+
}
140+
}
141+
142+
fn standard_b64encode_impl(
143+
source: &mut PyObject,
144+
) -> Result<*mut PyObject, ()> {
145+
let buffer = match BorrowedBuffer::from_object(source) {
131146
Ok(buf) => buf,
132-
Err(_) => return ptr::null_mut(),
147+
Err(_) => return Err(()),
133148
};
134149

135150
let view_len = buffer.len();
136151
if view_len < 0 {
137152
unsafe {
138-
PyErr_SetString(
139-
PyExc_TypeError,
140-
c"standard_b64encode() argument has negative length".as_ptr(),
141-
);
153+
PyErr_SetString( PyExc_TypeError , c"standard_b64encode() argument has negative length".as_ptr());
142154
}
143-
return ptr::null_mut();
155+
return Err(());
144156
}
157+
145158
let input_len = view_len as usize;
146159
let input = unsafe { slice::from_raw_parts(buffer.as_ptr(), input_len) };
147160

148161
let Some(output_len) = encoded_output_len(input_len) else {
149162
unsafe {
150163
PyErr_NoMemory();
151164
}
152-
return ptr::null_mut();
165+
return Err(());
153166
};
154167

155168
if output_len > isize::MAX as usize {
156169
unsafe {
157170
PyErr_NoMemory();
158171
}
159-
return ptr::null_mut();
172+
return Err(());
160173
}
161174

162175
let result = unsafe {
163176
PyBytes_FromStringAndSize(ptr::null(), output_len as Py_ssize_t)
164177
};
165178
if result.is_null() {
166-
return ptr::null_mut();
179+
return Err(());
167180
}
168181

169182
let dest_ptr = unsafe { PyBytes_AsString(result) };
170183
if dest_ptr.is_null() {
171184
unsafe {
172185
Py_DecRef(result);
173186
}
174-
return ptr::null_mut();
187+
return Err(());
175188
}
176189
let dest = unsafe { slice::from_raw_parts_mut(dest_ptr.cast::<u8>(), output_len) };
177190

178191
let written = encode_into(input, dest);
179192
debug_assert_eq!(written, output_len);
180-
result
193+
Ok(result)
181194
}
182195

183196
#[unsafe(no_mangle)]

0 commit comments

Comments
 (0)