Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions sqlparse/engine/statement_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ def _change_splitlevel(self, ttype, value):
self._in_case = False
return -1

if (unified in ('IF', 'FOR', 'WHILE', 'CASE')
if unified == 'CASE':
self._in_case = True
return 1

if (unified in ('IF', 'FOR', 'WHILE')
and self._is_create and self._begin_depth > 0):
if unified == 'CASE':
self._in_case = True
return 1

if unified in ('END IF', 'END FOR', 'END WHILE'):
Expand Down
31 changes: 31 additions & 0 deletions tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,34 @@ def test_split_begin_transaction_formatted(): # issue826
assert stmts[1].startswith('DELETE')
assert stmts[2].startswith('INSERT')
assert stmts[3] == 'END\nTRANSACTION;'


def test_splitlevel_case_end():
# CASE in a plain SELECT did not increment the level, but its matching END
# decremented it unconditionally. This led to levels being wrong after the
# CASE WHEN ... END block.
s = sqlparse.engine.statement_splitter.StatementSplitter()
level = s.level

token_stream = [
(sqlparse.tokens.Keyword.DML, 'SELECT'),
(sqlparse.tokens.Keyword, 'CASE'),
(sqlparse.tokens.Keyword, 'WHEN'),
(sqlparse.tokens.Name, 'foo'),
(sqlparse.tokens.Keyword, 'THEN'),
(sqlparse.tokens.Number, '1'),
(sqlparse.tokens.Keyword, 'END'),
(sqlparse.tokens.Keyword, 'FROM'),
(sqlparse.tokens.Name, 't'),
]

for ttype, value in token_stream:
level += s._change_splitlevel(ttype, value)

assert level == 0

# This issue could lead to incorrectly treating a semicolon inside a text
# literal as a statement terminator and incorrectly splitting the query.
assert len(sqlparse.parse(
"SELECT CASE WHEN 1 THEN 2 END, test IN ('foo \\', 'foo;') FROM t"
)) == 1