X-Git-Url: https://git.phdru.name/?a=blobdiff_plain;f=sqlconvert%2Fprocess_mysql.py;h=c375f797a91d12b7a34e9fcde629ea988f2d9252;hb=3b0e3272032c10fe531c322536673c450697328d;hp=f91f516d2d9236f53dd320f5eb2a958545d54c9d;hpb=bc473dfbbebeea2e2f4e0f368408632e047949ac;p=sqlconvert.git diff --git a/sqlconvert/process_mysql.py b/sqlconvert/process_mysql.py index f91f516..c375f79 100644 --- a/sqlconvert/process_mysql.py +++ b/sqlconvert/process_mysql.py @@ -1,6 +1,7 @@ -from sqlparse.sql import Comment +from sqlparse.sql import Comment, Function, Identifier, Parenthesis, Statement from sqlparse import tokens as T +from .process_tokens import escape_strings, is_comment_or_space def _is_directive_token(token): @@ -49,6 +50,93 @@ def requote_names(token_list): token.normalized = token.value = '"%s"' % value -def process_statement(statement): - remove_directive_tokens(statement) +def unescape_strings(token_list): + """Unescape strings""" + for token in token_list.flatten(): + if token.ttype is T.String.Single: + value = token.value + for orig, repl in ( + ('\\"', '"'), + ("\\'", "'"), + ("''", "'"), + ('\\b', '\b'), + ('\\n', '\n'), + ('\\r', '\r'), + ('\\t', '\t'), + ('\\\032', '\032'), + ('\\\\', '\\'), + ): + value = value.replace(orig, repl) + token.normalized = token.value = value + + +def is_insert(statement): + for token in statement.tokens: + if is_comment_or_space(token): + continue + return (token.ttype is T.DML) and (token.normalized == 'INSERT') + + +def split_ext_insert(statement): + """Split extended INSERT into multiple standard INSERTs""" + insert_tokens = [] + values_tokens = [] + last_token = None + expected = 'INSERT' + for token in statement.tokens: + if is_comment_or_space(token): + insert_tokens.append(token) + continue + elif expected == 'INSERT': + if (token.ttype is T.DML) and (token.normalized == 'INSERT'): + insert_tokens.append(token) + expected = 'INTO' + continue + elif expected == 'INTO': + if (token.ttype is T.Keyword) and (token.normalized == 'INTO'): + insert_tokens.append(token) + expected = 'TABLE_NAME' + continue + elif expected == 'TABLE_NAME': + if isinstance(token, (Function, Identifier)): + insert_tokens.append(token) + expected = 'VALUES' + continue + elif expected == 'VALUES': + if (token.ttype is T.Keyword) and (token.normalized == 'VALUES'): + insert_tokens.append(token) + expected = 'VALUES_OR_SEMICOLON' + continue + elif expected == 'VALUES_OR_SEMICOLON': + if isinstance(token, Parenthesis): + values_tokens.append(token) + continue + elif token.ttype is T.Punctuation: + if token.value == ',': + continue + elif token.value == ';': + last_token = token + break + raise ValueError( + 'SQL syntax error: expected "%s", got %s "%s"' % ( + expected, token.ttype, token.normalized)) + for values in values_tokens: + # The statemnt sets `parent` attribute of the every token to self + # but we don't care. + vl = [values] + if last_token: + vl.append(last_token) + statement = Statement(insert_tokens + vl) + yield statement + + +def process_statement(statement, quoting_style='sqlite'): requote_names(statement) + unescape_strings(statement) + remove_directive_tokens(statement) + escape_strings(statement, quoting_style) + if is_insert(statement): + for statement in split_ext_insert(statement): + yield statement + else: + yield statement