Integrate batch inference API
This commit is contained in:
parent
b87f076673
commit
dc6d3e2111
2 changed files with 51 additions and 25 deletions
|
@ -18,6 +18,8 @@ class SemanticSearchService {
|
|||
SemanticSearchService._privateConstructor();
|
||||
static final Computer _computer = Computer.shared();
|
||||
|
||||
static const int batchSize = 4;
|
||||
|
||||
bool hasLoaded = false;
|
||||
final _logger = Logger("SemanticSearchService");
|
||||
Future<List<EnteFile>>? _ongoingRequest;
|
||||
|
@ -123,7 +125,7 @@ class SemanticSearchService {
|
|||
|
||||
Future<void> _loadModel() async {
|
||||
const modelPath =
|
||||
"assets/models/clip/openai_clip-vit-base-patch32.ggmlv0.q4_0.bin";
|
||||
"assets/models/clip/openai_clip-vit-base-patch32.ggmlv0.f16.bin";
|
||||
|
||||
final path = await _getAccessiblePathForAsset(modelPath, "model.bin");
|
||||
final startTime = DateTime.now();
|
||||
|
@ -155,51 +157,75 @@ class SemanticSearchService {
|
|||
Future<void> _computeMissingEmbeddings() async {
|
||||
final files = await FilesDB.instance.getFilesWithoutEmbeddings();
|
||||
_logger.info(files.length.toString() + " pending to be embedded");
|
||||
int counter = 0;
|
||||
final List<EnteFile> batch = [];
|
||||
for (final file in files) {
|
||||
await _computeImageEmbedding(file);
|
||||
if (counter < batchSize) {
|
||||
batch.add(file);
|
||||
counter++;
|
||||
} else {
|
||||
await _computeImageEmbeddings(batch);
|
||||
counter = 0;
|
||||
batch.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Future<void> _computeImageEmbedding(EnteFile file) async {
|
||||
Future<void> _computeImageEmbeddings(List<EnteFile> files) async {
|
||||
if (!hasLoaded) {
|
||||
return;
|
||||
}
|
||||
// _logger.info("Running clip");
|
||||
final imagePath = (await getThumbnailFile(file))!.path;
|
||||
_logger.info("Running clip over " + files.length.toString() + " items");
|
||||
final List<EnteFile> filesToBeIndexed = [];
|
||||
final List<String> filePaths = [];
|
||||
|
||||
final startTime = DateTime.now();
|
||||
// ignore: prefer_typing_uninitialized_variables
|
||||
var imageEmbedding;
|
||||
final embeddings = await FilesDB.instance.getAllEmbeddings();
|
||||
bool hasCachedEmbedding = false;
|
||||
for (final embedding in embeddings) {
|
||||
if (embedding.id == file.generatedID) {
|
||||
imageEmbedding = embedding.embedding;
|
||||
hasCachedEmbedding = true;
|
||||
_logger.info("Found cached embedding");
|
||||
for (final file in files) {
|
||||
bool hasCachedEmbedding = false;
|
||||
for (final embedding in embeddings) {
|
||||
if (embedding.id == file.generatedID) {
|
||||
_logger.info("Found cached embedding");
|
||||
hasCachedEmbedding = true;
|
||||
}
|
||||
}
|
||||
if (!hasCachedEmbedding) {
|
||||
filesToBeIndexed.add(file);
|
||||
filePaths.add((await getThumbnailFile(file))!.path);
|
||||
}
|
||||
}
|
||||
if (!hasCachedEmbedding) {
|
||||
imageEmbedding ??= await _computer.compute(
|
||||
createImageEmbedding,
|
||||
param: {
|
||||
"imagePath": imagePath,
|
||||
},
|
||||
taskName: "createImageEmbedding",
|
||||
final imageEmbeddings = await _computer.compute(
|
||||
createImageEmbeddings,
|
||||
param: {
|
||||
"imagePaths": filePaths,
|
||||
},
|
||||
taskName: "createImageEmbedding",
|
||||
) as List<List<double>>;
|
||||
for (int i = 0; i < imageEmbeddings.length; i++) {
|
||||
await FilesDB.instance.insertEmbedding(
|
||||
Embedding(
|
||||
files[i].generatedID!,
|
||||
imageEmbeddings[i],
|
||||
-1,
|
||||
),
|
||||
);
|
||||
await FilesDB.instance
|
||||
.insertEmbedding(Embedding(file.generatedID!, imageEmbedding, -1));
|
||||
}
|
||||
final endTime = DateTime.now();
|
||||
_logger.info(
|
||||
"createImageEmbedding took: " +
|
||||
"createImageEmbeddings took: " +
|
||||
(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
|
||||
.toString() +
|
||||
"ms",
|
||||
"ms for " +
|
||||
imageEmbeddings.length.toString() +
|
||||
" items",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
List<List<double>> createImageEmbeddings(Map args) {
|
||||
return CLIP.createBatchImageEmbedding(args["imagePaths"]);
|
||||
}
|
||||
|
||||
List<double> createImageEmbedding(Map args) {
|
||||
return CLIP.createImageEmbedding(args["imagePath"]);
|
||||
}
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 84c5a499e700cc72e546c42d771d4a92a7a21ec2
|
||||
Subproject commit c15bc340878fbc679027083c962f1fcbfe68b263
|
Loading…
Add table
Reference in a new issue