Integrate batch inference API

This commit is contained in:
vishnukvmd 2023-09-23 23:58:19 +05:30
parent b87f076673
commit dc6d3e2111
2 changed files with 51 additions and 25 deletions

View file

@ -18,6 +18,8 @@ class SemanticSearchService {
SemanticSearchService._privateConstructor();
static final Computer _computer = Computer.shared();
static const int batchSize = 4;
bool hasLoaded = false;
final _logger = Logger("SemanticSearchService");
Future<List<EnteFile>>? _ongoingRequest;
@ -123,7 +125,7 @@ class SemanticSearchService {
Future<void> _loadModel() async {
const modelPath =
"assets/models/clip/openai_clip-vit-base-patch32.ggmlv0.q4_0.bin";
"assets/models/clip/openai_clip-vit-base-patch32.ggmlv0.f16.bin";
final path = await _getAccessiblePathForAsset(modelPath, "model.bin");
final startTime = DateTime.now();
@ -155,51 +157,75 @@ class SemanticSearchService {
Future<void> _computeMissingEmbeddings() async {
final files = await FilesDB.instance.getFilesWithoutEmbeddings();
_logger.info(files.length.toString() + " pending to be embedded");
int counter = 0;
final List<EnteFile> batch = [];
for (final file in files) {
await _computeImageEmbedding(file);
if (counter < batchSize) {
batch.add(file);
counter++;
} else {
await _computeImageEmbeddings(batch);
counter = 0;
batch.clear();
}
}
}
Future<void> _computeImageEmbedding(EnteFile file) async {
Future<void> _computeImageEmbeddings(List<EnteFile> files) async {
if (!hasLoaded) {
return;
}
// _logger.info("Running clip");
final imagePath = (await getThumbnailFile(file))!.path;
_logger.info("Running clip over " + files.length.toString() + " items");
final List<EnteFile> filesToBeIndexed = [];
final List<String> filePaths = [];
final startTime = DateTime.now();
// ignore: prefer_typing_uninitialized_variables
var imageEmbedding;
final embeddings = await FilesDB.instance.getAllEmbeddings();
bool hasCachedEmbedding = false;
for (final embedding in embeddings) {
if (embedding.id == file.generatedID) {
imageEmbedding = embedding.embedding;
hasCachedEmbedding = true;
_logger.info("Found cached embedding");
for (final file in files) {
bool hasCachedEmbedding = false;
for (final embedding in embeddings) {
if (embedding.id == file.generatedID) {
_logger.info("Found cached embedding");
hasCachedEmbedding = true;
}
}
if (!hasCachedEmbedding) {
filesToBeIndexed.add(file);
filePaths.add((await getThumbnailFile(file))!.path);
}
}
if (!hasCachedEmbedding) {
imageEmbedding ??= await _computer.compute(
createImageEmbedding,
param: {
"imagePath": imagePath,
},
taskName: "createImageEmbedding",
final imageEmbeddings = await _computer.compute(
createImageEmbeddings,
param: {
"imagePaths": filePaths,
},
taskName: "createImageEmbedding",
) as List<List<double>>;
for (int i = 0; i < imageEmbeddings.length; i++) {
await FilesDB.instance.insertEmbedding(
Embedding(
files[i].generatedID!,
imageEmbeddings[i],
-1,
),
);
await FilesDB.instance
.insertEmbedding(Embedding(file.generatedID!, imageEmbedding, -1));
}
final endTime = DateTime.now();
_logger.info(
"createImageEmbedding took: " +
"createImageEmbeddings took: " +
(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
.toString() +
"ms",
"ms for " +
imageEmbeddings.length.toString() +
" items",
);
}
}
List<List<double>> createImageEmbeddings(Map args) {
return CLIP.createBatchImageEmbedding(args["imagePaths"]);
}
List<double> createImageEmbedding(Map args) {
return CLIP.createImageEmbedding(args["imagePath"]);
}

@ -1 +1 @@
Subproject commit 84c5a499e700cc72e546c42d771d4a92a7a21ec2
Subproject commit c15bc340878fbc679027083c962f1fcbfe68b263