Skip to content
This repository was archived by the owner on Feb 26, 2019. It is now read-only.

Commit 9614d24

Browse files
committedDec 15, 2018
wip pr comments
1 parent feb6d6c commit 9614d24

16 files changed

+144
-160
lines changed
 

‎.gitignore

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
/target
22
**/*.rs.bk
33
Cargo.lock
4-
/tvm-sys/Cargo.lock
5-
/tvm-sys/target
6-
/tests/callback/target
7-
/tests/basics/target
4+
/tvm-sys/target/
5+
/tvm-sys/src/bindgen.rs
86
/tests/basics/add_*
7+
/tests/callback/target
98
/examples/resnet/target
109
/examples/resnet/deploy_*
1110
/examples/resnet/*.png
12-
/examples/resnet/synset.*
11+
/examples/resnet/synset.*

‎.rustfmt.toml

+2-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ indent_style = "Block"
77
wrap_comments = false
88
comment_width = 80
99
normalize_comments = false
10-
license_template_path = ""
1110
format_strings = false
1211
format_macro_matchers = false
1312
format_macro_bodies = true
@@ -41,12 +40,12 @@ blank_lines_upper_bound = 1
4140
blank_lines_lower_bound = 0
4241
edition = "2015"
4342
merge_derives = true
44-
use_try_shorthand = false
43+
use_try_shorthand = true
4544
use_field_init_shorthand = false
4645
force_explicit_abi = true
4746
condense_wildcard_suffixes = false
4847
color = "Auto"
49-
required_version = "0.99.6"
48+
required_version = "0.99.5"
5049
unstable_features = false
5150
disable_all_formatting = false
5251
skip_children = false

‎.travis.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ language: rust
22
rust:
33
- nightly
44
matrix:
5-
fast_finish: true
5+
fast_finish: true

‎Cargo.toml

+1-9
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,7 @@ crate-type = ["dylib"]
1616

1717
[dependencies]
1818
tvm-sys = { version = "0.1.0", path = "tvm-sys" }
19-
ndarray = { version = "0.12.1", features = ["blas"] }
20-
blas-src = { version = "0.1.2", default-features = false, features = ["openblas"] }
21-
openblas-src = { version = "0.5.6", default-features = false, features = ["cblas", "system"] }
19+
ndarray = { version = "0.12.1" }
2220
lazy_static = "1.1.0"
2321
num-traits = "0.2"
2422
custom_error = "1.3.0"
25-
26-
[profile.release]
27-
debug = true
28-
29-
[package.metadata.release]
30-
no-dev-version = true

‎examples/resnet/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ build = "build.rs"
99
ndarray = "0.12.1"
1010
tvm-frontend = { path = "../../" }
1111
image = "0.20.1"
12-
csv = "1"
12+
csv = "1"

‎examples/resnet/src/build_resnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,4 @@ def test_build(target_dir):
8585

8686
logger.info("testing the build")
8787
test_build(sys.argv[1])
88-
logger.info("test was successful")
88+
logger.info("test was successful")

‎examples/resnet/src/main.rs

+16-21
Original file line numberDiff line numberDiff line change
@@ -38,23 +38,24 @@ fn main() -> Result<(), Box<Error>> {
3838

3939
let arr = Array::from_shape_vec((224, 224, 3), pixels)?;
4040
let arr: ArrayD<f32> = arr.permuted_axes([2, 0, 1]).into_dyn();
41-
let arr = arr.insert_axis(Axis(0));
41+
let mut arr = arr.insert_axis(Axis(0));
4242
// create input tensor from rust's ndarray
43-
let input = NDArray::from_rust_ndarray(&arr, TVMContext::cpu(0), TVMType::from("float"))?;
43+
let input = NDArray::from_rust_ndarray(&mut arr, TVMContext::cpu(0), TVMType::from("float"))?;
4444
println!("input size is {:?}", input.shape().unwrap());
4545
let graph = fs::read_to_string("deploy_graph.json")?;
4646
// load module
4747
let lib = Module::load(&Path::new("deploy_lib.so"))?;
4848
// get the global TVM graph runtime function
4949
let runtime_create_fn =
5050
Function::get_function("tvm.graph_runtime.create", true, false).unwrap();
51-
// create runtime function from Rust
52-
let runtime_create_fn_ret = function::Builder::from(runtime_create_fn)
53-
.arg(&graph)
54-
.arg(&lib)
55-
.arg(&ctx.device_type)
56-
.arg(&ctx.device_id)
57-
.invoke()?;
51+
52+
let runtime_create_fn_ret = tvm_call!(
53+
runtime_create_fn,
54+
&graph,
55+
&lib,
56+
&ctx.device_type,
57+
&ctx.device_id
58+
)?;
5859
// get graph runtime module
5960
let graph_runtime_module = runtime_create_fn_ret.to_module();
6061
// get the registered `load_params` from runtime module
@@ -65,20 +66,17 @@ fn main() -> Result<(), Box<Error>> {
6566
let params: Vec<u8> = fs::read("deploy_param.params")?;
6667
let barr = TVMByteArray::from(&params);
6768
// load the parameters
68-
function::Builder::from(load_param_fn).arg(&barr).invoke()?;
69+
tvm_call!(load_param_fn, &barr)?;
6970
// get the set_input function
7071
let set_input_fn = graph_runtime_module
7172
.get_function("set_input", false)
7273
.unwrap();
73-
// set the input via set_input function
74-
function::Builder::from(set_input_fn)
75-
.arg("data")
76-
.arg(&input)
77-
.invoke()?;
74+
75+
tvm_call!(set_input_fn, "data", &input)?;
7876
// get `run` function from runtime module
7977
let run_fn = graph_runtime_module.get_function("run", false).unwrap();
80-
// execute the run function
81-
function::Builder::from(run_fn).invoke()?;
78+
// execute the run function. Note that it has no argument.
79+
tvm_call!(run_fn,)?;
8280
// prepare to get the output
8381
let mut output_shape = vec![1, 1000];
8482
let output = empty(
@@ -91,10 +89,7 @@ fn main() -> Result<(), Box<Error>> {
9189
.get_function("get_output", false)
9290
.unwrap();
9391
// execute the get output function
94-
function::Builder::from(get_output_fn)
95-
.arg(&0)
96-
.arg(&output)
97-
.invoke()?;
92+
tvm_call!(get_output_fn, &0, &output)?;
9893
// flatten the output as Vec<f32>
9994
let output = output.to_vec::<f32>()?;
10095
// find the maximum entry in the output and its index

‎src/bytearray.rs

+6-11
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
//! This function can be obtained from a graph runtime module loading the model.
55
//! For more detail, please see the example `resnet` in `examples` repository.
66
7-
use std::ffi::CString;
8-
use std::mem;
7+
use std::os::raw::c_char;
98

109
use ts;
1110

@@ -48,15 +47,11 @@ impl TVMByteArray {
4847

4948
impl<'a> From<&'a Vec<u8>> for TVMByteArray {
5049
fn from(arg: &Vec<u8>) -> Self {
51-
unsafe {
52-
let data = CString::from_vec_unchecked(arg.to_vec());
53-
let barr = ts::TVMByteArray {
54-
data: data.as_ptr(),
55-
size: arg.len(),
56-
};
57-
mem::forget(data);
58-
TVMByteArray::new(barr)
59-
}
50+
let barr = ts::TVMByteArray {
51+
data: arg.as_ptr() as *const c_char,
52+
size: arg.len(),
53+
};
54+
TVMByteArray::new(barr)
6055
}
6156
}
6257

‎src/context.rs

+18-42
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@ use std::fmt::{self, Display, Formatter};
2020
use std::os::raw::c_void;
2121
use std::ptr;
2222

23-
use ts;
24-
2523
use function;
2624
use internal_api;
25+
use ts;
2726
use Result;
2827

2928
/// Device type which can be from a supported device name. See the supported devices
@@ -81,16 +80,20 @@ impl From<ts::DLDeviceType> for TVMDeviceType {
8180

8281
impl Display for TVMDeviceType {
8382
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
84-
match self {
85-
TVMDeviceType(1) => write!(f, "cpu"),
86-
TVMDeviceType(2) => write!(f, "gpu"),
87-
TVMDeviceType(3) => write!(f, "cpu_pinned"),
88-
TVMDeviceType(4) => write!(f, "opencl"),
89-
TVMDeviceType(8) => write!(f, "meta"),
90-
TVMDeviceType(9) => write!(f, "vpi"),
91-
TVMDeviceType(10) => write!(f, "rocm"),
92-
TVMDeviceType(_) => write!(f, "rpc"),
93-
}
83+
write!(
84+
f,
85+
"{}",
86+
match self {
87+
TVMDeviceType(1) => "cpu",
88+
TVMDeviceType(2) => "gpu",
89+
TVMDeviceType(3) => "cpu_pinned",
90+
TVMDeviceType(4) => "opencl",
91+
TVMDeviceType(8) => "meta",
92+
TVMDeviceType(9) => "vpi",
93+
TVMDeviceType(10) => "rocm",
94+
TVMDeviceType(_) => "rpc",
95+
}
96+
)
9497
}
9598
}
9699

@@ -145,11 +148,6 @@ impl TVMContext {
145148
device_id: device_id,
146149
}
147150
}
148-
149-
/// Gets the currect context.
150-
pub fn current_context(&self) -> &Self {
151-
self
152-
}
153151
}
154152

155153
macro_rules! impl_ctxs {
@@ -186,12 +184,7 @@ impl TVMContext {
186184
pub fn exist(&self) -> bool {
187185
let func = internal_api::get_api("_GetDeviceAttr".to_owned());
188186
let dt = self.device_type.0 as usize;
189-
let ret = function::Builder::from(func)
190-
.arg(&dt)
191-
.arg(&self.device_id)
192-
.arg(&0)
193-
.invoke()
194-
.unwrap();
187+
let ret = tvm_call!(func, &dt, &self.device_id, &0).unwrap();
195188
ret.to_int() != 0
196189
}
197190

@@ -264,11 +257,11 @@ mod tests {
264257
let ctx = TVMContext::cpu(0);
265258
println!("ctx: {}", ctx);
266259
let default_ctx = TVMContext::new(TVMDeviceType(1), 0);
267-
assert_eq!(ctx.current_context().clone(), default_ctx);
260+
assert_eq!(ctx.clone(), default_ctx);
268261
assert_ne!(ctx, TVMContext::gpu(0));
269262

270263
let str_ctx = TVMContext::new(TVMDeviceType::from("gpu"), 0);
271-
assert_eq!(str_ctx.current_context().clone(), str_ctx);
264+
assert_eq!(str_ctx.clone(), str_ctx);
272265
assert_ne!(str_ctx, TVMContext::new(TVMDeviceType::from("cpu"), 0));
273266
}
274267

@@ -277,21 +270,4 @@ mod tests {
277270
let ctx = TVMContext::cpu(0);
278271
assert!(ctx.sync().is_ok())
279272
}
280-
281-
#[test]
282-
fn dev_attributes() {
283-
let ctx = TVMContext::cpu(0);
284-
assert!(ctx.exist());
285-
println!("max thread per block: {}", ctx.max_threads_per_block());
286-
println!("warp size: {}", ctx.warp_size());
287-
println!(
288-
"max shared memory per block: {}",
289-
ctx.max_shared_memory_per_block()
290-
);
291-
println!("compute version: {}", ctx.compute_version());
292-
println!("device name: {}", ctx.device_name());
293-
println!("max clock rate: {}", ctx.max_clock_rate());
294-
println!("multi processor count: {}", ctx.multi_processor_count());
295-
println!("max thread dimensions: {}", ctx.max_thread_dimensions());
296-
}
297273
}

0 commit comments

Comments
 (0)
This repository has been archived.