Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Rust] Unify types between bindings and pure Rust impl #2616

Merged
merged 17 commits into from
Apr 3, 2019
4 changes: 4 additions & 0 deletions rust/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
target/
*.rs.bk
Cargo.lock
c_runtime_api.rs
4 changes: 0 additions & 4 deletions rust/common/.gitignore

This file was deleted.

10 changes: 6 additions & 4 deletions rust/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ authors = ["TVM Contributors"]
license = "Apache-2.0"

[features]
runtime = []
frontend = ["tvm-sys"]
bindings = []

[dependencies]
error-chain = { version = "0.12.0", default-features = false }
tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true }
failure = "0.1.5"
ndarray = "0.12.1"

[build-dependencies]
bindgen = "0.37.4"
31 changes: 31 additions & 0 deletions rust/common/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
extern crate bindgen;

use std::path::PathBuf;

fn main() {
if cfg!(feature = "bindings") {
println!("cargo:rerun-if-env-changed=TVM_HOME");
println!("cargo:rustc-link-lib=dylib=tvm_runtime");
println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
}

// @see rust-bindgen#550 for `blacklist_type`
bindgen::Builder::default()
.header(format!(
"{}/include/tvm/runtime/c_runtime_api.h",
env!("TVM_HOME")
))
.header(format!(
"{}/include/tvm/runtime/c_backend_api.h",
env!("TVM_HOME")
))
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
.blacklist_type("max_align_t")
.layout_tests(false)
.derive_partialeq(true)
.derive_eq(true)
.generate()
.expect("unable to generate bindings")
.write_to_file(PathBuf::from("src/c_runtime_api.rs"))
.expect("can not write the bindings!");
}
128 changes: 128 additions & 0 deletions rust/common/src/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use std::{
any::TypeId,
mem,
os::raw::{c_int, c_void},
};

use crate::ffi::{
DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt,
DLDeviceType_kDLCPU, DLTensor,
};

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct DataType {
pub code: usize,
pub bits: usize,
pub lanes: usize,
}

impl DataType {
/// Returns the number of bytes occupied by an element of this `DataType`.
pub fn itemsize(&self) -> usize {
(self.bits * self.lanes) >> 3
}

/// Returns whether this `DataType` represents primitive type `T`.
pub fn is_type<T: 'static>(&self) -> bool {
if self.lanes != 1 {
return false;
}
let typ = TypeId::of::<T>();
(typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
|| (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
|| (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
|| (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
|| (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
|| (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
}

pub fn code(&self) -> usize {
self.code
}

pub fn bits(&self) -> usize {
self.bits
}

pub fn lanes(&self) -> usize {
self.lanes
}
}

impl<'a> From<&'a DataType> for DLDataType {
fn from(dtype: &'a DataType) -> Self {
Self {
code: dtype.code as u8,
bits: dtype.bits as u8,
lanes: dtype.lanes as u16,
}
}
}

impl From<DLDataType> for DataType {
fn from(dtype: DLDataType) -> Self {
Self {
code: dtype.code as usize,
bits: dtype.bits as usize,
lanes: dtype.lanes as usize,
}
}
}

#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TVMContext {
pub device_type: usize,
pub device_id: usize,
}

impl<'a> From<&'a TVMContext> for DLContext {
fn from(ctx: &'a TVMContext) -> Self {
Self {
device_type: ctx.device_type as u32,
device_id: ctx.device_id as i32,
}
}
}

impl Default for TVMContext {
fn default() -> Self {
Self {
device_type: DLDeviceType_kDLCPU as usize,
device_id: 0,
}
}
}

/// `From` conversions to `DLTensor` for `ndarray::Array`.
/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
macro_rules! impl_dltensor_from_ndarray {
($type:ty, $typecode:expr) => {
impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
DLTensor {
data: arr.as_mut_ptr() as *mut c_void,
ctx: DLContext {
device_type: DLDeviceType_kDLCPU,
device_id: 0,
},
ndim: arr.ndim() as c_int,
dtype: DLDataType {
code: $typecode as u8,
bits: 8 * mem::size_of::<$type>() as u8,
lanes: 1,
},
shape: arr.shape().as_ptr() as *const i64 as *mut i64,
strides: arr.strides().as_ptr() as *const isize as *mut i64,
byte_offset: 0,
}
}
}
};
}

impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
Loading