-from sqlparse.sql import Comment
+from sqlparse.sql import Comment, Function, Identifier, Parenthesis, Statement
from sqlparse import tokens as T
from .process_tokens import escape_strings, is_comment_or_space
return (token.ttype is T.DML) and (token.normalized == 'INSERT')
+def split_ext_insert(statement):
+ """Split extended INSERT into multiple standard INSERTs"""
+ insert_tokens = []
+ values_tokens = []
+ last_token = None
+ expected = 'INSERT'
+ for token in statement.tokens:
+ if is_comment_or_space(token):
+ insert_tokens.append(token)
+ continue
+ elif expected == 'INSERT':
+ if (token.ttype is T.DML) and (token.normalized == 'INSERT'):
+ insert_tokens.append(token)
+ expected = 'INTO'
+ continue
+ elif expected == 'INTO':
+ if (token.ttype is T.Keyword) and (token.normalized == 'INTO'):
+ insert_tokens.append(token)
+ expected = 'TABLE_NAME'
+ continue
+ elif expected == 'TABLE_NAME':
+ if isinstance(token, (Function, Identifier)):
+ insert_tokens.append(token)
+ expected = 'VALUES'
+ continue
+ elif expected == 'VALUES':
+ if (token.ttype is T.Keyword) and (token.normalized == 'VALUES'):
+ insert_tokens.append(token)
+ expected = 'VALUES_OR_SEMICOLON'
+ continue
+ elif expected == 'VALUES_OR_SEMICOLON':
+ if isinstance(token, Parenthesis):
+ values_tokens.append(token)
+ continue
+ elif token.ttype is T.Punctuation:
+ if token.value == ',':
+ continue
+ elif token.value == ';':
+ last_token = token
+ break
+ raise ValueError(
+ 'SQL syntax error: expected "%s", got %s "%s"' % (
+ expected, token.ttype, token.normalized))
+ for values in values_tokens:
+ # The statemnt sets `parent` attribute of the every token to self
+ # but we don't care.
+ vl = [values]
+ if last_token:
+ vl.append(last_token)
+ statement = Statement(insert_tokens + vl)
+ yield statement
+
+
def process_statement(statement, quoting_style='sqlite'):
requote_names(statement)
unescape_strings(statement)
remove_directive_tokens(statement)
escape_strings(statement, quoting_style)
- yield statement
- return
+ if is_insert(statement):
+ for statement in split_ext_insert(statement):
+ yield statement
+ else:
+ yield statement
parsed = parse("select /*! test */ * from /* test */ `T`")[0]
statement = next(process_statement(parsed))
assert not is_insert(statement)
+
parsed = parse("insert into test values ('\"te\\'st\\\"\\n')")[0]
statement = next(process_statement(parsed))
assert is_insert(statement)
+def test_split_ext_insert():
+ parsed = parse("insert into test values (1, 2)")[0]
+ statement = next(process_statement(parsed))
+ query = tlist2str(statement)
+ assert query == u"INSERT INTO test VALUES (1, 2)"
+
+ parsed = parse("insert into test (age, salary) values (1, 2);")[0]
+ statement = next(process_statement(parsed))
+ query = tlist2str(statement)
+ assert query == u"INSERT INTO test (age, salary) VALUES (1, 2);"
+
+ parsed = parse("insert into test values (1, 2), (3, 4);")[0]
+ stiter = process_statement(parsed)
+ statement = next(stiter)
+ query = tlist2str(statement)
+ assert query == u"INSERT INTO test VALUES (1, 2);"
+ statement = next(stiter)
+ query = tlist2str(statement)
+ assert query == u"INSERT INTO test VALUES (3, 4);"
+
+ parsed = parse("insert into test (age, salary) values (1, 2), (3, 4)")[0]
+ stiter = process_statement(parsed)
+ statement = next(stiter)
+ query = tlist2str(statement)
+ assert query == u"INSERT INTO test (age, salary) VALUES (1, 2)"
+ statement = next(stiter)
+ query = tlist2str(statement)
+ assert query == u"INSERT INTO test (age, salary) VALUES (3, 4)"
+
+
def test_process():
parsed = parse("select /*! test */ * from /* test */ `T`")[0]
statement = next(process_statement(parsed))