PyTorch implementation for SQL-o1: A Self-Reward Heuristic Dynamic Search Method for Text-to-SQL.
conda create -n moeuie python=3.11
conda activate SQL-o1
pip install torch==2.3.0
pip install -r requirements.txt
SQL-o1/
└──dataset/
├── spider/
├── train.json
├── tables.json
├── Spider_DK.json
├── spider-realistic.json
├── dev_syn.json
├── ...
├── dev.json
├── test.json
├── test_database/
└── database/
├── bird/
├── train/
├── train.json
├── train_tables.json
├── ...
└── train_databases/
├── dev/
├── dev.json
├── dev_tables.json
├── ...
└── dev_databases/
python preprocess_data.py --dataset spider|spider_real|spider_DK|spider_syn --mode train --LLM_model meta-llama/Meta-Llama-3-8B-Instruct --PSG --data_path /data/vda/dataset --output_path ./dataset
python preprocess_data.py --dataset bird --mode train --LLM_model meta-llama/Meta-Llama-3-8B-Instruct --PSG --data_path /data/vda/dataset --output_path ./dataset
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export --model_name_or_path /home/huggingface/meta-llama/Llama-3-8B-Instruct --adapter_name_or_path /data/vda/saves/llama3-8b/sft/lora --template llama3 --finetuning_type lora --use_dora --export_dir /data/vda/llama3_merge/ --export_size 2 --export_legacy_format False
python preprocess_data.py --dataset spider|spider_real|spider_DK|spider_syn --mode dev(test: spider_test) --LLM_model meta-llama/Meta-Llama-3-8B-Instruct --data_path /data/vda/dataset --output_path ./dataset
python preprocess_data.py --dataset bird --mode dev --LLM_model meta-llama/Meta-Llama-3-8B-Instruct --data_path /data/vda/dataset --output_path ./dataset
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 nohup python src/llm_api.py --model_name_or_path /data/vda/llama3_merge/ --template llama3 --temperature 0.9 >> result_llm_api_0.log 2>&1 &
2.3 MCTS Explore for Model (Results collection & Please replace it with your own valid parameters. )
nohup python _run_explore.py --task_name bird >> result_mcts_0.txt 2>&1 &
python validation_results.py --json_path ./mcts_results/bird_mcts_dev.json ( | spider_mcts_dev.json | spider_syn.json | spider_DK.json | spider_real.json | spider_test.json ) --db_root_path ./dataset/bird/dev/dev_databases --num_cpus 1 --diff_json_path ./dataset/bird/dev/dev.json --output_file spider_dev.sql (...)
bash kill_llm_api.sh