Home Manual Reference Source Test

test/unit/model_interface_spec.mjs

import chai from 'chai';
// import sinon from 'sinon';
import sinonChai from 'sinon-chai';
import chaiAsPromised from 'chai-as-promised';
import 'babel-polyfill';
import { TensorScriptModelInterface, } from '../../index.mjs';
const expect = chai.expect;
chai.use(sinonChai);
chai.use(chaiAsPromised);

/** @test {TensorScriptModelInterface} */
describe('TensorScriptModelInterface', () => {
  /** @test {TensorScriptModelInterface#constructor} */
  describe('constructor', () => {
    it('should export a named module class', () => {
      const TSM = new TensorScriptModelInterface();
      const TSMConfigured = new TensorScriptModelInterface({ test: 'prop', });
      expect(TensorScriptModelInterface).to.be.a('function');
      expect(TSM).to.be.instanceOf(TensorScriptModelInterface);
      expect(TSMConfigured.settings.test).to.eql('prop');
    });
  });
  /** @test {TensorScriptModelInterface#reshape} */
  describe('reshape', () => {
    it('should export a static method', () => {
      expect(TensorScriptModelInterface.reshape).to.be.a('function');
    });
    it('should reshape an array into a matrix', () => {
      const array = [1, 0, 0, 1, ];
      const shape = [2, 2, ];
      const matrix = [
        [1, 0,],
        [0, 1,],
      ];
      const result = TensorScriptModelInterface.reshape(array, shape);
      expect(result).to.eql(matrix);
      // expect(TensorScriptModelInterface.reshape.bind(null, array, [1, 2, ])).to.throw(/specified shape/);
    });
    it('should reshape multiple dimensions', () => {
      const array = [ 1, 1, 0, 1, 1, 0, ];
      const shape = [ 2, 3, 1, ];
      const matrix = [
        [ [1], [1], [0], ],
        [ [1], [1], [0], ],
      ];
      const result = TensorScriptModelInterface.reshape(array, shape);
      // console.log({ result });
    });
  });
  /** @test {TensorScriptModelInterface#getInputShape} */
  describe('getInputShape', () => {
    it('should export a static method', () => {
      expect(TensorScriptModelInterface.getInputShape).to.be.a('function');
    });
    it('should return the shape of a matrix', () => {
      const matrix = [
        [1, 0,],
        [0, 1,],
      ];
      const matrix2 = [
        [1, 0,],
        [1, 0,],
        [1, 0,],
        [1, 0,],
        [1, 0,],
        [1, 0,],
        [0, 1,],
      ];
      const matrix3 = [
        [1, 0, 4, 5,],
        [1, 0, 4, 5,],
        [1, 0, 4, 5,],
        [1, 0, 4, 5,],
        [1, 0, 4, 5,],
      ];
      const matrix4 = [
        [1, 0, 4, 5,],
        [1, 0, 4,],
        [1, 0, 4, 5,],
      ];
      const matrix5 = [
        [[1, ], [1, ], [0, ], ],
        [[1, ], [1, ], [0, ], ],
        [[1, ], [1, ], [0, ], ],
        [[1, ], [1, ], [0, ], ],
        [[1, ], [1, ], [0, ], ],
        [[1, ], [1, ], [0, ], ],
      ];
      TensorScriptModelInterface.getInputShape(matrix5);
      expect(TensorScriptModelInterface.getInputShape(matrix)).to.eql([2, 2,]);
      expect(TensorScriptModelInterface.getInputShape(matrix2)).to.eql([7, 2,]);
      expect(TensorScriptModelInterface.getInputShape(matrix3)).to.eql([5, 4,]);
      expect(TensorScriptModelInterface.getInputShape.bind(null, matrix4)).to.throw(/input must have the same length in each row/);
      expect(TensorScriptModelInterface.getInputShape(matrix5)).to.eql([6, 3,1]);
    });
    it('should throw an error if input is not a matrix', () => {
      expect(TensorScriptModelInterface.getInputShape.bind()).to.throw(/must be a matrix/);
    });
  });
  /** @test {TensorScriptModelInterface#train} */
  describe('train', () => {
    it('should throw an error if train method is not implemented', () => {
      class MLR extends TensorScriptModelInterface{
        train(x, y) {
          return true;
        }
      }
      const TSM = new TensorScriptModelInterface();
      const TSMMLR = new MLR();
      expect(TSM.train).to.be.a('function');
      expect(TSM.train.bind(null)).to.throw('train method is not implemented');
      expect(TSMMLR.train).to.be.a('function');
      expect(TSMMLR.train.bind(null)).to.be.ok;
    });
  });
  /** @test {TensorScriptModelInterface#calculate} */
  describe('calculate', () => {
    it('should throw an error if calculate method is not implemented', () => {
      class MLR extends TensorScriptModelInterface{
        calculate(x, y) {
          return true;
        }
      }
      const TSM = new TensorScriptModelInterface();
      const TSMMLR = new MLR();
      expect(TSM.calculate).to.be.a('function');
      expect(TSM.calculate.bind(null)).to.throw('calculate method is not implemented');
      expect(TSMMLR.calculate).to.be.a('function');
      expect(TSMMLR.calculate.bind(null)).to.be.ok;
    });
  });
  /** @test {TensorScriptModelInterface#predict} */
  describe('predict', () => {
    class MLR extends TensorScriptModelInterface{
      calculate(x) {
        this.yShape = [100, 2,];
        return {
          data: () => new Promise((resolve) => {
            const predictions = new Float32Array([21.41, 31.74, 41.01, 51.53,]);
            resolve(predictions);
          }),
        };
      }
    }
    it('should throw an error if input is invalid', async function () {
      const TSMMLR = new MLR();
      try {
        const predictPromise = await TSMMLR.predict();
        expect(predictPromise).to.not.exist;
      } catch (e) {
        expect(e).to.be.an('error');
        expect(e).to.match(/invalid input matrix/);
      }
      try {
        const predictPromiseCatch = await TSMMLR.predict([1,]);
        expect(predictPromiseCatch).to.not.exist;
      } catch (e2) {
        expect(e2).to.be.an('error');
        expect(e2).to.match(/Dimension mismatch/);
      }
    });
    it('should return preductions', async function () {
      const TSMMLR = new MLR();
      const input = [
        [1, 2, ],
        [1, 2, ],
      ];
      const predictions = await TSMMLR.predict(input);
      const predictionsRounded = await TSMMLR.predict(input, { probability:false, });
      const predictionsRaw = await TSMMLR.predict(input, { json: false, });
      expect(predictions).to.have.lengthOf(2);
      expect(predictionsRaw).to.be.a('Float32Array');
      predictionsRounded.forEach(predRow => {
        predRow.forEach(pred => {
          expect(Number.isInteger(pred)).to.be.true;
        });
      });
    });
  });
  /** @test {TensorScriptModelInterface#loadModel} */
  describe('loadModel', () => {
    it('should call tensorflow load model and store it', async function () {
      const TSM = new TensorScriptModelInterface({}, {
        tf: {
          loadModel: () => new Promise((resolve, reject) => resolve(true)),
        },
      });
      const loadedModel = await TSM.loadModel();
      expect(loadedModel).to.be.true;
    });
  });
});