From f5d8655ddfa2d62455dd3f8eddbfb8b87416c4cf Mon Sep 17 00:00:00 2001 From: Avinash Kumar Deepak Date: Sun, 1 Mar 2026 16:37:49 +0530 Subject: [PATCH] fix: atomic writes, secrets ID gen, MAX_CONTENT_LENGTH, expand tests (#354) --- server/controller/workflow.py | 4 +-- server/model/workflows.py | 36 ++++++++++----------- server/server.py | 1 + server/tests/test_workflow_controller.py | 40 ++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 20 deletions(-) diff --git a/server/controller/workflow.py b/server/controller/workflow.py index 2e03e99..ed79375 100644 --- a/server/controller/workflow.py +++ b/server/controller/workflow.py @@ -30,7 +30,7 @@ def getAllActionHash(root): def postWorkflow(): try: lastestHash = getLasteshActionHash(ET.fromstring(request.data)) - except: + except Exception: return "Invalid GraphML", 400 graphML = request.data.decode('utf') return workFlowModel.insert(graphML, lastestHash) @@ -60,7 +60,7 @@ def updateWorkflow(serverID): latestHash = getLasteshActionHash(root) if(not forceUpdate): allHash = getAllActionHash(root) - except: + except Exception: return "Invalid GraphML", 400 graphML = request.data.decode('utf') if(forceUpdate): diff --git a/server/model/workflows.py b/server/model/workflows.py index bfff7f0..0fbba2f 100644 --- a/server/model/workflows.py +++ b/server/model/workflows.py @@ -1,11 +1,11 @@ from pymongo import MongoClient -from pymongo import MongoClient +from pymongo.errors import DuplicateKeyError import time from bson.objectid import ObjectId from bson.errors import InvalidId import os import xml.etree.ElementTree as ET -import random +import secrets import string from dotenv import load_dotenv load_dotenv() @@ -14,20 +14,21 @@ class WorkFlowModel: def __init__(self) -> None: self.collection = MongoClient(os.getenv('MongoURL'))[ os.getenv('dbName')][os.getenv('tableName')] + self.collection.create_index('serverID', unique=True) def get_random_string(self, length): letters = string.ascii_letters+string.digits - return ''.join(random.choice(letters) for i in range(length)) + return ''.join(secrets.choice(letters) for i in range(length)) def insert(self, graphml, latestHash): - serverID = "" while(True): serverID = self.get_random_string(6) - if(not self.collection.find_one({'serverID': serverID})): - break - self.collection.insert_one( - {'graphml': graphml, 'latestHash': latestHash, 'serverID': serverID}) - return serverID + try: + self.collection.insert_one( + {'graphml': graphml, 'latestHash': latestHash, 'serverID': serverID}) + return serverID + except DuplicateKeyError: + continue def get(self, serverID): cl = self.collection.find_one({'serverID': serverID}) @@ -36,20 +37,19 @@ def get(self, serverID): return cl['graphml'] def update(self, serverID, graphml, latestHash, allHash): - existingRecord = self.collection.find_one({'serverID': serverID}) + existingRecord = self.collection.find_one_and_update( + {'serverID': serverID, 'latestHash': {'$in': allHash}}, + {"$set": {'graphml': graphml, 'latestHash': latestHash}}) if existingRecord is None: - return False, 'serverID do not exists.' - latestExistingHash = existingRecord['latestHash'] - if latestExistingHash not in allHash: + if not self.collection.find_one({'serverID': serverID}): + return False, 'serverID do not exists.' return False, 'Can not update as provided graph do not has latest changes.' - self.collection.update_one({'serverID': serverID}, { - "$set": {'graphml': graphml, 'latestHash': latestHash}}) return True, latestHash def forceUpdate(self, serverID, graphml, latestHash): - existingRecord = self.collection.find_one({'serverID': serverID}) + existingRecord = self.collection.find_one_and_update( + {'serverID': serverID}, + {"$set": {'graphml': graphml, 'latestHash': latestHash}}) if existingRecord is None: return False, 'serverID do not exists.' - self.collection.update_one({'serverID': serverID}, - {"$set": {'graphml': graphml, 'latestHash': latestHash}}) return True, latestHash diff --git a/server/server.py b/server/server.py index 721874f..4aa2b3e 100644 --- a/server/server.py +++ b/server/server.py @@ -6,6 +6,7 @@ load_dotenv() app = Flask(__name__) +app.config['MAX_CONTENT_LENGTH'] = 2 * 1024 * 1024 CORS(app) app.register_blueprint(workFlow, url_prefix='/workflow') diff --git a/server/tests/test_workflow_controller.py b/server/tests/test_workflow_controller.py index 4966d60..38c5ab0 100644 --- a/server/tests/test_workflow_controller.py +++ b/server/tests/test_workflow_controller.py @@ -22,6 +22,15 @@ def __init__(self, graph_response): def get(self, _server_id): return self.graph_response + def insert(self, graphml, latestHash): + return 'test01' + + def update(self, serverID, graphml, latestHash, allHash): + return (True, latestHash) + + def forceUpdate(self, serverID, graphml, latestHash): + return (True, latestHash) + class WorkflowControllerTests(unittest.TestCase): @classmethod @@ -77,6 +86,37 @@ def test_hash_header_returns_200_for_matching_history(self): self.assertEqual(response.status_code, 200) self.assertEqual(response.get_data(as_text=True), VALID_GRAPHML) + def test_post_workflow_returns_server_id(self): + client = self.make_client(None) + response = client.post('/workflow/', data=VALID_GRAPHML, + content_type='application/xml') + self.assertEqual(response.status_code, 200) + self.assertEqual(response.get_data(as_text=True), 'test01') + + def test_post_workflow_invalid_xml_returns_400(self): + client = self.make_client(None) + response = client.post('/workflow/', data=b'not xml', + content_type='application/xml') + self.assertEqual(response.status_code, 400) + + def test_update_workflow_returns_200(self): + client = self.make_client(None) + response = client.post('/workflow/test01', data=VALID_GRAPHML, + content_type='application/xml') + self.assertEqual(response.status_code, 200) + + def test_update_workflow_invalid_xml_returns_400(self): + client = self.make_client(None) + response = client.post('/workflow/test01', data=b'not xml', + content_type='application/xml') + self.assertEqual(response.status_code, 400) + + def test_force_update_workflow_returns_200(self): + client = self.make_client(None) + response = client.post('/workflow/test01?force=true', data=VALID_GRAPHML, + content_type='application/xml') + self.assertEqual(response.status_code, 200) + if __name__ == '__main__': unittest.main()