0
0
mirror of https://github.com/opencv/opencv.git synced 2026-01-18 17:21:42 +01:00

Merge pull request #28282 from abhishek-gola:conv_kernel_size_fix

Support Conv kernel inference from initializer weights
This commit is contained in:
Alexander Smorkalov
2025-12-23 15:04:51 +03:00
committed by GitHub

View File

@@ -2015,6 +2015,25 @@ void ONNXImporter::parseConv(LayerParams& layerParams, const opencv_onnx::NodePr
layerParams.blobs.push_back(getBlob(node_proto, j));
}
}
// ONNX allows omitting 'kernel_shape' attribute for Conv. In that case, it should be inferred from weights.
// See: https://onnx.ai/onnx/operators/onnx__Conv.html
if (!layerParams.has("kernel_size"))
{
Mat weights;
if (!layerParams.blobs.empty())
weights = layerParams.blobs[0];
else if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
weights = getBlob(node_proto, 1);
if (!weights.empty() && weights.dims >= 3)
{
const int kDims = weights.dims - 2;
std::vector<int32_t> kernel(kDims);
for (int i = 0; i < kDims; ++i)
kernel[i] = weights.size[2 + i];
layerParams.set("kernel_size", DictValue::arrayInt(kernel.data(), static_cast<int>(kernel.size())));
}
}
int outCn = layerParams.blobs.empty() ? outShapes[node_proto.input(1)][0] : layerParams.blobs[0].size[0];
layerParams.set("num_output", outCn);
@@ -2031,6 +2050,20 @@ void ONNXImporter::parseConvTranspose(LayerParams& layerParams, const opencv_onn
layerParams.set("num_output", layerParams.blobs[0].size[1] * layerParams.get<int>("group", 1));
layerParams.set("bias_term", node_proto.input_size() == 3);
// ONNX allows omitting 'kernel_shape' attribute for ConvTranspose. Infer it from weights if needed.
if (!layerParams.has("kernel_size"))
{
const Mat& weights = layerParams.blobs[0];
if (!weights.empty() && weights.dims >= 3)
{
const int kDims = weights.dims - 2;
std::vector<int32_t> kernel(kDims);
for (int i = 0; i < kDims; ++i)
kernel[i] = weights.size[2 + i];
layerParams.set("kernel_size", DictValue::arrayInt(kernel.data(), static_cast<int>(kernel.size())));
}
}
if (!layerParams.has("kernel_size"))
CV_Error(Error::StsNotImplemented,
"Required attribute 'kernel_size' is not present.");