-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.cpp
More file actions
80 lines (78 loc) · 3.82 KB
/
test.cpp
File metadata and controls
80 lines (78 loc) · 3.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <gflags/gflags.h>
#include <thread>
#include <opencv2/opencv.hpp>
#include "sam3.h"
DEFINE_string(vision_encoder, "sam3/vision-encoder.onnx", "Path to the viion encoder model");
DEFINE_string(text_encoder, "sam3/text-encoder.onnx", "Path to the text encoder model");
DEFINE_string(geometry_encoder, "sam3/geometry-encoder.onnx", "Path to the geometry encoder model");
DEFINE_string(decoder, "sam3/decoder.onnx", "Path to the decoder model");
DEFINE_string(tokenizer, "sam3/tokenizer.json", "Path to the tokenizer");
DEFINE_string(text, "", "Text prompt");
DEFINE_string(boxes, "", "Boxes prompt");
DEFINE_double(threshold, 0.5, "Threshold for detections");
DEFINE_string(image, "david-tomaseti-Vw2HZQ1FGjU-unsplash.jpg", "Path to the image");
DEFINE_string(device, "cpu", "cpu or cuda:0(1,2,3...)");
int main(int argc, char** argv) {
gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
Sam3 sam3;
std::cout<<"loadModel started"<<std::endl;
std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
bool successLoadModel = sam3.loadModel(FLAGS_vision_encoder, FLAGS_text_encoder, FLAGS_geometry_encoder, FLAGS_decoder, FLAGS_tokenizer, std::thread::hardware_concurrency(), FLAGS_device);
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
std::cout << "sec = " << (std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count()) / 1000000.0 <<std::endl;
if(!successLoadModel){
std::cout<<"loadModel error"<<std::endl;
return 1;
}
std::cout<<"preprocessImage started"<<std::endl;
begin = std::chrono::steady_clock::now();
cv::Mat image = cv::imread(FLAGS_image, cv::IMREAD_COLOR);
cv::Size imageSize = cv::Size(image.cols, image.rows);
cv::Size inputSize = sam3.getInputSize();
cv::resize(image, image, inputSize);
bool successPreprocessImage = sam3.preprocessImage(image);
end = std::chrono::steady_clock::now();
std::cout << "sec = " << (std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count()) / 1000000.0 <<std::endl;
if(!successPreprocessImage){
std::cout<<"preprocessImage error"<<std::endl;
return 1;
}
std::cout<<"Encode text tarted"<<std::endl;
begin = std::chrono::steady_clock::now();
std::vector<std::string> text_list = split(FLAGS_text, ',');
auto [rects_list, labels_list] = parse_box_list_prompts(FLAGS_boxes, imageSize);
sam3.alignTextsAndBoxesBatchSize(&text_list, &rects_list, &labels_list);
bool successEncodeText = sam3.encodeTextBatch(text_list);
end = std::chrono::steady_clock::now();
std::cout << "sec = " << (std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count()) / 1000000.0 <<std::endl;
if(!successEncodeText){
std::cout<<"Encode text error"<<std::endl;
return 1;
}
std::cout<<"Encode boxes started"<<std::endl;
begin = std::chrono::steady_clock::now();
bool successEncodeBoxes = sam3.encodeBoxesBatch(rects_list, labels_list);
end = std::chrono::steady_clock::now();
std::cout << "sec = " << (std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count()) / 1000000.0 <<std::endl;
if(!successEncodeBoxes){
std::cout<<"Encode boxes error"<<std::endl;
return 1;
}
std::cout<<"Decode started"<<std::endl;
begin = std::chrono::steady_clock::now();
float threshold = FLAGS_threshold;
bool skipDecode = false;
auto [masks, boxes] = sam3.decodeBatch(threshold, imageSize, skipDecode);
end = std::chrono::steady_clock::now();
std::cout << "sec = " << (std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count()) / 1000000.0 <<std::endl;
if(masks.size() == 0){
std::cout<<"Decode error"<<std::endl;
return 1;
}
std::cout<<"Found "<<masks.size()<<std::endl;
for(int i = 0; i < masks.size(); i++){
std::string fileName = "mask" + std::to_string(i) + ".png";
cv::imwrite(fileName, masks[i]);
}
return 0;
}