From d830f4bcd21deb078a89d59e4d98a6406ce5661d Mon Sep 17 00:00:00 2001 From: Oleg Broytman Date: Fri, 17 Mar 2017 19:19:41 +0300 Subject: [PATCH] Split extended INSERTs --- sqlconvert/process_mysql.py | 62 +++++++++++++++++++++++++++++++++++-- tests/test_process_mysql.py | 31 +++++++++++++++++++ 2 files changed, 90 insertions(+), 3 deletions(-) diff --git a/sqlconvert/process_mysql.py b/sqlconvert/process_mysql.py index 9e342c2..c375f79 100644 --- a/sqlconvert/process_mysql.py +++ b/sqlconvert/process_mysql.py @@ -1,5 +1,5 @@ -from sqlparse.sql import Comment +from sqlparse.sql import Comment, Function, Identifier, Parenthesis, Statement from sqlparse import tokens as T from .process_tokens import escape_strings, is_comment_or_space @@ -77,10 +77,66 @@ def is_insert(statement): return (token.ttype is T.DML) and (token.normalized == 'INSERT') +def split_ext_insert(statement): + """Split extended INSERT into multiple standard INSERTs""" + insert_tokens = [] + values_tokens = [] + last_token = None + expected = 'INSERT' + for token in statement.tokens: + if is_comment_or_space(token): + insert_tokens.append(token) + continue + elif expected == 'INSERT': + if (token.ttype is T.DML) and (token.normalized == 'INSERT'): + insert_tokens.append(token) + expected = 'INTO' + continue + elif expected == 'INTO': + if (token.ttype is T.Keyword) and (token.normalized == 'INTO'): + insert_tokens.append(token) + expected = 'TABLE_NAME' + continue + elif expected == 'TABLE_NAME': + if isinstance(token, (Function, Identifier)): + insert_tokens.append(token) + expected = 'VALUES' + continue + elif expected == 'VALUES': + if (token.ttype is T.Keyword) and (token.normalized == 'VALUES'): + insert_tokens.append(token) + expected = 'VALUES_OR_SEMICOLON' + continue + elif expected == 'VALUES_OR_SEMICOLON': + if isinstance(token, Parenthesis): + values_tokens.append(token) + continue + elif token.ttype is T.Punctuation: + if token.value == ',': + continue + elif token.value == ';': + last_token = token + break + raise ValueError( + 'SQL syntax error: expected "%s", got %s "%s"' % ( + expected, token.ttype, token.normalized)) + for values in values_tokens: + # The statemnt sets `parent` attribute of the every token to self + # but we don't care. + vl = [values] + if last_token: + vl.append(last_token) + statement = Statement(insert_tokens + vl) + yield statement + + def process_statement(statement, quoting_style='sqlite'): requote_names(statement) unescape_strings(statement) remove_directive_tokens(statement) escape_strings(statement, quoting_style) - yield statement - return + if is_insert(statement): + for statement in split_ext_insert(statement): + yield statement + else: + yield statement diff --git a/tests/test_process_mysql.py b/tests/test_process_mysql.py index 45e5020..c8b3b8c 100644 --- a/tests/test_process_mysql.py +++ b/tests/test_process_mysql.py @@ -57,11 +57,42 @@ def test_is_insert(): parsed = parse("select /*! test */ * from /* test */ `T`")[0] statement = next(process_statement(parsed)) assert not is_insert(statement) + parsed = parse("insert into test values ('\"te\\'st\\\"\\n')")[0] statement = next(process_statement(parsed)) assert is_insert(statement) +def test_split_ext_insert(): + parsed = parse("insert into test values (1, 2)")[0] + statement = next(process_statement(parsed)) + query = tlist2str(statement) + assert query == u"INSERT INTO test VALUES (1, 2)" + + parsed = parse("insert into test (age, salary) values (1, 2);")[0] + statement = next(process_statement(parsed)) + query = tlist2str(statement) + assert query == u"INSERT INTO test (age, salary) VALUES (1, 2);" + + parsed = parse("insert into test values (1, 2), (3, 4);")[0] + stiter = process_statement(parsed) + statement = next(stiter) + query = tlist2str(statement) + assert query == u"INSERT INTO test VALUES (1, 2);" + statement = next(stiter) + query = tlist2str(statement) + assert query == u"INSERT INTO test VALUES (3, 4);" + + parsed = parse("insert into test (age, salary) values (1, 2), (3, 4)")[0] + stiter = process_statement(parsed) + statement = next(stiter) + query = tlist2str(statement) + assert query == u"INSERT INTO test (age, salary) VALUES (1, 2)" + statement = next(stiter) + query = tlist2str(statement) + assert query == u"INSERT INTO test (age, salary) VALUES (3, 4)" + + def test_process(): parsed = parse("select /*! test */ * from /* test */ `T`")[0] statement = next(process_statement(parsed)) -- 2.39.5