Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RUST][FRONTEND] Add rust frontend v0.1 #2292

Merged
merged 20 commits into from
Feb 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions rust/.rustfmt.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
max_width = 100
hard_tabs = false
tab_spaces = 2
tab_spaces = 4
newline_style = "Auto"
use_small_heuristics = "Default"
indent_style = "Block"
Expand Down Expand Up @@ -38,7 +38,7 @@ trailing_comma = "Vertical"
match_block_trailing_comma = false
blank_lines_upper_bound = 1
blank_lines_lower_bound = 0
edition = "2015"
edition = "2018"
merge_derives = true
use_try_shorthand = true
use_field_init_shorthand = false
Expand All @@ -50,8 +50,8 @@ unstable_features = false
disable_all_formatting = false
skip_children = false
hide_parse_errors = false
error_on_line_overflow = false
error_on_unformatted = false
error_on_line_overflow = true
error_on_unformatted = true
report_todo = "Never"
report_fixme = "Never"
ignore = []
Expand Down
39 changes: 11 additions & 28 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,28 +1,11 @@
[package]
name = "tvm"
version = "0.1.0"
license = "Apache-2.0"
description = "TVM Rust runtime"
repository = "https://github.com/dmlc/tvm"
readme = "README.md"
keywords = ["tvm", "nnvm"]
categories = ["api-bindings", "science"]
authors = ["TVM Contributors"]

[features]
default = ["nom/std"]
sgx = ["nom/alloc"]

[dependencies]
bounded-spsc-queue = "0.4.0"
error-chain = { version = "0.12.0", default-features = false }
itertools = "0.7.8"
lazy_static = "1.1.0"
ndarray = "0.11.2"
nom = {version = "4.0.0", default-features = false }
serde = "1.0.59"
serde_derive = "1.0.79"
serde_json = "1.0.17"

