diff --git a/.gitignore b/.gitignore index 4b53d51eff..c9bb1f7484 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ *~ Cargo.lock -/target -/libtensorflow-sys/target +target diff --git a/Cargo.toml b/Cargo.toml index 344401bebe..7634e290e9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,9 +9,8 @@ readme = "README.md" repository = "https://github.com/google/tensorflow-rust" [dependencies] -libc = "^0.2" -libtensorflow-sys = { version = "0.0.1", path = "libtensorflow-sys" } +libc = "0.2" +tensorflow-sys = { version = "0.4", path = "tensorflow-sys" } [features] -default = [] tensorflow_unstable = [] diff --git a/src/lib.rs b/src/lib.rs index e960532d43..5245facffc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,21 +3,19 @@ #![cfg(feature = "tensorflow_unstable")] extern crate libc; -extern crate libtensorflow_sys; +extern crate tensorflow_sys as tf; +use libc::{c_char, c_int, c_uint, c_void, size_t}; use std::ffi::CStr; use std::ffi::CString; use std::ffi::NulError; -use std::fmt; use std::fmt::Debug; use std::fmt::Display; use std::fmt::Formatter; +use std::fmt; use std::marker; use std::mem; use std::ops::Drop; -use std::os::raw; - -use libtensorflow_sys as tf; mod buffer; pub use buffer::Buffer; @@ -69,13 +67,13 @@ macro_rules! c_enum { #[doc = $doc] #[derive(PartialEq,Eq,PartialOrd,Ord,Debug)] pub enum $enum_name { - UnrecognizedEnumValue(raw::c_uint), + UnrecognizedEnumValue(c_uint), $($name),* } impl $enum_name { #[allow(dead_code)] - fn from_int(value: raw::c_uint) -> $enum_name { + fn from_int(value: c_uint) -> $enum_name { match value { $($num => $enum_name::$name,)* c => $enum_name::UnrecognizedEnumValue(c), @@ -83,7 +81,7 @@ macro_rules! c_enum { } #[allow(dead_code)] - fn to_int(&self) -> raw::c_uint { + fn to_int(&self) -> c_uint { match self { &$enum_name::UnrecognizedEnumValue(c) => c, $(&$enum_name::$name => $num),* @@ -262,7 +260,7 @@ impl SessionOptions { pub fn set_config(&mut self, config: &[u8]) -> Result<()> { let status = Status::new(); unsafe { - tf::TF_SetConfig(self.inner, config.as_ptr() as *const raw::c_void, config.len(), status.inner); + tf::TF_SetConfig(self.inner, config.as_ptr() as *const _, config.len(), status.inner); } if status.is_ok() { Ok(()) @@ -309,7 +307,7 @@ impl Session { pub fn extend_graph(&mut self, proto: &[u8]) -> Result<()> { let status = Status::new(); unsafe { - tf::TF_ExtendGraph(self.inner, proto.as_ptr() as *const raw::c_void, proto.len(), status.inner); + tf::TF_ExtendGraph(self.inner, proto.as_ptr() as *const _, proto.len(), status.inner); } status.as_result() } @@ -323,16 +321,15 @@ impl Session { unsafe { let mut dims = Vec::with_capacity(tf::TF_NumDims(input_tensor) as usize); for i in 0..dims.capacity() { - dims.push(tf::TF_Dim(input_tensor, i as i32)); + dims.push(tf::TF_Dim(input_tensor, i as c_int)); } - input_tensors.push(tf::TF_NewTensor( - tf::TF_TensorType(input_tensor), - dims.as_ptr() as *mut i64, - dims.len() as libc::c_int, - tf::TF_TensorData(input_tensor), - tf::TF_TensorByteSize(input_tensor), - Some(noop_deallocator), - std::ptr::null_mut())); + input_tensors.push(tf::TF_NewTensor(tf::TF_TensorType(input_tensor), + dims.as_mut_ptr(), + dims.len() as c_int, + tf::TF_TensorData(input_tensor), + tf::TF_TensorByteSize(input_tensor), + Some(noop_deallocator), + std::ptr::null_mut())); } } @@ -346,12 +343,12 @@ impl Session { std::ptr::null(), step.input_name_ptrs.as_mut_ptr(), input_tensors.as_mut_ptr(), - input_tensors.len() as raw::c_int, + input_tensors.len() as c_int, step.output_name_ptrs.as_mut_ptr(), step.output_tensors.as_mut_ptr(), - step.output_tensors.len() as raw::c_int, + step.output_tensors.len() as c_int, step.target_name_ptrs.as_mut_ptr(), - step.target_name_ptrs.len() as raw::c_int, + step.target_name_ptrs.len() as c_int, std::ptr::null_mut(), status.inner); }; @@ -375,15 +372,15 @@ impl Drop for Session { /// adding some inputs to it, requesting some outputs, passing it to `Session::run` /// and then taking the outputs out of it. pub struct Step<'l> { - input_name_ptrs: Vec<*const raw::c_char>, + input_name_ptrs: Vec<*const c_char>, input_name_c_strings: Vec, input_tensors: Vec<*mut tf::TF_Tensor>, - output_name_ptrs: Vec<*const raw::c_char>, + output_name_ptrs: Vec<*const c_char>, output_name_c_strings: Vec, output_tensors: Vec<*mut tf::TF_Tensor>, - target_name_ptrs: Vec<*const raw::c_char>, + target_name_ptrs: Vec<*const c_char>, target_name_c_strings: Vec, phantom: marker::PhantomData<&'l ()>, @@ -579,10 +576,7 @@ pub struct Tensor { dims: Vec, } -unsafe extern "C" fn noop_deallocator(_data: *mut raw::c_void, - _len: ::libc::size_t, - _arg: *mut raw::c_void)-> () { -} +unsafe extern "C" fn noop_deallocator(_: *mut c_void, _: size_t, _: *mut c_void) -> () {} // TODO: Replace with Iterator::product once that's stable fn product(values: &[u64]) -> u64 { @@ -614,9 +608,9 @@ impl Tensor { } let inner = unsafe { tf::TF_NewTensor(mem::transmute(T::data_type().to_int()), - dims.as_ptr() as *mut i64, - dims.len() as i32, - data.as_ptr() as *mut raw::c_void, + dims.as_ptr() as *mut _, + dims.len() as c_int, + data.as_ptr() as *mut _, data.len(), Some(noop_deallocator), std::ptr::null_mut()) @@ -650,9 +644,9 @@ impl Tensor { } let mut dims = Vec::with_capacity(tf::TF_NumDims(tensor) as usize); for i in 0..dims.capacity() { - dims.push(tf::TF_Dim(tensor, i as raw::c_int) as u64); + dims.push(tf::TF_Dim(tensor, i as c_int) as u64); } - let data = Buffer::from_ptr(tf::TF_TensorData(tensor) as *mut T, product(&dims) as usize); + let data = Buffer::from_ptr(tf::TF_TensorData(tensor) as *mut _, product(&dims) as usize); Some(Tensor { inner: tensor, data: data, diff --git a/tensorflow-sys/Cargo.toml b/tensorflow-sys/Cargo.toml new file mode 100644 index 0000000000..c6bad4e704 --- /dev/null +++ b/tensorflow-sys/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "tensorflow-sys" +version = "0.4.0" +license = "Apache-2.0" +authors = [ + "Adam Crume ", + "Ivan Ukhov ", +] +description = "The package provides bindings to TensorFlow." +documentation = "https://google.github.io/tensorflow-rust" +homepage = "https://github.com/google/tensorflow-rust/tensorflow-sys" +repository = "https://github.com/google/tensorflow-rust" +build = "build.rs" +links = "tensorflow" + +[dependencies] +libc = "0.2" + +[build-dependencies] +pkg-config = "0.3" diff --git a/tensorflow-sys/README.md b/tensorflow-sys/README.md new file mode 100644 index 0000000000..ef500b437c --- /dev/null +++ b/tensorflow-sys/README.md @@ -0,0 +1,24 @@ +# tensorflow-sys [![Version][version-icon]][version-page] + +The package provides bindings to [TensorFlow][tensorflow]. + +## Requirements + +The build prerequisites can be found on the [corresponding +page][tensorflow-setup] of TensorFlow’s documentation. In particular, +[Bazel][bazel], [NumPy][numpy], and [SWIG][swig] are assumed to be installed. + +## Configuration + +The compilation process is configured via a number of environment variables, all +of which can be found in TensorFlow’s [configure][tensorflow-configure] script. +In particular, `TF_NEED_CUDA` is used to indicate if GPU support is needed. + +[bazel]: http://www.bazel.io +[numpy]: http://www.numpy.org +[swig]: http://www.swig.org +[tensorflow]: https://www.tensorflow.org +[tensorflow-configure]: https://github.com/tensorflow/tensorflow/blob/r0.9/configure +[tensorflow-setup]: https://www.tensorflow.org/versions/r0.9/get_started/os_setup.html +[version-icon]: https://img.shields.io/crates/v/tensorflow-sys.svg +[version-page]: https://crates.io/crates/tensorflow-sys diff --git a/tensorflow-sys/build.rs b/tensorflow-sys/build.rs new file mode 100644 index 0000000000..1e1073b03d --- /dev/null +++ b/tensorflow-sys/build.rs @@ -0,0 +1,48 @@ +extern crate pkg_config; + +use std::path::{Path, PathBuf}; +use std::process::Command; +use std::{env, fs}; + +const LIBRARY: &'static str = "tensorflow"; +const REPOSITORY: &'static str = "https://github.com/tensorflow/tensorflow.git"; +const TARGET: &'static str = "libtensorflow.so"; +const VERSION: &'static str = "0.9.0"; + +macro_rules! get(($name:expr) => (ok!(env::var($name)))); +macro_rules! ok(($expression:expr) => ($expression.unwrap())); + +fn main() { + if pkg_config::find_library(LIBRARY).is_ok() { + return; + } + + let output = PathBuf::from(&get!("OUT_DIR")); + if !output.join(TARGET).exists() { + let source = PathBuf::from(&get!("CARGO_MANIFEST_DIR")).join("target/source"); + if !Path::new(&source.join(".git")).exists() { + run("git", |command| command.arg("clone") + .arg(format!("--branch=v{}", VERSION)) + .arg("--recursive") + .arg(REPOSITORY) + .arg(&source)); + } + run("./configure", |command| command.current_dir(&source)); + run("bazel", |command| command.current_dir(&source) + .arg("build") + .arg(format!("--jobs={}", get!("NUM_JOBS"))) + .arg("--compilation_mode=opt") + .arg(format!("{}:{}", LIBRARY, TARGET))); + ok!(fs::copy(source.join("bazel-bin").join(LIBRARY).join(TARGET), output.join(TARGET))); + } + + println!("cargo:rustc-link-lib=dylib={}", LIBRARY); + println!("cargo:rustc-link-search={}", output.display()); +} + +fn run(name: &str, mut configure: F) where F: FnMut(&mut Command) -> &mut Command { + let mut command = Command::new(name); + if !ok!(configure(&mut command).status()).success() { + panic!("failed to execute {:?}", command); + } +} diff --git a/tensorflow-sys/examples/assets/multiplication.pb b/tensorflow-sys/examples/assets/multiplication.pb new file mode 100644 index 0000000000..83d4ff76d2 Binary files /dev/null and b/tensorflow-sys/examples/assets/multiplication.pb differ diff --git a/tensorflow-sys/examples/assets/multiplication.py b/tensorflow-sys/examples/assets/multiplication.py new file mode 100644 index 0000000000..c478f6a06d --- /dev/null +++ b/tensorflow-sys/examples/assets/multiplication.py @@ -0,0 +1,10 @@ +import os +import tensorflow as tf + +a = tf.placeholder(tf.float32, name='a') +b = tf.placeholder(tf.float32, name='b') +c = tf.mul(a, b, name='c') + +definition = tf.Session().graph_def +directory = os.path.dirname(os.path.realpath(__file__)) +tf.train.write_graph(definition, directory, 'multiplication.pb', as_text=False) diff --git a/tensorflow-sys/examples/multiplication.rs b/tensorflow-sys/examples/multiplication.rs new file mode 100644 index 0000000000..49e2cb4666 --- /dev/null +++ b/tensorflow-sys/examples/multiplication.rs @@ -0,0 +1,102 @@ +extern crate libc; +extern crate tensorflow_sys as ffi; + +use libc::{c_int, c_longlong, c_void, size_t}; +use std::ffi::{CStr, CString}; +use std::path::Path; + +macro_rules! nonnull( + ($pointer:expr) => ({ + let pointer = $pointer; + assert!(!pointer.is_null()); + pointer + }); +); + +macro_rules! ok( + ($status:expr) => ({ + if ffi::TF_GetCode($status) != ffi::TF_OK { + panic!(CStr::from_ptr(ffi::TF_Message($status)).to_string_lossy().into_owned()); + } + }); +); + +fn main() { + use std::mem::size_of; + use std::ptr::{null, null_mut}; + use std::slice::from_raw_parts; + + unsafe { + let options = nonnull!(ffi::TF_NewSessionOptions()); + let status = nonnull!(ffi::TF_NewStatus()); + let session = nonnull!(ffi::TF_NewSession(options, status)); + + let graph = read("examples/assets/multiplication.pb"); // c = a * b + ffi::TF_ExtendGraph(session, graph.as_ptr() as *const _, graph.len() as size_t, status); + ok!(status); + + let mut input_names = vec![]; + let mut inputs = vec![]; + + let name = CString::new("a").unwrap(); + let mut data = vec![1f32, 2.0, 3.0]; + let mut dims = vec![data.len() as c_longlong]; + let tensor = nonnull!(ffi::TF_NewTensor(ffi::TF_FLOAT, dims.as_mut_ptr(), + dims.len() as c_int, data.as_mut_ptr() as *mut _, + data.len() as size_t, Some(noop), null_mut())); + + input_names.push(name.as_ptr()); + inputs.push(tensor); + + let name = CString::new("b").unwrap(); + let mut data = vec![4f32, 5.0, 6.0]; + let mut dims = vec![data.len() as c_longlong]; + let tensor = nonnull!(ffi::TF_NewTensor(ffi::TF_FLOAT, dims.as_mut_ptr(), + dims.len() as c_int, data.as_mut_ptr() as *mut _, + data.len() as size_t, Some(noop), null_mut())); + + input_names.push(name.as_ptr()); + inputs.push(tensor); + + let mut output_names = vec![]; + let mut outputs = vec![]; + + let name = CString::new("c").unwrap(); + + output_names.push(name.as_ptr()); + outputs.push(null_mut()); + + let mut target_names = vec![]; + + ffi::TF_Run(session, null(), input_names.as_mut_ptr(), inputs.as_mut_ptr(), + input_names.len() as c_int, output_names.as_mut_ptr(), outputs.as_mut_ptr(), + output_names.len() as c_int, target_names.as_mut_ptr(), + target_names.len() as c_int, null_mut(), status); + ok!(status); + + let tensor = nonnull!(outputs[0]); + let data = nonnull!(ffi::TF_TensorData(tensor)) as *const f32; + let data = from_raw_parts(data, ffi::TF_TensorByteSize(tensor) / size_of::()); + + assert_eq!(data, &[1.0 * 4.0, 2.0 * 5.0, 3.0 * 6.0]); + + ffi::TF_CloseSession(session, status); + + ffi::TF_DeleteTensor(tensor); + ffi::TF_DeleteSession(session, status); + ffi::TF_DeleteStatus(status); + ffi::TF_DeleteSessionOptions(options); + } + + unsafe extern "C" fn noop(_: *mut c_void, _: size_t, _: *mut c_void) {} +} + +fn read>(path: T) -> Vec { + use std::fs::File; + use std::io::Read; + + let mut buffer = vec![]; + let mut file = File::open(path).unwrap(); + file.read_to_end(&mut buffer).unwrap(); + buffer +} diff --git a/tensorflow-sys/src/lib.rs b/tensorflow-sys/src/lib.rs new file mode 100644 index 0000000000..960b6b2f6b --- /dev/null +++ b/tensorflow-sys/src/lib.rs @@ -0,0 +1,159 @@ +//! Binding to [TensorFlow][1]. +//! +//! [1]: https://www.tensorflow.org + +#![allow(non_camel_case_types)] + +extern crate libc; + +use libc::{c_char, c_int, c_longlong, c_void, size_t}; + +#[repr(C)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum TF_DataType { + TF_FLOAT = 1, + TF_DOUBLE = 2, + TF_INT32 = 3, + TF_UINT8 = 4, + TF_INT16 = 5, + TF_INT8 = 6, + TF_STRING = 7, + TF_COMPLEX64 = 8, + TF_INT64 = 9, + TF_BOOL = 10, + TF_QINT8 = 11, + TF_QUINT8 = 12, + TF_QINT32 = 13, + TF_BFLOAT16 = 14, + TF_QINT16 = 15, + TF_QUINT16 = 16, + TF_UINT16 = 17, + TF_COMPLEX128 = 18, + TF_HALF = 19, +} +pub use TF_DataType::*; + +#[repr(C)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum TF_Code { + TF_OK = 0, + TF_CANCELLED = 1, + TF_UNKNOWN = 2, + TF_INVALID_ARGUMENT = 3, + TF_DEADLINE_EXCEEDED = 4, + TF_NOT_FOUND = 5, + TF_ALREADY_EXISTS = 6, + TF_PERMISSION_DENIED = 7, + TF_UNAUTHENTICATED = 16, + TF_RESOURCE_EXHAUSTED = 8, + TF_FAILED_PRECONDITION = 9, + TF_ABORTED = 10, + TF_OUT_OF_RANGE = 11, + TF_UNIMPLEMENTED = 12, + TF_INTERNAL = 13, + TF_UNAVAILABLE = 14, + TF_DATA_LOSS = 15, +} +pub use TF_Code::*; + +#[derive(Clone, Copy, Debug)] +pub enum TF_Status {} + +#[repr(C)] +#[derive(Clone, Copy, Debug)] +pub struct TF_Buffer { + pub data: *const c_void, + pub length: size_t, + pub data_deallocator: Option, +} + +#[derive(Clone, Copy, Debug)] +pub enum TF_Library {} + +#[derive(Clone, Copy, Debug)] +pub enum TF_Tensor {} + +#[derive(Clone, Copy, Debug)] +pub enum TF_SessionOptions {} + +#[derive(Clone, Copy, Debug)] +pub enum TF_Session {} + +extern "C" { + pub fn TF_NewBufferFromString(proto: *const c_void, proto_len: size_t) -> *mut TF_Buffer; + + pub fn TF_NewBuffer() -> *mut TF_Buffer; + + pub fn TF_DeleteBuffer(buffer: *mut TF_Buffer); + + pub fn TF_GetBuffer(buffer: *mut TF_Buffer) -> TF_Buffer; + + pub fn TF_NewStatus() -> *mut TF_Status; + + pub fn TF_DeleteStatus(status: *mut TF_Status); + + pub fn TF_SetStatus(status: *mut TF_Status, code: TF_Code, msg: *const c_char); + + pub fn TF_GetCode(status: *const TF_Status) -> TF_Code; + + pub fn TF_Message(status: *const TF_Status) -> *const c_char; + + pub fn TF_NewTensor(datatype: TF_DataType, dims: *mut c_longlong, num_dims: c_int, + data: *mut c_void, len: size_t, + deallocator: Option, + deallocator_arg: *mut c_void) -> *mut TF_Tensor; + + pub fn TF_DeleteTensor(tensor: *mut TF_Tensor); + + pub fn TF_TensorType(tensor: *const TF_Tensor) -> TF_DataType; + + pub fn TF_NumDims(tensor: *const TF_Tensor) -> c_int; + + pub fn TF_Dim(tensor: *const TF_Tensor, dim_index: c_int) -> c_longlong; + + pub fn TF_TensorByteSize(tensor: *const TF_Tensor) -> size_t; + + pub fn TF_TensorData(tensor: *const TF_Tensor) -> *mut c_void; + + pub fn TF_NewSessionOptions() -> *mut TF_SessionOptions; + + pub fn TF_SetTarget(options: *mut TF_SessionOptions, target: *const c_char); + + pub fn TF_SetConfig(options: *mut TF_SessionOptions, proto: *const c_void, proto_len: size_t, + status: *mut TF_Status); + + pub fn TF_DeleteSessionOptions(options: *mut TF_SessionOptions); + + pub fn TF_NewSession(options: *const TF_SessionOptions, status: *mut TF_Status) + -> *mut TF_Session; + + pub fn TF_CloseSession(session: *mut TF_Session, status: *mut TF_Status); + + pub fn TF_DeleteSession(session: *mut TF_Session, status: *mut TF_Status); + + pub fn TF_ExtendGraph(session: *mut TF_Session, proto: *const c_void, proto_len: size_t, + status: *mut TF_Status); + + pub fn TF_Run(session: *mut TF_Session, run_options: *const TF_Buffer, + input_names: *mut *const c_char, inputs: *mut *mut TF_Tensor, ninputs: c_int, + output_tensor_names: *mut *const c_char, outputs: *mut *mut TF_Tensor, + noutputs: c_int, target_node_names: *mut *const c_char, ntargets: c_int, + run_metadata: *mut TF_Buffer, status: *mut TF_Status); + + pub fn TF_PRunSetup(session: *mut TF_Session, input_names: *mut *const c_char, ninputs: c_int, + output_tensor_names: *mut *const c_char, noutputs: c_int, + target_node_names: *mut *const c_char, ntargets: c_int, + handle: *mut *mut c_char, status: *mut TF_Status); + + pub fn TF_PRun(session: *mut TF_Session, handle: *const c_char, + input_names: *mut *const c_char, inputs: *mut *mut TF_Tensor, ninputs: c_int, + output_tensor_names: *mut *const c_char, outputs: *mut *mut TF_Tensor, + noutputs: c_int, target_node_names: *mut *const c_char, ntargets: c_int, + status: *mut TF_Status); + + pub fn TF_LoadLibrary(library_filename: *const c_char, status: *mut TF_Status) + -> *mut TF_Library; + + pub fn TF_GetOpList(lib_handle: *mut TF_Library) -> TF_Buffer; +} diff --git a/tensorflow-sys/tests/lib.rs b/tensorflow-sys/tests/lib.rs new file mode 100644 index 0000000000..deb39ba4d3 --- /dev/null +++ b/tensorflow-sys/tests/lib.rs @@ -0,0 +1,10 @@ +extern crate tensorflow_sys as ffi; + +#[test] +fn linkage() { + unsafe { + let buffer = ffi::TF_NewBuffer(); + assert!(!buffer.is_null()); + ffi::TF_DeleteBuffer(buffer); + } +}