]> git.phdru.name Git - sqlconvert.git/commitdiff
Split extended INSERTs
authorOleg Broytman <phd@phdru.name>
Fri, 17 Mar 2017 16:19:41 +0000 (19:19 +0300)
committerOleg Broytman <phd@phdru.name>
Sat, 18 Mar 2017 20:56:34 +0000 (23:56 +0300)
sqlconvert/process_mysql.py
tests/test_process_mysql.py

index 9e342c2e04ec7c5ad4ba1796731932320762e5f3..c375f797a91d12b7a34e9fcde629ea988f2d9252 100644 (file)
@@ -1,5 +1,5 @@
 
-from sqlparse.sql import Comment
+from sqlparse.sql import Comment, Function, Identifier, Parenthesis, Statement
 from sqlparse import tokens as T
 from .process_tokens import escape_strings, is_comment_or_space
 
@@ -77,10 +77,66 @@ def is_insert(statement):
         return (token.ttype is T.DML) and (token.normalized == 'INSERT')
 
 
+def split_ext_insert(statement):
+    """Split extended INSERT into multiple standard INSERTs"""
+    insert_tokens = []
+    values_tokens = []
+    last_token = None
+    expected = 'INSERT'
+    for token in statement.tokens:
+        if is_comment_or_space(token):
+            insert_tokens.append(token)
+            continue
+        elif expected == 'INSERT':
+            if (token.ttype is T.DML) and (token.normalized == 'INSERT'):
+                insert_tokens.append(token)
+                expected = 'INTO'
+                continue
+        elif expected == 'INTO':
+            if (token.ttype is T.Keyword) and (token.normalized == 'INTO'):
+                insert_tokens.append(token)
+                expected = 'TABLE_NAME'
+                continue
+        elif expected == 'TABLE_NAME':
+            if isinstance(token, (Function, Identifier)):
+                insert_tokens.append(token)
+                expected = 'VALUES'
+                continue
+        elif expected == 'VALUES':
+            if (token.ttype is T.Keyword) and (token.normalized == 'VALUES'):
+                insert_tokens.append(token)
+                expected = 'VALUES_OR_SEMICOLON'
+                continue
+        elif expected == 'VALUES_OR_SEMICOLON':
+            if isinstance(token, Parenthesis):
+                values_tokens.append(token)
+                continue
+            elif token.ttype is T.Punctuation:
+                if token.value == ',':
+                    continue
+                elif token.value == ';':
+                    last_token = token
+                    break
+        raise ValueError(
+            'SQL syntax error: expected "%s", got %s "%s"' % (
+                expected, token.ttype, token.normalized))
+    for values in values_tokens:
+        # The statemnt sets `parent` attribute of the every token to self
+        # but we don't care.
+        vl = [values]
+        if last_token:
+            vl.append(last_token)
+        statement = Statement(insert_tokens + vl)
+        yield statement
+
+
 def process_statement(statement, quoting_style='sqlite'):
     requote_names(statement)
     unescape_strings(statement)
     remove_directive_tokens(statement)
     escape_strings(statement, quoting_style)
-    yield statement
-    return
+    if is_insert(statement):
+        for statement in split_ext_insert(statement):
+            yield statement
+    else:
+        yield statement
index 45e5020de20fdc91f8b80f1fad67cd1e57f02939..c8b3b8cd48f3557abb6c2b6832a13aa550b2a654 100644 (file)
@@ -57,11 +57,42 @@ def test_is_insert():
     parsed = parse("select /*! test */ * from /* test */ `T`")[0]
     statement = next(process_statement(parsed))
     assert not is_insert(statement)
+
     parsed = parse("insert into test values ('\"te\\'st\\\"\\n')")[0]
     statement = next(process_statement(parsed))
     assert is_insert(statement)
 
 
+def test_split_ext_insert():
+    parsed = parse("insert into test values (1, 2)")[0]
+    statement = next(process_statement(parsed))
+    query = tlist2str(statement)
+    assert query == u"INSERT INTO test VALUES (1, 2)"
+
+    parsed = parse("insert into test (age, salary) values (1, 2);")[0]
+    statement = next(process_statement(parsed))
+    query = tlist2str(statement)
+    assert query == u"INSERT INTO test (age, salary) VALUES (1, 2);"
+
+    parsed = parse("insert into test values (1, 2), (3, 4);")[0]
+    stiter = process_statement(parsed)
+    statement = next(stiter)
+    query = tlist2str(statement)
+    assert query == u"INSERT INTO test VALUES  (1, 2);"
+    statement = next(stiter)
+    query = tlist2str(statement)
+    assert query == u"INSERT INTO test VALUES  (3, 4);"
+
+    parsed = parse("insert into test (age, salary) values (1, 2), (3, 4)")[0]
+    stiter = process_statement(parsed)
+    statement = next(stiter)
+    query = tlist2str(statement)
+    assert query == u"INSERT INTO test (age, salary) VALUES  (1, 2)"
+    statement = next(stiter)
+    query = tlist2str(statement)
+    assert query == u"INSERT INTO test (age, salary) VALUES  (3, 4)"
+
+
 def test_process():
     parsed = parse("select /*! test */ * from /* test */ `T`")[0]
     statement = next(process_statement(parsed))