mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2025-03-14 21:22:43 +00:00
add server service
This commit is contained in:
parent
50fa540d9f
commit
c2c099f541
41
server/Dockerfile
Normal file
41
server/Dockerfile
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
ARG CUDA_VERSION=12.4.1
|
||||||
|
# **NOTE**: For 4090, change to 12.1.0
|
||||||
|
# For H100, change to 12.4.1
|
||||||
|
|
||||||
|
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
|
||||||
|
|
||||||
|
LABEL authors="luchentech"
|
||||||
|
|
||||||
|
# ====== Install base dependencies ===================
|
||||||
|
ARG PYTHON_VERSION=3.10
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
RUN apt-get update -y && \
|
||||||
|
apt-get install -y software-properties-common && \
|
||||||
|
add-apt-repository ppa:deadsnakes/ppa && \
|
||||||
|
apt-get update -y && \
|
||||||
|
apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv && \
|
||||||
|
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 && \
|
||||||
|
update-alternatives --install /usr/bin/python python /usr/bin/python${PYTHON_VERSION} 1
|
||||||
|
|
||||||
|
# install libgl1 for import error described in https://stackoverflow.com/questions/55313610/
|
||||||
|
# Alternative solution: pip install opencv-python-headless
|
||||||
|
RUN apt-get update -y && \
|
||||||
|
apt-get install -y python3-pip git curl ffmpeg && \
|
||||||
|
apt-get clean && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
RUN python3 --version && python --version
|
||||||
|
|
||||||
|
# Install pip s.t. it will be compatible with our PYTHON_VERSION
|
||||||
|
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION}
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
RUN python3 -m pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121
|
||||||
|
|
||||||
|
RUN git clone https://github.com/KwaiVGI/LivePortrait.git && \
|
||||||
|
cd LivePortrait && \
|
||||||
|
python3 -m pip install -r requirements.txt
|
||||||
|
|
||||||
|
RUN python3 -m pip install --no-cache-dir onnxruntime-gpu==1.17.1 --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple
|
12
server/client.py
Normal file
12
server/client.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
mapping = {
|
||||||
|
"src_key": "../assets/examples/source/s5.jpg",
|
||||||
|
"driving_key": "../assets/examples/driving/d13.mp4",
|
||||||
|
}
|
||||||
|
|
||||||
|
import requests, json
|
||||||
|
|
||||||
|
base_url = "http://localhost:8005"
|
||||||
|
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||||
|
|
||||||
|
response = requests.post(f"{base_url}/submit", headers=headers, json=mapping)
|
||||||
|
print(response)
|
190
server/consumer.py
Normal file
190
server/consumer.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import redis
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import sys
|
||||||
|
|
||||||
|
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
sys.path.insert(0, parent_dir)
|
||||||
|
|
||||||
|
from src.config.argument_config import ArgumentConfig
|
||||||
|
from src.config.inference_config import InferenceConfig
|
||||||
|
from src.config.crop_config import CropConfig
|
||||||
|
from src.live_portrait_pipeline import LivePortraitPipeline
|
||||||
|
|
||||||
|
OUTPUT_LOCAL_PATH = os.getenv("OUTPUT_LOCAL_PATH", "./tmp/outputs/")
|
||||||
|
|
||||||
|
|
||||||
|
# aioredis
|
||||||
|
# redis
|
||||||
|
|
||||||
|
|
||||||
|
import tyro
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
TASK_PREFIX = "lp-task-" # Prefix for task keys, used for retrieving data/status
|
||||||
|
TASK_TAG = "lp"
|
||||||
|
STREAM_NAME_PREFIX = os.getenv("REDIS_STREAM_PREFIX", "task_stream") + "_"
|
||||||
|
CONSUMER_GROUP_PREFIX = os.getenv("REDIS_GROUP_PREFIX", "task_group") + "_"
|
||||||
|
|
||||||
|
LOCK_KEY_PREFIX = os.getenv("LOCK_KEY_PREFIX", "lock_key") + "_" + TASK_TAG
|
||||||
|
CONSUMER_NAME = os.getenv("HOSTNAME", "test_consumer_lp")
|
||||||
|
|
||||||
|
live_portrait_engine = None
|
||||||
|
r = None
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="LivePortrait Redis Task Worker")
|
||||||
|
parser.add_argument("--redis-host", type=str, default="43.156.39.249", help="Redis host")
|
||||||
|
parser.add_argument("--redis-port", type=int, default=31317, help="Redis port")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def partial_fields(target_class, kwargs):
|
||||||
|
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
||||||
|
|
||||||
|
|
||||||
|
def init_live_portrait_pipeline():
|
||||||
|
# fast_check_args(args)
|
||||||
|
|
||||||
|
# specify configs for inference
|
||||||
|
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
|
||||||
|
crop_cfg = partial_fields(CropConfig, args.__dict__)
|
||||||
|
|
||||||
|
global live_portrait_pipeline
|
||||||
|
live_portrait_pipeline = LivePortraitPipeline(
|
||||||
|
inference_cfg=inference_cfg,
|
||||||
|
crop_cfg=crop_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
return live_portrait_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def task_worker(
|
||||||
|
stream_name: str,
|
||||||
|
consumer_group: str,
|
||||||
|
consumer_name: str,
|
||||||
|
task_prefix: str = "lp-task-",
|
||||||
|
lock_key_prefix: str = "lock_key_lp",
|
||||||
|
lock_timeout: int = 30,
|
||||||
|
):
|
||||||
|
logger.info(f"Starting task worker {consumer_name} in {consumer_group} for stream {stream_name}...")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
logger.info(f"Listening for tasks in {stream_name}...")
|
||||||
|
|
||||||
|
tasks = r.xreadgroup(consumer_group, consumer_name, {stream_name: ">"}, count=1, block=0)
|
||||||
|
|
||||||
|
if tasks:
|
||||||
|
for stream_name, task_list in tasks:
|
||||||
|
for task_id, task_data in task_list:
|
||||||
|
|
||||||
|
serving_task_id = task_data.get("request_id")
|
||||||
|
lock_key = lock_key_prefix + serving_task_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 目前部署的redis - keydb是集群模式兼容单服务,存在多次消费,需要加锁
|
||||||
|
# Try to acquire the lock
|
||||||
|
lock_acquired = r.set(lock_key, consumer_name, nx=True, ex=lock_timeout)
|
||||||
|
if not lock_acquired:
|
||||||
|
logger.info(f"Failed to acquire lock {lock_key}. Skipping...")
|
||||||
|
continue
|
||||||
|
logger.info(
|
||||||
|
f"Acquired lock {lock_key} (expiring in {lock_timeout} secs), processing tasks from stream {stream_name}"
|
||||||
|
f"on {r.connection_pool.connection_kwargs['host']}:{r.connection_pool.connection_kwargs['port']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
occupied_by_pod = r.hmget(task_prefix + serving_task_id, "occupied_by_pod")
|
||||||
|
if not occupied_by_pod:
|
||||||
|
logger.warning(
|
||||||
|
f"Task occupied_by_pod of {serving_task_id} not found. Processing the task anyway."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
occupied_by_pod = occupied_by_pod[0]
|
||||||
|
if occupied_by_pod and occupied_by_pod != consumer_name:
|
||||||
|
logger.error(
|
||||||
|
f"Task {serving_task_id} is already being processed by another consumer: {occupied_by_pod}. Skipping..."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
# NOTE Important: mark the task as occupied by the current consumer
|
||||||
|
r.hmset(task_prefix + serving_task_id, {"occupied_by_pod": consumer_name})
|
||||||
|
logger.info(f"Set task {serving_task_id} as occupied by {consumer_name}")
|
||||||
|
|
||||||
|
task_payload = r.hgetall(task_prefix + serving_task_id)
|
||||||
|
logger.info(f"Processing task {serving_task_id}, task payload: {task_payload}")
|
||||||
|
if not task_payload:
|
||||||
|
raise RuntimeError(f"Task payload of {serving_task_id} not found")
|
||||||
|
|
||||||
|
# # decoded_data = {key.decode('utf-8'): value.decode('utf-8') for key, value in task_payload.items()}
|
||||||
|
# task_payload = task_payload.get("payload") # actual payload from the request
|
||||||
|
# # res = await process_task(task_payload)
|
||||||
|
# # if not res:
|
||||||
|
# # raise RuntimeError(f"Error processing task {serving_task_id}")
|
||||||
|
# logger.info(f"Processing task {serving_task_id} with payload {task_payload}...")
|
||||||
|
args = tyro.cli(ArgumentConfig)
|
||||||
|
args.source = task_payload.get("src_key")
|
||||||
|
args.driving = task_payload.get("driving_key")
|
||||||
|
args.output_dir = OUTPUT_LOCAL_PATH
|
||||||
|
wfp, wfp_concat = live_portrait_pipeline.execute(args)
|
||||||
|
print(wfp, wfp_concat)
|
||||||
|
live_portrait_engine
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{e}")
|
||||||
|
# Callback with failed status
|
||||||
|
finally:
|
||||||
|
# NOTE: messages are always acknowledged and deleted after the task finished or failed;
|
||||||
|
# We won't add back the task to the stream if it's failed
|
||||||
|
r.xack(stream_name, consumer_group, task_id)
|
||||||
|
r.xdel(stream_name, task_id)
|
||||||
|
|
||||||
|
if r.get(lock_key) == consumer_name:
|
||||||
|
r.delete(lock_key)
|
||||||
|
logger.info(f"Released lock {lock_key} by consumer {consumer_name}")
|
||||||
|
|
||||||
|
time.sleep(1.0)
|
||||||
|
|
||||||
|
def run_worker(args):
|
||||||
|
|
||||||
|
global live_portrait_engine, r
|
||||||
|
|
||||||
|
redis_host, redis_port = args.redis_host, args.redis_port
|
||||||
|
# task_group_lp
|
||||||
|
consumer_group = CONSUMER_GROUP_PREFIX + TASK_TAG
|
||||||
|
# task_stream_lp
|
||||||
|
stream_name = STREAM_NAME_PREFIX + TASK_TAG
|
||||||
|
|
||||||
|
r = redis.Redis(host=redis_host, port=redis_port, decode_responses=True)
|
||||||
|
logger.info(f"Creating consumer group {consumer_group} for stream {stream_name}")
|
||||||
|
try:
|
||||||
|
r.xgroup_create(stream_name, consumer_group, id="0", mkstream=True)
|
||||||
|
except redis.exceptions.ResponseError as e:
|
||||||
|
if "BUSYGROUP" in str(e):
|
||||||
|
logger.warning(f"Consumer group {consumer_group} already exists")
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
logger.info("Initializing LivePortrait engine...")
|
||||||
|
live_portrait_engine = init_live_portrait_pipeline()
|
||||||
|
logger.info("LivePortrait engine initialized")
|
||||||
|
|
||||||
|
task_worker(
|
||||||
|
stream_name=stream_name,
|
||||||
|
consumer_group=consumer_group,
|
||||||
|
consumer_name=CONSUMER_NAME,
|
||||||
|
task_prefix=TASK_PREFIX,
|
||||||
|
lock_key_prefix=LOCK_KEY_PREFIX)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.run_until_complete(run_worker(args))
|
114
server/proxy_server.py
Normal file
114
server/proxy_server.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
from fastapi import FastAPI, Depends
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
import aioredis
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, root_validator
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
|
TASK_PREFIX = "lp-task-" # Prefix for task keys, used for retrieving data/status
|
||||||
|
TASK_TAG = "lp"
|
||||||
|
STREAM_NAME_PREFIX = os.getenv("REDIS_STREAM_PREFIX", "task_stream") + "_"
|
||||||
|
CONSUMER_GROUP_PREFIX = os.getenv("REDIS_GROUP_PREFIX", "task_group") + "_"
|
||||||
|
TASK_STATUS_EXPIRE = int(os.getenv("TASK_STATUS_EXPIRE", 43200)) # 12 hours by default
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
r = None
|
||||||
|
|
||||||
|
async def _init_consumer_group(stream_name: str, group_name: str):
|
||||||
|
try:
|
||||||
|
await r.xgroup_create(stream_name, group_name, id="0", mkstream=True)
|
||||||
|
except aioredis.exceptions.ResponseError as e:
|
||||||
|
if "BUSYGROUP" in str(e):
|
||||||
|
logger.warning(f"Consumer group {group_name} already exists in {stream_name}")
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
global r
|
||||||
|
r = aioredis.Redis(host=args.redis_host, port=args.redis_port, decode_responses=True)
|
||||||
|
|
||||||
|
await _init_consumer_group(STREAM_NAME_PREFIX + TASK_TAG, CONSUMER_GROUP_PREFIX + TASK_TAG)
|
||||||
|
|
||||||
|
yield
|
||||||
|
await r.close()
|
||||||
|
logger.info("Redis connection closed")
|
||||||
|
logger.info("Proxy server shut down")
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
class LivePortraitRequestProto(BaseModel):
|
||||||
|
'''Define the request schema for the Live Portrait API.
|
||||||
|
Live Portrait needs two inputs to generate the final output:
|
||||||
|
1. The source image/video (src_key or src_local_path)
|
||||||
|
2. The driving image/video (driving_key or driving_local_path)
|
||||||
|
This server will process source input according to the driving input.
|
||||||
|
Some characteristics of the driving input will be transferred and applied to the source input.
|
||||||
|
'''
|
||||||
|
request_id: str = Field('test_request_id_0', description="Unique request ID", min_length=1)
|
||||||
|
src_key: Optional[str] = Field(None, description="Object key of the source image/video from COS", min_length=1)
|
||||||
|
driving_key: Optional[str] = Field(None, description="Object key of the driving image/video from COS", min_length=1)
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def check_either_image_key_or_local_path(cls, values):
|
||||||
|
src_key = values.get("src_key")
|
||||||
|
driving_key = values.get("driving_key")
|
||||||
|
if not src_key or not driving_key:
|
||||||
|
raise ValueError("The 'src_key' and 'driving_key' must be provided")
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/submit")
|
||||||
|
async def submit(req: LivePortraitRequestProto):
|
||||||
|
mapping = {
|
||||||
|
"src_key": req.src_key,
|
||||||
|
"driving_key": req.driving_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Task data/status must be set first before adding to the queue
|
||||||
|
task_name = TASK_PREFIX + req.request_id
|
||||||
|
|
||||||
|
logger.info(f"Recving task {task_name} with payload {mapping}")
|
||||||
|
|
||||||
|
n = await r.hset(task_name, mapping=mapping)
|
||||||
|
await r.expire(task_name, TASK_STATUS_EXPIRE)
|
||||||
|
logger.info(f"{n} fields successfully added for {task_name} (expire in {TASK_STATUS_EXPIRE} secs)")
|
||||||
|
|
||||||
|
# Example: task_stream_lp
|
||||||
|
queue_name = STREAM_NAME_PREFIX + TASK_TAG
|
||||||
|
redis_task_id = await r.xadd(queue_name, {"request_id": req.request_id})
|
||||||
|
logger.info(f"Redis task {redis_task_id} added to Redis Stream {queue_name}")
|
||||||
|
|
||||||
|
# trim the Redis stream
|
||||||
|
await r.xtrim(queue_name, maxlen=2000, approximate=True)
|
||||||
|
|
||||||
|
return JSONResponse(content={"request_id": req.request_id, "message_id": redis_task_id})
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Proxy server for OpenSora serving")
|
||||||
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
parser.add_argument("--redis-host", type=str, default="43.156.39.249", help="Redis host")
|
||||||
|
parser.add_argument("--redis-port", type=int, default=31317, help="Redis port")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
uvicorn.run(app, host=args.host, port=args.port)
|
||||||
|
logger.info(f"Proxy server is running on {args.host}:{args.port})")
|
2
server/requirements.txt
Normal file
2
server/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
aioredis
|
||||||
|
redis
|
156
server/server.py
Normal file
156
server/server.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
import os
|
||||||
|
from fastapi import FastAPI, Depends
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from src.config.argument_config import ArgumentConfig
|
||||||
|
from src.config.inference_config import InferenceConfig
|
||||||
|
from src.config.crop_config import CropConfig
|
||||||
|
from src.live_portrait_pipeline import LivePortraitPipeline
|
||||||
|
from pydantic import BaseModel, Field, root_validator
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
|
import tyro
|
||||||
|
|
||||||
|
# Initialize logger
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# NOTE: For now, we use an API key to authenticate the requests.
|
||||||
|
# This method should only be used for testing purposes.
|
||||||
|
# For actual deployment, we should use OAuth2 or other more secure methods.
|
||||||
|
header_scheme = APIKeyHeader(name="X-API-Key")
|
||||||
|
SERVING_API_KEY = os.getenv("SERVING_API_KEY", "oLjQD5hWDYN5DeAQ4cx5CL3vJYOTXf0c")
|
||||||
|
OUTPUT_LOCAL_PATH = os.getenv("OUTPUT_LOCAL_PATH", "/tmp/outputs/")
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
live_portrait_pipeline = None
|
||||||
|
|
||||||
|
|
||||||
|
class LivePortraitRequestProto(BaseModel):
|
||||||
|
'''Define the request schema for the Live Portrait API.
|
||||||
|
Live Portrait needs two inputs to generate the final output:
|
||||||
|
1. The source image/video (src_key or src_local_path)
|
||||||
|
2. The driving image/video (driving_key or driving_local_path)
|
||||||
|
This server will process source input according to the driving input.
|
||||||
|
Some characteristics of the driving input will be transferred and applied to the source input.
|
||||||
|
'''
|
||||||
|
src_key: Optional[str] = Field(None, description="Object key of the source image/video from COS", min_length=1)
|
||||||
|
src_local_path: Optional[str] = Field(
|
||||||
|
None, description="Local file path of the source image/video to be processed on the server", min_length=1
|
||||||
|
)
|
||||||
|
|
||||||
|
driving_key: Optional[str] = Field(None, description="Object key of the driving image/video from COS", min_length=1)
|
||||||
|
driving_local_path: Optional[str] = Field(
|
||||||
|
None, description="Local file path of the driving image/video to be processed on the server", min_length=1
|
||||||
|
)
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def check_either_image_key_or_local_path(cls, values):
|
||||||
|
src_key = values.get("src_key")
|
||||||
|
src_local_path = values.get("src_local_path")
|
||||||
|
driving_key = values.get("driving_key")
|
||||||
|
driving_local_path = values.get("driving_local_path")
|
||||||
|
if (src_key and src_local_path) or (not src_key and not src_local_path):
|
||||||
|
raise ValueError("Either 'src_key' or 'src_local_path' must be provided, but not both.")
|
||||||
|
if (driving_key and driving_local_path) or (not driving_key and not driving_local_path):
|
||||||
|
raise ValueError("Either 'driving_key' or 'driving_local_path' must be provided, but not both.")
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def init_live_portrait_pipeline():
|
||||||
|
# fast_check_args(args)
|
||||||
|
|
||||||
|
# specify configs for inference
|
||||||
|
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
|
||||||
|
crop_cfg = partial_fields(CropConfig, args.__dict__)
|
||||||
|
|
||||||
|
global live_portrait_pipeline
|
||||||
|
live_portrait_pipeline = LivePortraitPipeline(
|
||||||
|
inference_cfg=inference_cfg,
|
||||||
|
crop_cfg=crop_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def partial_fields(target_class, kwargs):
|
||||||
|
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
||||||
|
|
||||||
|
|
||||||
|
def extract_inputs_from_request(request: LivePortraitRequestProto) -> Tuple[str, str]:
|
||||||
|
|
||||||
|
src_input = None
|
||||||
|
driving_input = None
|
||||||
|
|
||||||
|
# 1. Download the image/video from COS to local storage
|
||||||
|
if request.src_key:
|
||||||
|
storage_dir = "/tmp/sources"
|
||||||
|
filename = os.path.basename(request.src_key)
|
||||||
|
|
||||||
|
if request.driving_key:
|
||||||
|
storage_dir = "/tmp/driving"
|
||||||
|
filename = os.path.basename(request.driving_key)
|
||||||
|
|
||||||
|
os.makedirs(storage_path, exist_ok=True) # Ensure the directory exists
|
||||||
|
storage_path = os.path.join(storage_dir, filename)
|
||||||
|
# Download the image/video from COS
|
||||||
|
# download_success = await cos_download_file(request.src_key, src_local_path)
|
||||||
|
download_success = True
|
||||||
|
if not download_success:
|
||||||
|
logger.error(f"Failed to download image: {request.image_key}")
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
|
# 2. Process the local image/video
|
||||||
|
# This step is just for local tests
|
||||||
|
if request.src_local_path:
|
||||||
|
logger.warning("Using local files is not intended for deployment. This is for testing purposes only.")
|
||||||
|
local_image_path = request.local_path
|
||||||
|
if not os.path.exists(local_image_path):
|
||||||
|
logger.error(f"Local image path does not exist: {local_image_path}")
|
||||||
|
return JSONResponse({"message": "Local image path does not exist", "faces_detected": 0}, status_code=500)
|
||||||
|
|
||||||
|
if request.driving_local_path:
|
||||||
|
logger.warning("Using local files is not intended for deployment. This is for testing purposes only.")
|
||||||
|
local_image_path = request.local_path
|
||||||
|
if not os.path.exists(local_image_path):
|
||||||
|
logger.error(f"Local image path does not exist: {local_image_path}")
|
||||||
|
return JSONResponse({"message": "Local image path does not exist", "faces_detected": 0}, status_code=500)
|
||||||
|
|
||||||
|
@app.post("/live-portrait")
|
||||||
|
def live_portrait(request: LivePortraitRequestProto, api_key: str = Depends(header_scheme)):
|
||||||
|
if live_portrait_pipeline is None:
|
||||||
|
return JSONResponse({"message": "Server not ready", "output_path": ''}, status_code=503)
|
||||||
|
if api_key != SERVING_API_KEY:
|
||||||
|
return JSONResponse({"message": "Invalid API key", "output_path": ''}, status_code=401)
|
||||||
|
|
||||||
|
src_input = None
|
||||||
|
driving_input = None
|
||||||
|
|
||||||
|
src_input, driving_input = extract_inputs_from_request(request)
|
||||||
|
if src_input is None or driving_input is None:
|
||||||
|
return JSONResponse({"message": "Failed to process inputs", "output_path": ''}, status_code=500)
|
||||||
|
|
||||||
|
args = tyro.cli(ArgumentConfig)
|
||||||
|
args.source = src_input
|
||||||
|
args.driving = driving_input
|
||||||
|
args.output_dir = OUTPUT_LOCAL_PATH
|
||||||
|
wfp, wfp_concat = live_portrait_pipeline.execute(src_input, driving_input)
|
||||||
|
print(wfp, wfp_concat)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Live Portrait Server")
|
||||||
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
|
||||||
|
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# init_cos_client()
|
||||||
|
# logger.info("COS client initialized")
|
||||||
|
|
||||||
|
init_live_portrait_pipeline()
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host=args.host, port=args.port)
|
Loading…
Reference in New Issue
Block a user