Skip to content

Commit 0717d2e

Browse files
committed
Use axios in checkpointLoader instead of xhr.
1 parent afe336c commit 0717d2e

File tree

1 file changed

+94
-61
lines changed

1 file changed

+94
-61
lines changed

src/utils/checkpointLoader.js

Lines changed: 94 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,86 +4,119 @@
44
// https://opensource.org/licenses/MIT
55

66
import * as tf from '@tensorflow/tfjs';
7+
import axios from 'axios';
78

89
const MANIFEST_FILE = 'manifest.json';
910

11+
/**
12+
* @typedef {Record<string, { filename: string, shape: Array<number> }>} Manifest
13+
*/
14+
/**
15+
* Loads all of the variables of a model from a directory
16+
* which contains a `manifest.json` file and individual variable data files.
17+
* The `manifest.json` contains the `filename` and `shape` for each data file.
18+
*
19+
* @class
20+
* @property {string} urlPath
21+
* @property {Manifest} [checkpointManifest]
22+
* @property {Record<string, tf.Tensor>} variables
23+
*/
1024
export default class CheckpointLoader {
25+
/**
26+
* @param {string} urlPath - the directory URL
27+
*/
1128
constructor(urlPath) {
12-
this.urlPath = urlPath;
13-
if (this.urlPath.charAt(this.urlPath.length - 1) !== '/') {
14-
this.urlPath += '/';
15-
}
29+
this.urlPath = urlPath.endsWith('/') ? urlPath : `${urlPath}/`;
30+
this.variables = {};
1631
}
1732

33+
/**
34+
* @private
35+
* Executes the request to load the manifest.json file.
36+
*
37+
* @return {Promise<Manifest>}
38+
*/
1839
async loadManifest() {
19-
return new Promise((resolve, reject) => {
20-
const xhr = new XMLHttpRequest();
21-
xhr.open('GET', this.urlPath + MANIFEST_FILE);
22-
23-
xhr.onload = () => {
24-
this.checkpointManifest = JSON.parse(xhr.responseText);
25-
resolve();
26-
};
27-
xhr.onerror = (error) => {
28-
reject();
29-
throw new Error(`${MANIFEST_FILE} not found at ${this.urlPath}. ${error}`);
30-
};
31-
xhr.send();
32-
});
40+
try {
41+
const response = await axios.get(this.urlPath + MANIFEST_FILE);
42+
return response.data;
43+
} catch (error) {
44+
throw new Error(`${MANIFEST_FILE} not found at ${this.urlPath}. ${error}`);
45+
}
3346
}
3447

48+
/**
49+
* @private
50+
* Executes the request to load the file for a variable.
51+
*
52+
* @param {string} varName
53+
* @return {Promise<tf.Tensor>}
54+
*/
55+
async loadVariable(varName) {
56+
const manifest = await this.getCheckpointManifest();
57+
if (!(varName in manifest)) {
58+
throw new Error(`Cannot load non-existent variable ${varName}`);
59+
}
60+
const { filename, shape } = manifest[varName];
61+
const url = this.urlPath + filename;
62+
try {
63+
const response = await axios.get(url, { responseType: 'arraybuffer' });
64+
const values = new Float32Array(response.data);
65+
return tf.tensor(values, shape);
66+
} catch (error) {
67+
throw new Error(`Error loading variable ${varName} from URL ${url}: ${error}`);
68+
}
69+
}
3570

71+
/**
72+
* @public
73+
* Lazy-load the contents of the manifest.json file.
74+
*
75+
* @return {Promise<Manifest>}
76+
*/
3677
async getCheckpointManifest() {
37-
if (this.checkpointManifest == null) {
38-
await this.loadManifest();
78+
if (!this.checkpointManifest) {
79+
this.checkpointManifest = await this.loadManifest();
3980
}
4081
return this.checkpointManifest;
4182
}
4283

84+
/**
85+
* @public
86+
* Get the property names for each variable in the manifest.
87+
*
88+
* @return {Promise<string[]>}
89+
*/
90+
async getKeys() {
91+
const manifest = await this.getCheckpointManifest();
92+
return Object.keys(manifest);
93+
}
94+
95+
/**
96+
* @public
97+
* Get a dictionary with the tensors for all variables in the manifest.
98+
*
99+
* @return {Promise<Record<string, tf.Tensor>>}
100+
*/
43101
async getAllVariables() {
44-
if (this.variables != null) {
45-
return Promise.resolve(this.variables);
46-
}
47-
await this.getCheckpointManifest();
48-
const variableNames = Object.keys(this.checkpointManifest);
102+
// Ensure that all keys are loaded and then return the dictionary.
103+
const variableNames = await this.getKeys();
49104
const variablePromises = variableNames.map(v => this.getVariable(v));
50-
return Promise.all(variablePromises).then((variables) => {
51-
this.variables = {};
52-
for (let i = 0; i < variables.length; i += 1) {
53-
this.variables[variableNames[i]] = variables[i];
54-
}
55-
return this.variables;
56-
});
105+
await Promise.all(variablePromises);
106+
return this.variables;
57107
}
58-
getVariable(varName) {
59-
if (!(varName in this.checkpointManifest)) {
60-
throw new Error(`Cannot load non-existent variable ${varName}`);
61-
}
62-
const variableRequestPromiseMethod = (resolve) => {
63-
const xhr = new XMLHttpRequest();
64-
xhr.responseType = 'arraybuffer';
65-
const fname = this.checkpointManifest[varName].filename;
66-
xhr.open('GET', this.urlPath + fname);
67-
xhr.onload = () => {
68-
if (xhr.status === 404) {
69-
throw new Error(`Not found variable ${varName}`);
70-
}
71-
const values = new Float32Array(xhr.response);
72-
const tensor = tf.tensor(values, this.checkpointManifest[varName].shape);
73-
resolve(tensor);
74-
};
75-
xhr.onerror = (error) => {
76-
throw new Error(`Could not fetch variable ${varName}: ${error}`);
77-
};
78-
xhr.send();
79-
};
80-
if (this.checkpointManifest == null) {
81-
return new Promise((resolve) => {
82-
this.loadManifest().then(() => {
83-
new Promise(variableRequestPromiseMethod).then(resolve);
84-
});
85-
});
108+
109+
/**
110+
* @public
111+
* Access a single variable from its key. Will load only if not previously loaded.
112+
*
113+
* @param {string} varName
114+
* @return {Promise<tf.Tensor>}
115+
*/
116+
async getVariable(varName) {
117+
if (!this.variables[varName]) {
118+
this.variables[varName] = await this.loadVariable(varName);
86119
}
87-
return new Promise(variableRequestPromiseMethod);
120+
return this.variables[varName];
88121
}
89122
}

0 commit comments

Comments
 (0)