From 95103778dd6d6d279d9b3c9f83ff49ea5920c6e5 Mon Sep 17 00:00:00 2001 From: Oleg Broytman Date: Sun, 19 Mar 2017 19:17:54 +0300 Subject: [PATCH] Separate split INSERTs with newlines --- sqlconvert/process_mysql.py | 28 ++++++++++++++++++---------- tests/test_process_mysql.py | 4 ++-- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/sqlconvert/process_mysql.py b/sqlconvert/process_mysql.py index c375f79..5c2924a 100644 --- a/sqlconvert/process_mysql.py +++ b/sqlconvert/process_mysql.py @@ -1,5 +1,6 @@ -from sqlparse.sql import Comment, Function, Identifier, Parenthesis, Statement +from sqlparse.sql import Comment, Function, Identifier, Parenthesis, \ + Statement, Token from sqlparse import tokens as T from .process_tokens import escape_strings, is_comment_or_space @@ -81,11 +82,14 @@ def split_ext_insert(statement): """Split extended INSERT into multiple standard INSERTs""" insert_tokens = [] values_tokens = [] - last_token = None + end_tokens = [] expected = 'INSERT' for token in statement.tokens: if is_comment_or_space(token): - insert_tokens.append(token) + if expected == 'END': + end_tokens.append(token) + else: + insert_tokens.append(token) continue elif expected == 'INSERT': if (token.ttype is T.DML) and (token.normalized == 'INSERT'): @@ -115,18 +119,22 @@ def split_ext_insert(statement): if token.value == ',': continue elif token.value == ';': - last_token = token - break + end_tokens.append(token) + expected = 'END' + continue raise ValueError( 'SQL syntax error: expected "%s", got %s "%s"' % ( expected, token.ttype, token.normalized)) - for values in values_tokens: + new_line = Token(T.Newline, '\n') + new_lines = [new_line] # Insert newlines between split statements + for i, values in enumerate(values_tokens): + if i == len(values_tokens) - 1: # Last but one statement + # Insert newlines only between split statements but not after + new_lines = [] # The statemnt sets `parent` attribute of the every token to self # but we don't care. - vl = [values] - if last_token: - vl.append(last_token) - statement = Statement(insert_tokens + vl) + statement = Statement(insert_tokens + [values] + + end_tokens + new_lines) yield statement diff --git a/tests/test_process_mysql.py b/tests/test_process_mysql.py index c8b3b8c..df60a64 100644 --- a/tests/test_process_mysql.py +++ b/tests/test_process_mysql.py @@ -78,7 +78,7 @@ def test_split_ext_insert(): stiter = process_statement(parsed) statement = next(stiter) query = tlist2str(statement) - assert query == u"INSERT INTO test VALUES (1, 2);" + assert query == u"INSERT INTO test VALUES (1, 2);\n" statement = next(stiter) query = tlist2str(statement) assert query == u"INSERT INTO test VALUES (3, 4);" @@ -87,7 +87,7 @@ def test_split_ext_insert(): stiter = process_statement(parsed) statement = next(stiter) query = tlist2str(statement) - assert query == u"INSERT INTO test (age, salary) VALUES (1, 2)" + assert query == u"INSERT INTO test (age, salary) VALUES (1, 2)\n" statement = next(stiter) query = tlist2str(statement) assert query == u"INSERT INTO test (age, salary) VALUES (3, 4)" -- 2.39.2