fix: clean up

This commit is contained in:
mbukeRepo 2024-07-08 08:24:19 +02:00
parent eec7aa3337
commit 8536882471

View File

@ -6,7 +6,6 @@ from src.config.crop_config import CropConfig
from src.live_portrait_pipeline import LivePortraitPipeline from src.live_portrait_pipeline import LivePortraitPipeline
import requests import requests
class Predictor(BasePredictor): class Predictor(BasePredictor):
def setup(self) -> None: def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient""" """Load the model into memory to make running multiple predictions efficient"""
@ -17,18 +16,27 @@ class Predictor(BasePredictor):
def predict( def predict(
self, self,
image: Path = Input(description="Portrait image") input_image_path: Path = Input(description="Portrait image"),
input_video_path: Path = Input(description="Driving video"),
flag_relative_input: bool = Input(description="relative motion", default=True),
flag_do_crop_input: bool = Input(description="We recommend checking the do crop option when facial areas occupy a relatively small portion of your image.", default=True),
flag_pasteback: bool = Input(description="paste-back", default=True),
) -> Path: ) -> Path:
"""Run a single prediction on the model""" """Run a single prediction on the model"""
user_args = ArgumentConfig(
flag_relative=flag_relative_input,
flag_do_crop=flag_do_crop_input,
flag_pasteback=flag_pasteback,
source_image=input_image_path,
driving_info=str(input_video_path),
output_dir="/tmp/"
)
self.live_portrait_pipeline.cropper.update_config(user_args.__dict__)
self.live_portrait_pipeline.live_portrait_wrapper.update_config(user_args.__dict__)
video_path, _ = self.live_portrait_pipeline.execute( video_path, _ = self.live_portrait_pipeline.execute(
ArgumentConfig( user_args
source_image=image,
driving_info="assets/examples/driving/d0.mp4",
output_dir="/tmp/",
flag_pasteback=False,
flag_do_crop=False,
)
) )
return Path(video_path) return Path(video_path)