]> git.phdru.name Git - sqlconvert.git/commitdiff
Feat(process_mysql): Get DML type instead of just testing for INSERT
authorOleg Broytman <phd@phdru.name>
Thu, 26 Oct 2017 19:13:33 +0000 (22:13 +0300)
committerOleg Broytman <phd@phdru.name>
Thu, 26 Oct 2017 19:33:10 +0000 (22:33 +0300)
sqlconvert/process_mysql.py
tests/test_process_mysql.py

index 1834855dc9b59aee5261cf065afcab1c58310fac..8dc43e2fdde46cdfb74e494bb450530da7aae108 100644 (file)
@@ -71,11 +71,14 @@ def unescape_strings(token_list):
             token.normalized = token.value = value
 
 
-def is_insert(statement):
+def get_DML_type(statement):
     for token in statement.tokens:
         if is_comment_or_space(token):
             continue
-        return (token.ttype is T.DML) and (token.normalized == 'INSERT')
+        if (token.ttype is T.DML):
+            return token.normalized
+        break
+    raise ValueError("Not a DML statement")
 
 
 def split_ext_insert(statement):
@@ -143,7 +146,11 @@ def process_statement(statement, quoting_style='sqlite'):
     unescape_strings(statement)
     remove_directive_tokens(statement)
     escape_strings(statement, quoting_style)
-    if is_insert(statement):
+    try:
+        is_insert = get_DML_type(statement) == 'INSERT'
+    except ValueError:
+        is_insert = False
+    if is_insert:
         for statement in split_ext_insert(statement):
             yield statement
     else:
index df60a647d5de00a997bbc7248303af33b2ff0b55..656292b2f646ce1a70d9e8f3ecd1552ce6f9bbf8 100644 (file)
@@ -1,11 +1,12 @@
 # -*- coding: utf-8 -*-
 
+import pytest
 from sqlparse import parse
 
 from sqlconvert.print_tokens import tlist2str
 from sqlconvert.process_mysql import remove_directive_tokens, \
         is_directive_statement, requote_names, unescape_strings, \
-        is_insert, process_statement
+        get_DML_type, process_statement
 from sqlconvert.process_tokens import escape_strings
 
 
@@ -53,14 +54,19 @@ def test_escape_string_sqlite():
     assert query == u"INSERT INTO test VALUES ('\"te''st\"\n')"
 
 
-def test_is_insert():
+def test_DML_type():
+    parsed = parse("create table test ();")[0]
+    statement = next(process_statement(parsed))
+    with pytest.raises(ValueError):
+        get_DML_type(statement)
+
     parsed = parse("select /*! test */ * from /* test */ `T`")[0]
     statement = next(process_statement(parsed))
-    assert not is_insert(statement)
+    assert get_DML_type(statement) == "SELECT"
 
     parsed = parse("insert into test values ('\"te\\'st\\\"\\n')")[0]
     statement = next(process_statement(parsed))
-    assert is_insert(statement)
+    assert get_DML_type(statement) == "INSERT"
 
 
 def test_split_ext_insert():