image-classifier.service.ts 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import { Injectable, Logger } from '@nestjs/common';
  2. import * as mobilenet from '@tensorflow-models/mobilenet';
  3. import * as cocoSsd from '@tensorflow-models/coco-ssd';
  4. import * as tf from '@tensorflow/tfjs-node';
  5. import * as fs from 'fs';
  6. @Injectable()
  7. export class ImageClassifierService {
  8. private readonly MOBILENET_VERSION = 2;
  9. private readonly MOBILENET_ALPHA = 1.0;
  10. private mobileNetModel: mobilenet.MobileNet;
  11. constructor() {
  12. Logger.log(
  13. `Running Node TensorFlow Version : ${tf.version['tfjs']}`,
  14. 'ImageClassifier',
  15. );
  16. mobilenet
  17. .load({
  18. version: this.MOBILENET_VERSION,
  19. alpha: this.MOBILENET_ALPHA,
  20. })
  21. .then((mobilenetModel) => (this.mobileNetModel = mobilenetModel));
  22. }
  23. async tagImage(thumbnailPath: string) {
  24. try {
  25. const isExist = fs.existsSync(thumbnailPath);
  26. if (isExist) {
  27. const tags = [];
  28. const image = fs.readFileSync(thumbnailPath);
  29. const decodedImage = tf.node.decodeImage(image, 3) as tf.Tensor3D;
  30. const predictions = await this.mobileNetModel.classify(decodedImage);
  31. for (const prediction of predictions) {
  32. if (prediction.probability >= 0.1) {
  33. tags.push(...prediction.className.split(',').map((e) => e.trim()));
  34. }
  35. }
  36. tf.dispose(decodedImage);
  37. return tags;
  38. }
  39. } catch (e) {
  40. console.log('Error reading file ', e);
  41. }
  42. }
  43. }