Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@
get_fleet_placement_group_models,
get_placement_group_model_for_job,
placement_group_model_to_placement_group_optional,
schedule_fleet_placement_groups_deletion,
)
from dstack._internal.server.services.runs import (
run_model_to_run,
Expand Down Expand Up @@ -481,17 +480,15 @@ async def _process_submitted_job(
logger.info("%s: provisioned %s new instance(s)", fmt(job_model), len(provisioned_jobs))
provisioned_job_models = _get_job_models_for_jobs(run_model.jobs, provisioned_jobs)
instance = None # Instance for attaching volumes in case of single job provisioned
# FIXME: Fleet is not locked which may lead to duplicate instance_num.
# This is currently hard to fix without locking the fleet for entire provisioning duration.
# Processing should be done in multiple steps so that
# InstanceModel is created before provisioning.
taken_instance_nums = await _get_taken_instance_nums(session, fleet_model)
for provisioned_job_model, jpd in zip(provisioned_job_models, jpds):
provisioned_job_model.job_provisioning_data = jpd.json()
switch_job_status(session, provisioned_job_model, JobStatus.PROVISIONING)
# FIXME: Fleet is not locked which may lead to duplicate instance_num.
# This is currently hard to fix without locking the fleet for entire provisioning duration.
# Processing should be done in multiple steps so that
# InstanceModel is created before provisioning.
instance_num = await _get_next_instance_num(
session=session,
fleet_model=fleet_model,
)
instance_num = get_next_instance_num(taken_instance_nums)
instance = _create_instance_model_for_job(
project=project,
fleet_model=fleet_model,
Expand All @@ -502,6 +499,7 @@ async def _process_submitted_job(
instance_num=instance_num,
profile=effective_profile,
)
taken_instance_nums.add(instance_num)
provisioned_job_model.job_runtime_data = _prepare_job_runtime_data(
offer, multinode
).json()
Expand Down Expand Up @@ -847,15 +845,9 @@ async def _run_jobs_on_new_instances(
finally:
if fleet_model is not None and len(fleet_model.instances) == 0:
# Clean up placement groups that did not end up being used.
# Flush to update still uncommitted placement groups.
await session.flush()
await schedule_fleet_placement_groups_deletion(
session=session,
fleet_id=fleet_model.id,
except_placement_group_ids=(
[placement_group_model.id] if placement_group_model is not None else []
),
)
for pg in placement_group_models:
if placement_group_model is None or pg.id != placement_group_model.id:
pg.fleet_deleted = True
return None


Expand Down Expand Up @@ -906,15 +898,14 @@ async def _create_fleet_model_for_job(
return fleet_model


async def _get_next_instance_num(session: AsyncSession, fleet_model: FleetModel) -> int:
async def _get_taken_instance_nums(session: AsyncSession, fleet_model: FleetModel) -> set[int]:
res = await session.execute(
select(InstanceModel.instance_num).where(
InstanceModel.fleet_id == fleet_model.id,
InstanceModel.deleted.is_(False),
)
)
taken_instance_nums = set(res.scalars().all())
return get_next_instance_num(taken_instance_nums)
return set(res.scalars().all())


def _create_instance_model_for_job(
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __init__(self, url: str, engine: Optional[AsyncEngine] = None):
self.session_maker = async_sessionmaker(
bind=self.engine, # type: ignore[assignment]
expire_on_commit=False,
# Disable autoflush to avoid accidental long write transactions on SQLite.
autoflush=False,
class_=AsyncSession,
)

Expand Down
4 changes: 2 additions & 2 deletions src/tests/_internal/server/services/test_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def test_includes_termination_reason_in_event_messages_only_once(
instance.termination_reason_message = "Some err"
instances_services.switch_instance_status(session, instance, InstanceStatus.TERMINATING)
instances_services.switch_instance_status(session, instance, InstanceStatus.TERMINATED)

await session.commit()
events = await list_events(session)
assert len(events) == 2
assert {e.message for e in events} == {
Expand All @@ -61,7 +61,7 @@ async def test_includes_termination_reason_in_event_message_when_switching_direc
instance.termination_reason = InstanceTerminationReason.ERROR
instance.termination_reason_message = "Some err"
instances_services.switch_instance_status(session, instance, InstanceStatus.TERMINATED)

await session.commit()
events = await list_events(session)
assert len(events) == 1
assert events[0].message == (
Expand Down