@@ -10,7 +10,7 @@ use std::fmt;
10
10
use std:: cmp;
11
11
use std:: collections:: { HashMap , HashSet } ;
12
12
use std:: fs:: File ;
13
- use std:: io;
13
+ use std:: io:: { self , BufReader } ;
14
14
use std:: path:: Path ;
15
15
use std:: rc:: Rc ;
16
16
use std:: sync:: { Arc , RwLock } ;
@@ -455,7 +455,7 @@ impl<B: IBackend> Layer<B> {
455
455
// reshape input tensor to the reshaped shape
456
456
let old_shape = self . input_blobs_data [ input_i] . read ( ) . unwrap ( ) . desc ( ) . clone ( ) ;
457
457
if old_shape. size ( ) != reshaped_shape. size ( ) {
458
- panic ! ( "The provided input does not have the expected shape" ) ;
458
+ panic ! ( "The provided input does not have the expected shape of {:?}" , reshaped_shape ) ;
459
459
}
460
460
self . input_blobs_data [ input_i] . write ( ) . unwrap ( ) . reshape ( & reshaped_shape) . unwrap ( ) ;
461
461
}
@@ -583,6 +583,39 @@ impl<B: IBackend> Layer<B> {
583
583
/// Serialize the Layer and it's weights to a Cap'n Proto file at the specified path.
584
584
///
585
585
/// You can find the capnp schema [here](../../../../capnp/leaf.capnp).
586
+ ///
587
+ /// ```
588
+ /// # #[cfg(feature = "native")]
589
+ /// # mod native {
590
+ /// # use std::rc::Rc;
591
+ /// # use leaf::layer::*;
592
+ /// # use leaf::layers::*;
593
+ /// # use leaf::util;
594
+ /// # pub fn test() {
595
+ /// #
596
+ /// let mut net_cfg = SequentialConfig::default();
597
+ /// // ... set up network ...
598
+ /// let cfg = LayerConfig::new("network", net_cfg);
599
+ ///
600
+ /// let native_backend = Rc::new(util::native_backend());
601
+ /// let mut layer = Layer::from_config(native_backend, &cfg);
602
+ /// // ... do stuff with the layer ...
603
+ /// // ... and save it
604
+ /// layer.save("mynetwork").unwrap();
605
+ /// #
606
+ /// # }}
607
+ /// #
608
+ /// # #[cfg(not(feature = "native"))]
609
+ /// # mod native {
610
+ /// # pub fn test() {}
611
+ /// # }
612
+ /// #
613
+ /// # fn main() {
614
+ /// # if cfg!(feature = "native") {
615
+ /// # ::native::test();
616
+ /// # }
617
+ /// # }
618
+ /// ```
586
619
pub fn save < P : AsRef < Path > > ( & mut self , path : P ) -> io:: Result < ( ) > {
587
620
let path = path. as_ref ( ) ;
588
621
let ref mut out = try!( File :: create ( path) ) ;
@@ -597,6 +630,92 @@ impl<B: IBackend> Layer<B> {
597
630
Ok ( ( ) )
598
631
}
599
632
633
+ /// Read a Cap'n Proto file at the specified path and deserialize the Layer inside it.
634
+ ///
635
+ /// You can find the capnp schema [here](../../../../capnp/leaf.capnp).
636
+ ///
637
+ /// ```
638
+ /// # extern crate leaf;
639
+ /// # extern crate collenchyma;
640
+ /// # #[cfg(feature = "native")]
641
+ /// # mod native {
642
+ /// # use std::rc::Rc;
643
+ /// # use leaf::layer::*;
644
+ /// # use leaf::layers::*;
645
+ /// # use leaf::util;
646
+ /// use collenchyma::prelude::*;
647
+ /// # pub fn test() {
648
+ ///
649
+ /// let native_backend = Rc::new(util::native_backend());
650
+ /// # let mut net_cfg = SequentialConfig::default();
651
+ /// # let cfg = LayerConfig::new("network", net_cfg);
652
+ /// # let mut layer = Layer::from_config(native_backend.clone(), &cfg);
653
+ /// # layer.save("mynetwork").unwrap();
654
+ /// // Load layer from file "mynetwork"
655
+ /// let layer = Layer::<Backend<Native>>::load(native_backend, "mynetwork").unwrap();
656
+ /// #
657
+ /// # }}
658
+ /// #
659
+ /// # #[cfg(not(feature = "native"))]
660
+ /// # mod native {
661
+ /// # pub fn test() {}
662
+ /// # }
663
+ /// #
664
+ /// # fn main() {
665
+ /// # if cfg!(feature = "native") {
666
+ /// # ::native::test();
667
+ /// # }
668
+ /// # }
669
+ /// ```
670
+ pub fn load < LB : IBackend + LayerOps < f32 > + ' static , P : AsRef < Path > > ( backend : Rc < LB > , path : P ) -> io:: Result < Layer < LB > > {
671
+ let path = path. as_ref ( ) ;
672
+ let ref mut file = try!( File :: open ( path) ) ;
673
+ let mut reader = BufReader :: new ( file) ;
674
+
675
+ let message_reader = :: capnp:: serialize_packed:: read_message ( & mut reader,
676
+ :: capnp:: message:: ReaderOptions :: new ( ) ) . unwrap ( ) ;
677
+ let read_layer = message_reader. get_root :: < capnp_layer:: Reader > ( ) . unwrap ( ) ;
678
+
679
+ let name = read_layer. get_name ( ) . unwrap ( ) . to_owned ( ) ;
680
+ let layer_config = LayerConfig :: read_capnp ( read_layer. get_config ( ) . unwrap ( ) ) ;
681
+ let mut layer = Layer :: from_config ( backend, & layer_config) ;
682
+ layer. name = name;
683
+
684
+ let read_weights = read_layer. get_weights_data ( ) . unwrap ( ) ;
685
+
686
+ let names = layer. learnable_weights_names ( ) ;
687
+ let weights_data = layer. learnable_weights_data ( ) ;
688
+
689
+ let native_backend = Backend :: < Native > :: default ( ) . unwrap ( ) ;
690
+ for ( i, ( name, weight) ) in names. iter ( ) . zip ( weights_data) . enumerate ( ) {
691
+ for j in 0 ..read_weights. len ( ) {
692
+ let capnp_weight = read_weights. get ( i as u32 ) ;
693
+ if capnp_weight. get_name ( ) . unwrap ( ) != name {
694
+ continue
695
+ }
696
+
697
+ let mut weight_lock = weight. write ( ) . unwrap ( ) ;
698
+ weight_lock. sync ( native_backend. device ( ) ) . unwrap ( ) ;
699
+
700
+ let capnp_tensor = capnp_weight. get_tensor ( ) . unwrap ( ) ;
701
+ let mut shape = Vec :: new ( ) ;
702
+ let capnp_shape = capnp_tensor. get_shape ( ) . unwrap ( ) ;
703
+ for k in 0 ..capnp_shape. len ( ) {
704
+ shape. push ( capnp_shape. get ( k) as usize )
705
+ }
706
+ weight_lock. reshape ( & shape) . unwrap ( ) ;
707
+
708
+ let mut native_slice = weight_lock. get_mut ( native_backend. device ( ) ) . unwrap ( ) . as_mut_native ( ) . unwrap ( ) . as_mut_slice :: < f32 > ( ) ;
709
+ let data = capnp_tensor. get_data ( ) . unwrap ( ) ;
710
+ for k in 0 ..data. len ( ) {
711
+ native_slice[ k as usize ] = data. get ( k) ;
712
+ }
713
+ }
714
+ }
715
+
716
+ Ok ( layer)
717
+ }
718
+
600
719
/// Sets whether the layer should compute gradients w.r.t. a
601
720
/// weight at a particular index given by `weight_id`.
602
721
///
@@ -672,6 +791,9 @@ impl<B: IBackend> Layer<B> {
672
791
}
673
792
}
674
793
794
+ #[ allow( unsafe_code) ]
795
+ unsafe impl < B : IBackend > Send for Layer < B > { }
796
+
675
797
impl < ' a , B : IBackend > CapnpWrite < ' a > for Layer < B > {
676
798
type Builder = capnp_layer:: Builder < ' a > ;
677
799
@@ -1269,6 +1391,31 @@ impl<'a> CapnpWrite<'a> for LayerType {
1269
1391
}
1270
1392
}
1271
1393
1394
+ impl < ' a > CapnpRead < ' a > for LayerType {
1395
+ type Reader = capnp_layer_type:: Reader < ' a > ;
1396
+
1397
+ fn read_capnp ( reader : Self :: Reader ) -> Self {
1398
+ match reader. which ( ) . unwrap ( ) {
1399
+ #[ cfg( all( feature="cuda" , not( feature="native" ) ) ) ]
1400
+ capnp_layer_type:: Which :: Convolution ( read_config) => { let config = ConvolutionConfig :: read_capnp ( read_config. unwrap ( ) ) ; LayerType :: Convolution ( config) } ,
1401
+ #[ cfg( not( all( feature="cuda" , not( feature="native" ) ) ) ) ]
1402
+ capnp_layer_type:: Which :: Convolution ( _) => { panic ! ( "Can not load Network because Convolution layer is not supported with the used feature flags." ) } ,
1403
+ capnp_layer_type:: Which :: Linear ( read_config) => { let config = LinearConfig :: read_capnp ( read_config. unwrap ( ) ) ; LayerType :: Linear ( config) } ,
1404
+ capnp_layer_type:: Which :: LogSoftmax ( read_config) => { LayerType :: LogSoftmax } ,
1405
+ #[ cfg( all( feature="cuda" , not( feature="native" ) ) ) ]
1406
+ capnp_layer_type:: Which :: Pooling ( read_config) => { let config = PoolingConfig :: read_capnp ( read_config. unwrap ( ) ) ; LayerType :: Pooling ( config) } ,
1407
+ #[ cfg( not( all( feature="cuda" , not( feature="native" ) ) ) ) ]
1408
+ capnp_layer_type:: Which :: Pooling ( _) => { panic ! ( "Can not load Network because Pooling layer is not supported with the used feature flags." ) } ,
1409
+ capnp_layer_type:: Which :: Sequential ( read_config) => { let config = SequentialConfig :: read_capnp ( read_config. unwrap ( ) ) ; LayerType :: Sequential ( config) } ,
1410
+ capnp_layer_type:: Which :: Softmax ( _) => { LayerType :: Softmax } ,
1411
+ capnp_layer_type:: Which :: Relu ( _) => { LayerType :: ReLU } ,
1412
+ capnp_layer_type:: Which :: Sigmoid ( _) => { LayerType :: Sigmoid } ,
1413
+ capnp_layer_type:: Which :: NegativeLogLikelihood ( read_config) => { let config = NegativeLogLikelihoodConfig :: read_capnp ( read_config. unwrap ( ) ) ; LayerType :: NegativeLogLikelihood ( config) } ,
1414
+ capnp_layer_type:: Which :: Reshape ( read_config) => { let config = ReshapeConfig :: read_capnp ( read_config. unwrap ( ) ) ; LayerType :: Reshape ( config) } ,
1415
+ }
1416
+ }
1417
+ }
1418
+
1272
1419
impl LayerConfig {
1273
1420
/// Creates a new LayerConfig
1274
1421
pub fn new < L : Into < LayerType > > ( name : & str , layer_type : L ) -> LayerConfig {
@@ -1338,9 +1485,13 @@ impl LayerConfig {
1338
1485
Err ( "propagate_down config must be specified either 0 or inputs_len times" )
1339
1486
}
1340
1487
}
1488
+ }
1489
+
1490
+ impl < ' a > CapnpWrite < ' a > for LayerConfig {
1491
+ type Builder = capnp_layer_config:: Builder < ' a > ;
1341
1492
1342
1493
/// Write the LayerConfig into a capnp message.
1343
- pub fn write_capnp ( & self , builder : & mut capnp_layer_config :: Builder ) {
1494
+ fn write_capnp ( & self , builder : & mut Self :: Builder ) {
1344
1495
builder. set_name ( & self . name ) ;
1345
1496
{
1346
1497
let mut layer_type = builder. borrow ( ) . init_layer_type ( ) ;
@@ -1373,3 +1524,44 @@ impl LayerConfig {
1373
1524
}
1374
1525
}
1375
1526
}
1527
+
1528
+ impl < ' a > CapnpRead < ' a > for LayerConfig {
1529
+ type Reader = capnp_layer_config:: Reader < ' a > ;
1530
+
1531
+ fn read_capnp ( reader : Self :: Reader ) -> Self {
1532
+ let name = reader. get_name ( ) . unwrap ( ) . to_owned ( ) ;
1533
+ let layer_type = LayerType :: read_capnp ( reader. get_layer_type ( ) ) ;
1534
+
1535
+ let read_outputs = reader. get_outputs ( ) . unwrap ( ) ;
1536
+ let mut outputs = Vec :: new ( ) ;
1537
+ for i in 0 ..read_outputs. len ( ) {
1538
+ outputs. push ( read_outputs. get ( i) . unwrap ( ) . to_owned ( ) )
1539
+ }
1540
+ let read_inputs = reader. get_inputs ( ) . unwrap ( ) ;
1541
+ let mut inputs = Vec :: new ( ) ;
1542
+ for i in 0 ..read_inputs. len ( ) {
1543
+ inputs. push ( read_inputs. get ( i) . unwrap ( ) . to_owned ( ) )
1544
+ }
1545
+
1546
+ let read_params = reader. get_params ( ) . unwrap ( ) ;
1547
+ let mut params = Vec :: new ( ) ;
1548
+ for i in 0 ..read_params. len ( ) {
1549
+ params. push ( WeightConfig :: read_capnp ( read_params. get ( i) ) )
1550
+ }
1551
+
1552
+ let read_propagate_down = reader. get_propagate_down ( ) . unwrap ( ) ;
1553
+ let mut propagate_down = Vec :: new ( ) ;
1554
+ for i in 0 ..read_propagate_down. len ( ) {
1555
+ propagate_down. push ( read_propagate_down. get ( i) )
1556
+ }
1557
+
1558
+ LayerConfig {
1559
+ name : name,
1560
+ layer_type : layer_type,
1561
+ outputs : outputs,
1562
+ inputs : inputs,
1563
+ params : params,
1564
+ propagate_down : propagate_down,
1565
+ }
1566
+ }
1567
+ }
0 commit comments