Skip to content

Commit 4968279

Browse files
nhynestqchen
authored andcommitted
[Rust] Unify types between bindings and pure Rust impl (#2616)
1 parent 71abe36 commit 4968279

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1371
-2741
lines changed

rust/.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
target/
2+
*.rs.bk
3+
Cargo.lock
4+
c_runtime_api.rs

rust/common/.gitignore

-4
This file was deleted.

rust/common/Cargo.toml

+6-4
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ authors = ["TVM Contributors"]
55
license = "Apache-2.0"
66

77
[features]
8-
runtime = []
9-
frontend = ["tvm-sys"]
8+
bindings = []
109

1110
[dependencies]
12-
error-chain = { version = "0.12.0", default-features = false }
13-
tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true }
11+
failure = "0.1.5"
12+
ndarray = "0.12.1"
13+
14+
[build-dependencies]
15+
bindgen = "0.37.4"

rust/common/build.rs

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
extern crate bindgen;
2+
3+
use std::path::PathBuf;
4+
5+
fn main() {
6+
if cfg!(feature = "bindings") {
7+
println!("cargo:rerun-if-env-changed=TVM_HOME");
8+
println!("cargo:rustc-link-lib=dylib=tvm_runtime");
9+
println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
10+
}
11+
12+
// @see rust-bindgen#550 for `blacklist_type`
13+
bindgen::Builder::default()
14+
.header(format!(
15+
"{}/include/tvm/runtime/c_runtime_api.h",
16+
env!("TVM_HOME")
17+
))
18+
.header(format!(
19+
"{}/include/tvm/runtime/c_backend_api.h",
20+
env!("TVM_HOME")
21+
))
22+
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
23+
.blacklist_type("max_align_t")
24+
.layout_tests(false)
25+
.derive_partialeq(true)
26+
.derive_eq(true)
27+
.generate()
28+
.expect("unable to generate bindings")
29+
.write_to_file(PathBuf::from("src/c_runtime_api.rs"))
30+
.expect("can not write the bindings!");
31+
}

rust/common/src/array.rs

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
use std::{
2+
any::TypeId,
3+
mem,
4+
os::raw::{c_int, c_void},
5+
};
6+
7+
use crate::ffi::{
8+
DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt,
9+
DLDeviceType_kDLCPU, DLTensor,
10+
};
11+
12+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13+
pub struct DataType {
14+
pub code: usize,
15+
pub bits: usize,
16+
pub lanes: usize,
17+
}
18+
19+
impl DataType {
20+
/// Returns the number of bytes occupied by an element of this `DataType`.
21+
pub fn itemsize(&self) -> usize {
22+
(self.bits * self.lanes) >> 3
23+
}
24+
25+
/// Returns whether this `DataType` represents primitive type `T`.
26+
pub fn is_type<T: 'static>(&self) -> bool {
27+
if self.lanes != 1 {
28+
return false;
29+
}
30+
let typ = TypeId::of::<T>();
31+
(typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
32+
|| (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
33+
|| (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
34+
|| (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
35+
|| (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
36+
|| (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
37+
}
38+
39+
pub fn code(&self) -> usize {
40+
self.code
41+
}
42+
43+
pub fn bits(&self) -> usize {
44+
self.bits
45+
}
46+
47+
pub fn lanes(&self) -> usize {
48+
self.lanes
49+
}
50+
}
51+
52+
impl<'a> From<&'a DataType> for DLDataType {
53+
fn from(dtype: &'a DataType) -> Self {
54+
Self {
55+
code: dtype.code as u8,
56+
bits: dtype.bits as u8,
57+
lanes: dtype.lanes as u16,
58+
}
59+
}
60+
}
61+
62+
impl From<DLDataType> for DataType {
63+
fn from(dtype: DLDataType) -> Self {
64+
Self {
65+
code: dtype.code as usize,
66+
bits: dtype.bits as usize,
67+
lanes: dtype.lanes as usize,
68+
}
69+
}
70+
}
71+
72+
#[derive(Debug, Clone, Copy, PartialEq)]
73+
pub struct TVMContext {
74+
pub device_type: usize,
75+
pub device_id: usize,
76+
}
77+
78+
impl<'a> From<&'a TVMContext> for DLContext {
79+
fn from(ctx: &'a TVMContext) -> Self {
80+
Self {
81+
device_type: ctx.device_type as u32,
82+
device_id: ctx.device_id as i32,
83+
}
84+
}
85+
}
86+
87+
impl Default for TVMContext {
88+
fn default() -> Self {
89+
Self {
90+
device_type: DLDeviceType_kDLCPU as usize,
91+
device_id: 0,
92+
}
93+
}
94+
}
95+
96+
/// `From` conversions to `DLTensor` for `ndarray::Array`.
97+
/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
98+
macro_rules! impl_dltensor_from_ndarray {
99+
($type:ty, $typecode:expr) => {
100+
impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
101+
fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
102+
DLTensor {
103+
data: arr.as_mut_ptr() as *mut c_void,
104+
ctx: DLContext {
105+
device_type: DLDeviceType_kDLCPU,
106+
device_id: 0,
107+
},
108+
ndim: arr.ndim() as c_int,
109+
dtype: DLDataType {
110+
code: $typecode as u8,
111+
bits: 8 * mem::size_of::<$type>() as u8,
112+
lanes: 1,
113+
},
114+
shape: arr.shape().as_ptr() as *const i64 as *mut i64,
115+
strides: arr.strides().as_ptr() as *const isize as *mut i64,
116+
byte_offset: 0,
117+
}
118+
}
119+
}
120+
};
121+
}
122+
123+
impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
124+
impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
125+
impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
126+
impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
127+
impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
128+
impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);

0 commit comments

Comments
 (0)