Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions core/foundation_stereo.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,23 +372,20 @@ def forward(self, features_left_04, features_left_08, features_left_16, features
return disp_up


class TrtRunner(nn.Module):
def __init__(self, args, feature_runner_engine_path, post_runner_engine_path):
class _BaseTrtRunner(nn.Module):
def __init__(self, args):
super().__init__()
import tensorrt as trt
self.args = args
with open(feature_runner_engine_path, 'rb') as file:
engine_data = file.read()
self.trt_logger = trt.Logger(trt.Logger.WARNING)
self.feature_engine = trt.Runtime(self.trt_logger).deserialize_cuda_engine(engine_data)
self.feature_context = self.feature_engine.create_execution_context()

with open(post_runner_engine_path, 'rb') as file:
def load_engine(self, engine_path):
import tensorrt as trt
with open(engine_path, 'rb') as file:
engine_data = file.read()
self.post_engine = trt.Runtime(self.trt_logger).deserialize_cuda_engine(engine_data)
self.post_context = self.post_engine.create_execution_context()
self.max_disp = args.max_disp
self.cv_group = args.get('cv_group', 8)
engine = trt.Runtime(self.trt_logger).deserialize_cuda_engine(engine_data)
context = engine.create_execution_context()
return engine, context

def trt_dtype_to_torch(self, dt):
import tensorrt as trt
Expand Down Expand Up @@ -429,6 +426,15 @@ def run_trt(self, engine, context, inputs_by_name:dict):
assert ok
return outputs


class TrtRunner(_BaseTrtRunner):
def __init__(self, args, feature_runner_engine_path, post_runner_engine_path):
super().__init__(args)
self.feature_engine, self.feature_context = self.load_engine(feature_runner_engine_path)
self.post_engine, self.post_context = self.load_engine(post_runner_engine_path)
self.max_disp = args.max_disp
self.cv_group = args.get('cv_group', 8)

def forward(self, image1, image2):
import tensorrt as trt
feat_out = self.run_trt(self.feature_engine, self.feature_context, {'left': image1, 'right': image2})
Expand All @@ -442,4 +448,14 @@ def forward(self, image1, image2):
del post_inputs[k]
out = self.run_trt(self.post_engine, self.post_context, post_inputs)
disp = out['disp']
return disp
return disp


class SingleTrtRunner(_BaseTrtRunner):
def __init__(self, args, engine_path):
super().__init__(args)
self.engine, self.context = self.load_engine(engine_path)

def forward(self, image1, image2):
out = self.run_trt(self.engine, self.context, {'left': image1, 'right': image2})
return out['disp']
8 changes: 2 additions & 6 deletions core/submodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,17 +609,14 @@ def __init__(self, in_planes, ratio=16):
"""From selective-IGEV
"""
super(ChannelAttentionEnhancement, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)

self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
self.sigmoid = nn.Sigmoid()

def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
avg_out = self.fc(torch.mean(x, dim=(2, 3), keepdim=True))
max_out = self.fc(torch.amax(x, dim=(2, 3), keepdim=True))
out = avg_out + max_out
return self.sigmoid(out)

Expand Down Expand Up @@ -672,4 +669,3 @@ def forward(self, x):

x = input + x
return x

85 changes: 58 additions & 27 deletions scripts/make_onnx.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings, argparse, logging, os, sys,zipfile
import torch.nn as nn
os.environ['TORCH_COMPILE_DISABLE'] = '1'
os.environ['TORCHDYNAMO_DISABLE'] = '1'
code_dir = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -21,6 +22,23 @@ def forward(self, left, right):
return disp


class SingleOnnxRunner(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

@torch.no_grad()
def forward(self, left, right):
with torch.amp.autocast('cuda', enabled=True, dtype=U.AMP_DTYPE):
return self.model(
left,
right,
iters=self.model.args.valid_iters,
test_mode=True,
optimize_build_volume='pytorch1',
)



if __name__ == '__main__':
parser = argparse.ArgumentParser()
Expand All @@ -37,14 +55,17 @@ def forward(self, left, right):
parser.add_argument('--n_gru_layers', type=int, default=1, help="number of hidden GRU levels")
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
parser.add_argument('--low_memory', type=int, default=1, help='reduce memory usage')
parser.add_argument('--single_onnx', action='store_true', help='Export the full model to a single ONNX file using the pure PyTorch volume builder')
parser.add_argument('--single_onnx_name', type=str, default='foundation_stereo.onnx', help='Filename for the single-model ONNX export')
args = parser.parse_args()
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
os.makedirs(args.save_path, exist_ok=True)

torch.autograd.set_grad_enabled(False)

model = torch.load(args.model_dir, map_location='cpu', weights_only=False)
model.args.max_disp = args.max_disp
model.args.valid_iters = args.valid_iters
model.args.image_size = [args.height, args.width]
model.cuda().eval()

feature_runner = TrtFeatureRunner(model)
Expand All @@ -56,31 +77,41 @@ def forward(self, left, right):
left_img = torch.randn(1, 3, args.height, args.width).cuda().float()*255
right_img = torch.randn(1, 3, args.height, args.width).cuda().float()*255

torch.onnx.export(
feature_runner,
(left_img, right_img),
args.save_path+'/feature_runner.onnx',
opset_version=17,
input_names = ['left', 'right'],
output_names = ['features_left_04', 'features_left_08', 'features_left_16', 'features_left_32', 'features_right_04', 'stem_2x'],
do_constant_folding=True
)

features_left_04, features_left_08, features_left_16, features_left_32, features_right_04, stem_2x = feature_runner(left_img, right_img)
gwc_volume = build_gwc_volume_triton(features_left_04.half(), features_right_04.half(), args.max_disp//4, model.cv_group)
disp = post_runner(features_left_04.float(), features_left_08.float(), features_left_16.float(), features_left_32.float(), features_right_04.float(), stem_2x.float(), gwc_volume.float())

torch.onnx.export(
post_runner,
(features_left_04, features_left_08, features_left_16, features_left_32, features_right_04, stem_2x, gwc_volume),
args.save_path+'/post_runner.onnx',
opset_version=17,
input_names = ['features_left_04', 'features_left_08', 'features_left_16', 'features_left_32', 'features_right_04', 'stem_2x', 'gwc_volume'],
output_names = ['disp'],
do_constant_folding=True
)
if args.single_onnx:
single_runner = SingleOnnxRunner(model).cuda().eval()
torch.onnx.export(
single_runner,
(left_img, right_img),
os.path.join(args.save_path, args.single_onnx_name),
opset_version=17,
input_names=['left', 'right'],
output_names=['disp'],
do_constant_folding=True,
)
else:
torch.onnx.export(
feature_runner,
(left_img, right_img),
args.save_path+'/feature_runner.onnx',
opset_version=17,
input_names = ['left', 'right'],
output_names = ['features_left_04', 'features_left_08', 'features_left_16', 'features_left_32', 'features_right_04', 'stem_2x'],
do_constant_folding=True
)

features_left_04, features_left_08, features_left_16, features_left_32, features_right_04, stem_2x = feature_runner(left_img, right_img)
gwc_volume = build_gwc_volume_triton(features_left_04.half(), features_right_04.half(), args.max_disp//4, model.cv_group)
disp = post_runner(features_left_04.float(), features_left_08.float(), features_left_16.float(), features_left_32.float(), features_right_04.float(), stem_2x.float(), gwc_volume.float())

torch.onnx.export(
post_runner,
(features_left_04, features_left_08, features_left_16, features_left_32, features_right_04, stem_2x, gwc_volume),
args.save_path+'/post_runner.onnx',
opset_version=17,
input_names = ['features_left_04', 'features_left_08', 'features_left_16', 'features_left_32', 'features_right_04', 'stem_2x', 'gwc_volume'],
output_names = ['disp'],
do_constant_folding=True
)

with open(f'{args.save_path}/onnx.yaml', 'w') as f:
cfg = OmegaConf.to_container(model.args)
cfg['image_size'] = [args.height, args.width]
yaml.safe_dump(cfg, f)
yaml.safe_dump(OmegaConf.to_container(model.args), f)