fix: issue 483 (#484)

This commit is contained in:
Jianzhu Guo 2025-02-02 17:06:01 +08:00 committed by GitHub
parent 1df473f067
commit 6c4a883a9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 8 deletions

View File

@ -43,7 +43,7 @@ class XPoseRunner(object):
args = Config.fromfile(model_config_path)
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage)
checkpoint = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False)
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
model.eval()
return model

View File

@ -18,7 +18,7 @@
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
@ -50,7 +50,7 @@ at::Tensor ms_deform_attn_cuda_forward(
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
@ -61,7 +61,7 @@ at::Tensor ms_deform_attn_cuda_forward(
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
@ -81,7 +81,7 @@ at::Tensor ms_deform_attn_cuda_forward(
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
@ -127,11 +127,11 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
@ -150,4 +150,4 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
return {
grad_value, grad_sampling_loc, grad_attn_weight
};
}
}