Skip to content
This repository was archived by the owner on Feb 26, 2019. It is now read-only.

Commit daadc59

Browse files
committed
wip pr comments
1 parent ecc9d5f commit daadc59

File tree

5 files changed

+53
-55
lines changed

5 files changed

+53
-55
lines changed

src/errors.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ error_chain!{
2222
description("type mismatch!")
2323
display("expected type `{}`, but found `{}`", expected, found)
2424
}
25-
NoneError {
26-
description("called `Option::unwrap()` on a `None` value")
25+
MissingShapeError {
26+
description("ndarray `shape()` returns `None`")
27+
display("called `Option::unwrap()` on a `None` value")
2728
}
2829

2930
}
@@ -36,6 +37,6 @@ error_chain!{
3637

3738
impl From<option::NoneError> for Error {
3839
fn from(err: option::NoneError) -> Self {
39-
ErrorKind::NoneError.into()
40+
ErrorKind::MissingShapeError.into()
4041
}
4142
}

src/function.rs

+44-42
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ impl<'a> FnOnce<((),)> for Builder<'a> {
224224
let mut ret_val = unsafe { mem::uninitialized::<ts::TVMValue>() };
225225
let mut ret_type_code = 0 as c_int;
226226
if self.arg_buf.is_some() {
227-
let arg_buf = self.arg_buf.unwrap();
227+
let arg_buf = self.arg_buf?;
228228
let mut num_args = arg_buf.len();
229229
let mut values = arg_buf
230230
.iter()
@@ -236,15 +236,15 @@ impl<'a> FnOnce<((),)> for Builder<'a> {
236236
.collect::<Vec<_>>();
237237
if self.ret_buf.is_some() {
238238
num_args = num_args + 1;
239-
ret_val = *self.ret_buf.clone().unwrap()[0].value;
240-
ret_type_code = self.ret_buf.clone().unwrap()[0].type_code as c_int;
239+
ret_val = *self.ret_buf.clone()?[0].value;
240+
ret_type_code = self.ret_buf.clone()?[0].type_code as c_int;
241241
values.append(&mut vec![ret_val]);
242242
tcodes.append(&mut vec![ret_type_code]);
243243
}
244244
values.truncate(num_args);
245245
tcodes.truncate(num_args);
246246
check_call!(ts::TVMFuncCall(
247-
self.func.unwrap().handle,
247+
self.func?.handle,
248248
values.as_mut_ptr(),
249249
tcodes.as_mut_ptr(),
250250
num_args as c_int,
@@ -253,7 +253,7 @@ impl<'a> FnOnce<((),)> for Builder<'a> {
253253
));
254254
} else {
255255
check_call!(ts::TVMFuncCall(
256-
self.func.unwrap().handle,
256+
self.func?.handle,
257257
ptr::null_mut(),
258258
ptr::null_mut(),
259259
0 as c_int,
@@ -292,44 +292,47 @@ unsafe extern "C" fn tvm_callback(
292292
fhandle: *mut c_void,
293293
) -> c_int {
294294
let len = num_args as usize;
295-
let args_list = unsafe { slice::from_raw_parts_mut(args, len).to_vec() };
296-
let type_codes_list = unsafe { slice::from_raw_parts_mut(type_codes, len).to_vec() };
295+
let args_list = unsafe { slice::from_raw_parts_mut(args, len) };
296+
let type_codes_list = unsafe { slice::from_raw_parts_mut(type_codes, len) };
297297
let mut local_args: Vec<TVMArgValue> = Vec::new();
298-
let mut value = unsafe { mem::uninitialized::<ts::TVMValue>() };
299-
let mut tcode = unsafe { mem::uninitialized::<c_int>() };
300-
let rust_fn = unsafe {
301-
mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue<'static>>>(fhandle)
302-
};
303-
for i in 0..len {
304-
value = args_list[i];
305-
tcode = type_codes_list[i];
306-
if tcode == TypeCode::kNodeHandle as c_int
307-
|| tcode == TypeCode::kFuncHandle as c_int
308-
|| tcode == TypeCode::kModuleHandle as c_int
309-
{
310-
check_call!(ts::TVMCbArgToReturn(&mut value as *mut _, tcode));
298+
unsafe {
299+
let mut value = mem::uninitialized::<ts::TVMValue>();
300+
let mut tcode = mem::uninitialized::<c_int>();
301+
let rust_fn = mem::transmute::<
302+
*mut c_void,
303+
fn(&[TVMArgValue]) -> Result<TVMRetValue<'static>>,
304+
>(fhandle);
305+
for i in 0..len {
306+
value = args_list[i];
307+
tcode = type_codes_list[i];
308+
if tcode == TypeCode::kNodeHandle as c_int
309+
|| tcode == TypeCode::kFuncHandle as c_int
310+
|| tcode == TypeCode::kModuleHandle as c_int
311+
{
312+
check_call!(ts::TVMCbArgToReturn(&mut value as *mut _, tcode));
313+
}
314+
local_args.push(TVMArgValue::new(
315+
TVMValue::new(ValueKind::Handle, value),
316+
tcode.into(),
317+
));
311318
}
312-
local_args.push(TVMArgValue::new(
313-
TVMValue::new(ValueKind::Handle, value),
314-
tcode.into(),
319+
320+
let rv = match rust_fn(local_args.as_slice()) {
321+
Ok(v) => v,
322+
Err(msg) => {
323+
::set_last_error(&msg);
324+
return -1;
325+
}
326+
};
327+
let mut ret_val = *rv.value;
328+
let mut ret_type_code = rv.type_code as c_int;
329+
check_call!(ts::TVMCFuncSetReturn(
330+
ret,
331+
&mut ret_val as *mut _,
332+
&mut ret_type_code as *mut _,
333+
1 as c_int
315334
));
316335
}
317-
318-
let rv = match rust_fn(local_args.as_slice()) {
319-
Ok(v) => v,
320-
Err(msg) => {
321-
::set_last_error(&msg);
322-
return -1;
323-
}
324-
};
325-
let mut ret_val = *rv.value;
326-
let mut ret_type_code = rv.type_code as c_int;
327-
check_call!(ts::TVMCFuncSetReturn(
328-
ret,
329-
&mut ret_val as *mut _,
330-
&mut ret_type_code as *mut _,
331-
1 as c_int
332-
));
333336
0
334337
}
335338

@@ -384,12 +387,11 @@ pub fn register(
384387
override_: bool,
385388
) -> Result<()> {
386389
let func = convert_to_tvm_func(f);
387-
let ovd = if override_ { 1 } else { 0 };
388-
let name = CString::new(name).unwrap();
390+
let name = CString::new(name)?;
389391
check_call!(ts::TVMFuncRegisterGlobal(
390392
name.as_ptr() as *const c_char,
391393
func.handle(),
392-
ovd
394+
override_ as c_int
393395
));
394396
mem::forget(name);
395397
Ok(())

src/module.rs

+1-3
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,15 @@ impl Module {
4242
}
4343

4444
/// Sets the entry function of a module.
45-
pub fn entry_func(mut self) -> Self {
45+
pub fn entry_func(&mut self) {
4646
if self.entry.is_none() {
4747
self.entry = self.get_function(ENTRY_FUNC, false).ok();
4848
}
49-
self
5049
}
5150

5251
/// Gets a function by name from a registered module.
5352
pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function> {
5453
let name = CString::new(name)?;
55-
let query_import = if query_import == true { 1 } else { 0 };
5654
let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle;
5755
check_call!(ts::TVMModGetFunction(
5856
self.handle,

src/ndarray.rs

+2-5
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,8 @@ impl NDArray {
204204
) -> Result<Self> {
205205
let mut shape = rnd.shape().to_vec();
206206
let mut nd = empty(&mut shape, ctx, dtype);
207-
let mut buf = Array::from_iter(rnd.iter())
208-
.iter()
209-
.map(|e| **e)
210-
.collect::<Vec<T>>();
211-
nd.copy_from_buffer(&mut buf);
207+
let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T));
208+
nd.copy_from_buffer(buf.as_slice_mut()?);
212209
Ok(nd)
213210
}
214211
}

tests/basics/src/main.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ fn main() {
2020
let path = Path::new("add_cpu.so");
2121
let mut fadd = Module::load(&path).unwrap();
2222
assert!(fadd.enabled("cpu".to_owned()));
23-
fadd = fadd.entry_func();
23+
fadd.entry_func();
2424
function::Builder::from(&mut fadd)
2525
.arg(&arr)
2626
.arg(&arr)
@@ -45,7 +45,7 @@ fn main() {
4545
let fadd_dep = Module::load(ptx).unwrap();
4646
assert!(fadd.enabled("gpu".to_owned()), "GPU is not enabled!");
4747
fadd.import_module(fadd_dep);
48-
fadd = fadd.entry_func();
48+
fadd.entry_func();
4949
function::Builder::from(&mut fadd)
5050
.arg(&arr)
5151
.arg(&arr)

0 commit comments

Comments
 (0)