@@ -27,31 +27,30 @@ import {assert} from '../util';
27
27
import { getModelArtifactsForJSON , getModelArtifactsInfoForJSON , getModelJSONForModelArtifacts , getWeightSpecs } from './io_utils' ;
28
28
import { CompositeArrayBuffer } from './composite_array_buffer' ;
29
29
import { IORouter , IORouterRegistry } from './router_registry' ;
30
- import { IOHandler , LoadOptions , ModelArtifacts , ModelJSON , OnProgressCallback , SaveResult , WeightData , WeightsManifestConfig , WeightsManifestEntry } from './types' ;
31
- import { loadWeightsAsArrayBuffer } from './weights_loader' ;
30
+ import { IOHandler , LoadOptions , ModelArtifacts , ModelJSON , SaveResult , WeightData , WeightsManifestConfig , WeightsManifestEntry } from './types' ;
31
+ import { loadWeightsAsArrayBuffer , streamWeights } from './weights_loader' ;
32
32
33
33
const OCTET_STREAM_MIME_TYPE = 'application/octet-stream' ;
34
34
const JSON_TYPE = 'application/json' ;
35
35
export class HTTPRequest implements IOHandler {
36
36
protected readonly path : string ;
37
37
protected readonly requestInit : RequestInit ;
38
38
39
- private readonly fetch : Function ;
39
+ private readonly fetch : typeof fetch ;
40
40
private readonly weightUrlConverter : ( weightName : string ) => Promise < string > ;
41
41
42
42
readonly DEFAULT_METHOD = 'POST' ;
43
43
44
44
static readonly URL_SCHEME_REGEX = / ^ h t t p s ? : \/ \/ / ;
45
45
46
46
private readonly weightPathPrefix : string ;
47
- private readonly onProgress : OnProgressCallback ;
47
+ private readonly loadOptions : LoadOptions ;
48
48
49
49
constructor ( path : string , loadOptions ?: LoadOptions ) {
50
50
if ( loadOptions == null ) {
51
51
loadOptions = { } ;
52
52
}
53
53
this . weightPathPrefix = loadOptions . weightPathPrefix ;
54
- this . onProgress = loadOptions . onProgress ;
55
54
this . weightUrlConverter = loadOptions . weightUrlConverter ;
56
55
57
56
if ( loadOptions . fetchFunc != null ) {
@@ -84,6 +83,7 @@ export class HTTPRequest implements IOHandler {
84
83
'requestInit is expected to have no pre-existing body, but has one.' ) ;
85
84
}
86
85
this . requestInit = loadOptions . requestInit || { } ;
86
+ this . loadOptions = loadOptions ;
87
87
}
88
88
89
89
async save ( modelArtifacts : ModelArtifacts ) : Promise < SaveResult > {
@@ -135,15 +135,7 @@ export class HTTPRequest implements IOHandler {
135
135
}
136
136
}
137
137
138
- /**
139
- * Load model artifacts via HTTP request(s).
140
- *
141
- * See the documentation to `tf.io.http` for details on the saved
142
- * artifacts.
143
- *
144
- * @returns The loaded model artifacts (if loading succeeds).
145
- */
146
- async load ( ) : Promise < ModelArtifacts > {
138
+ private async loadModelJSON ( ) : Promise < ModelJSON > {
147
139
const modelConfigRequest = await this . fetch ( this . path , this . requestInit ) ;
148
140
149
141
if ( ! modelConfigRequest . ok ) {
@@ -182,18 +174,45 @@ export class HTTPRequest implements IOHandler {
182
174
`topology or manifest for weights.` ) ;
183
175
}
184
176
177
+ return modelJSON ;
178
+ }
179
+
180
+ /**
181
+ * Load model artifacts via HTTP request(s).
182
+ *
183
+ * See the documentation to `tf.io.http` for details on the saved
184
+ * artifacts.
185
+ *
186
+ * @returns The loaded model artifacts (if loading succeeds).
187
+ */
188
+ async load ( ) : Promise < ModelArtifacts > {
189
+ if ( this . loadOptions . streamWeights ) {
190
+ return this . loadStream ( ) ;
191
+ }
192
+ const modelJSON = await this . loadModelJSON ( ) ;
185
193
return getModelArtifactsForJSON (
186
194
modelJSON , ( weightsManifest ) => this . loadWeights ( weightsManifest ) ) ;
187
195
}
188
196
189
- private async loadWeights ( weightsManifest : WeightsManifestConfig ) :
190
- Promise < [ WeightsManifestEntry [ ] , WeightData ] > {
197
+ private async loadStream ( ) : Promise < ModelArtifacts > {
198
+ const modelJSON = await this . loadModelJSON ( ) ;
199
+ const fetchURLs = await this . getWeightUrls ( modelJSON . weightsManifest ) ;
200
+ const weightSpecs = getWeightSpecs ( modelJSON . weightsManifest ) ;
201
+ const stream = ( ) => streamWeights ( fetchURLs , this . loadOptions ) ;
202
+
203
+ return {
204
+ ...modelJSON ,
205
+ weightSpecs,
206
+ getWeightStream : stream ,
207
+ } ;
208
+ }
209
+
210
+ private async getWeightUrls ( weightsManifest : WeightsManifestConfig ) :
211
+ Promise < string [ ] > {
191
212
const weightPath = Array . isArray ( this . path ) ? this . path [ 1 ] : this . path ;
192
213
const [ prefix , suffix ] = parseUrl ( weightPath ) ;
193
214
const pathPrefix = this . weightPathPrefix || prefix ;
194
215
195
- const weightSpecs = getWeightSpecs ( weightsManifest ) ;
196
-
197
216
const fetchURLs : string [ ] = [ ] ;
198
217
const urlPromises : Array < Promise < string > > = [ ] ;
199
218
for ( const weightsGroup of weightsManifest ) {
@@ -209,12 +228,15 @@ export class HTTPRequest implements IOHandler {
209
228
if ( this . weightUrlConverter ) {
210
229
fetchURLs . push ( ...await Promise . all ( urlPromises ) ) ;
211
230
}
231
+ return fetchURLs ;
232
+ }
233
+
234
+ private async loadWeights ( weightsManifest : WeightsManifestConfig ) :
235
+ Promise < [ WeightsManifestEntry [ ] , WeightData ] > {
236
+ const fetchURLs = await this . getWeightUrls ( weightsManifest ) ;
237
+ const weightSpecs = getWeightSpecs ( weightsManifest ) ;
212
238
213
- const buffers = await loadWeightsAsArrayBuffer ( fetchURLs , {
214
- requestInit : this . requestInit ,
215
- fetchFunc : this . fetch ,
216
- onProgress : this . onProgress
217
- } ) ;
239
+ const buffers = await loadWeightsAsArrayBuffer ( fetchURLs , this . loadOptions ) ;
218
240
return [ weightSpecs , buffers ] ;
219
241
}
220
242
}
0 commit comments