mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2025-02-05 18:08:13 +00:00
fix: issue 483 (#484)
This commit is contained in:
parent
1df473f067
commit
6c4a883a9e
@ -43,7 +43,7 @@ class XPoseRunner(object):
|
|||||||
args = Config.fromfile(model_config_path)
|
args = Config.fromfile(model_config_path)
|
||||||
args.device = device
|
args.device = device
|
||||||
model = build_model(args)
|
model = build_model(args)
|
||||||
checkpoint = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage)
|
checkpoint = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False)
|
||||||
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
@ -61,7 +61,7 @@ at::Tensor ms_deform_attn_cuda_forward(
|
|||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
{
|
{
|
||||||
auto columns = output_n.select(0, n);
|
auto columns = output_n.select(0, n);
|
||||||
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||||
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
spatial_shapes.data<int64_t>(),
|
spatial_shapes.data<int64_t>(),
|
||||||
@ -131,7 +131,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
{
|
{
|
||||||
auto grad_output_g = grad_output_n.select(0, n);
|
auto grad_output_g = grad_output_n.select(0, n);
|
||||||
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||||
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
grad_output_g.data<scalar_t>(),
|
grad_output_g.data<scalar_t>(),
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
|
Loading…
Reference in New Issue
Block a user