LivePortrait/predict.py

35 lines
1.1 KiB
Python
Raw Normal View History

2024-07-07 13:46:28 +00:00
from cog import BasePredictor, Input, Path, File
2024-07-06 11:13:39 +00:00
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig
from src.live_portrait_pipeline import LivePortraitPipeline
2024-07-07 13:46:28 +00:00
import requests
2024-07-06 11:13:39 +00:00
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
self.live_portrait_pipeline = LivePortraitPipeline(
inference_cfg=InferenceConfig(),
crop_cfg=CropConfig()
)
def predict(
self,
2024-07-07 13:46:28 +00:00
image: Path = Input(description="Portrait image")
2024-07-06 11:13:39 +00:00
) -> Path:
"""Run a single prediction on the model"""
2024-07-07 13:46:28 +00:00
2024-07-06 11:13:39 +00:00
video_path, _ = self.live_portrait_pipeline.execute(
2024-07-07 13:46:28 +00:00
ArgumentConfig(
source_image=image,
driving_info="assets/examples/driving/d0.mp4",
output_dir="/tmp/",
flag_pasteback=False,
flag_do_crop=False,
)
2024-07-06 11:13:39 +00:00
)
return Path(video_path)