X-Git-Url: https://git.phdru.name/?p=sqlconvert.git;a=blobdiff_plain;f=sqlconvert%2Fprocess_mysql.py;h=1834855dc9b59aee5261cf065afcab1c58310fac;hp=0320b7f30e794886624ebbb6d4a0df5993d2f671;hb=72abf4a136b1a2d164259a4ac300e6a5a4762432;hpb=a8a222e1ecd38370eeb2a23d9761eacd959a2e2b diff --git a/sqlconvert/process_mysql.py b/sqlconvert/process_mysql.py index 0320b7f..1834855 100644 --- a/sqlconvert/process_mysql.py +++ b/sqlconvert/process_mysql.py @@ -1,7 +1,8 @@ -from sqlparse.sql import Comment +from sqlparse.sql import Comment, Function, Identifier, Parenthesis, \ + Statement, Token from sqlparse import tokens as T -from .process_tokens import escape_strings +from .process_tokens import escape_strings, is_comment_or_space def _is_directive_token(token): @@ -28,7 +29,7 @@ def is_directive_statement(statement): def remove_directive_tokens(statement): - """Remove /*! directives */ from the first-level""" + """Remove /\*! directives \*/ from the first-level""" new_tokens = [] for token in statement.tokens: if _is_directive_token(token): @@ -70,10 +71,80 @@ def unescape_strings(token_list): token.normalized = token.value = value +def is_insert(statement): + for token in statement.tokens: + if is_comment_or_space(token): + continue + return (token.ttype is T.DML) and (token.normalized == 'INSERT') + + +def split_ext_insert(statement): + """Split extended INSERT into multiple standard INSERTs""" + insert_tokens = [] + values_tokens = [] + end_tokens = [] + expected = 'INSERT' + for token in statement.tokens: + if is_comment_or_space(token): + if expected == 'END': + end_tokens.append(token) + else: + insert_tokens.append(token) + continue + elif expected == 'INSERT': + if (token.ttype is T.DML) and (token.normalized == 'INSERT'): + insert_tokens.append(token) + expected = 'INTO' + continue + elif expected == 'INTO': + if (token.ttype is T.Keyword) and (token.normalized == 'INTO'): + insert_tokens.append(token) + expected = 'TABLE_NAME' + continue + elif expected == 'TABLE_NAME': + if isinstance(token, (Function, Identifier)): + insert_tokens.append(token) + expected = 'VALUES' + continue + elif expected == 'VALUES': + if (token.ttype is T.Keyword) and (token.normalized == 'VALUES'): + insert_tokens.append(token) + expected = 'VALUES_OR_SEMICOLON' + continue + elif expected == 'VALUES_OR_SEMICOLON': + if isinstance(token, Parenthesis): + values_tokens.append(token) + continue + elif token.ttype is T.Punctuation: + if token.value == ',': + continue + elif token.value == ';': + end_tokens.append(token) + expected = 'END' + continue + raise ValueError( + 'SQL syntax error: expected "%s", got %s "%s"' % ( + expected, token.ttype, token.normalized)) + new_line = Token(T.Newline, '\n') + new_lines = [new_line] # Insert newlines between split statements + for i, values in enumerate(values_tokens): + if i == len(values_tokens) - 1: # Last but one statement + # Insert newlines only between split statements but not after + new_lines = [] + # The statement sets `parent` attribute of the every token to self + # but we don't care. + statement = Statement(insert_tokens + [values] + + end_tokens + new_lines) + yield statement + + def process_statement(statement, quoting_style='sqlite'): requote_names(statement) unescape_strings(statement) remove_directive_tokens(statement) escape_strings(statement, quoting_style) - yield statement - return + if is_insert(statement): + for statement in split_ext_insert(statement): + yield statement + else: + yield statement