From 2105657f742c2f06a507dd046d06942f63923edf Mon Sep 17 00:00:00 2001 From: guojianzhu Date: Sun, 29 Dec 2024 17:37:15 +0800 Subject: [PATCH] chore: updare the animals model version --- inference.py | 3 ++- inference_animals.py | 3 ++- src/config/inference_config.py | 9 +++++---- src/utils/animal_landmark_runner.py | 2 +- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/inference.py b/inference.py index 5c80818..97c1436 100644 --- a/inference.py +++ b/inference.py @@ -1,6 +1,7 @@ # coding: utf-8 + """ -for human +The entrance of humans """ import os diff --git a/inference_animals.py b/inference_animals.py index 8fddf7b..20f8452 100644 --- a/inference_animals.py +++ b/inference_animals.py @@ -1,6 +1,7 @@ # coding: utf-8 + """ -for animal +The entrance of animal """ import os diff --git a/src/config/inference_config.py b/src/config/inference_config.py index c9ed197..485679c 100644 --- a/src/config/inference_config.py +++ b/src/config/inference_config.py @@ -26,10 +26,11 @@ class InferenceConfig(PrintableConfig): checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip # ANIMAL MODEL CONFIG, NOT EXPORTED PARAMS - checkpoint_F_animal: str = make_abs_path('../../pretrained_weights/liveportrait_animals/base_models/appearance_feature_extractor.pth') # path to checkpoint of F - checkpoint_M_animal: str = make_abs_path('../../pretrained_weights/liveportrait_animals/base_models/motion_extractor.pth') # path to checkpoint pf M - checkpoint_G_animal: str = make_abs_path('../../pretrained_weights/liveportrait_animals/base_models/spade_generator.pth') # path to checkpoint of G - checkpoint_W_animal: str = make_abs_path('../../pretrained_weights/liveportrait_animals/base_models/warping_module.pth') # path to checkpoint of W + version_animals = "_v1.1" # set it to "" for the previous version + checkpoint_F_animal: str = make_abs_path(f'../../pretrained_weights/liveportrait_animals/base_models{version_animals}/appearance_feature_extractor.pth') # path to checkpoint of F + checkpoint_M_animal: str = make_abs_path(f'../../pretrained_weights/liveportrait_animals/base_models{version_animals}/motion_extractor.pth') # path to checkpoint pf M + checkpoint_G_animal: str = make_abs_path(f'../../pretrained_weights/liveportrait_animals/base_models{version_animals}/spade_generator.pth') # path to checkpoint of G + checkpoint_W_animal: str = make_abs_path(f'../../pretrained_weights/liveportrait_animals/base_models{version_animals}/warping_module.pth') # path to checkpoint of W checkpoint_S_animal: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip, NOTE: use human temporarily! # EXPORTED PARAMS diff --git a/src/utils/animal_landmark_runner.py b/src/utils/animal_landmark_runner.py index c66efe4..dd91aa5 100644 --- a/src/utils/animal_landmark_runner.py +++ b/src/utils/animal_landmark_runner.py @@ -60,7 +60,7 @@ class XPoseRunner(object): def get_unipose_output(self, image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold): instance_list = instance_text_prompt.split(',') - + if len(keypoint_text_prompt) == 9: # torch.Size([1, 512]) torch.Size([9, 512]) ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_9, self.kpt_text_embeddings_9