Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ venv/
ENV/
env.bak/
venv.bak/
.env

# Spyder project settings
.spyderproject
Expand Down
99 changes: 92 additions & 7 deletions app/api/routers/generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from fastapi import APIRouter, Depends, Request, Body, Query
from fastapi.encoders import jsonable_encoder
from fastapi.responses import PlainTextResponse, StreamingResponse, JSONResponse
from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
from starlette.status import (
HTTP_200_OK,
HTTP_400_BAD_REQUEST,
HTTP_500_INTERNAL_SERVER_ERROR,
HTTP_404_NOT_FOUND,
)
from app.domain import (
Tags,
TagsGenerative,
Expand All @@ -35,6 +40,7 @@
PATH_CHAT_COMPLETIONS = "/v1/chat/completions"
PATH_COMPLETIONS = "/v1/completions"
PATH_EMBEDDINGS = "/v1/embeddings"
PATH_MODELS = "/v1/models"

router = APIRouter()
config = get_settings()
Expand Down Expand Up @@ -200,7 +206,12 @@ def generate_chat_completions(
max_tokens = request_data.max_tokens
temperature = request_data.temperature
top_p = request_data.top_p
stop_sequences = request_data.stop_sequences
if isinstance(request_data.stop, str):
stop_sequences = [request_data.stop]
elif isinstance(request_data.stop, list):
stop_sequences = request_data.stop
else:
stop_sequences = []
tracking_id = tracking_id or str(uuid.uuid4())

if not messages:
Expand Down Expand Up @@ -337,12 +348,11 @@ def generate_text_completions(
max_tokens = request_data.max_tokens
temperature = request_data.temperature
top_p = request_data.top_p
stop = request_data.stop

if isinstance(stop, str):
stop_sequences = [stop]
elif isinstance(stop, list):
stop_sequences = stop
if isinstance(request_data.stop, str):
stop_sequences = [request_data.stop]
elif isinstance(request_data.stop, list):
stop_sequences = request_data.stop
else:
stop_sequences = []

Expand Down Expand Up @@ -534,6 +544,81 @@ def embed_texts(
)


@router.get(
PATH_MODELS,
tags=[Tags.OpenAICompatible],
dependencies=[Depends(cms_globals.props.current_active_user)],
description="List available models, similar to OpenAI's /v1/models endpoint",
)
def list_models(
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
) -> JSONResponse:
"""
Lists all available models, mimicking OpenAI's /v1/models endpoint.

Args:
model_service (AbstractModelService): The model service dependency.

Returns:
JSONResponse: A response containing the list of models.
"""
response = {
"object": "list",
"data": [
{
"id": model_service.model_name.replace(" ", "_"),
"object": "model",
"created": 0,
"owned_by": "cms",
}
],
}
return JSONResponse(content=response)


@router.get(
PATH_MODELS + "/{model_name}",
tags=[Tags.OpenAICompatible],
dependencies=[Depends(cms_globals.props.current_active_user)],
description="Get a specific model, similar to OpenAI's /v1/models/{model_id} endpoint",
)
def get_model(
model_name: str,
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
) -> JSONResponse:
"""
Gets a specific model by ID, mimicking OpenAI's /v1/models/{model_id} endpoint.

Args:
model_name (str): The model name to retrieve.
model_service (AbstractModelService): The model service dependency.

Returns:
JSONResponse: A response containing the model details.
"""
if model_name != model_service.model_name.replace(" ", "_"):
error_response = {
"error": {
"message": f"The model `{model_name}` does not exist",
"type": "invalid_request_error",
"param": None,
"code": "model_not_found",
}
}
return JSONResponse(content=error_response, status_code=HTTP_404_NOT_FOUND
)
response = {
"id": model_name,
"object": "model",
"created": 0,
"owned_by": "cms",
"permission": [],
"root": model_name,
"parent": None,
}
return JSONResponse(content=response)


def _empty_prompt_error() -> Iterable[str]:
yield "ERROR: No prompt text provided\n"

Expand Down
1 change: 1 addition & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Settings(BaseSettings): # type: ignore
HF_PIPELINE_AGGREGATION_STRATEGY: str = "simple" # the strategy used for aggregating the predictions of the Hugging Face NER model
LOG_PER_CONCEPT_ACCURACIES: str = "false" # if "true", per-concept accuracies will be exposed to the metrics scrapper. Switch this on with caution due to the potentially high number of concepts
MEDCAT2_MAPPED_ONTOLOGIES: str = "" # the comma-separated names of ontologies for MedCAT2 to map to
ENABLE_SPDA_ATTN: str = "true" # if "true", attempt to use SPDA attention for HuggingFace LLM loading
DEBUG: str = "false" # if "true", the debug mode is switched on

class Config:
Expand Down
7 changes: 5 additions & 2 deletions app/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,10 @@ class OpenAIChatCompletionsRequest(BaseModel):
model: str = Field(..., description="The name of the model used for generating the completion")
temperature: float = Field(0.7, description="The temperature of the generated text", ge=0.0, le=1.0)
top_p: float = Field(0.9, description="The top-p value for nucleus sampling", ge=0.0, le=1.0)
stop_sequences: Optional[List[str]] = Field(default=None, description="The list of sequences used to stop the generation")
stop: Optional[Union[str, List[str]]] = Field(
default=None,
description="The single sequence or the list of sequences used to stop the generation",
)


class OpenAIChatCompletionsResponse(BaseModel):
Expand All @@ -242,7 +245,7 @@ class OpenAICompletionsRequest(BaseModel):
top_p: float = Field(0.9, description="The top-p value for nucleus sampling", ge=0.0, le=1.0)
stop: Optional[Union[str, List[str]]] = Field(
default=None,
description="The list of sequences used to stop the generation",
description="The single sequence or the list of sequences used to stop the generation",
)


Expand Down
3 changes: 3 additions & 0 deletions app/envs/.env
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,8 @@ TRAINING_HF_TAGGING_SCHEME=flat
# The comma-separated names of ontologies for MedCAT2 to map to
MEDCAT2_MAPPED_ONTOLOGIES=opcs4,icd10

# If "true", attempt to use SPDA attention for Hugging Face LLM loading
ENABLE_SPDA_ATTN=true

# If "true", the debug mode is switched on
DEBUG=false
6 changes: 3 additions & 3 deletions app/mcp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ cms mcp run --transport sse
"mcp-remote",
"http://127.0.0.1:8080/sse",
"--header",
"X-API-Key:${AUTH_HEADER}"
"AUTHORIZATION:${AUTH_HEADER}"
],
"env": {
"AUTH_HEADER": "Bearer <ACCESS_TOKEN>"
Expand Down Expand Up @@ -123,7 +123,7 @@ cms mcp run --transport sse
| `CMS_ACCESS_TOKEN` | Empty | Bearer token for ModelServe API |
| `CMS_API_KEY` | `Bearer` | API key for ModelServe API |
| `CMS_MCP_API_KEYS` | None | Comma-separated API keys for authentication |
| `CMS_MCP_OAUTH_ENABLED` | `false` | Enable OAuth authentication |
| `CMS_MCP_OAUTH_PROVIDER` | Empty | Enable OAuth authentication if set to "github" or "google" |
| `CMS_MCP_BASE_URL` | `http://<host>:<port>` | Base URL for OAuth callback |
| `CMS_MCP_DEV` | `0` | Run in development mode |

Expand All @@ -137,7 +137,7 @@ When `CMS_MCP_API_KEYS` is set, clients must authenticate using:
- **Header**: `X-API-Key: your-key`

### 2. OAuth Authentication (SSE Transport)
When `CMS_MCP_OAUTH_ENABLED=true`, the server provides a built-in OAuth 2.0 login flow for SSE transport.
When `CMS_MCP_OAUTH_PROVIDER` is set, the server provides a built-in OAuth 2.0 login flow for SSE transport.

**OAuth Endpoints:**
- `/oauth/login` - Login page with Google and GitHub options
Expand Down
Loading
Loading