token.normalized = token.value = value
-def is_insert(statement):
+def get_DML_type(statement):
for token in statement.tokens:
if is_comment_or_space(token):
continue
- return (token.ttype is T.DML) and (token.normalized == 'INSERT')
+ if (token.ttype is T.DML):
+ return token.normalized
+ break
+ raise ValueError("Not a DML statement")
def split_ext_insert(statement):
unescape_strings(statement)
remove_directive_tokens(statement)
escape_strings(statement, quoting_style)
- if is_insert(statement):
+ try:
+ is_insert = get_DML_type(statement) == 'INSERT'
+ except ValueError:
+ is_insert = False
+ if is_insert:
for statement in split_ext_insert(statement):
yield statement
else:
# -*- coding: utf-8 -*-
+import pytest
from sqlparse import parse
from sqlconvert.print_tokens import tlist2str
from sqlconvert.process_mysql import remove_directive_tokens, \
is_directive_statement, requote_names, unescape_strings, \
- is_insert, process_statement
+ get_DML_type, process_statement
from sqlconvert.process_tokens import escape_strings
assert query == u"INSERT INTO test VALUES ('\"te''st\"\n')"
-def test_is_insert():
+def test_DML_type():
+ parsed = parse("create table test ();")[0]
+ statement = next(process_statement(parsed))
+ with pytest.raises(ValueError):
+ get_DML_type(statement)
+
parsed = parse("select /*! test */ * from /* test */ `T`")[0]
statement = next(process_statement(parsed))
- assert not is_insert(statement)
+ assert get_DML_type(statement) == "SELECT"
parsed = parse("insert into test values ('\"te\\'st\\\"\\n')")[0]
statement = next(process_statement(parsed))
- assert is_insert(statement)
+ assert get_DML_type(statement) == "INSERT"
def test_split_ext_insert():