add docstrings, refactor logger

This commit is contained in:
hiyouga
2024-09-08 00:56:56 +08:00
parent 8eac1b929f
commit 54c6905937
30 changed files with 334 additions and 57 deletions

View File

@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from contextlib import asynccontextmanager
from functools import partial
from typing import Optional
from typing_extensions import Annotated
@@ -50,15 +52,24 @@ if is_uvicorn_available():
import uvicorn
async def sweeper() -> None:
while True:
torch_gc()
await asyncio.sleep(300)
@asynccontextmanager
async def lifespan(app: "FastAPI"): # collects GPU memory
async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
if chat_model.engine_type == "huggingface":
asyncio.create_task(sweeper())
yield
torch_gc()
def create_app(chat_model: "ChatModel") -> "FastAPI":
root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
app = FastAPI(lifespan=lifespan, root_path=root_path)
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
@@ -66,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_methods=["*"],
allow_headers=["*"],
)
api_key = os.environ.get("API_KEY")
api_key = os.environ.get("API_KEY", None)
security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
@@ -80,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
dependencies=[Depends(verify_api_key)],
)
async def list_models():
model_card = ModelCard(id="gpt-3.5-turbo")
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
return ModelList(data=[model_card])
@app.post(