X-Git-Url: https://git.phdru.name/?a=blobdiff_plain;ds=sidebyside;f=mysql2sql%2Fprocess_tokens.py;h=37752e7c146a8d76ad63f9b7ccbbe395066eec4e;hb=96d6304e4c08ef94bb7dfac3e3069fcb8c982747;hp=7f30a5589b7fad13265b4a9e147fe6308afba545;hpb=3151032d036f9a66bad633dc1018395a14f46bac;p=sqlconvert.git diff --git a/mysql2sql/process_tokens.py b/mysql2sql/process_tokens.py index 7f30a55..37752e7 100644 --- a/mysql2sql/process_tokens.py +++ b/mysql2sql/process_tokens.py @@ -1,4 +1,5 @@ +from sqlparse import parse from sqlparse.sql import Statement from sqlparse.tokens import Name, Error, Punctuation, Comment, Newline, \ Whitespace @@ -27,21 +28,29 @@ def find_error(token_list): class StatementGrouper(object): def __init__(self): - self.tokens = [] self.statements = [] + self.tokens = [] + self.lines = [] - def get_statements(self): - for statement in self.statements: - yield statement - self.statements = [] + def process_line(self, line): + lines = self.lines + lines.append(line) + tokens = parse('\n'.join(lines))[0] + self.process_tokens(tokens) + self.lines = [] - def process(self, tokens): + def process_tokens(self, tokens): for token in tokens: self.tokens.append(token) if (token.ttype == Punctuation) and (token.value == ';'): self.statements.append(Statement(self.tokens)) self.tokens = [] + def get_statements(self): + for statement in self.statements: + yield statement + self.statements = [] + def close(self): for token in self.tokens: if (token.ttype not in (Comment.Single, Comment.Multiline,