X-Git-Url: https://git.phdru.name/?p=sqlconvert.git;a=blobdiff_plain;f=sqlconvert%2Fprocess_mysql.py;h=c89d6d190940dad18230043b6ae77da89bfababf;hp=5c63d0d1bc1de9c39c67455cd422ffc838a4220f;hb=d0a633edc6379e87cac256534c5edc1b51d644e8;hpb=86cace63e9bc60457f310bc77aba8fc54b748bbe diff --git a/sqlconvert/process_mysql.py b/sqlconvert/process_mysql.py index 5c63d0d..c89d6d1 100644 --- a/sqlconvert/process_mysql.py +++ b/sqlconvert/process_mysql.py @@ -1,6 +1,41 @@ -from sqlparse.sql import Comment +from sqlparse.sql import Comment, Function, Identifier, Parenthesis, \ + Statement, Token, Values from sqlparse import tokens as T +from .process_tokens import escape_strings, is_comment_or_space + + +def _is_directive_token(token): + if isinstance(token, Comment): + subtokens = token.tokens + if subtokens: + comment = subtokens[0] + if comment.ttype is T.Comment.Multiline and \ + comment.value.startswith('/*!'): + return True + return False + + +def is_directive_statement(statement): + tokens = statement.tokens + if not _is_directive_token(tokens[0]): + return False + if tokens[-1].ttype is not T.Punctuation or tokens[-1].value != ';': + return False + for token in tokens[1:-1]: + if token.ttype not in (T.Newline, T.Whitespace): + return False + return True + + +def remove_directive_tokens(statement): + """Remove /*! directives */ from the first-level""" + new_tokens = [] + for token in statement.tokens: + if _is_directive_token(token): + continue + new_tokens.append(token) + statement.tokens = new_tokens def requote_names(token_list): @@ -16,21 +51,115 @@ def requote_names(token_list): token.normalized = token.value = '"%s"' % value -def remove_directives(statement): - """Remove /*! directives */ from the first-level""" - new_tokens = [] +def unescape_strings(token_list): + """Unescape strings""" + for token in token_list.flatten(): + if token.ttype is T.String.Single: + value = token.value + for orig, repl in ( + ('\\"', '"'), + ("\\'", "'"), + ("''", "'"), + ('\\b', '\b'), + ('\\n', '\n'), + ('\\r', '\r'), + ('\\t', '\t'), + ('\\\032', '\032'), + ('\\\\', '\\'), + ): + value = value.replace(orig, repl) + token.normalized = token.value = value + + +def get_DML_type(statement): + for token in statement.tokens: + if is_comment_or_space(token): + continue + if (token.ttype is T.DML): + return token.normalized + break + raise ValueError("Not a DML statement") + + +def split_ext_insert(statement): + """Split extended INSERT into multiple standard INSERTs""" + insert_tokens = [] + values_tokens = [] + end_tokens = [] + expected = 'INSERT' for token in statement.tokens: - if isinstance(token, Comment): - subtokens = token.tokens - if subtokens: - comment = subtokens[0] - if comment.ttype is T.Comment.Multiline and \ - comment.value.startswith('/*!'): + if is_comment_or_space(token): + if expected == 'END': + end_tokens.append(token) + else: + insert_tokens.append(token) + continue + elif expected == 'INSERT': + if (token.ttype is T.DML) and (token.normalized == 'INSERT'): + insert_tokens.append(token) + expected = 'INTO' + continue + elif expected == 'INTO': + if (token.ttype is T.Keyword) and (token.normalized == 'INTO'): + insert_tokens.append(token) + expected = 'TABLE_NAME' + continue + elif expected == 'TABLE_NAME': + if isinstance(token, (Function, Identifier)): + insert_tokens.append(token) + expected = 'VALUES' + continue + elif expected == 'VALUES': + if isinstance(token, Values): + for subtoken in token.tokens: + if isinstance(subtoken, Parenthesis): + values_tokens.append(subtoken) + insert_tokens.append(Token(T.Keyword, 'VALUES')) + insert_tokens.append(Token(T.Whitespace, ' ')) + expected = 'VALUES_OR_SEMICOLON' + continue + if (token.ttype is T.Keyword) and (token.normalized == 'VALUES'): + insert_tokens.append(token) + expected = 'VALUES_OR_SEMICOLON' + continue + elif expected == 'VALUES_OR_SEMICOLON': + if isinstance(token, Parenthesis): + values_tokens.append(token) + continue + elif token.ttype is T.Punctuation: + if token.value == ',': continue - new_tokens.append(token) - statement.tokens = new_tokens + elif token.value == ';': + end_tokens.append(token) + expected = 'END' + continue + raise ValueError( + 'SQL syntax error: expected "%s", got %s "%s"' % ( + expected, token.ttype, token.normalized)) + new_line = Token(T.Newline, '\n') + new_lines = [new_line] # Insert newlines between split statements + for i, values in enumerate(values_tokens): + if i == len(values_tokens) - 1: # Last but one statement + # Insert newlines only between split statements but not after + new_lines = [] + # The statement sets `parent` attribute of the every token to self + # but we don't care. + statement = Statement(insert_tokens + [values] + + end_tokens + new_lines) + yield statement -def process_statement(statement): +def process_statement(statement, dbname='sqlite'): requote_names(statement) - remove_directives(statement) + unescape_strings(statement) + remove_directive_tokens(statement) + escape_strings(statement, dbname) + try: + dml_type = get_DML_type(statement) + except ValueError: + dml_type = 'UNKNOWN' + if dml_type == 'INSERT': + for statement in split_ext_insert(statement): + yield statement + else: + yield statement