@@ -66,8 +66,10 @@ mod layer_spec {
66
66
let loaded_weights = loaded_layer. learnable_weights_data ( ) ;
67
67
let loaded_weight_lock = loaded_weights[ 0 ] . read ( ) . unwrap ( ) ;
68
68
69
- let original_weight = original_weight_lock. get ( native_backend ( ) . device ( ) ) . unwrap ( ) . as_native ( ) . unwrap ( ) . as_slice :: < f32 > ( ) ;
70
- let loaded_weight = loaded_weight_lock. get ( native_backend ( ) . device ( ) ) . unwrap ( ) . as_native ( ) . unwrap ( ) . as_slice :: < f32 > ( ) ;
69
+ let original_weight = original_weight_lock. read ( native_backend ( ) . device ( ) )
70
+ . unwrap ( ) . as_native ( ) . unwrap ( ) . as_slice :: < f32 > ( ) ;
71
+ let loaded_weight = loaded_weight_lock. read ( native_backend ( ) . device ( ) )
72
+ . unwrap ( ) . as_native ( ) . unwrap ( ) . as_slice :: < f32 > ( ) ;
71
73
72
74
assert_eq ! ( original_weight, loaded_weight) ;
73
75
}
@@ -131,27 +133,28 @@ mod layer_spec {
131
133
let mut reshape_network = Layer :: from_config ( cuda_backend. clone ( ) , & LayerConfig :: new ( "reshape_model" , LayerType :: Sequential ( reshape_model) ) ) ;
132
134
133
135
let input = vec ! [ 1f32 , 1f32 , 2f32 ] ;
134
- let mut normal_tensor = SharedTensor :: < f32 > :: new ( native_backend . device ( ) , & ( 3 ) ) . unwrap ( ) ;
136
+ let mut normal_tensor = SharedTensor :: < f32 > :: new ( & [ 3 ] ) ;
135
137
// let mut normal_tensor_output = SharedTensor::<f32>::new(native_backend.device(), &(3)).unwrap();
136
- let mut reshape_tensor = SharedTensor :: < f32 > :: new ( native_backend . device ( ) , & ( 3 ) ) . unwrap ( ) ;
138
+ let mut reshape_tensor = SharedTensor :: < f32 > :: new ( & [ 3 ] ) ;
137
139
// let mut reshape_tensor_output = SharedTensor::<f32>::new(native_backend.device(), &(3)).unwrap();
138
- write_to_memory ( normal_tensor. get_mut ( native_backend. device ( ) ) . unwrap ( ) , & input) ;
139
- write_to_memory ( reshape_tensor. get_mut ( native_backend. device ( ) ) . unwrap ( ) , & input) ;
140
+ write_to_memory ( normal_tensor. write_only ( native_backend. device ( ) ) . unwrap ( ) , & input) ;
141
+ write_to_memory ( reshape_tensor. write_only ( native_backend. device ( ) ) . unwrap ( ) , & input) ;
140
142
141
143
let normal_tensor_output = normal_network. forward ( & [ Arc :: new ( RwLock :: new ( normal_tensor) ) ] ) [ 0 ] . clone ( ) ;
142
- let _ = normal_tensor_output. write ( ) . unwrap ( ) . add_device ( native_backend. device ( ) ) ;
143
- normal_tensor_output. write ( ) . unwrap ( ) . sync ( native_backend. device ( ) ) . unwrap ( ) ;
144
144
let normal_tensor_output_native_ = normal_tensor_output. read ( ) . unwrap ( ) ;
145
- let normal_tensor_output_native = normal_tensor_output_native_. get ( native_backend. device ( ) ) . unwrap ( ) . as_native ( ) . unwrap ( ) ;
146
- assert_eq ! ( & [ 0.7310585786f32 , 0.7310586f32 , 0.880797f32 ] , normal_tensor_output_native. as_slice:: <f32 >( ) ) ;
145
+ let normal_tensor_output_native = normal_tensor_output_native_
146
+ . read ( native_backend. device ( ) ) . unwrap ( ) . as_native ( ) . unwrap ( ) ;
147
+ assert_eq ! ( & [ 0.7310585786f32 , 0.7310586f32 , 0.880797f32 ] ,
148
+ normal_tensor_output_native. as_slice:: <f32 >( ) ) ;
147
149
148
150
let reshape_tensor_output = reshape_network. forward ( & [ Arc :: new ( RwLock :: new ( reshape_tensor) ) ] ) [ 0 ] . clone ( ) ;
149
- let _ = reshape_tensor_output. write ( ) . unwrap ( ) . add_device ( native_backend. device ( ) ) ;
150
- reshape_tensor_output. write ( ) . unwrap ( ) . sync ( native_backend. device ( ) ) . unwrap ( ) ;
151
151
let reshape_tensor_output_native_ = reshape_tensor_output. read ( ) . unwrap ( ) ;
152
- let reshape_tensor_output_native = reshape_tensor_output_native_. get ( native_backend. device ( ) ) . unwrap ( ) . as_native ( ) . unwrap ( ) ;
153
- assert_eq ! ( & [ 0.7310585786f32 , 0.7310586f32 , 0.880797f32 ] , reshape_tensor_output_native. as_slice:: <f32 >( ) ) ;
154
- assert_eq ! ( normal_tensor_output_native. as_slice:: <f32 >( ) , reshape_tensor_output_native. as_slice:: <f32 >( ) ) ;
152
+ let reshape_tensor_output_native = reshape_tensor_output_native_
153
+ . read ( native_backend. device ( ) ) . unwrap ( ) . as_native ( ) . unwrap ( ) ;
154
+ assert_eq ! ( & [ 0.7310585786f32 , 0.7310586f32 , 0.880797f32 ] ,
155
+ reshape_tensor_output_native. as_slice:: <f32 >( ) ) ;
156
+ assert_eq ! ( normal_tensor_output_native. as_slice:: <f32 >( ) ,
157
+ reshape_tensor_output_native. as_slice:: <f32 >( ) ) ;
155
158
}
156
159
}
157
160
0 commit comments