Skip to content

Commit e2ba43c

Browse files
Stream weights to the GPU when loading a model (#7994)
When downloading model weight data, slice it into weight tensors and push them to the GPU eagerly. This avoids storing an extra copy of the weights on CPU, allowing for larger models (1.3B to possibly ~6.7B or larger) to be loaded without causing a V8 OOM crash. When streaming the weights, check CPU_HANDOFF_SIZE_THRESHOLD or WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD to determine whether the weight should be sent to GPU or remain on CPU. This feature is guarded by the streamWeights option in LoadOptions. Since most of TFJS's graph model saving relies on the CPU copy of the model, model saving is disabled when the model was streamed (i.e. it will throw an error since the weights ArrayBuffer is missing).
1 parent 929b35d commit e2ba43c

11 files changed

+532
-284
lines changed

tfjs-converter/src/executor/graph_model.ts

+27-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import {OperationMapper} from '../operations/operation_mapper';
2323

2424
import {GraphExecutor} from './graph_executor';
2525
import {ResourceManager} from './resource_manager';
26+
// tslint:disable-next-line: no-imports-from-dist
27+
import {decodeWeightsStream} from '@tensorflow/tfjs-core/dist/io/io_utils';
2628

2729
export const TFHUB_SEARCH_PARAM = '?tfjs-format=file';
2830
export const DEFAULT_MODEL_NAME = 'model.json';
@@ -154,7 +156,12 @@ export class GraphModel<ModelURL extends Url = string | io.IOHandler> implements
154156

155157
const loadResult = this.handler.load() as ReturnType<IOHandler['load']>;
156158
if (util.isPromise(loadResult)) {
157-
return loadResult.then(artifacts => this.loadSync(artifacts)) as Result;
159+
return loadResult.then(artifacts => {
160+
if (artifacts.getWeightStream == null) {
161+
return this.loadSync(artifacts);
162+
}
163+
return this.loadStreaming(artifacts);
164+
}) as Result;
158165
}
159166

160167
return this.loadSync(loadResult) as Result;
@@ -167,6 +174,25 @@ export class GraphModel<ModelURL extends Url = string | io.IOHandler> implements
167174
* @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
168175
*/
169176
loadSync(artifacts: io.ModelArtifacts) {
177+
const weightMap = this.io.decodeWeights(
178+
artifacts.weightData, artifacts.weightSpecs);
179+
180+
return this.loadWithWeightMap(artifacts, weightMap);
181+
}
182+
183+
private async loadStreaming(artifacts: io.ModelArtifacts): Promise<boolean> {
184+
if (artifacts.getWeightStream == null) {
185+
throw new Error('Model artifacts missing streamWeights function');
186+
}
187+
188+
const weightMap = await decodeWeightsStream(
189+
artifacts.getWeightStream(), artifacts.weightSpecs);
190+
191+
return this.loadWithWeightMap(artifacts, weightMap);
192+
}
193+
194+
private loadWithWeightMap(artifacts: io.ModelArtifacts,
195+
weightMap: NamedTensorMap) {
170196
this.artifacts = artifacts;
171197
const graph = this.artifacts.modelTopology as tensorflow.IGraphDef;
172198

@@ -184,8 +210,6 @@ export class GraphModel<ModelURL extends Url = string | io.IOHandler> implements
184210
this.signature = signature;
185211

186212
this.version = `${graph.versions.producer}.${graph.versions.minConsumer}`;
187-
const weightMap = this.io.decodeWeights(
188-
this.artifacts.weightData, this.artifacts.weightSpecs);
189213
this.executor = new GraphExecutor(
190214
OperationMapper.Instance.transformGraph(graph, this.signature));
191215
this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap);

tfjs-converter/src/executor/graph_model_test.ts

+36-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import {GraphNode} from '../operations/types';
2525
import {GraphModel, loadGraphModel, loadGraphModelSync} from './graph_model';
2626
import {HASH_TABLE_MODEL_V2} from './test_data/hash_table_v2_model_loader';
2727
import {STRUCTURED_OUTPUTS_MODEL} from './test_data/structured_outputs_model_loader';
28+
// tslint:disable-next-line: no-imports-from-dist
29+
import {expectArrayBuffersEqual} from '@tensorflow/tfjs-core/dist/test_util';
2830

2931
const HOST = 'http://example.org';
3032
const MODEL_URL = `${HOST}/model.json`;
@@ -125,6 +127,24 @@ const SIMPLE_HTTP_MODEL_LOADER = {
125127
}
126128
};
127129

