diff --git a/.gitignore b/.gitignore index 1e8dd82..31fc4d2 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,4 @@ cython_debug/ #.idea/ testcases/** +.vscode \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index 43367c1..f216884 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,15 @@ [submodule "benchmark/Go/sally"] path = benchmark/Go/sally url = https://github.com/uber-go/sally.git +[submodule "benchmark/Javascript/microlight"] + path = benchmark/Javascript/microlight + url = https://github.com/asvd/microlight.git +[submodule "benchmark/Javascript/mocha"] + path = benchmark/Javascript/mocha + url = https://github.com/mochajs/mocha.git +[submodule "squish"] + path = squish + url = https://github.com/shgysk8zer0/squish.git +[submodule "benchmark/Javascript/squish"] + path = benchmark/Javascript/squish + url = https://github.com/shgysk8zer0/squish.git diff --git a/benchmark/Javascript/toy/NPD/case01.js b/benchmark/Javascript/toy/NPD/case01.js new file mode 100644 index 0000000..6fc2c93 --- /dev/null +++ b/benchmark/Javascript/toy/NPD/case01.js @@ -0,0 +1,15 @@ +var myname = "daniel"; +myname = null; + +function test2_process(data) { + let current = myname; + let value = data[0]; + console.log(current.length) + return value; +} + + +function test2_caller() { + let data = null; + return test2_process(data) +} diff --git a/benchmark/Javascript/toy/NPD/case02.js b/benchmark/Javascript/toy/NPD/case02.js new file mode 100644 index 0000000..f4d20ba --- /dev/null +++ b/benchmark/Javascript/toy/NPD/case02.js @@ -0,0 +1,14 @@ +function func_generator(value) { + let fn = null; + if (value % 3 == 0) { + fn = console.log; + } else if (value % 3 == 1) { + fn = console.error; + } + return fn; +} + +const print = () => { + func_generator(8)("Hello world!"); + console.log("Done"); +} diff --git a/benchmark/Javascript/toy/NPD/case03.js b/benchmark/Javascript/toy/NPD/case03.js new file mode 100644 index 0000000..b620a91 --- /dev/null +++ b/benchmark/Javascript/toy/NPD/case03.js @@ -0,0 +1,11 @@ +function getLength2(value) { + if (!value) { + return 0; + } + return value.length; +} + +const print2 = () => { + let a = getLength2(null); + console.log(); +} \ No newline at end of file diff --git a/benchmark/Javascript/toy/NPD/case04.js b/benchmark/Javascript/toy/NPD/case04.js new file mode 100644 index 0000000..5af5d93 --- /dev/null +++ b/benchmark/Javascript/toy/NPD/case04.js @@ -0,0 +1,17 @@ +function hello3() { + let output = []; + + for (let i = 0; i < 5; i++) { + output.push(null); + } + return output; +} + +function hello4() { + let output = hello3(); + for (let i = 0; i < 4; i++) { + output[i] = i.toString(); + } + return output[4] ? output[4].length : 0; +} + diff --git a/benchmark/Javascript/toy/NPD/case05.js b/benchmark/Javascript/toy/NPD/case05.js new file mode 100644 index 0000000..c06a3ba --- /dev/null +++ b/benchmark/Javascript/toy/NPD/case05.js @@ -0,0 +1,6 @@ +var a = console.error; +delete a.error; +const exec = function () { + a.error(); +} +exec() \ No newline at end of file diff --git a/benchmark/Javascript/toy/NPD/case06.js b/benchmark/Javascript/toy/NPD/case06.js new file mode 100644 index 0000000..1bb5f1a --- /dev/null +++ b/benchmark/Javascript/toy/NPD/case06.js @@ -0,0 +1,25 @@ +const obj = { + greet() { + let obj = 1; + console.log("hello"); + } +}; + + +const a = obj; + +function call(items) { + a = items; +} + +const exec = function () { + var b = null; + let c = 1; + call(b); + + for (let i = 0; i < 5; i++) { + a.greet(); + } +} + +exec(); \ No newline at end of file diff --git a/lib/build.py b/lib/build.py index f563469..6ed1ca6 100644 --- a/lib/build.py +++ b/lib/build.py @@ -37,10 +37,16 @@ os.system( f'git clone https://github.com/tree-sitter/tree-sitter-python.git {cwd / "vendor/tree-sitter-python"}' ) + # Checkout to specific commit for language version 14 compatibility os.system( f'cd {cwd / "vendor/tree-sitter-python"} && git checkout 710796b8b877a970297106e5bbc8e2afa47f86ec' ) + +if not (cwd / "vendor/tree-sitter-javascript/grammar.js").exists(): + os.system( + f'git clone https://github.com/tree-sitter/tree-sitter-javascript.git {cwd / "vendor/tree-sitter-javascript"}' + ) if not (cwd / "vendor/tree-sitter-go/grammar.js").exists(): os.system( @@ -61,6 +67,7 @@ str(cwd / "vendor/tree-sitter-cpp"), str(cwd / "vendor/tree-sitter-java"), str(cwd / "vendor/tree-sitter-python"), + str(cwd / "vendor/tree-sitter-javascript"), str(cwd / "vendor/tree-sitter-go"), ], ) diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 78dfb30..0000000 --- a/requirements.txt +++ /dev/null @@ -1,19 +0,0 @@ -black -tree-sitter>=0.20.0,<0.22.0 -transformers -torch -tiktoken -replicate -openai -google-generativeai -tqdm -networkx -streamlit -botocore -boto3 -black -anthropic -mypy -types-networkx -types-tqdm -boto3-stubs[essential] diff --git a/src/agent/dfbscan.py b/src/agent/dfbscan.py index 41bc244..22757bd 100644 --- a/src/agent/dfbscan.py +++ b/src/agent/dfbscan.py @@ -19,6 +19,7 @@ from tstool.dfbscan_extractor.Cpp.Cpp_UAF_extractor import * from tstool.dfbscan_extractor.Java.Java_NPD_extractor import * from tstool.dfbscan_extractor.Python.Python_NPD_extractor import * +from tstool.dfbscan_extractor.Javascript.Javascript_NPD_extractor import * from tstool.dfbscan_extractor.Go.Go_NPD_extractor import * from llmtool.LLM_utils import * @@ -109,9 +110,13 @@ def __obtain_extractor(self) -> DFBScanExtractor: elif self.language == "Python": if self.bug_type == "NPD": return Python_NPD_Extractor(self.ts_analyzer) + elif self.language == "Javascript": + if self.bug_type == "NPD": + return Javascript_NPD_Extractor(self.ts_analyzer) elif self.language == "Go": if self.bug_type == "NPD": return Go_NPD_Extractor(self.ts_analyzer) + raise NotImplementedError( f"Unsupported bug type: {self.bug_type} in {self.language}" ) @@ -187,6 +192,77 @@ def __update_worklist( set({(para, new_call_context)}), ) + if value.label == ValueLabel.NONLOCAL: + # Consider side effect. + # Example: the non local variable g is used in the function g = null; + # We need to consider the side effect of g. + caller_functions = self.ts_analyzer.get_all_caller_functions(function) + + if caller_functions: + for caller_function in caller_functions: + new_call_context = copy.deepcopy(call_context) + + top_unmatched_context_label = ( + new_call_context.get_top_unmatched_context_label() + ) + + call_site_nodes = self.ts_analyzer.get_callsites_by_callee_name( + caller_function, function.function_name + ) + for call_site_node in call_site_nodes: + caller_function_file_name = self.ts_analyzer.functionToFile[ + caller_function.function_id + ] + file_content = self.ts_analyzer.code_in_files[ + caller_function_file_name + ] + call_site_lower_line_number = ( + file_content[: call_site_node.start_byte].count("\n") + + 1 + ) + + if top_unmatched_context_label is not None: + if ( + top_unmatched_context_label.parenthesis + == Parenthesis.LEFT_PAR + ): + if ( + call_site_lower_line_number + != top_unmatched_context_label.line_number + or caller_function_file_name + != top_unmatched_context_label.file_name + or top_unmatched_context_label.function_id + != function.function_id + ): + continue + + append_context_label = ContextLabel( + caller_function_file_name, + call_site_lower_line_number, + function.function_id, + Parenthesis.RIGHT_PAR, + ) + new_value = Value( + value.name, + call_site_node.start_point[0] + 1, + ValueLabel.NONLOCAL, + value.file, + ) + + new_call_context.add_and_check_context(append_context_label) + + delta_worklist.append( + ( + new_value, + caller_function, + new_call_context, + ) + ) + self.state.update_external_value_match( + (value, call_context), + set({(new_value, new_call_context)}), + ) + if value.label == ValueLabel.PARA: # Consider side-effect. # Example: the parameter *p is used in the function: p->f = null; @@ -357,14 +433,24 @@ def __collect_potential_buggy_paths( if value.label == ValueLabel.SINK: # For NPD-style bug types if self.is_reachable: - self.state.update_potential_buggy_paths( - src_value, path_with_unknown_status + [value] - ) + + # Checks if the sink is a called to a predefined function + is_defined_function = False + for func in self.ts_analyzer.function_env.values(): + if value.name == func.function_name: + is_defined_function = True + break + + if not is_defined_function: + self.state.update_potential_buggy_paths( + src_value, path_with_unknown_status + [value] + ) elif value.label in { ValueLabel.PARA, ValueLabel.RET, ValueLabel.ARG, ValueLabel.OUT, + ValueLabel.NONLOCAL, }: # For other propagation types, check further external matches. if (value, ctx) in external_match_snapshot: @@ -458,6 +544,7 @@ def start_scan_sequential(self) -> None: sink_values, call_statements, ret_values, + non_locals=[], ) # Invoke the intra-procedural data-flow analysis @@ -554,6 +641,27 @@ def start_scan(self) -> None: # Total number of source values total_src_values = len(self.src_values) + total_global_src_values = len(self.ts_analyzer.globals_env) + + with tqdm( + total=total_global_src_values, + desc="Processing Global Source Values", + unit="src", + ) as pbar: + with ThreadPoolExecutor(max_workers=self.max_neural_workers) as executor: + futures = [ + executor.submit(self.__process_global_value, global_value) + for _, global_value in self.ts_analyzer.globals_env.items() + ] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + self.logger.print_log("Error processing source value:", e) + finally: + # Update the progress bar after each source value is processed + pbar.update(1) + # Process each source value in parallel with a progress bar with tqdm( total=total_src_values, desc="Processing Source Values", unit="src" @@ -586,19 +694,27 @@ def start_scan(self) -> None: return def __process_src_value(self, src_value: Value) -> None: + """ + Perform data-flow analysis starting from a local source value. + + 1. Locates the function containing the source. + 2. Performs intra-procedural data-flow analysis to find reachable values. + 3. Collects and validates potential buggy paths, creating bug reports + when confirmed. + """ worklist = [] src_function = self.ts_analyzer.get_function_from_localvalue(src_value) if src_function is None: return - initial_context = CallContext(False) + initial_context = CallContext(False) worklist.append((src_value, src_function, initial_context)) - while len(worklist) > 0: - (start_value, start_function, call_context) = worklist.pop(0) + + while worklist: + start_value, start_function, call_context = worklist.pop(0) if len(call_context.context) > self.call_depth: continue - # Construct the input for intra-procedural data-flow analysis sinks_in_function = self.__obtain_extractor().extract_sinks(start_function) sink_values = [ (sink.name, sink.line_number - start_function.start_line_number + 1) @@ -618,70 +734,74 @@ def __process_src_value(self, src_value: Value) -> None: ret_values = [ (ret.name, ret.line_number - start_function.start_line_number + 1) - for ret in ( - start_function.retvals if start_function.retvals is not None else [] - ) + for ret in (start_function.retvals if start_function.retvals else []) ] + + non_local_list = [] + + if ( + start_function.parse_tree_root_node + in self.ts_analyzer.function_root_to_scope_id + ): + function_scope_id = self.ts_analyzer.function_root_to_scope_id[ + start_function.parse_tree_root_node + ] + if function_scope_id in self.ts_analyzer.child_scope_id_to_non_locals: + non_local_list = [ + (value.name, value.line_number) + for value in self.ts_analyzer.child_scope_id_to_non_locals[ + function_scope_id + ] + ] + df_input = IntraDataFlowAnalyzerInput( - start_function, start_value, sink_values, call_statements, ret_values + start_function, + start_value, + sink_values, + call_statements, + ret_values, + non_local_list, ) - - # Invoke the intra-procedural data-flow analysis df_output = self.intra_dfa.invoke(df_input, IntraDataFlowAnalyzerOutput) - if df_output is None: continue for path_index in range(len(df_output.reachable_values)): - reachable_values_in_single_path = set([]) - for value in df_output.reachable_values[path_index]: - reachable_values_in_single_path.add((value, call_context)) + reachable_values_in_single_path = { + (value, call_context) + for value in df_output.reachable_values[path_index] + } self.state.update_reachable_values_per_path( (start_value, call_context), reachable_values_in_single_path ) - delta_worklist = self.__update_worklist( df_input, df_output, call_context, path_index ) worklist.extend(delta_worklist) - # Collect potential buggy paths + # Collect and validate buggy paths self.__collect_potential_buggy_paths(src_value, (src_value, CallContext(False))) - - # If no potential buggy paths are found, return early if src_value not in self.state.potential_buggy_paths: return - # Validate buggy paths and generate bug reports for buggy_path in self.state.potential_buggy_paths[src_value].values(): values_to_functions = { value: self.ts_analyzer.get_function_from_localvalue(value) for value in buggy_path } - - functions: Set[Function] = set() - for func in values_to_functions.values(): - if func is not None: - functions.add(func) - + functions = {func for func in values_to_functions.values() if func} if self.state.check_existence(src_value, functions): continue pv_input = PathValidatorInput( - self.bug_type, - buggy_path, - values_to_functions, + self.bug_type, buggy_path, values_to_functions ) pv_output = self.path_validator.invoke(pv_input, PathValidatorOutput) - - if pv_output is None: - continue - - if pv_output.is_reachable: + if pv_output and pv_output.is_reachable: relevant_functions = {} for value in buggy_path: function = self.ts_analyzer.get_function_from_localvalue(value) - if function is not None: + if function: relevant_functions[function.function_id] = function bug_report = BugReport( @@ -691,16 +811,176 @@ def __process_src_value(self, src_value: Value) -> None: pv_output.explanation_str, ) self.state.update_bug_report(bug_report) + bug_report_dict = { bug_report_id: bug.to_dict() for bug_report_id, bug in self.state.bug_reports.items() } + with open(self.res_dir_path + "/detect_info.json", "w") as f: + json.dump(bug_report_dict, f, indent=4) - with open( - self.res_dir_path + "/detect_info.json", "w" - ) as bug_info_file: - json.dump(bug_report_dict, bug_info_file, indent=4) - return + def __process_global_value(self, global_value): + """ + Perform data-flow analysis starting from a global variable. + + 1. Finds all functions referencing the global variable. + 2. Runs intra-procedural data-flow analysis to discover reachable values. + 3. If the global is marked as a source (SRC), collects potential buggy paths and + reports them if confirmed. + """ + worklist = [] + + reference_in_funcs = self.ts_analyzer.get_function_global_value_reference( + global_value + ) + if len(reference_in_funcs) == 0: + return + + initial_context = CallContext(False) + + # Seed worklist with all function references to the global. + for func, global_references in reference_in_funcs.items(): + for global_reference in global_references: + worklist.append((global_reference, func, initial_context)) + + # Worklist-driven intra-procedural analysis + while worklist: + start_value, start_function, call_context = worklist.pop(0) + if len(call_context.context) > self.call_depth: + continue + + sinks_in_function = self.__obtain_extractor().extract_sinks(start_function) + sink_values = [ + (sink.name, sink.line_number - start_function.start_line_number + 1) + for sink in sinks_in_function + ] + + call_statements = [] + for call_site_node in start_function.function_call_site_nodes: + file_content = self.ts_analyzer.code_in_files[start_function.file_path] + call_site_line_number = ( + file_content[: call_site_node.start_byte].count("\n") + 1 + ) + call_site_name = file_content[ + call_site_node.start_byte : call_site_node.end_byte + ] + call_statements.append((call_site_name, call_site_line_number)) + + ret_values = [ + (ret.name, ret.line_number - start_function.start_line_number + 1) + for ret in (start_function.retvals if start_function.retvals else []) + ] + + non_local_list = [] + + if ( + start_function.parse_tree_root_node + in self.ts_analyzer.function_root_to_scope_id + ): + function_scope_id = self.ts_analyzer.function_root_to_scope_id[ + start_function.parse_tree_root_node + ] + if function_scope_id in self.ts_analyzer.child_scope_id_to_non_locals: + non_local_list = [ + (value.name, value.line_number) + for value in self.ts_analyzer.child_scope_id_to_non_locals[ + function_scope_id + ] + ] + + df_input = IntraDataFlowAnalyzerInput( + start_function, + start_value, + sink_values, + call_statements, + ret_values, + non_local_list, + ) + + df_output = self.intra_dfa.invoke(df_input, IntraDataFlowAnalyzerOutput) + if df_output is None: + continue + + for path_index in range(len(df_output.reachable_values)): + reachable_values_in_single_path = { + (value, call_context) + for value in df_output.reachable_values[path_index] + } + self.state.update_reachable_values_per_path( + (start_value, call_context), reachable_values_in_single_path + ) + delta_worklist = self.__update_worklist( + df_input, df_output, call_context, path_index + ) + worklist.extend(delta_worklist) + + found_potential_buggy_paths = False + for func, global_references in reference_in_funcs.items(): + for global_reference in global_references: + self.__collect_potential_buggy_paths( + global_reference, (global_reference, CallContext(False)) + ) + if global_reference in self.state.potential_buggy_paths: + found_potential_buggy_paths = True + + if not found_potential_buggy_paths: + return + + # Validate each potential buggy path + for start_value, buggy_paths in self.state.potential_buggy_paths.items(): + for buggy_path in buggy_paths.values(): + values_to_functions = { + value: self.ts_analyzer.get_function_from_localvalue(value) + for value in buggy_path + } + + functions = set() + relevant_global_exprs = [] + for func in values_to_functions.values(): + if func: + functions.add(func) + + current = func.parse_tree_root_node + while current.parent: + current = current.parent + + relevant_global_exprs.extend( + self.ts_analyzer.get_global_expressions_by_identifier( + global_value.name, current + ) + ) + + if self.state.check_existence(start_value, functions): + continue + + pv_input = PathValidatorInput( + self.bug_type, + buggy_path, + values_to_functions, + relevant_global_exprs, + ) + pv_output = self.path_validator.invoke(pv_input, PathValidatorOutput) + if pv_output and pv_output.is_reachable: + relevant_functions = {} + for value in buggy_path: + function = self.ts_analyzer.get_function_from_localvalue(value) + if function: + relevant_functions[function.function_id] = function + + bug_report = BugReport( + self.bug_type, + start_value, + relevant_functions, + pv_output.explanation_str, + ) + self.state.update_bug_report(bug_report) + + bug_report_dict = { + bug_report_id: bug.to_dict() + for bug_report_id, bug in self.state.bug_reports.items() + } + with open(self.res_dir_path + "/detect_info.json", "w") as f: + json.dump(bug_report_dict, f, indent=4) def get_agent_state(self) -> DFBScanState: return self.state diff --git a/src/llmtool/LLM_utils.py b/src/llmtool/LLM_utils.py index 843c2db..4976710 100644 --- a/src/llmtool/LLM_utils.py +++ b/src/llmtool/LLM_utils.py @@ -92,7 +92,7 @@ def run_with_timeout(self, func, timeout): def infer_with_gemini(self, message: str) -> str: """Infer using the Gemini model from Google Generative AI""" - gemini_model = genai.GenerativeModel("gemini-pro") + gemini_model = genai.GenerativeModel(self.online_model_name) def call_api(): message_with_role = self.systemRole + "\n" + message diff --git a/src/llmtool/dfbscan/intra_dataflow_analyzer.py b/src/llmtool/dfbscan/intra_dataflow_analyzer.py index b26b54b..ee263ad 100644 --- a/src/llmtool/dfbscan/intra_dataflow_analyzer.py +++ b/src/llmtool/dfbscan/intra_dataflow_analyzer.py @@ -19,12 +19,14 @@ def __init__( sink_values: List[Tuple[str, int]], call_statements: List[Tuple[str, int]], ret_values: List[Tuple[str, int]], + non_locals: List[Tuple[str, int]], ) -> None: self.function = function self.summary_start = summary_start self.sink_values = sink_values self.call_statements = call_statements self.ret_values = ret_values + self.non_locals = non_locals return def __hash__(self) -> int: @@ -108,6 +110,15 @@ def _get_prompt(self, input: LLMToolInput) -> str: for ret_val in input.ret_values: rets_str += f"- {ret_val[0]} at line {ret_val[1]}\n" prompt = prompt.replace("", rets_str) + + if input.non_locals: + non_local_str = "Non local variables relevant to this function:\n" + for non_local in input.non_locals: + non_local_str += f"- {non_local[0]} at line {non_local[1]}\n" + prompt = prompt.replace("", non_local_str) + else: + prompt = prompt.replace("", "") + return prompt def _parse_response( @@ -223,6 +234,12 @@ def _parse_response( reachable_values_per_path.add( Value(detail["name"], line_number, ValueLabel.SINK, file_path) ) + elif detail["type"] == "Nonlocal": + reachable_values_per_path.add( + Value( + detail["name"], line_number, ValueLabel.NONLOCAL, file_path + ) + ) reachable_values.append(reachable_values_per_path) output = IntraDataFlowAnalyzerOutput(reachable_values) diff --git a/src/llmtool/dfbscan/path_validator.py b/src/llmtool/dfbscan/path_validator.py index 069edfc..cee6e88 100644 --- a/src/llmtool/dfbscan/path_validator.py +++ b/src/llmtool/dfbscan/path_validator.py @@ -16,10 +16,12 @@ def __init__( bug_type: str, values: List[Value], values_to_functions: Dict[Value, Optional[Function]], + relevant_global_exprs: List[Node] = [], ) -> None: self.bug_type = bug_type self.values = values self.values_to_functions = values_to_functions + self.relevant_global_exprs = relevant_global_exprs return def __hash__(self) -> int: @@ -66,7 +68,7 @@ def _get_prompt(self, input: LLMToolInput) -> str: prompt = prompt_template_dict["task"] prompt += "\n" + "\n".join(prompt_template_dict["analysis_rules"]) prompt += "\n" + "\n".join(prompt_template_dict["analysis_examples"]) - prompt += "\n" + "".join(prompt_template_dict["meta_prompts"]) + prompt += "\n" + "\n".join(prompt_template_dict["meta_prompts"]) prompt = prompt.replace( "", "\n".join(prompt_template_dict["answer_format"]) ).replace("", "\n".join(prompt_template_dict["question_template"])) @@ -87,11 +89,25 @@ def _get_prompt(self, input: LLMToolInput) -> str: prompt = prompt.replace("", "\n".join(value_lines)) prompt = prompt.replace("", input.bug_type) - program = "\n".join( - [ - "```\n" + func.lined_code + "\n```\n" if func is not None else "\n" - for func in input.values_to_functions.values() - ] + functions: Set[Function] = set() + for func in input.values_to_functions.values(): + if func is not None: + functions.add(func) + + program = "\n" + if len(input.relevant_global_exprs) > 0: + program = ( + "\n".join( + [ + "```\n" + expr.text.decode() + "\n```\n" + for expr in input.relevant_global_exprs + ] + ) + + "\n" + ) + + program += "\n".join( + ["```\n" + func.lined_code + "\n```\n" for func in functions] ) prompt = prompt.replace("", program) return prompt diff --git a/src/memory/syntactic/value.py b/src/memory/syntactic/value.py index 61e439f..f46d816 100644 --- a/src/memory/syntactic/value.py +++ b/src/memory/syntactic/value.py @@ -15,7 +15,8 @@ class ValueLabel(Enum): NON_BUF_ACCESS_EXPR = 8 # non-buffer access LOCAL = 9 - GLOBAL = 10 + NONLOCAL = 10 + GLOBAL = 11 def __str__(self) -> str: mapping = { @@ -28,6 +29,7 @@ def __str__(self) -> str: ValueLabel.BUF_ACCESS_EXPR: "ValueLabel.BUF_ACCESS_EXPR", ValueLabel.NON_BUF_ACCESS_EXPR: "ValueLabel.NON_BUF_ACCESS_EXPR", ValueLabel.LOCAL: "ValueLabel.LOCAL", + ValueLabel.NONLOCAL: "ValueLabel.NONLOCAL", ValueLabel.GLOBAL: "ValueLabel.GLOBAL", } return mapping[self] @@ -44,6 +46,7 @@ def from_str(s: str): "ValueLabel.BUF_ACCESS_EXPR": ValueLabel.BUF_ACCESS_EXPR, "ValueLabel.NON_BUF_ACCESS_EXPR": ValueLabel.NON_BUF_ACCESS_EXPR, "ValueLabel.LOCAL": ValueLabel.LOCAL, + "ValueLabel.NONLOCAL": ValueLabel.NONLOCAL, "ValueLabel.GLOBAL": ValueLabel.GLOBAL, } try: diff --git a/src/prompt/Javascript/dfbscan/intra_dataflow_analyzer.json b/src/prompt/Javascript/dfbscan/intra_dataflow_analyzer.json new file mode 100644 index 0000000..9b20e83 --- /dev/null +++ b/src/prompt/Javascript/dfbscan/intra_dataflow_analyzer.json @@ -0,0 +1,135 @@ +{ + "model_role_name": "Intra-procedural Data Flow Analyzer", + "user_role_name": "Intra-procedural Data Flow Analyzer", + "system_role": "You are a Javascript programmer and very good at analyzing Javascript code. Particularly, you excel at understanding individual Javascript functions and their data flow relationships.", + "task": "Given a specific source variable/parameter/expression (denoted as SRC) at a specific line (denoted as L1), analyze the execution flows of the given function and determine the variables to which SRC can propagate.", + "analysis_rules": [ + "The key principle for answering this question is to extract all execution paths related to SRC and simulate the function's execution along each path to determine where SRC propagates. In Javascript, SRC can propagate to five possible locations:", + "1. Function Calls: SRC propagates to a call site where it is passed as an argument to a callee function within the current function.", + "2. Return Statements: SRC propagates to a return statement, returning a value to the caller of the current function.", + "3. Function Parameters: SRC propagates to a parameter of the current function and can be referenced in the caller function, since objects are passed by reference.", + "4. Sink variables: SRC reaches one of the predefined sink variables provided in the input.", + "5. Non local variable assignment: SRC propagates its value to a predefined non local variable.", + "If SRC is referenced by function parameters, it can propagate beyond the function scope after the function exits, due to object references being shared between caller and callee. For example, if function goo passes an object base to its callee function foo, and foo(obj: Base) { obj = SRC; }, then the caller function goo can access the updated state of SRC through the object base.", + "To conduct the analysis, follow these three steps:", + "", + "- Step 1: Identify SRC and its aliases within the current function. Extract key points, including:", + " 1. Sink Statements: Where SRC is used or assigned to predefined sink variables.", + " 2. Function Invocations: Call sites where SRC is passed as an argument.", + " 3. Return Statements: Points where the function returns, possibly propagating SRC.", + " 4. Parameter Assignments: Assignments where SRC is assigned to a parameter or an object field that is accessible outside the function.", + " 5. Non local variable assignment: Assignments where SRC is assigned to a predefined non local variable.", + "- Step 2: Identify all execution paths relevant to the key points found in Step 1. For each path:", + " - Identify every potential execution path;", + " - Verify whether the key points are executed along each path;", + " - Expand execution paths affected by conditional branches (if-else, switch), loops, and exception-handling blocks.", + "- Step 3: For each execution path extracted in Step 2, simulate function execution line by line and determine where SRC propagates based on the four possible propagation locations.", + "", + "Additional guidelines:", + "- Different propagation paths can exist due to control flow constructs (if-else, loops, race conditions); identify all execution paths before analysis;", + "- If there are potential race conditions, for example, calling an async function without await, differentiate the control paths to consider the race condition;", + "- Expand the first iteration of loops to analyze nested execution paths;", + "- Treat each conditional branch (if, switch) as a separate execution path;", + "- Expand nested conditions and loops to ensure all paths are analyzed.", + "- For non local variables, they can be defined anywhere outside the function and you should not assume any information about its declaration (including line numbers of their declaration or their initial values)." + ], + "analysis_examples": [ + "Example 1: Propagation via Sink, Function call, and Return", + "User:", + "Now I will give you a target function with the source point 'p' at line 1:", + "```javascript", + "function process(p) {", + " let a = p; // SRC: a", + " if (a > 0) {", + " let b = a + 1;", + " logValue(a); // Sink value: a", + " } else {", + " let c = a - 1;", + " updateValue(a); // Function call", + " }", + " return a; // Return statement: 'a' is returned", + "}", + "```", + "Where does the source variable 'p' at line 1 propagate within this function?", + "System:", + "Explanation:", + "Step 1: Identify SRC and its alias;", + "SRC: The variable p at line 1 is assigned to variable a;", + "Step 2: Identify key points and execution paths:", + "Path 1 (if branch): When a > 0, variable a is used in the sink statement at line 5 (logValue(a)) and returned at line 10;", + "Path 2 (else branch): When a <= 0, variable a is passed as an argument to updateValue at line 8 and then returned at line 10;", + "Step 3: Simulate each execution path:", + "In Path 1, propagation is from p -> a -> sink at logValue(a) and p -> a -> return a;", + "In Path 2, propagation is from p -> a -> updateValue(a) and p -> a -> return a;", + "Answer:", + "Path 1: Lines 1 -> 2 -> 4 -> 5 -> 10;", + "- Type: Sink; Name: a; Function: None; Index: None; Line: 5; Dependency: SRC (p) is assigned to 'a', then used in logValue;", + "- Type: Return; Name: a; Function: None; Index: 0; Line: 10; Dependency: SRC (p) is assigned to 'a', then returned.", + "Path 2: Lines 1 -> 2 -> 6 -> 8 -> 10;", + "- Type: Argument; Name: a; Function: updateValue; Index: 0; Line: 8; Dependency: SRC (p) is assigned to 'a', then passed as an argument to updateValue;", + "- Type: Return; Name: a; Function: None; Index: 0; Line: 10; Dependency: SRC (p) is assigned to 'a', then returned.", + "", + "Example 2: Propagation via Parameter Passing and Return", + "User:", + "Now I will give you a target function with the source point 'src' at line 1:", + "```javascript", + "function modifyValue(src, flag) {", + " if (flag > 0) {", + " flag = src; // source value: src", + " } else if (flag == 0) {", + " return src; // return statement", + " }", + " return -1; // Default return value ", + "}", + "```", + "Where does the source variable 'src' at line 1 propagate within this function?", + "System:", + "Explanation:", + "Step 1: Identify SRC and its alias;", + "SRC: The variable src is defined at line 1;", + "Step 2: Identify key points and execution paths:", + "Path 1 (flag > 0): src is assigned to flag at line 3, making it accessible outside the function if flag is referenced after the call;", + "Path 2 (flag == 0): src is returned at line 5, propagating to the caller;", + "Path 3 (flag < 0): Function returns -1, so SRC does not propagate in this path;", + "Step 3: Simulate the execution paths:", + "Path 1: When flag > 0, src is assigned to flag, allowing potential propagation outside the function through the parameter reference;", + "Path 2: When flag == 0, src is returned to the caller;", + "Path 3: When flag < 0, src does not propagate, as the function returns -1;", + "Answer:", + "Path 1: Lines 1 -> 3;", + "- Type: Parameter; Name: flag; Function: None; Index: 1; Line: 3; Dependency: SRC (src) is assigned to parameter 'flag', which may be referenced by the caller;", + "Path 2: Lines 1 -> 5;", + "- Type: Return; Name: src; Function: None; Index: 0; Line: 5; Dependency: SRC (src) is returned to the caller;", + "Path 3: Lines 1 -> 6;", + "- No propagation; Dependency: Default return value -1 is unrelated to SRC." + ], + "question_template": "- Where does the source at line in this function propagate?", + "answer_format_cot": [ + "(1) First, provide a detailed step-by-step reasoning process, following the explanation format used in the examples;", + "(2) Once the reasoning is complete, begin the final answer section with 'Answer:';", + "(3) For each execution path, list the propagation details using the following format:", + "- Path : ;", + " - For a function argument propagation: 'Type: Argument; Name: {argument name}; Function: {callee function name}; Index: {argument index}; Line: {call site line number}; Dependency: {summary of dependency from SRC to argument}';", + " - For a return propagation: 'Type: Return; Name: {return name}; Function: None; Index: {return value index}; Line: {return statement line number}; Dependency: {summary of dependency from SRC to return value}';", + " - For parameter propagation: 'Type: Parameter; Name: {parameter name}; Function: None; Index: {parameter index}; Line: {assignment line number}; Dependency: {summary of dependency from SRC to parameter}';", + " - For sink propagation: 'Type: Sink; Name: {sink name}; Function: None; Index: None; Line: {sink statement line number}; Dependency: {summary of dependency from SRC to sink}';", + " - For non local variable assignment: 'Type: Nonlocal; Name: {non local name}; Function: None; Index: None; Line: {assignment statement line number}; Dependency: {summary of dependency from SRC to assignment}';", + "(4) If there is no propagation along a path, provide a brief explanation of why SRC does not propagate in that path as follows:", + "- Path : ;", + " - No propagation; Dependency: {reason for no propagation};", + "(5) Each Execution Path should start with the word \"Lines\", with each line number separated by \" -> \" and ended with a semicolon.", + "(6) Remember: All the indexes start from 0 instead of 1. If there is only one return value, the index is 0.", + "(7) Remember: For non local variable assignment, only list 'Type: Nonlocal; Name: {non local name}; Function: None; Index: None; Line: {assignment statement line number}; Dependency: {summary of dependency from SRC to assignment}' if assignments are explicitly stated in the function code (Eg. non_local = something)." + ], + "meta_prompts": [ + "Now I will give you a target function with the source point `` at line : \n```\n\n``` \n\n", + "You may see the following expressions at these line as sink points. Identify which of these are related to SRC and its aliases;\n", + "\n", + "Here are the Function call sites and return statements within the function, which can be used in Step 1;\n", + "\n", + "\n", + "", + "Now, please answer the following question:\n\n", + "Your response should strictly follow the format:\n\n" + ] +} diff --git a/src/prompt/Javascript/dfbscan/path_validator.json b/src/prompt/Javascript/dfbscan/path_validator.json new file mode 100644 index 0000000..989ae21 --- /dev/null +++ b/src/prompt/Javascript/dfbscan/path_validator.json @@ -0,0 +1,93 @@ +{ + "model_role_name": "Path Validator", + "user_role_name": "Path Validator", + "system_role": "You are a Javascript programmer and very good at analyzing Javascript code. In particular, you are skilled at understanding how data flows across multiple functions.", + "task": "You will be provided with an interprocedural data-flow path along with a specified . Your task is to decide whether the given propagation path is reachable – that is, whether its path condition is satisfiable. For example, for NPD (null-pointer dereference) detection, if the dereferenced object is guarded by a branch condition such as 'p !== null', then the path should be deemed unreachable.", + "analysis_rules": [ + "Keep the following guidelines in mind:", + "- If the source in the first function flows to the sink in the last function without any interference, then the path is reachable and your answer should be Yes.", + "- For NPD detection, if the source value is modified or its null/undefined state is verified (for example, via an explicit check like 'p !== null') before reaching the sink, then the path is unreachable and you should answer No.", + "- If a function exits or returns before the sink or other propagation sites (such as function calls) are reached, the path is unreachable; answer No in such cases.", + "- If a sink is a call to an object or a function that is builtin in Javascript or defined in the scope, then the path is unreachable; answer No in such cases.", + "- Analyze conditions within each function: infer the outcome of branch statements and then verify whether the conditions across different sub-paths conflict. If conflicts exist, the overall path is unreachable.", + "- If the data flow propagation path only one element, and the element is a sink, then the path is reachable and you should say Yes.", + "- Consider the values of relevant variables; if those values contradict the necessary branch conditions for triggering the bug, the path is unreachable and you should answer No.", + "In summary, assess the conditions in every sub-path, check for conflicts, and decide whether the entire propagation path is reachable." + ], + "question_template": [ + "When these functions are executed, does the following data-flow propagation path cause the bug?", + "```", + "", + "```", + "Provide your detailed explanation for this propagation path:", + "", + "" + ], + "analysis_examples": [ + "Example 1:", + "User:", + "Here is the Javascript program:", + "```javascript", + "function getArray(length) {", + " let array = null;", + " if (length > 0) {", + " array = new Array(length);", + " }", + " return array;", + "}", + "", + "function getElement(array, index) {", + " return array[index];", + "}", + "```", + "Does the following propagation path cause the NPD bug?", + "Propagation Path: 'array' at line 2 in getArray --> 'array' used at line 2 in getElement", + "Explanation: In getArray, if length <= 0, array remains null and is returned. In getElement, a null array would trigger a TypeError (null dereference) when accessed at line 10. However, when length > 0, the array is non-null. Since the conditions for array being null and non-null conflict, this propagation path is unreachable and does not cause the NPD bug.", + "Answer: No.", + "", + "Example 2:", + "User:", + "Here is the Javascript program:", + "```javascript", + "function foo(obj) {", + " if (obj === null) {", + " return null;", + " }", + " return obj;", + "}", + "", + "function bar() {", + " const myObj = foo(null);", + " myObj.toString();", + "}", + "```", + "Does the following propagation path cause the NPD bug?", + "Parameter 'obj' in foo --> foo returns null --> myObj assigned null in bar, which then gets dereferenced causing a method call on null", + "Explanation: The function foo returns null when passed a null input. In bar, this leads to myObj being null, which in turn causes a TypeError when calling toString(). As there is no conflicting branch condition preventing this case, the propagation path is reachable and causes the NPD bug.", + "Answer: Yes." + ], + "additional_fact": [ + "Additional details may include whether specific lines fall within if-statements and the corresponding line numbers for those conditions.", + "For each line in the provided path, follow this reasoning:", + "- Indicate whether line {line_number} is inside the 'true' or 'else' branch of an if-statement.", + "- State whether, given the variable values, the branch condition will always be evaluated as true, always as false, or is indeterminate.", + "- Conclude whether line {line_number} is reachable.", + "After analyzing each line, decide if the overall path's condition is satisfiable (reachable) or not." + ], + "answer_format": [ + "(1) In the first line, provide your detailed reasoning and explanation.", + "(2) In the second line, simply state Yes or No.", + "Example:", + "Explanation: {Your detailed explanation.}", + "Answer: Yes" + ], + "meta_prompts": [ + "Now I will provide you with the program:", + "", + "Please answer the following question:", + "", + "Your answer should follow this format:", + "", + "Remember: Do not assume the behavior or return values of external methods not provided in the program. Only evaluate the conditions present in the given code." + ] +} diff --git a/src/repoaudit.py b/src/repoaudit.py index 24d3639..042e6ff 100644 --- a/src/repoaudit.py +++ b/src/repoaudit.py @@ -10,6 +10,7 @@ from tstool.analyzer.Go_TS_analyzer import * from tstool.analyzer.Java_TS_analyzer import * from tstool.analyzer.Python_TS_analyzer import * +from tstool.analyzer.Javascript_TS_analyzer import * from typing import List @@ -17,6 +18,7 @@ "Cpp": ["MLK", "NPD", "UAF"], "Java": ["NPD"], "Python": ["NPD"], + "Javascript": ["NPD"], "Go": ["NPD"], } @@ -59,6 +61,8 @@ def __init__( suffixs = ["java"] elif self.language == "Python": suffixs = ["py"] + elif self.language == "Javascript": + suffixs = ["js", "jsx"] else: raise ValueError("Invalid language setting") @@ -82,6 +86,10 @@ def __init__( self.ts_analyzer = Python_TSAnalyzer( self.code_in_files, self.language, self.max_symbolic_workers ) + elif self.language == "Javascript": + self.ts_analyzer = Javascript_TSAnalyzer( + self.code_in_files, self.language, self.max_symbolic_workers + ) return def start_repo_auditing(self) -> None: diff --git a/src/tstool/analyzer/Cpp_TS_analyzer.py b/src/tstool/analyzer/Cpp_TS_analyzer.py index 244e82f..796d87a 100644 --- a/src/tstool/analyzer/Cpp_TS_analyzer.py +++ b/src/tstool/analyzer/Cpp_TS_analyzer.py @@ -16,6 +16,21 @@ class Cpp_TSAnalyzer(TSAnalyzer): Implements language-specific parsing and analysis. """ + def extract_scope_info(self, tree: tree_sitter.Tree) -> None: + """ + Parse source code to extract scope topography. + Currently Not implemented. + :param tree: Parsed syntax tree + """ + pass + + def extract_nonlocal_info(self) -> None: + """ + Traverse the scopes to identify declarations of non locals. + Currently Not implemented. + """ + pass + def extract_function_info( self, file_path: str, source_code: str, tree: tree_sitter.Tree ) -> None: @@ -416,3 +431,15 @@ def get_loop_statements( loop_body_end_line, ) return loop_statements + + def get_global_expressions_by_identifier( + self, identifier: str, program_root: Node + ) -> List[Node]: + """ + Extracts all expressions related to a specific identifier in the global scope. + Currently not implemented. + :param identifier: The identifier + :param program_root: Program root node + :return: A list of extracted nodes + """ + return [] diff --git a/src/tstool/analyzer/Go_TS_analyzer.py b/src/tstool/analyzer/Go_TS_analyzer.py index 13ceb2a..05b248c 100644 --- a/src/tstool/analyzer/Go_TS_analyzer.py +++ b/src/tstool/analyzer/Go_TS_analyzer.py @@ -16,6 +16,21 @@ class Go_TSAnalyzer(TSAnalyzer): Implements Go-specific parsing and analysis. """ + def extract_scope_info(self, tree: tree_sitter.Tree) -> None: + """ + Parse source code to extract scope topography. + Currently Not implemented. + :param tree: Parsed syntax tree + """ + pass + + def extract_nonlocal_info(self) -> None: + """ + Traverse the scopes to identify declarations of non locals. + Currently Not implemented. + """ + pass + def extract_function_info( self, file_path: str, source_code: str, tree: tree_sitter.Tree ) -> None: @@ -349,3 +364,15 @@ def get_loop_statements( loop_body_end_line, ) return loop_statements + + def get_global_expressions_by_identifier( + self, identifier: str, program_root: Node + ) -> List[Node]: + """ + Extracts all expressions related to a specific identifier in the global scope. + Currently not implemented. + :param identifier: The identifier + :param program_root: Program root node + :return: A list of extracted nodes + """ + return [] diff --git a/src/tstool/analyzer/Java_TS_analyzer.py b/src/tstool/analyzer/Java_TS_analyzer.py index 4464bea..3a25281 100644 --- a/src/tstool/analyzer/Java_TS_analyzer.py +++ b/src/tstool/analyzer/Java_TS_analyzer.py @@ -16,6 +16,21 @@ class Java_TSAnalyzer(TSAnalyzer): Implements Java-specific parsing and analysis. """ + def extract_scope_info(self, tree: tree_sitter.Tree) -> None: + """ + Parse source code to extract scope topography. + Currently Not implemented. + :param tree: Parsed syntax tree + """ + pass + + def extract_nonlocal_info(self) -> None: + """ + Traverse the scopes to identify declarations of non locals. + Currently Not implemented. + """ + pass + def extract_function_info( self, file_path: str, source_code: str, tree: tree_sitter.Tree ) -> None: @@ -361,3 +376,15 @@ def get_loop_statements( loop_body_end_line, ) return loop_statements + + def get_global_expressions_by_identifier( + self, identifier: str, program_root: Node + ) -> List[Node]: + """ + Extracts all expressions related to a specific identifier in the global scope. + Currently not implemented. + :param identifier: The identifier + :param program_root: Program root node + :return: A list of extracted nodes + """ + return [] diff --git a/src/tstool/analyzer/Javascript_TS_analyzer.py b/src/tstool/analyzer/Javascript_TS_analyzer.py new file mode 100644 index 0000000..d27c5f9 --- /dev/null +++ b/src/tstool/analyzer/Javascript_TS_analyzer.py @@ -0,0 +1,554 @@ +import sys +from os import path +from typing import List, Tuple, Dict, Set +import tree_sitter + +sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))) + +from .TS_analyzer import * +from memory.syntactic.function import * +from memory.syntactic.value import * + + +class Javascript_TSAnalyzer(TSAnalyzer): + """ + TSAnalyzer for Javascript source files using tree-sitter. + Implements Javascript-specific parsing and analysis. + """ + + def extract_scope_info(self, tree: tree_sitter.Tree) -> None: + """ + Parse source code to extract scope topography + :param tree: Parsed syntax tree + """ + scope_stack: List[int] = [] + scope_id: int = 0 + + def search(root: Node) -> None: + nonlocal scope_id + + for child in root.children: + if child.type == "statement_block": + if len(scope_stack) > 0: + self.scope_env[scope_stack[-1]][1].add(scope_id) + + self.scope_env[scope_id] = (child, set()) + self.scope_root_to_scope_id[child] = scope_id + scope_stack.append(scope_id) + + if child.parent: + if child.parent.type == "function_declaration": + self.function_root_to_scope_id[child.parent] = scope_id + elif ( + child.parent.type == "arrow_function" + or child.parent.type == "function_expression" + ): + if child.parent.parent: + self.function_root_to_scope_id[child.parent.parent] = ( + scope_id + ) + + scope_id += 1 + search(child) + scope_stack.pop() + else: + search(child) + + return + + self.scope_env[scope_id] = (tree.root_node, set()) + self.scope_root_to_scope_id[tree.root_node] = scope_id + scope_stack.append(scope_id) + scope_id += 1 + search(tree.root_node) + return + + def extract_nonlocal_info(self) -> None: + identifiers_per_scope: Dict[int, List[Node]] = {} + + for _, scope_data in self.scope_env.items(): + scope_root, child_scope_ids = scope_data + + for scope_child in scope_root.children: + # Only process lexical/variable declarations + if scope_child.type not in ( + "lexical_declaration", + "variable_declaration", + ): + continue + + decl_child = scope_child.child(1) + if decl_child is None: + continue + + name_node = decl_child.child_by_field_name("name") + if name_node is None or name_node.text is None: + continue + + variable_name: str = name_node.text.decode("utf-8") + + label = ( + ValueLabel.GLOBAL + if scope_root.type == "program" + else ValueLabel.LOCAL + ) + + non_local_value = Value( + variable_name, + scope_child.start_point[0] + 1, + label, + file="", + index=-1, + ) + + effective_child_scope_ids = child_scope_ids + if scope_child.type == "variable_declaration": + function_root: Optional[Node] = scope_root + + # Find closest parent function + while ( + function_root is not None and function_root.parent is not None + ): + parent = function_root.parent + if parent.type in ( + "arrow_function", + "function_declaration", + "function_expression", + ): + break + function_root = parent + + if ( + function_root is None + or function_root not in self.scope_root_to_scope_id + ): + continue + + function_scope_id = self.scope_root_to_scope_id[function_root] + effective_child_scope_ids = self.scope_env[function_scope_id][1] + + # Process child scopes + for child_scope_id in effective_child_scope_ids: + child_scope_root, _ = self.scope_env[child_scope_id] + + # Must be inside a function-like construct + parent_node: Optional[Node] = child_scope_root.parent + if parent_node is None or parent_node.type not in ( + "arrow_function", + "function_declaration", + "function_expression", + ): + continue + + # Cache identifiers per scope + if child_scope_id not in identifiers_per_scope: + identifiers_per_scope[child_scope_id] = find_nodes_by_type( + child_scope_root, "identifier" + ) + + for candidate_node in identifiers_per_scope[child_scope_id]: + if candidate_node: + continue + + # Name mismatch + if candidate_node.text is None: + continue + if candidate_node.text.decode("utf-8") != variable_name: + continue + + # Skip if this identifier declares a new variable in this scope with the same name + candidate_parent = candidate_node.parent + if ( + candidate_parent is not None + and candidate_parent.type == "variable_declarator" + and candidate_parent.child_by_field_name("name") + is candidate_node + ): + continue + + self.child_scope_id_to_non_locals.setdefault( + child_scope_id, set() + ).add(non_local_value) + + self.non_local_to_child_scopes.setdefault( + non_local_value, set() + ).add(child_scope_id) + + def extract_function_info( + self, file_path: str, source_code: str, tree: tree_sitter.Tree + ) -> None: + """ + Parse the function information in a source file. + :param file_path: The path of the source file. + :param source_code: The content of the source file. + :param tree: The parse tree of the source file. + """ + all_function_header_nodes = find_nodes_by_type( + tree.root_node, "function_declaration" + ) + all_variable_declarator_nodes = find_nodes_by_type( + tree.root_node, "variable_declarator" + ) + + for node in all_function_header_nodes: + function_name = "" + for sub_node in node.children: + if sub_node.type == "identifier": + function_name = source_code[sub_node.start_byte : sub_node.end_byte] + break + + if function_name == "": + continue + + start_line_number = source_code[: node.start_byte].count("\n") + 1 + end_line_number = source_code[: node.end_byte].count("\n") + 1 + function_id = len(self.functionRawDataDic) + 1 + + self.functionRawDataDic[function_id] = ( + function_name, + start_line_number, + end_line_number, + node, + ) + self.functionToFile[function_id] = file_path + + if function_name not in self.functionNameToId: + self.functionNameToId[function_name] = set([]) + self.functionNameToId[function_name].add(function_id) + + for node in all_variable_declarator_nodes: + name_node = node.child_by_field_name("name") + value_node = node.child_by_field_name("value") + + if not name_node or not value_node: + continue + + if ( + value_node.type != "arrow_function" + and value_node.type != "function_expression" + ): + continue + + function_name = source_code[name_node.start_byte : name_node.end_byte] + start_line = source_code[: node.start_byte].count("\n") + 1 + end_line = source_code[: node.end_byte].count("\n") + 1 + function_id = len(self.functionRawDataDic) + 1 + + self.functionRawDataDic[function_id] = ( + function_name, + start_line, + end_line, + node, + ) + self.functionToFile[function_id] = file_path + self.functionNameToId.setdefault(function_name, set()).add(function_id) + + return + + def extract_global_info( + self, file_path: str, source_code: str, tree: tree_sitter.Tree + ) -> None: + """ + Parse global variable information from a Javascript source file. + For Javascript, this may include module-level variables. + Currently not implemented. + """ + declaration_types = ["lexical_declaration", "variable_declaration"] + for child in tree.root_node.children: + if child.type not in declaration_types: + continue + + declarator_node = child.child(1) + if ( + declarator_node is not None + and declarator_node.type == "variable_declarator" + ): + name_node = declarator_node.child_by_field_name("name") + value_node = declarator_node.child_by_field_name("value") + + if not name_node or not value_node: + continue + + if ( + value_node.type == "arrow_function" + or value_node.type == "function_expression" + ): + continue + + global_name = source_code[name_node.start_byte : name_node.end_byte] + line = source_code[: name_node.start_byte].count("\n") + 1 + global_id = len(self.globalsRawDataDic) + 1 + self.globalsRawDataDic[global_id] = (global_name, line, child) + self.globalsToFile[global_id] = file_path + + return + + def get_callee_name_at_call_site( + self, node: tree_sitter.Node, source_code: str + ) -> str: + """ + Get the callee name at the call site. + :param node: the node of the call site + :param source_code: the content of the file + """ + function_name = "" + for sub_node in node.children: + if sub_node.type == "identifier": + function_name = source_code[sub_node.start_byte : sub_node.end_byte] + break + if sub_node.type == "member_expression": + for sub_sub_node in sub_node.children: + if sub_sub_node.type == "identifier": + function_name = source_code[ + sub_sub_node.start_byte : sub_sub_node.end_byte + ] + break + return function_name + + def get_callsites_by_callee_name( + self, current_function: Function, callee_name: str + ) -> List[tree_sitter.Node]: + """ + Find the call sites by the callee function name. + :param current_function: the function to be analyzed + :param callee_name: the callee function name + """ + results = [] + file_content = self.code_in_files[current_function.file_path] + call_site_nodes = find_nodes_by_type( + current_function.parse_tree_root_node, "call_expression" + ) + for call_site in call_site_nodes: + if ( + self.get_callee_name_at_call_site(call_site, file_content) + == callee_name + ): + results.append(call_site) + return results + + def get_arguments_at_callsite( + self, current_function: Function, call_site_node: tree_sitter.Node + ) -> Set[Value]: + """ + Get arguments from a call site in a function. + :param current_function: the function to be analyzed + :param call_site_node: the node of the call site + :return: the arguments + """ + arguments: Set[Value] = set([]) + file_name = current_function.file_path + source_code = self.code_in_files[file_name] + for sub_node in call_site_node.children: + if sub_node.type == "arguments": + arg_list = sub_node.children[1:-1] + for element in arg_list: + if element.type != ",": + line_number = source_code[: element.start_byte].count("\n") + 1 + arguments.add( + Value( + source_code[element.start_byte : element.end_byte], + line_number, + ValueLabel.ARG, + file_name, + len(arguments), + ) + ) + return arguments + + def get_parameters_in_single_function( + self, current_function: Function + ) -> Set[Value]: + """ + Find the parameters of a function. + :param current_function: The function to be analyzed. + :return: A set of parameters as values + """ + if current_function.paras is not None: + return current_function.paras + current_function.paras = set([]) + file_content = self.code_in_files[current_function.file_path] + parameters = find_nodes_by_type( + current_function.parse_tree_root_node, "formal_parameters" + ) + + index = 0 + for parameter_node in parameters: + parameter_name = "" + for sub_node in parameter_node.children: + for sub_sub_node in find_nodes_by_type(sub_node, "identifier"): + parameter_name = file_content[ + sub_sub_node.start_byte : sub_sub_node.end_byte + ] + if parameter_name != "" and parameter_name != "self": + line_number = ( + file_content[: sub_node.start_byte].count("\n") + 1 + ) + current_function.paras.add( + Value( + parameter_name, + line_number, + ValueLabel.PARA, + current_function.file_path, + index, + ) + ) + index += 1 + return current_function.paras + + def get_return_values_in_single_function( + self, current_function: Function + ) -> Set[Value]: + """ + Find the return values of a Go function + :param current_function: The function to be analyzed. + :return: A set of return values + """ + if current_function.retvals is not None: + return current_function.retvals + + current_function.retvals = set([]) + file_content = self.code_in_files[current_function.file_path] + retnodes = find_nodes_by_type( + current_function.parse_tree_root_node, "return_statement" + ) + for retnode in retnodes: + line_number = file_content[: retnode.start_byte].count("\n") + 1 + restmts_str = file_content[retnode.start_byte : retnode.end_byte] + returned_value = restmts_str.replace("return", "").strip() + current_function.retvals.add( + Value( + returned_value, + line_number, + ValueLabel.RET, + current_function.file_path, + 0, + ) + ) + return current_function.retvals + + def get_if_statements( + self, function: Function, source_code: str + ) -> Dict[Tuple, Tuple]: + """ + Identify if-statements in the Javascript function. + This is a simplified analysis for illustrative purposes. + """ + if_statement_nodes = find_nodes_by_type( + function.parse_tree_root_node, "if_statement" + ) + if_statements = {} + for if_node in if_statement_nodes: + condition_str = "" + condition_start_line = 0 + condition_end_line = 0 + true_branch_start_line = 0 + true_branch_end_line = 0 + else_branch_start_line = 0 + else_branch_end_line = 0 + + block_num = 0 + for sub_target in if_node.children: + if sub_target.type == "parenthesized_expression": + condition_start_line = ( + source_code[: sub_target.start_byte].count("\n") + 1 + ) + condition_end_line = ( + source_code[: sub_target.end_byte].count("\n") + 1 + ) + condition_str = source_code[ + sub_target.start_byte : sub_target.end_byte + ] + if sub_target.type == "statement_block": + lower_lines = [] + upper_lines = [] + for sub_sub in sub_target.children: + if sub_sub.type not in {"{", "}"}: + lower_lines.append( + source_code[: sub_sub.start_byte].count("\n") + 1 + ) + upper_lines.append( + source_code[: sub_sub.end_byte].count("\n") + 1 + ) + if lower_lines and upper_lines: + if block_num == 0: + true_branch_start_line = min(lower_lines) + true_branch_end_line = max(upper_lines) + block_num += 1 + elif block_num == 1: + else_branch_start_line = min(lower_lines) + else_branch_end_line = max(upper_lines) + block_num += 1 + if sub_target.type == "expression_statement": + true_branch_start_line = ( + source_code[: sub_target.start_byte].count("\n") + 1 + ) + true_branch_end_line = ( + source_code[: sub_target.end_byte].count("\n") + 1 + ) + + if_statement_start_line = source_code[: if_node.start_byte].count("\n") + 1 + if_statement_end_line = source_code[: if_node.end_byte].count("\n") + 1 + line_scope = (if_statement_start_line, if_statement_end_line) + info = ( + condition_start_line, + condition_end_line, + condition_str, + (true_branch_start_line, true_branch_end_line), + (else_branch_start_line, else_branch_end_line), + ) + if_statements[line_scope] = info + return if_statements + + def get_loop_statements( + self, function: Function, source_code: str + ) -> Dict[Tuple, Tuple]: + """ + Identify loop statements (for and while) in the Javascript function. + """ + loops = {} + loop_nodes = find_nodes_by_type(function.parse_tree_root_node, "for_statement") + loop_nodes.extend( + find_nodes_by_type(function.parse_tree_root_node, "for_in_statement") + ) + loop_nodes.extend( + find_nodes_by_type(function.parse_tree_root_node, "while_statement") + ) + for node in loop_nodes: + start_line = source_code[: node.start_byte].count("\n") + 1 + end_line = source_code[: node.end_byte].count("\n") + 1 + # Simplified header and body analysis. + loops[(start_line, end_line)] = ( + start_line, + start_line, + "", + start_line, + end_line, + ) + return loops + + def get_global_expressions_by_identifier( + self, identifier: str, program_root: Node + ) -> List[Node]: + """ + Extracts all expressions related to a specific identifier in the global scope + :param identifier: The identifier + :param program_root: Program root node + :return: A list of extracted nodes + """ + + output_nodes = [] + children = program_root.children + global_expression_types = [ + "variable_declaration", + "lexical_declaration", + "expression_statement", + ] + + for child in children: + if child.type not in global_expression_types: + continue + + if find_nodes_by_type(child, "identifier")[0].text.decode() == identifier: + output_nodes.append(child) + + return output_nodes diff --git a/src/tstool/analyzer/Python_TS_analyzer.py b/src/tstool/analyzer/Python_TS_analyzer.py index 24e0f02..0954385 100644 --- a/src/tstool/analyzer/Python_TS_analyzer.py +++ b/src/tstool/analyzer/Python_TS_analyzer.py @@ -16,6 +16,21 @@ class Python_TSAnalyzer(TSAnalyzer): Implements Python-specific parsing and analysis. """ + def extract_scope_info(self, tree: tree_sitter.Tree) -> None: + """ + Parse source code to extract scope topography + :param tree: Parsed syntax tree + """ + # TODO: Add scope extraction if needed + pass + + def extract_nonlocal_info(self) -> None: + """ + Traverse the scopes to identify declarations of non locals + """ + # TODO: add non local variable extraction if needed + pass + def extract_function_info( self, file_path: str, source_code: str, tree: tree_sitter.Tree ) -> None: @@ -279,3 +294,15 @@ def get_loop_statements( end_line, ) return loops + + def get_global_expressions_by_identifier( + self, identifier: str, program_root: Node + ) -> List[Node]: + """ + Extracts all expressions related to a specific identifier in the global scope + :param identifier: The identifier + :param program_root: Program root node + :return: A list of extracted nodes + """ + # TODO: implement if needed + return [] diff --git a/src/tstool/analyzer/TS_analyzer.py b/src/tstool/analyzer/TS_analyzer.py index 31118ab..42e4f94 100644 --- a/src/tstool/analyzer/TS_analyzer.py +++ b/src/tstool/analyzer/TS_analyzer.py @@ -156,6 +156,8 @@ def __init__( self.language = Language(str(language_path), "java") elif language_name == "Python": self.language = Language(str(language_path), "python") + elif language_name == "Javascript": + self.language = Language(str(language_path), "javascript") elif language_name == "Go": self.language = Language(str(language_path), "go") else: @@ -168,10 +170,26 @@ def __init__( self.functionToFile: Dict[int, str] = {} self.fileContentDic: Dict[str, str] = {} self.glb_var_map: Dict[str, str] = {} # global var info + self.globalsRawDataDic: Dict[int, Tuple[str, int, Node]] = {} + self.globalsToFile: Dict[int, str] = {} self.function_env: Dict[int, Function] = {} + self.globals_env: Dict[int, Value] = {} + self.scope_env: Dict[int, Tuple[Node, Set[int]]] = {} self.api_env: Dict[int, API] = {} + # Dictionary storing mapping from the root node of the scope to its scope id + self.scope_root_to_scope_id: Dict[Node, int] = {} + + # Dictionary storing mapping from function root node to its scope id + self.function_root_to_scope_id: Dict[Node, int] = {} + + # Dictionary storing mapping from a scope id to all the non locals it is depended on + self.child_scope_id_to_non_locals: Dict[int, Set[Value]] = {} + + # Dictionary storing mapping from a non local value to its child scopes + self.non_local_to_child_scopes: Dict[Value, Set[int]] = {} + # Results of call graph analysis ## Caller-callee relationship between user-defined functions self.function_caller_callee_map: Dict[int, Set[int]] = {} @@ -201,6 +219,7 @@ def _parse_single_file(self, file_path: str, source_code: str) -> Tuple[str, str # Call user-defined processing. self.extract_function_info(file_path, source_code, tree) self.extract_global_info(file_path, source_code, tree) + self.extract_scope_info(tree) return file_path, source_code def _analyze_single_function( @@ -229,6 +248,7 @@ def parse_project(self) -> None: """ Parse all project files using tree-sitter. """ + # Parses files in the project with concurrent.futures.ThreadPoolExecutor( max_workers=self.max_symbolic_workers_num ) as executor: @@ -247,6 +267,9 @@ def parse_project(self) -> None: pbar.update(1) pbar.close() + self.extract_nonlocal_info() + + # Analyzes extracted functions with concurrent.futures.ThreadPoolExecutor( max_workers=self.max_symbolic_workers_num ) as executor: @@ -265,6 +288,25 @@ def parse_project(self) -> None: self.function_env[func_id] = current_function pbar.update(1) pbar.close() + + # Analyzes extracted global variables + pbar = tqdm( + total=len(self.globalsRawDataDic), desc="Analyzing Global Variables" + ) + for global_id, global_var_tuple in self.globalsRawDataDic.items(): + name = global_var_tuple[0] + line = global_var_tuple[1] + value = Value( + name=name, + line_number=line, + label=ValueLabel.GLOBAL, + file=self.globalsToFile[global_id], + ) + + self.globals_env[global_id] = value + pbar.update(1) + pbar.close() + return def analyze_call_graph(self) -> None: @@ -294,6 +336,21 @@ def analyze_call_graph(self) -> None: ########################################### # Helper function for project AST parsing # ########################################### + @abstractmethod + def extract_scope_info(self, tree: tree_sitter.Tree) -> None: + """ + Parse source code to extract scope topography + :param tree: Parsed syntax tree + """ + pass + + @abstractmethod + def extract_nonlocal_info(self) -> None: + """ + Traverse the scopes to identify declarations of non locals + """ + pass + @abstractmethod def extract_function_info( self, file_path: str, source_code: str, tree: Tree @@ -354,7 +411,11 @@ def extract_call_graph_edges(self, current_function: Function) -> None: file_content = self.fileContentDic[file_name] call_node_type = None - if self.language_name == "C" or self.language_name == "Cpp": + if ( + self.language_name == "C" + or self.language_name == "Cpp" + or self.language_name == "Javascript" + ): call_node_type = "call_expression" elif self.language_name == "Java": call_node_type = "method_invocation" @@ -397,7 +458,7 @@ def extract_call_graph_edges(self, current_function: Function) -> None: tmp_api = API(-1, callee_name, len(arguments)) # Insert the API into the API environment if it does not exist previously - for single_api_id in self.api_env: + for single_api_id in list(self.api_env): if self.api_env[single_api_id] == tmp_api: api_id = single_api_id if api_id == None: @@ -670,6 +731,18 @@ def get_loop_statements( """ pass + @abstractmethod + def get_global_expressions_by_identifier( + self, identifier: str, program_root: Node + ) -> List[Node]: + """ + Extracts all expressions related to a specific identifier in the global scope + :param identifier: The identifier + :param program_root: Program root node + :return: A list of extracted nodes + """ + pass + def check_control_order( self, function: Function, src_line_number: int, sink_line_number: int ) -> bool: @@ -761,6 +834,44 @@ def get_node_by_line_number(self, line_number: int) -> List[Tuple[str, Node]]: code_node_list.append((function.function_code, node)) return code_node_list + def get_function_global_value_reference( + self, global_value: Value + ) -> Dict[Function, List[Value]]: + """ + Find references to a given global value in all functions + belonging to the same source file. + + Args: + global_value: The global Value to search for. + + Returns: + A dictionary mapping each Function to a list of Value + references where the global is used. + """ + file_name = global_value.file + references: Dict[Function, List[Value]] = {} + + for _, function in self.function_env.items(): + if function.file_path != file_name: + continue + + identifiers = find_nodes_by_type( + function.parse_tree_root_node, "identifier" + ) + for identifier in identifiers: + if global_value.name == identifier.text.decode(): + line_number = identifier.start_point[0] + 1 + ref_value = Value( + global_value.name, + line_number, + ValueLabel.GLOBAL, + function.file_path, + -1, + ) + references.setdefault(function, []).append(ref_value) + + return references + def get_function_from_localvalue(self, value: Value) -> Optional[Function]: """ Retrieve the function corresponding to a local value. diff --git a/src/tstool/dfbscan_extractor/Cpp/Cpp_MLK_extractor.py b/src/tstool/dfbscan_extractor/Cpp/Cpp_MLK_extractor.py index b3c8367..f9294c8 100644 --- a/src/tstool/dfbscan_extractor/Cpp/Cpp_MLK_extractor.py +++ b/src/tstool/dfbscan_extractor/Cpp/Cpp_MLK_extractor.py @@ -4,6 +4,13 @@ class Cpp_MLK_Extractor(DFBScanExtractor): + def is_global_source(self, global_declarator_node: Node) -> bool: + """ + Determines whether the global variable is initially a source. + Currently not implemented. + """ + return False + def extract_sources(self, function: Function) -> List[Value]: """ Extract the sources that can cause the memory leak bugs from C/C++ programs. diff --git a/src/tstool/dfbscan_extractor/Cpp/Cpp_NPD_extractor.py b/src/tstool/dfbscan_extractor/Cpp/Cpp_NPD_extractor.py index b7fe94f..38f4528 100644 --- a/src/tstool/dfbscan_extractor/Cpp/Cpp_NPD_extractor.py +++ b/src/tstool/dfbscan_extractor/Cpp/Cpp_NPD_extractor.py @@ -6,6 +6,13 @@ class Cpp_NPD_Extractor(DFBScanExtractor): + def is_global_source(self, global_declarator_node: Node) -> bool: + """ + Determines whether the global variable is initially a source. + Currently not implemented. + """ + return False + def extract_sources(self, function: Function) -> List[Value]: root_node = function.parse_tree_root_node source_code = self.ts_analyzer.code_in_files[function.file_path] diff --git a/src/tstool/dfbscan_extractor/Cpp/Cpp_UAF_extractor.py b/src/tstool/dfbscan_extractor/Cpp/Cpp_UAF_extractor.py index a18bf18..7ad8ac5 100644 --- a/src/tstool/dfbscan_extractor/Cpp/Cpp_UAF_extractor.py +++ b/src/tstool/dfbscan_extractor/Cpp/Cpp_UAF_extractor.py @@ -6,6 +6,13 @@ class Cpp_UAF_Extractor(DFBScanExtractor): + def is_global_source(self, global_declarator_node: Node) -> bool: + """ + Determines whether the global variable is initially a source. + Currently not implemented. + """ + return False + def extract_sources(self, function: Function) -> List[Value]: """ Extract the sources that can cause the use-after-free bugs from C/C++ programs. diff --git a/src/tstool/dfbscan_extractor/Go/Go_NPD_extractor.py b/src/tstool/dfbscan_extractor/Go/Go_NPD_extractor.py index 93a666b..747dcc0 100644 --- a/src/tstool/dfbscan_extractor/Go/Go_NPD_extractor.py +++ b/src/tstool/dfbscan_extractor/Go/Go_NPD_extractor.py @@ -6,6 +6,13 @@ class Go_NPD_Extractor(DFBScanExtractor): + def is_global_source(self, global_declarator_node: Node) -> bool: + """ + Determines whether the global variable is initially a source. + Currently not implemented. + """ + return False + def extract_sources(self, function: Function) -> List[Value]: root_node = function.parse_tree_root_node source_code = self.ts_analyzer.code_in_files[function.file_path] diff --git a/src/tstool/dfbscan_extractor/Java/Java_NPD_extractor.py b/src/tstool/dfbscan_extractor/Java/Java_NPD_extractor.py index 91a5201..e5fe58f 100644 --- a/src/tstool/dfbscan_extractor/Java/Java_NPD_extractor.py +++ b/src/tstool/dfbscan_extractor/Java/Java_NPD_extractor.py @@ -6,6 +6,13 @@ class Java_NPD_Extractor(DFBScanExtractor): + def is_global_source(self, global_declarator_node: Node) -> bool: + """ + Determines whether the global variable is initially a source. + Currently not implemented. + """ + return False + def extract_sources(self, function: Function) -> List[Value]: root_node = function.parse_tree_root_node source_code = self.ts_analyzer.code_in_files[function.file_path] diff --git a/src/tstool/dfbscan_extractor/Javascript/Javascript_NPD_extractor.py b/src/tstool/dfbscan_extractor/Javascript/Javascript_NPD_extractor.py new file mode 100644 index 0000000..25340a9 --- /dev/null +++ b/src/tstool/dfbscan_extractor/Javascript/Javascript_NPD_extractor.py @@ -0,0 +1,147 @@ +from tstool.analyzer.TS_analyzer import * +from tstool.analyzer.Javascript_TS_analyzer import * +from ..dfbscan_extractor import * + + +class Javascript_NPD_Extractor(DFBScanExtractor): + NULLISH_VALUES = {"null", "undefined"} + + def is_expression_delete(self, expr: Node) -> bool: + if expr.type == "unary_expression": + operator = expr.child(0) + if operator and operator.type == "delete": + return True + + return False + + def is_expression_null(self, expr: Node) -> bool: + if expr.type != "assignment_expression": + return False + + value_node = expr.child(2) + value_type = value_node.type if value_node else "" + + # Nullish constant (e.g. null/undefined) + if value_type in self.NULLISH_VALUES: + return True + + return False + + def is_global_source(self, global_declaration_node: Node) -> bool: + # global_name is usually bytes, decode for safe string comparison + name_node = global_declaration_node.child(1) + if name_node is None: + return False + + name_field = name_node.child_by_field_name("name") + if name_field is None or name_field.text is None: + return False + + global_name = name_field.text.decode("utf-8") + + sibling: Optional[Node] = global_declaration_node.next_sibling + + while sibling is not None: + # Skip empty siblings + if not sibling.children: + sibling = sibling.next_sibling + continue + + expr = sibling.child(0) + if expr is None: + sibling = sibling.next_sibling + continue + + # Handle deletion of property + if self.is_expression_delete(expr): + second_child = expr.child(1) + if second_child is not None: + obj_node = second_child.child_by_field_name("object") + if obj_node is not None and obj_node.text is not None: + if obj_node.text.decode("utf-8") == global_name: + return True + + # Handle nullish assignment + if self.is_expression_null(expr): + lhs = expr.child(0) + if lhs is not None and lhs.text is not None: + if lhs.text.decode("utf-8") == global_name: + return True + + sibling = sibling.next_sibling + + return False + + def extract_sources(self, function: Function) -> List[Value]: + """ + Extract the potential null/undefined values as sources from the source code. + 1. variable = null; + 2. return null; + 3. delete obj.prop; + 4. func(null); + """ + + root_node = function.parse_tree_root_node + source_code = self.ts_analyzer.code_in_files[function.file_path] + file_path = function.file_path + + nodes = find_nodes_by_type(root_node, "variable_declarator") + nodes.extend(find_nodes_by_type(root_node, "assignment_expression")) + nodes.extend(find_nodes_by_type(root_node, "return_statement")) + nodes.extend(find_nodes_by_type(root_node, "arguments")) + + sources = [] + + # Look for nullish value nodes + for node in nodes: + is_seed_node = False + + for child in node.children: + if child.type in self.NULLISH_VALUES: + is_seed_node = True + + if is_seed_node: + line_number = source_code[: node.start_byte].count("\n") + 1 + name = source_code[node.start_byte : node.end_byte] + sources.append(Value(name, line_number, ValueLabel.SRC, file_path)) + + unary_expressions = find_nodes_by_type(root_node, "unary_expression") + + # Look for delete expressions + for unary_expression in unary_expressions: + operator = unary_expression.child(0) + if operator is not None and operator.type == "delete": + line_number = source_code[: unary_expression.start_byte].count("\n") + 1 + name = source_code[ + unary_expression.start_byte : unary_expression.end_byte + ] + sources.append(Value(name, line_number, ValueLabel.SRC, file_path)) + + return sources + + def extract_sinks(self, function: Function) -> List[Value]: + """ + Extract the sinks that can cause the null pointer dereferences from Javascript programs. + 1. null_obj.prop; + 2. null_obj[1]; + 3. null_obj(); + + :param: function: Function object. + :return: List of sink values + """ + + root_node = function.parse_tree_root_node + source_code = self.ts_analyzer.code_in_files[function.file_path] + file_path = function.file_path + + nodes = find_nodes_by_type(root_node, "member_expression") + nodes.extend(find_nodes_by_type(root_node, "subscript_expression")) + nodes.extend(find_nodes_by_type(root_node, "call_expression")) + sinks = [] + + for node in nodes: + first_child = node.children[0] + line_number = source_code[: first_child.start_byte].count("\n") + 1 + name = source_code[first_child.start_byte : first_child.end_byte] + sinks.append(Value(name, line_number, ValueLabel.SINK, file_path, -1)) + return sinks diff --git a/src/tstool/dfbscan_extractor/Javascript/__init__.py b/src/tstool/dfbscan_extractor/Javascript/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tstool/dfbscan_extractor/Python/Python_NPD_extractor.py b/src/tstool/dfbscan_extractor/Python/Python_NPD_extractor.py index c59562d..1fe805f 100644 --- a/src/tstool/dfbscan_extractor/Python/Python_NPD_extractor.py +++ b/src/tstool/dfbscan_extractor/Python/Python_NPD_extractor.py @@ -1,11 +1,16 @@ from tstool.analyzer.TS_analyzer import * from tstool.analyzer.Python_TS_analyzer import * from ..dfbscan_extractor import * -import tree_sitter -import argparse class Python_NPD_Extractor(DFBScanExtractor): + def is_global_source(self, global_declarator_node: Node) -> bool: + """ + Determines whether the global variable is initially a source. + """ + # TODO: Implement source detection for global variables if needed + return False + def extract_sources(self, function: Function) -> List[Value]: root_node = function.parse_tree_root_node source_code = self.ts_analyzer.code_in_files[function.file_path] diff --git a/src/tstool/dfbscan_extractor/dfbscan_extractor.py b/src/tstool/dfbscan_extractor/dfbscan_extractor.py index d300cdb..9ce30be 100644 --- a/src/tstool/dfbscan_extractor/dfbscan_extractor.py +++ b/src/tstool/dfbscan_extractor/dfbscan_extractor.py @@ -24,18 +24,39 @@ def extract_all(self) -> Tuple[List[Value], List[Value]]: """ Start the source/sink extraction process. """ - pbar = tqdm(total=len(self.ts_analyzer.function_env), desc="Parsing files") + pbar = tqdm( + total=len(self.ts_analyzer.function_env) + + len(self.ts_analyzer.globalsRawDataDic), + desc="Parsing files", + ) + + # Extract src/sink values from functions for function_id in self.ts_analyzer.function_env: pbar.update(1) function: Function = self.ts_analyzer.function_env[function_id] if "test" in function.file_path or "example" in function.file_path: continue - file_content = self.ts_analyzer.code_in_files[function.file_path] - function_root_node = function.parse_tree_root_node + self.sources.extend(self.extract_sources(function)) self.sinks.extend(self.extract_sinks(function)) + + # Filter out non src global values in global_env + for global_id, global_data in self.ts_analyzer.globalsRawDataDic.items(): + pbar.update(1) + global_node = global_data[2] + if self.is_global_source(global_node): + self.ts_analyzer.globals_env[global_id].label = ValueLabel.SRC + else: + del self.ts_analyzer.globals_env[global_id] + + pbar.close() + return self.sources, self.sinks + @abstractmethod + def is_global_source(self, global_var: Node) -> bool: + pass + @abstractmethod def extract_sources(self, function: Function) -> List[Value]: """