Skip to content

Commit 2014a62

Browse files
committed
Update resnet example
1 parent 443537f commit 2014a62

File tree

5 files changed

+27
-17
lines changed

5 files changed

+27
-17
lines changed

rust/common/src/packed_func.rs

+16-2
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ macro_rules! ensure_type {
5959
macro_rules! impl_prim_tvm_arg {
6060
($type_code:ident, $field:ident, $field_type:ty, [ $( $type:ty ),+ ] ) => {
6161
$(
62-
impl<'a> From<$type> for TVMArgValue<'a> {
62+
impl From<$type> for TVMArgValue<'static> {
6363
fn from(val: $type) -> Self {
6464
TVMArgValue {
6565
value: TVMValue { $field: val as $field_type },
@@ -112,11 +112,25 @@ impl_prim_tvm_arg!(
112112
[u8, u16, u32, u64, usize]
113113
);
114114

115+
#[cfg(feature = "bindings")]
116+
// only allow this in bindings because pure-rust can't take ownership of leaked CString
117+
impl<'a> From<&String> for TVMArgValue<'a> {
118+
fn from(string: &String) -> Self {
119+
TVMArgValue {
120+
value: TVMValue {
121+
v_str: std::ffi::CString::new(string.clone()).unwrap().into_raw(),
122+
},
123+
type_code: TVMTypeCode_kStr as i64,
124+
_lifetime: PhantomData,
125+
}
126+
}
127+
}
128+
115129
impl<'a> From<&std::ffi::CString> for TVMArgValue<'a> {
116130
fn from(string: &std::ffi::CString) -> Self {
117131
TVMArgValue {
118132
value: TVMValue {
119-
v_handle: string.as_ptr() as *const _ as *mut c_void,
133+
v_str: string.as_ptr(),
120134
},
121135
type_code: TVMTypeCode_kStr as i64,
122136
_lifetime: PhantomData,

rust/common/src/value.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ impl FromStr for TVMType {
4545
_ => return Err(format_err!("Unknown type {}", type_name)),
4646
};
4747

48-
Ok(dbg!(TVMType::new(type_code, bits, lanes)))
48+
Ok(TVMType::new(type_code, bits, lanes))
4949
}
5050
}
5151

rust/frontend/examples/resnet/src/main.rs

+2-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#![feature(try_from)]
2-
31
extern crate csv;
42
extern crate image;
53
extern crate ndarray;
@@ -55,10 +53,8 @@ fn main() {
5553
"input size is {:?}",
5654
input.shape().expect("cannot get the input shape")
5755
);
58-
let graph = std::ffi::CString::new(
59-
fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap(),
60-
)
61-
.unwrap();
56+
let graph =
57+
fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap();
6258
// load the built module
6359
let lib = Module::load(&Path::new(concat!(
6460
env!("CARGO_MANIFEST_DIR"),

rust/frontend/src/context.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,15 @@ pub struct TVMContext {
158158
/// Supported device types
159159
pub device_type: TVMDeviceType,
160160
/// Device id
161-
pub device_id: usize,
161+
pub device_id: i32,
162162
}
163163

164164
impl TVMContext {
165165
/// Creates context from device type and id.
166-
pub fn new(device_type: TVMDeviceType, device_id: usize) -> Self {
166+
pub fn new(device_type: TVMDeviceType, device_id: i32) -> Self {
167167
TVMContext {
168-
device_type: device_type,
169-
device_id: device_id,
168+
device_type,
169+
device_id,
170170
}
171171
}
172172
}
@@ -175,7 +175,7 @@ macro_rules! impl_ctxs {
175175
($(($ctx:ident, $dldevt:expr));+) => {
176176
$(
177177
impl TVMContext {
178-
pub fn $ctx(device_id: usize) -> Self {
178+
pub fn $ctx(device_id: i32) -> Self {
179179
Self::new(TVMDeviceType($dldevt), device_id)
180180
}
181181
}
@@ -238,7 +238,7 @@ macro_rules! impl_device_attrs {
238238
// `unwrap` is ok here because if there is any error,
239239
// if would occur in function call.
240240
function::Builder::from(func)
241-
.args(&[dt, self.device_id, $attr_kind])
241+
.args(&[dt, self.device_id as usize, $attr_kind])
242242
.invoke()
243243
.unwrap()
244244
.try_into()
@@ -262,7 +262,7 @@ impl From<ffi::DLContext> for TVMContext {
262262
fn from(ctx: ffi::DLContext) -> Self {
263263
TVMContext {
264264
device_type: TVMDeviceType::from(ctx.device_type),
265-
device_id: ctx.device_id as usize,
265+
device_id: ctx.device_id,
266266
}
267267
}
268268
}

rust/frontend/src/ndarray.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ impl NDArray {
188188

189189
/// Copies the NDArray to another target NDArray.
190190
pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray, Error> {
191-
if dbg!(self.dtype()) != dbg!(target.dtype()) {
191+
if self.dtype() != target.dtype() {
192192
bail!(
193193
"{}",
194194
errors::TypeMismatchError {

0 commit comments

Comments
 (0)