130+
const SIMPLE_STREAMING_MODEL_LOADER = {
131+
load: async () => {
132+
return {
133+
modelTopology: SIMPLE_MODEL,
134+
weightSpecs: weightsManifest,
135+
getWeightStream: () => {
136+
const data = bias.dataSync();
137+
const blob = new Blob([data]);
138+
return blob.stream();
139+
},
140+
format: 'tfjs-graph-model',
141+
generatedBy: '1.15',
142+
convertedBy: '1.3.1',
143+
userDefinedMetadata: {signature: SIGNATURE}
144+
};
145+
}
146+
};
147+
128148
const NO_INPUT_SIGNATURE_MODEL_LOADER = {
129149
load: async () => {
130150
return {
@@ -438,7 +458,7 @@ describe('loadGraphModel', () => {
438458
});
439459

440460
it('Pass a fetchFunc', async () => {
441-
const fetchFunc = () => {};
461+
const fetchFunc = (() => {}) as unknown as typeof fetch;
442462
spyIo.getLoadHandlers.and.returnValue([CUSTOM_HTTP_MODEL_LOADER]);
443463
await loadGraphModel(MODEL_URL, {fetchFunc}, spyIo);
444464
expect(spyIo.getLoadHandlers).toHaveBeenCalledWith(MODEL_URL, {fetchFunc});
@@ -594,7 +614,13 @@ describe('Model', () => {
594614

595615
describe('simple model', () => {
596616
beforeEach(() => {
597-
spyIo.getLoadHandlers.and.returnValue([SIMPLE_HTTP_MODEL_LOADER]);
617+
spyIo.getLoadHandlers.and.callFake((_url: string|string[],
618+
loadOptions?: io.LoadOptions) => {
619+
if (loadOptions.streamWeights) {
620+
return [SIMPLE_STREAMING_MODEL_LOADER];
621+
}
622+
return [SIMPLE_HTTP_MODEL_LOADER];
623+
});
598624
spyIo.browserHTTPRequest.and.returnValue(SIMPLE_HTTP_MODEL_LOADER);
599625
});
600626
it('load', async () => {
@@ -776,6 +802,14 @@ describe('Model', () => {
776802
expect(model).toBeDefined();
777803
});
778804

805+
it('should stream graph model weights', async () => {
806+
const model = await loadGraphModel(MODEL_URL, {streamWeights: true},
807+
spyIo);
808+
expect(model).toBeDefined();
809+
expectArrayBuffersEqual(model.weights['Const'][0].dataSync(),
810+
bias.dataSync());
811+
});
812+
779813
describe('InferenceModel interface', () => {
780814
it('should expose inputs', async () => {
781815
await model.load();

tfjs-core/src/io/http.ts

+45-23
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,30 @@ import {assert} from '../util';
2727
import {getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts, getWeightSpecs} from './io_utils';
2828
import {CompositeArrayBuffer} from './composite_array_buffer';
2929
import {IORouter, IORouterRegistry} from './router_registry';
30-
import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types';
31-
import {loadWeightsAsArrayBuffer} from './weights_loader';
30+
import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types';
31+
import {loadWeightsAsArrayBuffer, streamWeights} from './weights_loader';
3232

3333
const OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
3434
const JSON_TYPE = 'application/json';
3535
export class HTTPRequest implements IOHandler {
3636
protected readonly path: string;
3737
protected readonly requestInit: RequestInit;
3838

39-
private readonly fetch: Function;
39+
private readonly fetch: typeof fetch;
4040
private readonly weightUrlConverter: (weightName: string) => Promise<string>;
4141

4242
readonly DEFAULT_METHOD = 'POST';
4343

4444
static readonly URL_SCHEME_REGEX = /^https?:\/\//;
4545

4646
private readonly weightPathPrefix: string;
47-
private readonly onProgress: OnProgressCallback;
47+
private readonly loadOptions: LoadOptions;
4848

4949
constructor(path: string, loadOptions?: LoadOptions) {
5050
if (loadOptions == null) {
5151
loadOptions = {};
5252
}
5353
this.weightPathPrefix = loadOptions.weightPathPrefix;
54-
this.onProgress = loadOptions.onProgress;
5554
this.weightUrlConverter = loadOptions.weightUrlConverter;
5655

5756
if (loadOptions.fetchFunc != null) {
@@ -84,6 +83,7 @@ export class HTTPRequest implements IOHandler {
8483
'requestInit is expected to have no pre-existing body, but has one.');
8584
}
8685
this.requestInit = loadOptions.requestInit || {};
86+
this.loadOptions = loadOptions;
8787
}
8888

8989
async save(modelArtifacts: ModelArtifacts): Promise<SaveResult> {
@@ -135,15 +135,7 @@ export class HTTPRequest implements IOHandler {
135135
}
136136
}
137137

138-
/**
139-
* Load model artifacts via HTTP request(s).
140-
*
141-
* See the documentation to `tf.io.http` for details on the saved
142-
* artifacts.
143-
*
144-
* @returns The loaded model artifacts (if loading succeeds).
145-
*/
146-
async load(): Promise<ModelArtifacts> {
138+
private async loadModelJSON(): Promise<ModelJSON> {
147139
const modelConfigRequest = await this.fetch(this.path, this.requestInit);
148140

149141
if (!modelConfigRequest.ok) {
@@ -182,18 +174,45 @@ export class HTTPRequest implements IOHandler {
182174
`topology or manifest for weights.`);
183175
}
184176

177+
return modelJSON;
178+
}
179+
180+
/**
181+
* Load model artifacts via HTTP request(s).
182+
*
183+
* See the documentation to `tf.io.http` for details on the saved
184+
* artifacts.
185+
*
186+
* @returns The loaded model artifacts (if loading succeeds).
187+
*/
188+
async load(): Promise<ModelArtifacts> {
189+
if (this.loadOptions.streamWeights) {
190+
return this.loadStream();
191+
}
192+
const modelJSON = await this.loadModelJSON();
185193
return getModelArtifactsForJSON(
186194
modelJSON, (weightsManifest) => this.loadWeights(weightsManifest));
187195
}
188196

189-
private async loadWeights(weightsManifest: WeightsManifestConfig):
190-
Promise<[WeightsManifestEntry[], WeightData]> {
197+
private async loadStream(): Promise<ModelArtifacts> {
198+
const modelJSON = await this.loadModelJSON();
199+
const fetchURLs = await this.getWeightUrls(modelJSON.weightsManifest);
200+
const weightSpecs = getWeightSpecs(modelJSON.weightsManifest);
201+
const stream = () => streamWeights(fetchURLs, this.loadOptions);
202+
203+
return {
204+
...modelJSON,
205+
weightSpecs,
206+
getWeightStream: stream,
207+
};
208+
}
209+
210+
private async getWeightUrls(weightsManifest: WeightsManifestConfig):
211+
Promise<string[]> {
191212
const weightPath = Array.isArray(this.path) ? this.path[1] : this.path;
192213
const [prefix, suffix] = parseUrl(weightPath);
193214
const pathPrefix = this.weightPathPrefix || prefix;
194215

195-
const weightSpecs = getWeightSpecs(weightsManifest);
196-
197216
const fetchURLs: string[] = [];
198217
const urlPromises: Array<Promise<string>> = [];
199218
for (const weightsGroup of weightsManifest) {
@@ -209,12 +228,15 @@ export class HTTPRequest implements IOHandler {
209228
if (this.weightUrlConverter) {
210229
fetchURLs.push(...await Promise.all(urlPromises));
211230
}
231+
return fetchURLs;
232+
}
233+
234+
private async loadWeights(weightsManifest: WeightsManifestConfig):
235+
Promise<[WeightsManifestEntry[], WeightData]> {
236+
const fetchURLs = await this.getWeightUrls(weightsManifest);
237+
const weightSpecs = getWeightSpecs(weightsManifest);
212238

213-
const buffers = await loadWeightsAsArrayBuffer(fetchURLs, {
214-
requestInit: this.requestInit,
215-
fetchFunc: this.fetch,
216-
onProgress: this.onProgress
217-
});
239+
const buffers = await loadWeightsAsArrayBuffer(fetchURLs, this.loadOptions);
218240
return [weightSpecs, buffers];
219241
}
220242
}

tfjs-core/src/io/io.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import './local_storage';
2222

2323
import {browserFiles} from './browser_files';
2424
import {browserHTTPRequest, http, isHTTPScheme} from './http';
25-
import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsForJSON, getModelArtifactsForJSONSync, getModelArtifactsInfoForJSON, getWeightSpecs} from './io_utils';
25+
import {concatenateArrayBuffers, decodeWeights, decodeWeightsStream, encodeWeights, getModelArtifactsForJSON, getModelArtifactsForJSONSync, getModelArtifactsInfoForJSON, getWeightSpecs} from './io_utils';
2626
import {fromMemory, fromMemorySync, withSaveHandler, withSaveHandlerSync} from './passthrough';
2727
import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry';
2828
import {IOHandler, IOHandlerSync, LoadHandler, LoadOptions, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, OnProgressCallback, RequestDetails, SaveConfig, SaveHandler, SaveResult, TrainingConfig, WeightGroup, WeightsManifestConfig, WeightsManifestEntry, WeightData} from './types';
@@ -36,6 +36,7 @@ export {
3636
CompositeArrayBuffer,
3737
concatenateArrayBuffers,
3838
decodeWeights,
39+
decodeWeightsStream,
3940
encodeWeights,
4041
fromMemory,
4142
fromMemorySync,

0 commit comments

Comments
 (0)