X-Git-Url: https://git.phdru.name/?a=blobdiff_plain;f=sqlconvert%2Fprocess_mysql.py;h=fad7e4cc291a04241834557bbb90aa4c36283e67;hb=b72239847f13152061182973ea573dd2e835a89e;hp=c375f797a91d12b7a34e9fcde629ea988f2d9252;hpb=d830f4bcd21deb078a89d59e4d98a6406ce5661d;p=sqlconvert.git diff --git a/sqlconvert/process_mysql.py b/sqlconvert/process_mysql.py index c375f79..fad7e4c 100644 --- a/sqlconvert/process_mysql.py +++ b/sqlconvert/process_mysql.py @@ -1,5 +1,6 @@ -from sqlparse.sql import Comment, Function, Identifier, Parenthesis, Statement +from sqlparse.sql import Comment, Function, Identifier, Parenthesis, \ + Statement, Token from sqlparse import tokens as T from .process_tokens import escape_strings, is_comment_or_space @@ -28,7 +29,7 @@ def is_directive_statement(statement): def remove_directive_tokens(statement): - """Remove /*! directives */ from the first-level""" + """Remove /\*! directives \*/ from the first-level""" new_tokens = [] for token in statement.tokens: if _is_directive_token(token): @@ -70,22 +71,28 @@ def unescape_strings(token_list): token.normalized = token.value = value -def is_insert(statement): +def get_DML_type(statement): for token in statement.tokens: if is_comment_or_space(token): continue - return (token.ttype is T.DML) and (token.normalized == 'INSERT') + if (token.ttype is T.DML): + return token.normalized + break + raise ValueError("Not a DML statement") def split_ext_insert(statement): """Split extended INSERT into multiple standard INSERTs""" insert_tokens = [] values_tokens = [] - last_token = None + end_tokens = [] expected = 'INSERT' for token in statement.tokens: if is_comment_or_space(token): - insert_tokens.append(token) + if expected == 'END': + end_tokens.append(token) + else: + insert_tokens.append(token) continue elif expected == 'INSERT': if (token.ttype is T.DML) and (token.normalized == 'INSERT'): @@ -115,27 +122,35 @@ def split_ext_insert(statement): if token.value == ',': continue elif token.value == ';': - last_token = token - break + end_tokens.append(token) + expected = 'END' + continue raise ValueError( 'SQL syntax error: expected "%s", got %s "%s"' % ( expected, token.ttype, token.normalized)) - for values in values_tokens: - # The statemnt sets `parent` attribute of the every token to self + new_line = Token(T.Newline, '\n') + new_lines = [new_line] # Insert newlines between split statements + for i, values in enumerate(values_tokens): + if i == len(values_tokens) - 1: # Last but one statement + # Insert newlines only between split statements but not after + new_lines = [] + # The statement sets `parent` attribute of the every token to self # but we don't care. - vl = [values] - if last_token: - vl.append(last_token) - statement = Statement(insert_tokens + vl) + statement = Statement(insert_tokens + [values] + + end_tokens + new_lines) yield statement -def process_statement(statement, quoting_style='sqlite'): +def process_statement(statement, dbname='sqlite'): requote_names(statement) unescape_strings(statement) remove_directive_tokens(statement) - escape_strings(statement, quoting_style) - if is_insert(statement): + escape_strings(statement, dbname) + try: + dml_type = get_DML_type(statement) + except ValueError: + dml_type = 'UNKNOWN' + if dml_type == 'INSERT': for statement in split_ext_insert(statement): yield statement else: