X-Git-Url: https://git.phdru.name/?p=sqlconvert.git;a=blobdiff_plain;f=sqlconvert%2Fprocess_mysql.py;h=ad4d422b4c13fa0ca2f1821009038b0e512c9214;hp=5c2924a1b5e7fe269f353e2c78e77eb0c9e8bd67;hb=HEAD;hpb=95103778dd6d6d279d9b3c9f83ff49ea5920c6e5 diff --git a/sqlconvert/process_mysql.py b/sqlconvert/process_mysql.py index 5c2924a..c89d6d1 100644 --- a/sqlconvert/process_mysql.py +++ b/sqlconvert/process_mysql.py @@ -1,6 +1,6 @@ from sqlparse.sql import Comment, Function, Identifier, Parenthesis, \ - Statement, Token + Statement, Token, Values from sqlparse import tokens as T from .process_tokens import escape_strings, is_comment_or_space @@ -71,11 +71,14 @@ def unescape_strings(token_list): token.normalized = token.value = value -def is_insert(statement): +def get_DML_type(statement): for token in statement.tokens: if is_comment_or_space(token): continue - return (token.ttype is T.DML) and (token.normalized == 'INSERT') + if (token.ttype is T.DML): + return token.normalized + break + raise ValueError("Not a DML statement") def split_ext_insert(statement): @@ -107,6 +110,14 @@ def split_ext_insert(statement): expected = 'VALUES' continue elif expected == 'VALUES': + if isinstance(token, Values): + for subtoken in token.tokens: + if isinstance(subtoken, Parenthesis): + values_tokens.append(subtoken) + insert_tokens.append(Token(T.Keyword, 'VALUES')) + insert_tokens.append(Token(T.Whitespace, ' ')) + expected = 'VALUES_OR_SEMICOLON' + continue if (token.ttype is T.Keyword) and (token.normalized == 'VALUES'): insert_tokens.append(token) expected = 'VALUES_OR_SEMICOLON' @@ -131,19 +142,23 @@ def split_ext_insert(statement): if i == len(values_tokens) - 1: # Last but one statement # Insert newlines only between split statements but not after new_lines = [] - # The statemnt sets `parent` attribute of the every token to self + # The statement sets `parent` attribute of the every token to self # but we don't care. statement = Statement(insert_tokens + [values] + end_tokens + new_lines) yield statement -def process_statement(statement, quoting_style='sqlite'): +def process_statement(statement, dbname='sqlite'): requote_names(statement) unescape_strings(statement) remove_directive_tokens(statement) - escape_strings(statement, quoting_style) - if is_insert(statement): + escape_strings(statement, dbname) + try: + dml_type = get_DML_type(statement) + except ValueError: + dml_type = 'UNKNOWN' + if dml_type == 'INSERT': for statement in split_ext_insert(statement): yield statement else: