JSONStackModel

DeepLearningClassification

      
const independentVariables = [
'sepal_length_cm',
'sepal_width_cm',
'petal_length_cm',
'petal_width_cm',
];
const dependentVariables = [
'plant_Iris-setosa',
'plant_Iris-versicolor',
'plant_Iris-virginica',
];
const columns = independentVariables.concat(dependentVariables);
let housingDataCSV;
let DataSet;
let x_matrix;
let y_matrix;
let nnClassification;
let nnClassificationModel;
const fit = {
epochs: 100,
batchSize: 5,
};
const encodedAnswers = {
'Iris-setosa': [1, 0, 0, ],
'Iris-versicolor': [0, 1, 0, ],
'Iris-virginica': [0, 0, 1, ],
};
const input_x = [
[5.1, 3.5, 1.4, 0.2, ],
[6.3,3.3,6.0,2.5, ],
[5.6, 3.0, 4.5, 1.5, ],
[5.0, 3.2, 1.2, 0.2, ],
[4.5, 2.3, 1.3, 0.3, ],
];
function scaleColumnMap(columnName) {
return {
  name: columnName,
  options: {
    strategy: 'scale',
    scaleOptions: {
      strategy:'standard',
    },
  },
};
}
/** @test {DeepLearningClassification} */
describe('DeepLearningClassification', function () {
beforeAll(async function () {
  /**
   * encodedData = [ 
   *  { sepal_length_cm: 5.1,
       sepal_width_cm: 3.5,
      petal_length_cm: 1.4,
      petal_width_cm: 0.2,
      plant: 'Iris-setosa',
      'plant_Iris-setosa': 1,
      'plant_Iris-versicolor': 0,
      'plant_Iris-virginica': 0 },
      ...
      { sepal_length_cm: 5.9,
      sepal_width_cm: 3,
      petal_length_cm: 4.2,
      petal_width_cm: 1.5,
      plant: 'Iris-versicolor',
      'plant_Iris-setosa': 0,
      'plant_Iris-versicolor': 1,
      'plant_Iris-virginica': 0 },
    ];
  */
  housingDataCSV = await ms.csv.loadCSV(path.join(__dirname,'/test/mock/data/iris_data.csv'));
  DataSet = new ms.DataSet(housingDataCSV);
  // DataSet.fitColumns({
  //   columns: columns.map(scaleColumnMap),
  //   returnData:false,
  // });
  const encodedData = DataSet.fitColumns({
    columns: [
      {
        name: 'plant',
        options: {
          strategy: 'onehot',
        },
      },
    ],
    returnData:true,
  });
  x_matrix = DataSet.columnMatrix(independentVariables); 
  y_matrix = DataSet.columnMatrix(dependentVariables);
  /*
  x_matrix = [
    [ 5.1, 3.5, 1.4, 0.2 ],
    [ 4.9, 3, 1.4, 0.2 ],
    [ 4.7, 3.2, 1.3, 0.2 ],
    ...
  ]; 
  y_matrix = [
    [ 1, 0, 0 ],
    [ 1, 0, 0 ],
    [ 1, 0, 0 ],
    ...
  ] 
  */
  // console.log({ x_matrix, y_matrix, });

  nnClassification = new DeepLearningClassification({ fit, });
  nnClassificationModel = await nnClassification.train(x_matrix, y_matrix);
},120000);
const predictions = await nnClassification.predict(input_x);
      const answers = await nnClassification.predict(input_x, {
        probability:false,
      });
      const shape = nnClassification.getInputShape(predictions);
    

    

TextEmbedding

      const TextEmbedder = new JSONStackModel.TextEmbedding();
      await TextEmbedder.train();
      const sentences = [
        'Hello.',
        'How are you?',
      ];
      const predictions = await TextEmbedder.predict(sentences);