[target.'cfg(not(target_env = "sgx"))'.dependencies]
num_cpus = "1.8.0"
[workspace]
members = [
"common",
"runtime",
"runtime/tests/test_tvm_basic",
"runtime/tests/test_nnvm",
"frontend",
"frontend/tests/basics",
"frontend/tests/callback",
"frontend/examples/resnet"
]
4 changes: 4 additions & 0 deletions rust/common/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
target
**/*.rs.bk
Cargo.lock
/tvm-sys/src/bindgen.rs
13 changes: 13 additions & 0 deletions rust/common/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "tvm-common"
version = "0.1.0"
authors = ["TVM Contributors"]
license = "Apache-2.0"

[features]
runtime = []
frontend = ["tvm-sys"]

[dependencies]
error-chain = { version = "0.12.0", default-features = false }
tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true }
File renamed without changes.
15 changes: 15 additions & 0 deletions rust/common/src/errors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//! Error types for `TVMArgValue` and `TVMRetValue` conversions.

error_chain! {
errors {
TryFromTVMArgValueError(expected: String, actual: String) {
description("mismatched types while converting from TVMArgValue")
display("expected `{}` but given `{}`", expected, actual)
}

TryFromTVMRetValueError(expected: String, actual: String) {
description("mismatched types while downcasting TVMRetValue")
display("invalid downcast: expected `{}` but given `{}`", expected, actual)
}
}
}
39 changes: 39 additions & 0 deletions rust/common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//! This crate contains the refactored basic components required
//! for `runtime` and `frontend` TVM crates.

#![crate_name = "tvm_common"]
#![recursion_limit = "1024"]
#![allow(non_camel_case_types, unused_imports)]
#![feature(box_syntax, try_from)]

#[macro_use]
extern crate error_chain;

/// Unified ffi module for both runtime and frontend crates.
pub mod ffi {
#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)]

#[cfg(feature = "frontend")]
pub extern crate tvm_sys as ts;

#[cfg(feature = "runtime")]
pub mod runtime {
use std::os::raw::{c_char, c_int, c_void};

include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));

pub type BackendPackedCFunc = extern "C" fn(
args: *const TVMValue,
type_codes: *const c_int,
num_args: c_int,
) -> c_int;
}
}

pub mod errors;
pub mod ty;
pub mod value;

pub use errors::*;
pub use ty::TVMTypeCode;
pub use value::{TVMArgValue, TVMRetValue, TVMValue};
144 changes: 144 additions & 0 deletions rust/common/src/ty.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
//! This module containes `TVMTypeCode` and `TVMType` with some conversion methods.
//!
//! # Example
//!
//! ```
//! let dtype = TVMType::from("float");
//! println!("dtype is: {}", dtype);
//! ```

use std::{
ffi::{CStr, CString},
fmt::{self, Display, Formatter},
};

/// TVM type codes.
#[repr(u32)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum TVMTypeCode {
kDLInt = 0,
kDLUInt = 1,
kDLFloat = 2,
kHandle = 3,
kNull = 4,
kTVMType = 5,
kTVMContext = 6,
kArrayHandle = 7,
kNodeHandle = 8,
kModuleHandle = 9,
kFuncHandle = 10,
kStr = 11,
kBytes = 12,
kNDArrayContainer = 13,
}

impl Default for TVMTypeCode {
fn default() -> Self {
TVMTypeCode::kDLInt
}
}

impl From<TVMTypeCode> for i64 {
fn from(arg: TVMTypeCode) -> i64 {
match arg {
TVMTypeCode::kDLInt => 0,
TVMTypeCode::kDLUInt => 1,
TVMTypeCode::kDLFloat => 2,
TVMTypeCode::kHandle => 3,
TVMTypeCode::kNull => 4,
TVMTypeCode::kTVMType => 5,
TVMTypeCode::kTVMContext => 6,
TVMTypeCode::kArrayHandle => 7,
TVMTypeCode::kNodeHandle => 8,
TVMTypeCode::kModuleHandle => 9,
TVMTypeCode::kFuncHandle => 10,
TVMTypeCode::kStr => 11,
TVMTypeCode::kBytes => 12,
TVMTypeCode::kNDArrayContainer => 13,
}
}
}

impl Into<TVMTypeCode> for i64 {
fn into(self) -> TVMTypeCode {
match self {
0 => TVMTypeCode::kDLInt,
1 => TVMTypeCode::kDLUInt,
2 => TVMTypeCode::kDLFloat,
3 => TVMTypeCode::kHandle,
4 => TVMTypeCode::kNull,
5 => TVMTypeCode::kTVMType,
6 => TVMTypeCode::kTVMContext,
7 => TVMTypeCode::kArrayHandle,
8 => TVMTypeCode::kNodeHandle,
9 => TVMTypeCode::kModuleHandle,
10 => TVMTypeCode::kFuncHandle,
11 => TVMTypeCode::kStr,
12 => TVMTypeCode::kBytes,
13 => TVMTypeCode::kNDArrayContainer,
_ => unreachable!(),
}
}
}

impl Display for TVMTypeCode {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"{}",
match self {
TVMTypeCode::kDLInt => "int",
TVMTypeCode::kDLUInt => "uint",
TVMTypeCode::kDLFloat => "float",
TVMTypeCode::kHandle => "handle",
TVMTypeCode::kNull => "null",
TVMTypeCode::kTVMType => "TVM type",
TVMTypeCode::kTVMContext => "TVM context",
TVMTypeCode::kArrayHandle => "Array handle",
TVMTypeCode::kNodeHandle => "Node handle",
TVMTypeCode::kModuleHandle => "Module handle",
TVMTypeCode::kFuncHandle => "Function handle",
TVMTypeCode::kStr => "string",
TVMTypeCode::kBytes => "bytes",
TVMTypeCode::kNDArrayContainer => "ndarray container",
}
)
}
}

macro_rules! impl_prim_type {
($type:ty, $variant:ident) => {
impl<'a> From<&'a $type> for TVMTypeCode {
fn from(_arg: &$type) -> Self {
TVMTypeCode::$variant
}
}

impl<'a> From<&'a mut $type> for TVMTypeCode {
fn from(_arg: &mut $type) -> Self {
TVMTypeCode::$variant
}
}
};
}

impl_prim_type!(usize, kDLInt);
impl_prim_type!(i64, kDLInt);
impl_prim_type!(i32, kDLInt);
impl_prim_type!(i16, kDLInt);
impl_prim_type!(i8, kDLInt);

impl_prim_type!(u64, kDLUInt);
impl_prim_type!(u32, kDLUInt);
impl_prim_type!(u16, kDLUInt);
impl_prim_type!(u8, kDLUInt);

impl_prim_type!(f64, kDLFloat);
impl_prim_type!(f32, kDLFloat);

impl_prim_type!(str, kStr);
impl_prim_type!(CStr, kStr);
impl_prim_type!(String, kStr);
impl_prim_type!(CString, kStr);

impl_prim_type!([u8], kBytes);
Loading