Skip to content
This repository was archived by the owner on Feb 26, 2019. It is now read-only.

Commit 8ea9c84

Browse files
committed
enhance docs
1 parent af668bd commit 8ea9c84

12 files changed

+99
-60
lines changed

README.md

+12-15
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
# TVM Runtime Frontend Support
22

3-
This crate provides idiomatic Rust API for [TVM](https://github.com/dmlc/tvm) runtime frontend as part of [ongoing RFC 1601](https://github.com/dmlc/tvm/issues/1601). Currently this requires **Nightly Rust**.
3+
This crate provides an idiomatic Rust API for [TVM](https://github.com/dmlc/tvm) runtime frontend as part of the [ongoing RFC](https://github.com/dmlc/tvm/issues/1601). Currently this requires **Nightly Rust**.
44

55
Checkout the [docs](https://ehsanmok.github.io/tvm_frontend/tvm_frontend/index.html).
66

77
## What Does This Crate Offer?
88

99
Here is a major workflow
1010

11-
1. Train your **Deep Learning** model using any major framework [PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.incubator.apache.org/) and [TensorFlow](https://www.tensorflow.org/)
12-
2. Use **TVM** to build optimized model artifacts for a given supported TVM context such as CPU, GPU, OpenCL, Vulkan, VPI, ROCM, etc.
11+
1. Train your **Deep Learning** model using any major framework such as [PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.incubator.apache.org/) or [TensorFlow](https://www.tensorflow.org/)
12+
2. Use **TVM** to build optimized model artifacts on a supported context such as CPU, GPU, OpenCL, Vulkan, VPI, ROCM, etc.
1313
3. Deploy your models using **Rust** :heart:
1414

1515
### Example: Deploy Image Classification from Pretrained Resnet18 on ImageNet1k
1616

1717
Please checkout [examples/resnet](https://github.com/ehsanmok/tvm-rust/tree/master/examples/resnet) for the complete end-to-end example.
1818

19-
Here's python snippet for download and building a pretrained Resnet18 via MXNet and TVM
19+
Here's a Python snippet for downloading and building a pretrained Resnet18 via MXNet and TVM
2020

2121
```python
2222
block = get_model('resnet18_v1', pretrained=True)
@@ -39,7 +39,7 @@ with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo:
3939
fo.write(nnvm.compiler.save_param_dict(params))
4040
```
4141

42-
Now, we can read input the artifacts to create and run the Graph Runtime to detect our cat image
42+
Now, we need to input the artifacts to create and run the *Graph Runtime* to detect our input cat image
4343

4444
![cat](https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true)
4545

@@ -94,11 +94,11 @@ call_packed!(get_output_fn, &0, &output)?;
9494
let output = output.to_vec::<f32>()?;
9595
```
9696

97-
## Installation
97+
## Installations
9898

99-
Please follow the TVM [installation](https://docs.tvm.ai/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`.
99+
Please follow TVM [installations](https://docs.tvm.ai/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`.
100100

101-
*Note:* To run the end-to-end examples and tests, `tvm`, `nnvm` and `topi` need to be added to your `PYTHONPATH`.
101+
*Note:* To run the end-to-end examples and tests, `tvm`, `nnvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when install individually.
102102

103103
## Supported TVM Functionalities
104104

@@ -108,11 +108,9 @@ One can use the following Python snippet to generate `add_gpu.so` which add two
108108

109109
```python
110110
import os
111-
112111
import tvm
113112
from tvm.contrib import cc
114113

115-
116114
def test_add(target_dir):
117115
if not tvm.module.enabled("cuda"):
118116
print(f"skip {__file__} because cuda is not enabled...")
@@ -121,9 +119,7 @@ def test_add(target_dir):
121119
A = tvm.placeholder((n,), name='A')
122120
B = tvm.placeholder((n,), name='B')
123121
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
124-
125122
s = tvm.create_schedule(C.op)
126-
127123
bx, tx = s[C].split(C.op.axis[0], factor=64)
128124
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
129125
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
@@ -144,7 +140,7 @@ if __name__ == "__main__":
144140

145141
### Run the Generated Shared Library
146142

147-
The following code snippet demonstrate how to load generated shared library (`add_gpu.so`).
143+
The following code snippet demonstrates how to load and test the generated shared library (`add_gpu.so`) in Rust.
148144

149145
```rust
150146
extern crate tvm_frontend as tvm;
@@ -174,14 +170,15 @@ fn main() {
174170
assert_eq!(ret.to_vec::<f32>().unwrap(), vec![6f32, 8.0]);
175171
}
176172
```
173+
177174
**Note:** it is required to instruct the `rustc` to link to the generated `add_gpu.so` in runtime, for example by
178175
`cargo:rustc-link-search=native=add_gpu`.
179176

180177
See the tests and examples custom `build.rs` for more details.
181178

182179
### Convert and Register a Rust Function as a TVM Packed Function
183180

184-
One can you the `register_global_func!` macro to convert and register a Rust's
181+
One can use `register_global_func!` macro to convert and register a Rust
185182
function of type `fn(&[TVMArgValue]) -> Result<TVMRetValue>` to a global TVM **packed function** as follows
186183

187184
```rust
@@ -210,12 +207,12 @@ fn main() {
210207
let mut data = vec![3f32, 4.0];
211208
let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float"));
212209
arr.copy_from_buffer(data.as_mut_slice());
213-
214210
let mut registered = function::Builder::default();
215211
registered
216212
.get_function("sum", true)
217213
.arg(&arr)
218214
.arg(&arr);
215+
219216
assert_eq!(registered.invoke().unwrap().to_float(), 14f64);
220217
}
221218
```

run_tests.sh

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
source activate py36
2+
3+
cargo build
4+
cargo test
5+
6+
cd tests/basics
7+
cargo build --features cpu
8+
cargo run --features cpu
9+
if [ $(which nvcc) ]; then
10+
cargo build --features gpu
11+
cargo run --features gpu
12+
fi
13+
cd -
14+
15+
cd tests/callback
16+
cargo build
17+
cargo run --bin int
18+
cargo run --bin float
19+
cargo run --bin array
20+
cargo run --bin string
21+
cargo run --bin error
22+
cd -
23+
24+
cd examples/resnet
25+
cargo build
26+
cargo run
27+
cd -
28+

src/bytearray.rs

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
//! Provides [`TVMByteArray`] which is used for passing the model parameters
1+
//! Provides [`TVMByteArray`] used for passing the model parameters
22
//! (stored as byte-array) to a runtime module.
33
//!
4-
//! This function can be obtained from a graph runtime module loading the model.
54
//! For more detail, please see the example `resnet` in `examples` repository.
65
76
use std::os::raw::c_char;
87

98
use ts;
109

11-
/// A struct holding the TVM byte-array.
10+
/// A struct holding TVM byte-array.
1211
///
1312
/// ## Example
1413
///

src/context.rs

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//! Provides [`TVMContext`] and related device specific informations.
1+
//! Provides [`TVMContext`] and related device specific queries.
22
//!
33
//! Create a new context by device type (cpu is 1) and device id.
44
//!
@@ -9,7 +9,8 @@
99
//! let cpu0 = TVMContext::cpu(0);
1010
//! assert_eq!(ctx, cpu0);
1111
//! ```
12-
//! Or from supported device name.
12+
//!
13+
//! Or from a supported device name.
1314
//!
1415
//! ```
1516
//! let cpu0 = TVMContext::from("cpu");
@@ -27,8 +28,8 @@ use internal_api;
2728
use ts;
2829
use Result;
2930

30-
/// Device type which can be from a supported device name. See the supported devices
31-
/// in [TVM](https://github.com/dmlc/tvm) project.
31+
/// Device type can be from a supported device name. See the supported devices
32+
/// in [TVM](https://github.com/dmlc/tvm).
3233
///
3334
/// ## Example
3435
///
@@ -128,7 +129,8 @@ impl<'a> From<&'a str> for TVMDeviceType {
128129
/// assert!(ctx.exist());
129130
///
130131
/// ```
131-
/// It is possible to query the context and get information such as
132+
///
133+
/// It is possible to query the underlying context as follows
132134
///
133135
/// ```
134136
/// println!("maximun threads per block: {}", ctx.max_threads_per_block());

src/errors.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use rust_ndarray;
77
error_chain!{
88
errors {
99
EmptyArray {
10-
description("cannot convert from empty array")
10+
description("cannot convert from an empty array")
1111
}
1212

1313
NullHandle(name: String) {
@@ -31,6 +31,7 @@ error_chain!{
3131
}
3232

3333
}
34+
3435
foreign_links {
3536
ShapeError(rust_ndarray::ShapeError);
3637
NulError(ffi::NulError);

src/function.rs

+24-16
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
//! This module provides idiomatic Rust API for creating and working with TVM functions.
1+
//! This module provides an idiomatic Rust API for creating and working with TVM functions.
22
//!
3-
//! For calling an already registered TVM function use [`function::Builder`] and to register
4-
//! a TVM packed function from Rust either use [`function::register`] or the
5-
//! macro [`register_global_func`].
3+
//! For calling an already registered TVM function use [`function::Builder`]
4+
//! To register a TVM packed function from Rust side either
5+
//! use [`function::register`] or the macro [`register_global_func`].
66
//!
7-
//! See the `tests` and `examples` repository for usage examples.
7+
//! See the tests and examples repository for more examples.
88
99
use std::{
1010
ffi::{CStr, CString},
@@ -60,8 +60,9 @@ pub fn get_global_func(name: &str, is_global: bool) -> Option<Function> {
6060
}
6161

6262
/// Wrapper around TVM function handle which includes `is_global`
63-
/// indicating whether the function is global or not and `is_released`
64-
/// to help for dropping the function handle.
63+
/// indicating whether the function is global or not, `is_released`
64+
/// to hint dropping the function handle and `is_cloned` showing
65+
/// not to drop a cloned function from Rust side.
6566
/// The value of these fields can be accessed through their respective methods.
6667
#[derive(Debug, Hash)]
6768
pub struct Function {
@@ -106,6 +107,12 @@ impl Function {
106107
pub fn is_released(&self) -> bool {
107108
self.is_released
108109
}
110+
111+
/// Returns `true` if the underlying TVM function has been cloned
112+
/// from the frontend and `false` otherwise.
113+
pub fn is_cloned(&self) -> bool {
114+
self.is_cloned
115+
}
109116
}
110117

111118
impl Clone for Function {
@@ -133,7 +140,8 @@ impl Drop for Function {
133140
}
134141

135142
/// Function builder in order to create and call functions.
136-
/// *Note:* Currently TVM functions accept at most one return value.
143+
///
144+
/// *Note:* Currently TVM functions accept *at most* one return value.
137145
#[derive(Debug, Clone, Default)]
138146
pub struct Builder<'a> {
139147
pub func: Option<Function>,
@@ -159,7 +167,7 @@ impl<'a> Builder<'a> {
159167
self
160168
}
161169

162-
/// Pushes a [`TVMArgValue`] into the function.
170+
/// Pushes a [`TVMArgValue`] into the function argument buffer.
163171
pub fn arg<'b, T: ?Sized>(&mut self, arg: &'b T) -> &mut Self
164172
where
165173
TVMValue: From<&'b T>,
@@ -181,7 +189,7 @@ impl<'a> Builder<'a> {
181189
self
182190
}
183191

184-
/// Pushes multiple [`TVMArgValue`]s into the function.
192+
/// Pushes multiple [`TVMArgValue`]s into the function argument buffer.
185193
pub fn args<'b, T: 'b + ?Sized, I>(&mut self, args: I) -> &mut Self
186194
where
187195
I: IntoIterator<Item = &'b T>,
@@ -195,7 +203,7 @@ impl<'a> Builder<'a> {
195203
}
196204

197205
/// Sets an output for a function that requirs a mutable output to be provided.
198-
/// See the `basics` in `tests` for an example.
206+
/// See the `basics` in tests for an example.
199207
pub fn set_output<'b, T: 'b + ?Sized>(&mut self, arg: &'b mut T) -> &mut Self
200208
where
201209
TVMValue: From<&'b T>,
@@ -215,7 +223,7 @@ impl<'a> Builder<'a> {
215223
self
216224
}
217225

218-
/// Calls the function that created from builder.
226+
/// Calls the function that created from `Builder`.
219227
pub fn invoke(&mut self) -> Result<TVMRetValue> {
220228
self.clone()(())
221229
}
@@ -356,7 +364,7 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function
356364
}
357365

358366
/// Registers a Rust function with signature
359-
/// `fn(&[TVMArgValue]) -> Result<TVMRetValue<'static>>`
367+
/// `fn(&[TVMArgValue]) -> Result<TVMRetValue>`
360368
/// as a **global TVM packed function** from frontend to TVM backend.
361369
///
362370
/// Use [`register_global_func`] if overriding an existing global TVM function
@@ -365,7 +373,7 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function
365373
/// ## Example
366374
///
367375
/// ```
368-
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue<'static>> {
376+
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
369377
/// let mut ret = 0;
370378
/// for arg in args.iter() {
371379
/// ret += arg.to_int();
@@ -438,8 +446,8 @@ macro_rules! register_global_func {
438446
}}
439447
}
440448

441-
/// Convenient macro for calling TVM packed functions by providing
442-
/// function identifier and the arguments. This macro outputs a `Result`
449+
/// Convenient macro for calling TVM packed functions by providing a
450+
/// function identifier and some arguments. This macro outputs a `Result` type
443451
/// and let user to perform proper error handling.
444452
///
445453
/// **Note**: this macro does *not* expect an outside mutable output. To

src/internal_api.rs

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::{cell::RefCell, collections::HashMap};
22

33
use Function;
44

5+
// access TVM internal API
56
thread_local! {
67
pub(crate) static API: RefCell<HashMap<String, Function>> = RefCell::new(HashMap::new());
78
}

src/lib.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
//! [TVM](https://github.com/dmlc/tvm) is a compiler stack for deep learning systems.
22
//!
3-
//! This crate provides idiomatic Rust API for TVM runtime frontend.
3+
//! This crate provides an idiomatic Rust API for TVM runtime frontend.
44
//!
5-
//! One particular usage is that given an optimized deep learning model,
6-
//! compiled with TVM, one can load the model artifacts which includes a shared library
7-
//! `lib.so`, `graph.json` and byte-array `param.params`
8-
//! in Rust to create a runtime, run the model for some inputs and get the
5+
//! One particular use case is that given optimized deep learning model artifacts,
6+
//! already compiled with TVM, which include a shared library
7+
//! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them
8+
//! in Rust to create a graph runtime. Then, run the model for some inputs and get the
99
//! desired predictions all in Rust.
1010
//!
1111
//! Checkout the `examples` repository for more details.

src/module.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//! Provides [`Module`] type and methods for working with runtime TVM modules.
1+
//! Provides the [`Module`] type and methods for working with runtime TVM modules.
22
33
use std::{
44
ffi::CString,
@@ -17,9 +17,9 @@ use Result;
1717

1818
const ENTRY_FUNC: &'static str = "__tvm_main__";
1919

20-
/// Wrapper around TVM module handle which contains an entry function
21-
/// which can be applied to an imported module through [`entry_func`]
22-
/// and to check whether the module has be dropped use [`is_released`].
20+
/// Wrapper around TVM module handle which contains an entry function.
21+
/// The entry function can be applied to an imported module through [`entry_func`].
22+
/// Also [`is_released`] shows whether the module is dropped or not.
2323
///
2424
/// [`entry_func`]:struct.Module.html#method.entry_func
2525
/// [`is_released`]:struct.Module.html#method.is_released

0 commit comments

Comments
 (0)