From d73e4faa556e337eba2ffece77ae3d114bfc3da0 Mon Sep 17 00:00:00 2001 From: Oleg Broytman Date: Thu, 26 Oct 2017 22:13:33 +0300 Subject: [PATCH] Feat(process_mysql): Get DML type instead of just testing for INSERT --- sqlconvert/process_mysql.py | 13 ++++++++++--- tests/test_process_mysql.py | 14 ++++++++++---- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/sqlconvert/process_mysql.py b/sqlconvert/process_mysql.py index 1834855..8dc43e2 100644 --- a/sqlconvert/process_mysql.py +++ b/sqlconvert/process_mysql.py @@ -71,11 +71,14 @@ def unescape_strings(token_list): token.normalized = token.value = value -def is_insert(statement): +def get_DML_type(statement): for token in statement.tokens: if is_comment_or_space(token): continue - return (token.ttype is T.DML) and (token.normalized == 'INSERT') + if (token.ttype is T.DML): + return token.normalized + break + raise ValueError("Not a DML statement") def split_ext_insert(statement): @@ -143,7 +146,11 @@ def process_statement(statement, quoting_style='sqlite'): unescape_strings(statement) remove_directive_tokens(statement) escape_strings(statement, quoting_style) - if is_insert(statement): + try: + is_insert = get_DML_type(statement) == 'INSERT' + except ValueError: + is_insert = False + if is_insert: for statement in split_ext_insert(statement): yield statement else: diff --git a/tests/test_process_mysql.py b/tests/test_process_mysql.py index df60a64..656292b 100644 --- a/tests/test_process_mysql.py +++ b/tests/test_process_mysql.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- +import pytest from sqlparse import parse from sqlconvert.print_tokens import tlist2str from sqlconvert.process_mysql import remove_directive_tokens, \ is_directive_statement, requote_names, unescape_strings, \ - is_insert, process_statement + get_DML_type, process_statement from sqlconvert.process_tokens import escape_strings @@ -53,14 +54,19 @@ def test_escape_string_sqlite(): assert query == u"INSERT INTO test VALUES ('\"te''st\"\n')" -def test_is_insert(): +def test_DML_type(): + parsed = parse("create table test ();")[0] + statement = next(process_statement(parsed)) + with pytest.raises(ValueError): + get_DML_type(statement) + parsed = parse("select /*! test */ * from /* test */ `T`")[0] statement = next(process_statement(parsed)) - assert not is_insert(statement) + assert get_DML_type(statement) == "SELECT" parsed = parse("insert into test values ('\"te\\'st\\\"\\n')")[0] statement = next(process_statement(parsed)) - assert is_insert(statement) + assert get_DML_type(statement) == "INSERT" def test_split_ext_insert(): -- 2.39.2