diff --git a/medcat-trainer/docs/installation.md b/medcat-trainer/docs/installation.md index 47fbed688..c97975135 100644 --- a/medcat-trainer/docs/installation.md +++ b/medcat-trainer/docs/installation.md @@ -74,6 +74,7 @@ Host-level Compose variables (for example port overrides) can be set by copying | `SOLR_PORT` | Host port for Solr admin (default `8983`). | | `MEDCAT_CONFIG_FILE` | MedCAT config file path inside the container. | | `LOAD_EXAMPLES` | Load example model pack + dataset + project on startup (`1`/`0`). | +| `PROVISIONING_CONFIG_PATH` | File path of a yaml defining projects to create on startup | | `REMOTE_MODEL_SERVICE_TIMEOUT` | Timeout (seconds) for remote model-service calls. | | `MCTRAINER_BOOTSTRAP_ADMIN_USERNAME` | Bootstrap admin username (default `admin`). | | `MCTRAINER_BOOTSTRAP_ADMIN_EMAIL` | Bootstrap admin email. | diff --git a/medcat-trainer/docs/provisioning.md b/medcat-trainer/docs/provisioning.md new file mode 100644 index 000000000..1ff9433f4 --- /dev/null +++ b/medcat-trainer/docs/provisioning.md @@ -0,0 +1,54 @@ +# Provisioning guide + +On startup, MedCAT Trainer can create example projects, datasets, and (optionally) model packs from a YAML config. The provisioner runs after the API is up. + +!!! warning + Provisioning only takes place if there are no preexisting projects/datasets/model packs. + +## Environment variables + +| Variable | Required | Description | +|----------|----------|-------------| +| `LOAD_EXAMPLES` | Yes (to enable) | Set to true to provision on startup | +| `API_URL` | No | Base API URL (e.g. `http://localhost:8001/api/`). Default: `http://localhost:8001/api/`. | +| `PROVISIONING_CONFIG_PATH` | No | Path to the provisioning YAML file. Default: `scripts/provisioning/example_projects.provisioning.yaml`. | + + +## YAML format + +Top-level key is `projects`, a list of project specs. Each item is either a model-pack project or a remote-model-service project. + +### Option 1: Model pack (upload a .zip) + +```yaml +projects: + - modelPack: + name: "Example Model Pack" + url: "https://example.com/path/to/model_pack.zip" + dataset: + name: "My Dataset" + url: "https://example.com/dataset.csv" + description: "Short description of the dataset" + project: + name: "Example Project" + description: "Project description" + annotationGuidelineLink: "https://example.com/guidelines" +``` + +### Option 2: Remote MedCAT service (no model pack) + +Use a remote MedCAT service API for document processing instead of uploading a model pack. Set `useModelService` and `modelServiceUrl` on the **project** object; do **not** set `modelPack` on the spec. + +```yaml +projects: + - dataset: + name: "My Dataset" + url: "https://example.com/dataset.csv" + description: "Short description" + project: + name: "Example Project - Remote" + description: "Uses remote MedCAT service" + annotationGuidelineLink: "https://example.com/guidelines" + useModelService: true + modelServiceUrl: "http://medcat-service:8000" +``` diff --git a/medcat-trainer/envs/env b/medcat-trainer/envs/env index 57dcfb08f..3206d1c42 100644 --- a/medcat-trainer/envs/env +++ b/medcat-trainer/envs/env @@ -18,7 +18,7 @@ DEBUG=1 ### Load example CDB, Vocab ### LOAD_EXAMPLES=1 # URL that examples will be sent to -API_URL=http://localhost:8001/api/ +API_URL=http://localhost:8000/api/ ### Dataset conf ### UNIQUE_DOC_NAMES_IN_DATASETS=True diff --git a/medcat-trainer/mkdocs.yml b/medcat-trainer/mkdocs.yml index feb812191..6846ea62d 100644 --- a/medcat-trainer/mkdocs.yml +++ b/medcat-trainer/mkdocs.yml @@ -46,6 +46,7 @@ nav: - Reference: - Advanced usage: advanced_usage.md - Maintenance: maintenance.md + - Provisioning: provisioning.md - Client API: client.md plugins: diff --git a/medcat-trainer/webapp/api/api/tests/test_load_examples.py b/medcat-trainer/webapp/api/api/tests/test_load_examples.py new file mode 100644 index 000000000..9d4eeec09 --- /dev/null +++ b/medcat-trainer/webapp/api/api/tests/test_load_examples.py @@ -0,0 +1,214 @@ +import os +import sys +import tempfile +from contextlib import contextmanager +from pathlib import Path +import requests +from django.contrib.auth.models import User +from django.test import LiveServerTestCase, TestCase + +# Allow importing webapp/scripts +WEBAPP_DIR = Path(__file__).resolve().parents[2].parent # api/tests -> api -> api -> webapp +if str(WEBAPP_DIR) not in sys.path: + sys.path.insert(0, str(WEBAPP_DIR)) + +# GitHub permalinks for test data (raw content). During CI this test runs in a docker container, so doesnt have access to these files. +CARDIO_CSV_URL = "https://raw.githubusercontent.com/CogStack/cogstack-nlp/051edf6cbd94fa83436fab807aff49d78dd68e59/medcat-trainer/notebook_docs/example_data/cardio.csv" +MODEL_PACK_ZIP_URL = "https://raw.githubusercontent.com/CogStack/cogstack-nlp/051edf6cbd94fa83436fab807aff49d78dd68e59/medcat-service/models/examples/example-medcat-v2-model-pack.zip" + +from scripts.load_examples import main, run_provisioning # noqa: E402 +from scripts.provisioning.model import ( # noqa: E402 + DatasetSpec, + ModelPackSpec, + ProjectSpec, + ProvisioningConfig, + ProvisioningProjectSpec, +) + + +def get_medcat_trainer_token(api_url: str, username: str = "admin", password: str = "admin") -> str: + """Get a DRF token for the MedCAT trainer API.""" + resp = requests.post( + f"{api_url}api-token-auth/", + json={"username": username, "password": password}, + ) + resp.raise_for_status() + return resp.json()["token"] + + +def get_project_list(api_url: str) -> list[dict]: + """Return list of projects from project-annotate-entities.""" + token = get_medcat_trainer_token(api_url) + resp = requests.get( + f"{api_url}project-annotate-entities/", + headers={"Authorization": f"Token {token}"}, + ) + resp.raise_for_status() + return resp.json()["results"] + + +@contextmanager +def provisioning_temp_files(): + """Yield (model_pack_path, dataset_path) and unlink both on exit.""" + mp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False) + mp.close() + ds = tempfile.NamedTemporaryFile(suffix=".csv", delete=False) + ds.close() + try: + yield mp.name, ds.name + finally: + Path(mp.name).unlink(missing_ok=True) + Path(ds.name).unlink(missing_ok=True) + + +@contextmanager +def env_set(**kwargs: str): + """Set os.environ keys; restore previous values on exit.""" + orig = {k: os.environ.get(k) for k in kwargs} + try: + for k, v in kwargs.items(): + os.environ[k] = v + yield + finally: + for k in orig: + prev = orig[k] + if prev is None: + os.environ.pop(k, None) + else: + os.environ[k] = prev + + +class LoadExamplesTestCase(TestCase): + """Minimal test that load_examples.main can be imported and run.""" + + def test_main_returns_when_load_examples_disabled(self): + with env_set(LOAD_EXAMPLES="0"): + main() + + +class LoadExamplesLiveAPITestCase(LiveServerTestCase): + """ + Run the live server and call load_examples.main against it. + Sets API_URL to self.live_server_url + '/api/' so the script hits this test's server. + """ + + def setUp(self): + super().setUp() + User.objects.create_user(username="admin", password="admin", is_staff=True) + + def test_main_calls_live_api(self): + api_url = self.live_server_url + "/api/" + # Use a temp YAML that points at GitHub permalinks so main() downloads without mocking + config = ProvisioningConfig( + projects=[ + ProvisioningProjectSpec( + model_pack=ModelPackSpec(name="Example Model Pack", url=MODEL_PACK_ZIP_URL), + dataset=DatasetSpec( + name="M-IV_NeuroNotes", + url=CARDIO_CSV_URL, + description="Clinical texts from MIMIC-IV", + ), + project=ProjectSpec( + name="Example Project - Model Pack (Diseases / Symptoms / Findings)", + description="Example project", + annotation_guideline_link="https://example.com/guide", + ), + ), + ], + ) + spec = config.projects[0] + assert spec.model_pack is not None + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + # Write YAML with GitHub permalink URLs (camelCase keys) + f.write("projects:\n") + f.write(" - modelPack:\n") + f.write(f' name: "{spec.model_pack.name}"\n') + f.write(f' url: "{MODEL_PACK_ZIP_URL}"\n') + f.write(" dataset:\n") + f.write(f' name: "{spec.dataset.name}"\n') + f.write(f' url: "{CARDIO_CSV_URL}"\n') + f.write(f' description: "{spec.dataset.description}"\n') + f.write(" project:\n") + f.write(f' name: "{spec.project.name}"\n') + f.write(f' description: "{spec.project.description}"\n') + f.write(f' annotationGuidelineLink: "{spec.project.annotation_guideline_link}"\n') + config_path = f.name + try: + with env_set(API_URL=api_url, LOAD_EXAMPLES="1", PROVISIONING_CONFIG_PATH=config_path): + with provisioning_temp_files() as (mp_path, ds_path): + main(model_pack_tmp_file=mp_path, dataset_tmp_file=ds_path) + + projects = get_project_list(api_url) + self.assertIn( + spec.project.name, + [p["name"] for p in projects], + f"Project list: {[p['name'] for p in projects]}", + ) + finally: + Path(config_path).unlink(missing_ok=True) + + +def _spec_with_model_pack(project_name: str, model_pack_url: str, dataset_url: str) -> ProvisioningProjectSpec: + return ProvisioningProjectSpec( + model_pack=ModelPackSpec(name="Test Model Pack", url=model_pack_url), + dataset=DatasetSpec(name="TestDataset", url=dataset_url, description="Test dataset"), + project=ProjectSpec( + name=project_name, + description="Created from unit test (model pack).", + annotation_guideline_link="https://example.com/guide", + ), + ) + + +def _spec_with_remote_service(project_name: str, model_service_url: str, dataset_url: str) -> ProvisioningProjectSpec: + return ProvisioningProjectSpec( + dataset=DatasetSpec(name="RemoteDataset", url=dataset_url, description="Dataset for remote model test"), + project=ProjectSpec( + name=project_name, + description="Created from unit test (remote model service).", + annotation_guideline_link="https://example.com/guide", + use_model_service=True, + model_service_url=model_service_url, + ), + ) + + +class RunProvisioningWithConfigTestCase(LiveServerTestCase): + """ + Tests that call run_provisioning() with a programmatic ProvisioningConfig + (no YAML file). Use the live server and mock only external HTTP (S3/dataset URLs). + """ + + def setUp(self): + super().setUp() + User.objects.create_user(username="admin", password="admin", is_staff=True) + + def test_run_provisioning_with_model_pack_creates_project(self): + """ProvisioningConfig with model pack: download from GitHub permalinks, assert project is created.""" + api_url = self.live_server_url + "/api/" + project_name = "Unit Test Project (Model Pack)" + + config = ProvisioningConfig(projects=[_spec_with_model_pack(project_name, MODEL_PACK_ZIP_URL, CARDIO_CSV_URL)]) + with provisioning_temp_files() as (mp_path, ds_path): + run_provisioning(config, api_url, model_pack_tmp_file=mp_path, dataset_tmp_file=ds_path) + + projects = get_project_list(api_url) + self.assertIn(project_name, [p["name"] for p in projects], f"Project list: {[p['name'] for p in projects]}") + + def test_run_provisioning_with_model_service_url_creates_project(self): + """ProvisioningConfig with use_model_service=True: dataset from GitHub permalink, assert project is created.""" + api_url = self.live_server_url + "/api/" + project_name = "Unit Test Project (Remote Model Service)" + model_service_url = "http://medcat-service:8000" + + config = ProvisioningConfig( + projects=[_spec_with_remote_service(project_name, model_service_url, CARDIO_CSV_URL)] + ) + with provisioning_temp_files() as (mp_path, ds_path): + run_provisioning(config, api_url, model_pack_tmp_file=mp_path, dataset_tmp_file=ds_path) + + projects = get_project_list(api_url) + self.assertIn(project_name, [p["name"] for p in projects], f"Project list: {[p['name'] for p in projects]}") + created = next(p for p in projects if p["name"] == project_name) + self.assertTrue(created.get("use_model_service"), "Project should have use_model_service=True") + self.assertEqual(created.get("model_service_url"), model_service_url) diff --git a/medcat-trainer/webapp/scripts/load_examples.py b/medcat-trainer/webapp/scripts/load_examples.py index df2b04657..10607ce7b 100644 --- a/medcat-trainer/webapp/scripts/load_examples.py +++ b/medcat-trainer/webapp/scripts/load_examples.py @@ -1,23 +1,35 @@ import os import sys import logging +from pathlib import Path import pandas as pd import requests from time import sleep import json +# Ensure the parent directory of the `scripts` package is on sys.path so that +# `from scripts....` imports work both when running as a module +# (python -m scripts.load_examples) and when executing the file directly +# (python /path/to/scripts/load_examples.py). +# SCRIPTS_PARENT = Path(__file__).resolve().parent.parent +# if str(SCRIPTS_PARENT) not in sys.path: +# sys.path.insert(0, str(SCRIPTS_PARENT)) + +from scripts.provisioning import load_example_projects_config, ProvisioningConfig +from scripts.provisioning.model import ProvisioningProjectSpec + # Set up logging with prefix including process ID pid = os.getpid() -logging.basicConfig( - level=logging.INFO, - format=f'[load_examples.py pid:{pid}] %(message)s' -) +logging.basicConfig(level=logging.INFO, format=f"[load_examples.py pid:{pid}] %(message)s") logger = logging.getLogger(__name__) +# Default path to provisioning YAML (when LOAD_EXAMPLES_CONFIG is not set). +_DEFAULT_PROVISIONING_PATH = Path(__file__).resolve().parent / "provisioning" / "example_projects.provisioning.yaml" + def get_keycloak_access_token(): - logger.info('Getting Keycloak access token...') + logger.info("Getting Keycloak access token...") keycloak_url = os.environ.get("KEYCLOAK_URL", "http://keycloak.cogstack.localhost") realm = os.environ.get("KEYCLOAK_REALM", "cogstack-realm") client_id = os.environ.get("KEYCLOAK_CLIENT_ID", "cogstack-medcattrainer-frontend") @@ -31,7 +43,7 @@ def get_keycloak_access_token(): "client_id": client_id, "username": username, "password": password, - "scope": "openid profile email" + "scope": "openid profile email", } resp = requests.post(token_url, data=data) @@ -39,136 +51,189 @@ def get_keycloak_access_token(): return resp.json()["access_token"] -def main(port=8001, - model_pack_tmp_file='/home/model_pack.zip', - dataset_tmp_file='/home/ds.csv', - initial_wait=15): - - logger.info('Checking for environment variable LOAD_EXAMPLES...') - val = os.environ.get('LOAD_EXAMPLES') - if val is not None and val not in ('1', 'true', 't', 'y'): - logger.info('Found Env Var LOAD_EXAMPLES is False, not loading example data, cdb, vocab and project') - return - - logger.info('Found Env Var LOAD_EXAMPLES, waiting 15 seconds for API to be ready...') - URL = os.environ.get('API_URL', f'http://localhost:{port}/api/') - sleep(initial_wait) - - logger.info('Checking for default projects / datasets / CDBs / Vocabs') +def wait_for_api_ready(api_url: str, max_wait_seconds: int = 300, interval: int = 5) -> None: + """Poll api_url/health/ready/?format=json until 200 or max_wait_seconds. Exits with 1 on timeout.""" + health_ready_url = f"{api_url}health/ready/?format=json" + waited = 0 + while waited < max_wait_seconds: + try: + if requests.get(health_ready_url).status_code == 200: + logger.info("API health/ready returned 200") + return + except (ConnectionRefusedError, requests.exceptions.ConnectionError): + pass + logger.info( + f"API {health_ready_url} not ready yet, retrying in {interval}s ({waited + interval}/{max_wait_seconds})") + sleep(interval) + waited += interval + logger.error(f"FATAL - API ${health_ready_url} did not return 200 within {max_wait_seconds}s. Exiting.") + sys.exit(1) + + +def get_headers(url: str) -> dict: + """ + Return auth headers for the API: Bearer token (OIDC) if USE_OIDC is set, + otherwise Token from DRF api-token-auth. Returns None if DRF auth fails. + """ + use_oidc = os.environ.get("USE_OIDC") + logger.debug("Checking for environment variable USE_OIDC...") + if use_oidc is not None and use_oidc in "1": + logger.info("Found environment variable USE_OIDC is set to truthy value. Will load data using JWT") + token = get_keycloak_access_token() + return {"Authorization": f"Bearer {token}"} + logger.info("Getting DRF auth token ...") + payload = {"username": "admin", "password": "admin"} + resp = requests.post(f"{url}api-token-auth/", json=payload) + if resp.status_code != 200: + raise RuntimeError(f"Failed to get DRF auth token: {resp.status_code} {resp.text}") + return {"Authorization": f"Token {json.loads(resp.text)['token']}"} + + +def run_provisioning( + provisioning_config: ProvisioningConfig, + api_url: str, + model_pack_tmp_file: str = "/home/model_pack.zip", + dataset_tmp_file: str = "/home/ds.csv", +) -> None: + """ + Wait for the API, then create projects from provisioning_config. + Exits with code 1 on max retries or API not ready. Unit tests can call this + with a ProvisioningConfig instance instead of reading from file. + """ + wait_for_api_ready(api_url) + + logger.info("Checking for default projects / datasets / CDBs / Vocabs") max_retries = 60 # 60 retries = 5 minutes retry_count = 0 while retry_count < max_retries: try: - # check API is available - if requests.get(URL).status_code == 200: - - use_oidc = os.environ.get('USE_OIDC') - logger.info('Checking for environment variable USE_OIDC...') - if use_oidc is not None and use_oidc in '1': - logger.info('Found environment variable USE_OIDC is set to truthy value. Will load data using JWT') - token = get_keycloak_access_token() - headers = { - 'Authorization': f'Bearer {token}', - } - else: - # check API default username and pass are available. - logger.info('Getting DRF auth token ...') - payload = {"username": "admin", "password": "admin"} - resp = requests.post(f"{URL}api-token-auth/", json=payload) - if resp.status_code != 200: - break - - headers = { - 'Authorization': f'Token {json.loads(resp.text)["token"]}', - } - - # check concepts DB, vocab, datasets and projects are empty - resp_model_packs = requests.get(f'{URL}modelpacks/', headers=headers) - resp_ds = requests.get(f'{URL}datasets/', headers=headers) - resp_projs = requests.get(f'{URL}project-annotate-entities/', headers=headers) - all_resps = [resp_model_packs, resp_ds, resp_projs] - - codes = [r.status_code == 200 for r in all_resps] - - if all(codes) and all(len(r.text) > 0 and json.loads(r.text)['count'] == 0 for r in all_resps): - logger.info("Found No Objects. Populating Example: Model Pack, Dataset and Project...") - # download example model pack and dataset - logger.info("Downloading example model pack...") - model_pack_file = requests.get( - 'https://trainer-example-data.s3.eu-north-1.amazonaws.com/medcat2_model_pack_0f66077250cc2957.zip') - with open(model_pack_tmp_file, 'wb') as f: + headers = get_headers(api_url) + + resp_model_packs = requests.get(f"{api_url}modelpacks/", headers=headers) + resp_ds = requests.get(f"{api_url}datasets/", headers=headers) + resp_projs = requests.get(f"{api_url}project-annotate-entities/", headers=headers) + all_resps = [resp_model_packs, resp_ds, resp_projs] + codes = [r.status_code == 200 for r in all_resps] + + if not (all(codes) and all(len(r.text) > 0 and json.loads(r.text)["count"] == 0 for r in all_resps)): + logger.info( + "Found at least one object amongst model packs, datasets & projects. Skipping example creation" + ) + break + + logger.info("Found No Objects. Populating Example: Model Pack, Dataset and Project...") + for spec in provisioning_config.projects: + if not spec.project.use_model_service: + logger.info(f"Downloading example model pack from {spec.model_pack.url}") + model_pack_file = requests.get(spec.model_pack.url) + with open(model_pack_tmp_file, "wb") as f: f.write(model_pack_file.content) - logger.info("Downloading example dataset") - ds = requests.get('https://trainer-example-data.s3.eu-north-1.amazonaws.com/dr_notes.csv') - with open(dataset_tmp_file, 'w') as f: - f.write(ds.text) + logger.info(f"Downloading example dataset from {spec.dataset.url}") + ds = requests.get(spec.dataset.url) + with open(dataset_tmp_file, "w") as f: + f.write(ds.text) - ds_dict = pd.read_csv(dataset_tmp_file).loc[:, ['name', 'text']].to_dict() - create_example_project(URL, headers, model_pack_tmp_file, 'M-IV_NeuroNotes', ds_dict, - 'Example Project - Model Pack (Diseases / Symptoms / Findings)') + ds_dict = pd.read_csv(dataset_tmp_file).loc[:, ["name", "text"]].to_dict() + create_example_project(api_url, headers, spec, model_pack_tmp_file, ds_dict) - # clean up temp files + if not spec.project.use_model_service: os.remove(model_pack_tmp_file) - os.remove(dataset_tmp_file) - break - else: - logger.info('Found at least one object amongst model packs, datasets & projects. Skipping example creation') - break + os.remove(dataset_tmp_file) + break + except ConnectionRefusedError: retry_count += 1 if retry_count < max_retries: logger.info( - f'Loading examples - Connection refused to {URL}. Retrying in 5 seconds... (attempt {retry_count}/{max_retries})') + f"Loading examples - Connection refused to {api_url}. Retrying in 5 seconds... (attempt {retry_count}/{max_retries})" + ) sleep(5) continue except requests.exceptions.ConnectionError: retry_count += 1 if retry_count < max_retries: logger.info( - f'Loading examples - Connection error to {URL}. Retrying in 5 seconds... (attempt {retry_count}/{max_retries})') + f"Loading examples - Connection error to {api_url}. Retrying in 5 seconds... (attempt {retry_count}/{max_retries})" + ) sleep(5) continue - # If we exited the loop due to max retries, exit with error code if retry_count >= max_retries: - logger.error(f'FATAL - Error loading examples. Max retries ({max_retries}) reached. Exiting with code 1.') + logger.error(f"FATAL - Error loading examples. Max retries ({max_retries}) reached. Exiting with code 1.") sys.exit(1) - logger.info('Successfully loaded examples') - + logger.info("Successfully loaded examples") + + +def create_example_project(url, headers, spec: ProvisioningProjectSpec, model_pack_tmp_file, ds_dict): + """Create dataset and project. Branch only on spec.project.use_model_service.""" + if not spec.project.use_model_service: + logger.info("Creating Model Pack / Dataset / Project in the Trainer") + res_model_pack_mk = requests.post( + f"{url}modelpacks/", + headers=headers, + data={"name": spec.model_pack.name}, + files={"model_pack": open(model_pack_tmp_file, "rb")}, + ) + model_pack_id = json.loads(res_model_pack_mk.text)["id"] + else: + logger.info("Creating Dataset / Project (remote model service) in the Trainer") + model_pack_id = None -def create_example_project(url, headers, model_pack, ds_name, ds_dict, project_name): - logger.info('Creating Model Pack / Dataset / Project in the Trainer') - res_model_pack_mk = requests.post(f'{url}modelpacks/', headers=headers, - data={'name': 'Example Model Pack'}, - files={'model_pack': open(model_pack, 'rb')}) - model_pack_id = json.loads(res_model_pack_mk.text)['id'] - - # Upload the dataset payload = { - 'dataset_name': ds_name, - 'dataset': ds_dict, - 'description': 'Clinical texts from MIMIC-IV' + "dataset_name": spec.dataset.name, + "dataset": ds_dict, + "description": spec.dataset.description, } - resp = requests.post(f'{url}create-dataset/', json=payload, headers=headers) - ds_id = json.loads(resp.text)['dataset_id'] + resp = requests.post(f"{url}create-dataset/", json=payload, headers=headers) + ds_id = json.loads(resp.text)["dataset_id"] - user_id = json.loads(requests.get(f'{url}users/', headers=headers).text)['results'][0]['id'] + user_id = json.loads(requests.get(f"{url}users/", headers=headers).text)["results"][0]["id"] - # Create the project payload = { - 'name': project_name, - 'description': 'Example projects using example psychiatric clinical notes from ' - 'https://www.mtsamples.com/', - 'cuis': '', - 'annotation_guideline_link': 'https://docs.google.com/document/d/1xxelBOYbyVzJ7vLlztP2q1Kw9F5Vr1pRwblgrXPS7QM/edit?usp=sharing', - 'dataset': ds_id, - 'model_pack': model_pack_id, - 'members': [user_id] + "name": spec.project.name, + "description": spec.project.description, + "cuis": "", + "annotation_guideline_link": spec.project.annotation_guideline_link, + "dataset": ds_id, + "members": [user_id], } - requests.post(f'{url}project-annotate-entities/', json=payload, headers=headers) - logger.info('Successfully created the example project') + if not spec.project.use_model_service: + payload["model_pack"] = model_pack_id + else: + payload["use_model_service"] = True + payload["model_service_url"] = spec.project.model_service_url + + requests.post(f"{url}project-annotate-entities/", json=payload, headers=headers) + logger.info("Successfully created the example project") + + +def main( + port: int = 8001, + model_pack_tmp_file: str = "/home/model_pack.zip", + dataset_tmp_file: str = "/home/ds.csv", +) -> None: + """Entrypoint: check LOAD_EXAMPLES, load config from file, then run_provisioning.""" + logger.info("Checking for environment variable LOAD_EXAMPLES...") + val = os.environ.get("LOAD_EXAMPLES") + if val is not None and val not in ("1", "true", "t", "y"): + logger.info("Found Env Var LOAD_EXAMPLES is False, not loading example data, cdb, vocab and project") + return + + config_path = Path(os.environ.get("PROVISIONING_CONFIG_PATH") or _DEFAULT_PROVISIONING_PATH) + if not config_path.is_file(): + logger.error( + f"FATAL - Provisioning config not found: {config_path}. Set PROVISIONING_CONFIG_PATH or add the YAML file." + ) + sys.exit(1) + provisioning_config = load_example_projects_config(config_path) + logger.info(f"Loaded provisioning config from {config_path} ({len(provisioning_config.projects)} project(s))") + + api_url = os.environ.get("API_URL") or f"http://localhost:{port}/api/" + logger.info("Found Env Var LOAD_EXAMPLES, waiting for API to be ready...") + + run_provisioning(provisioning_config, api_url, model_pack_tmp_file, dataset_tmp_file) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/medcat-trainer/webapp/scripts/provisioning/__init__.py b/medcat-trainer/webapp/scripts/provisioning/__init__.py new file mode 100644 index 000000000..635caf02b --- /dev/null +++ b/medcat-trainer/webapp/scripts/provisioning/__init__.py @@ -0,0 +1,16 @@ +"""Provisioning: YAML config for example projects to load on startup.""" + +from pathlib import Path + +import yaml +from .model import ProvisioningConfig + + +def load_example_projects_config(path: str | Path) -> ProvisioningConfig: + """Load and validate the example-projects YAML config from a file path.""" + + with open(path) as f: + data = yaml.safe_load(f) + if data is None: + raise ValueError(f"Empty or invalid YAML: {path}") + return ProvisioningConfig.model_validate(data) diff --git a/medcat-trainer/webapp/scripts/provisioning/example_projects.provisioning.yaml b/medcat-trainer/webapp/scripts/provisioning/example_projects.provisioning.yaml new file mode 100644 index 000000000..693f876ef --- /dev/null +++ b/medcat-trainer/webapp/scripts/provisioning/example_projects.provisioning.yaml @@ -0,0 +1,28 @@ +# Example config for projects to load on startup. + +projects: + # Example using an uploaded model pack (useModelService defaults to false on project) + - modelPack: + name: "Example Model Pack" + url: "https://trainer-example-data.s3.eu-north-1.amazonaws.com/medcat2_model_pack_0f66077250cc2957.zip" + dataset: + name: "M-IV_NeuroNotes" + url: "https://trainer-example-data.s3.eu-north-1.amazonaws.com/dr_notes.csv" + description: "Clinical texts from MIMIC-IV" + project: + name: "Example Project - Model Pack (Diseases / Symptoms / Findings)" + description: "Example projects using example psychiatric clinical notes from https://www.mtsamples.com/" + annotationGuidelineLink: "https://docs.google.com/document/d/1xxelBOYbyVzJ7vLlztP2q1Kw9F5Vr1pRwblgrXPS7QM/edit?usp=sharing" + + # Example using remote MedCAT service API (no model pack upload). + # Uncomment and set modelServiceUrl under project to your MedCAT service (e.g. http://medcat-service:8000). + # - dataset: + # name: "M-IV_NeuroNotes-Remote" + # url: "https://trainer-example-data.s3.eu-north-1.amazonaws.com/dr_notes.csv" + # description: "Clinical texts from MIMIC-IV (remote model)" + # project: + # name: "Example Project - Remote Model Service" + # description: "Uses remote MedCAT service for document processing (no local model pack)." + # annotationGuidelineLink: "https://docs.google.com/document/d/1xxelBOYbyVzJ7vLlztP2q1Kw9F5Vr1pRwblgrXPS7QM/edit?usp=sharing" + # useModelService: true + # modelServiceUrl: "http://medcat-service:8000" diff --git a/medcat-trainer/webapp/scripts/provisioning/model.py b/medcat-trainer/webapp/scripts/provisioning/model.py new file mode 100644 index 000000000..bad869bf7 --- /dev/null +++ b/medcat-trainer/webapp/scripts/provisioning/model.py @@ -0,0 +1,90 @@ +""" +Pydantic models for the provisioning YAML config. + +The user provides a YAML file that describes which projects to load on startup. +Nested structure (modelPack, dataset, project) matches where fields are used +when POSTing to the API (modelpacks/, create-dataset/, project-annotate-entities/). + +Python fields are snake_case; YAML keys are camelCase via alias_generator. +""" + +from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic.alias_generators import to_camel + +_common_config = ConfigDict( + alias_generator=to_camel, + validate_by_name=True, + validate_by_alias=True, +) + + +class ModelPackSpec(BaseModel): + """Model pack to upload to modelpacks/.""" + + name: str = Field(description="Display name for the model pack") + url: str = Field(description="URL of the model pack .zip to download") + + +class DatasetSpec(BaseModel): + """Dataset to upload via create-dataset/ (name, description, data from url).""" + + name: str = Field(description="Display name for the dataset") + url: str = Field(description="URL of the dataset CSV to download") + description: str = Field(description="Dataset description") + + +class ProjectSpec(BaseModel): + """Project to create via project-annotate-entities/.""" + + model_config = _common_config + + name: str = Field(description="Name of the created project") + description: str = Field(description="Project description") + annotation_guideline_link: str = Field(description="URL to annotation guidelines") + use_model_service: bool = Field( + default=False, + description="Use remote MedCAT service API for document processing instead of local models.", + ) + model_service_url: str | None = Field( + default=None, + description="URL of the remote MedCAT service API (e.g. http://medcat-service:8000). Required when use_model_service is True.", + ) + + +class ProvisioningProjectSpec(BaseModel): + """ + Spec for one example project to be loaded on startup. + Either provide model_pack (project uses uploaded model), or set project.use_model_service=True + and project.model_service_url (remote MedCAT service API for document processing). + """ + + model_config = _common_config + + model_pack: ModelPackSpec | None = Field( + default=None, + description="Model pack to upload (name + url). Required when project.use_model_service is False.", + ) + dataset: DatasetSpec = Field() + project: ProjectSpec = Field() + + @model_validator(mode="after") + def exactly_one_model_source(self): + if self.project.use_model_service: + if not self.project.model_service_url or not self.project.model_service_url.strip(): + raise ValueError("model_service_url is required when use_model_service is True") + if self.model_pack is not None: + raise ValueError("Do not set model_pack when use_model_service is True") + else: + if self.model_pack is None: + raise ValueError("model_pack is required when use_model_service is False") + return self + + +class ProvisioningConfig(BaseModel): + """Root config: list of example projects to load on startup.""" + + model_config = _common_config + + projects: list[ProvisioningProjectSpec] = Field( + description="List of example project specs to load", + ) diff --git a/medcat-trainer/webapp/scripts/run.sh b/medcat-trainer/webapp/scripts/run.sh index bb156cddd..9e343148c 100755 --- a/medcat-trainer/webapp/scripts/run.sh +++ b/medcat-trainer/webapp/scripts/run.sh @@ -35,7 +35,8 @@ if not User.objects.filter(username=admin_username).exists(): if [ $LOAD_EXAMPLES ]; then echo "Loading examples..." - uv run python /home/scripts/load_examples.py >> /dev/stdout 2>> /dev/stderr & + cd /home + uv run python -m scripts.load_examples >> /dev/stdout 2>> /dev/stderr & fi # Creating a default user group that can manage projects and annotate but not delete