一、设计目标

设计并实现一个股票交易策略回测系统,采用直观的图形界面,简洁页面设计方便数据读取与理解。

功能需求

(1)数据获取:从金融数据API下载股票历史数据,并存储到SQLite数据库

(2)策略设计:基于技术分析、基本面分析市场情绪面分析或资金面分析设计交易策略。利用历史数据模拟交易,对策略进行回测。

(3)可视化分析:使用Matplotlib/Plotly/PyQt等库展示回测结果,包括资金曲线、交易信号等。

(4)图形用户界面(GUI):使用PyQt/Tkinter等构建交互式界面,方便用户调整策略参数并查看回测结果.

二、运行环境(完整代码在本文底部)

推荐使用 anconda中jupyter notebook或Spyder,PyCharm也可。

本代码数据来源于akshare库从互联网获取股票数据,可下载数据避免过于频繁请求访问(代码中优先识别本地数据,若无下载则访问网络。)使用时请保持网络联通。

运行使用代码前请下载必要库:

pip install tkinter  pandas  numpy  matplotlib  mplfinance  json  sqlite3

三、系统设计

1、整体页面设计

   整体页面分为左右两侧,左侧为功能选择区,右侧作为数据展示及可视化视图展示。请首先点击选择股票按键选择想要分析的股票(首次加载需一段时间请耐心等待),点击获取数据即可使用后续功能。核心功能模块包括股票选择、日期范围设置、策略配置(如均线交叉策略)、初始资金调整以及数据管理(保存、加载、打开目录)。数据以表格形式展示,包含每日的开盘价、最高价、最低价、收盘价、成交量及涨跌幅,方便用户逐日分析市场表现。

2、可视化视图(举例)

K线图

均线图

成交量柱状图

回测示例

四、回测算法

   使用双均线系统(短期均线和长期均线),当短期均线上穿长期均线时产生买入信号(1),当短期均线下穿长期均线时产生卖出信号(-1),通过差分运算确保只在交叉点产生信号。

  

五、代码

import tkinter as tk
from tkinter import ttk, messagebox, scrolledtext, filedialog
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
import mplfinance as mpf
import os
from datetime import datetime, timedelta
import akshare as ak
import traceback
import json
import sqlite3
from sqlite3 import Error

# 设置matplotlib支持中文显示
plt.rcParams["font.family"] = ["SimHei", "Microsoft YaHei"]
plt.rcParams["axes.unicode_minus"] = False

# 设置数据保存目录和数据库路径
DATA_DIR = os.path.join(os.path.expanduser("~"), "stock_data")
if not os.path.exists(DATA_DIR):
    os.makedirs(DATA_DIR)
DB_PATH = os.path.join(DATA_DIR, "stock_data.db")

class DatabaseManager:
    """数据库管理类"""
    def __init__(self):
        self.conn = None
        self._initialize_database()
    
    def _initialize_database(self):
        """初始化数据库"""
        try:
            self.conn = sqlite3.connect(DB_PATH)
            self._create_tables()
        except Error as e:
            print(f"数据库初始化失败: {str(e)}")
            traceback.print_exc()
    
    def _create_tables(self):
        """创建数据表"""
        sql_create_stock_list_table = """
        CREATE TABLE IF NOT EXISTS stock_list (
            code TEXT PRIMARY KEY,
            name TEXT NOT NULL,
            update_date TEXT
        );
        """
        
        sql_create_stock_daily_table = """
        CREATE TABLE IF NOT EXISTS stock_daily (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            code TEXT NOT NULL,
            trade_date TEXT NOT NULL,
            open REAL,
            high REAL,
            low REAL,
            close REAL,
            volume REAL,
            amount REAL,
            pct_chg REAL,
            adjust TEXT,
            UNIQUE(code, trade_date, adjust)
        );
        """
        
        sql_create_stock_info_table = """
        CREATE TABLE IF NOT EXISTS stock_info (
            code TEXT PRIMARY KEY,
            info_json TEXT,
            update_date TEXT
        );
        """
        
        try:
            c = self.conn.cursor()
            c.execute(sql_create_stock_list_table)
            c.execute(sql_create_stock_daily_table)
            c.execute(sql_create_stock_info_table)
            self.conn.commit()
        except Error as e:
            print(f"创建表失败: {str(e)}")
            traceback.print_exc()
    
    def save_stock_list(self, stock_list):
        """保存股票列表到数据库"""
        try:
            today = datetime.now().strftime("%Y-%m-%d")
            c = self.conn.cursor()
            
            # 先删除旧数据
            c.execute("DELETE FROM stock_list")
            
            # 插入新数据
            for _, row in stock_list.iterrows():
                c.execute(
                    "INSERT INTO stock_list (code, name, update_date) VALUES (?, ?, ?)",
                    (row['代码'], row['名称'], today)
                )
            
            self.conn.commit()
            return True
        except Error as e:
            print(f"保存股票列表失败: {str(e)}")
            traceback.print_exc()
            return False
    
    def get_stock_list(self):
        """从数据库获取股票列表"""
        try:
            query = "SELECT code, name FROM stock_list ORDER BY code"
            df = pd.read_sql(query, self.conn)
            return df
        except Error as e:
            print(f"获取股票列表失败: {str(e)}")
            traceback.print_exc()
            return pd.DataFrame()
    
    def save_stock_daily(self, stock_code, df, adjust="qfq"):
        """保存股票日K线数据到数据库"""
        try:
            c = self.conn.cursor()
            
            # 批量插入数据
            for _, row in df.iterrows():
                try:
                    c.execute(
                        """INSERT INTO stock_daily 
                        (code, trade_date, open, high, low, close, volume, amount, pct_chg, adjust)
                        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
                        (
                            stock_code,
                            row['trade_date'].strftime("%Y-%m-%d"),
                            row['open'],
                            row['high'],
                            row['low'],
                            row['close'],
                            row['volume'],
                            row['amount'],
                            row['pct_chg'],
                            adjust
                        )
                    )
                except sqlite3.IntegrityError:
                    # 数据已存在,跳过
                    continue
            
            self.conn.commit()
            return True
        except Error as e:
            print(f"保存股票日K线数据失败: {str(e)}")
            traceback.print_exc()
            return False
    
    def get_stock_daily(self, stock_code, start_date=None, end_date=None, adjust="qfq"):
        """从数据库获取股票日K线数据"""
        try:
            query = """
            SELECT 
                trade_date, open, high, low, close, volume, amount, pct_chg
            FROM 
                stock_daily 
            WHERE 
                code = ? 
                AND adjust = ?
            """
            params = [stock_code, adjust]
            
            if start_date:
                query += " AND trade_date >= ?"
                params.append(start_date.strftime("%Y-%m-%d"))
            if end_date:
                query += " AND trade_date <= ?"
                params.append(end_date.strftime("%Y-%m-%d"))
            
            query += " ORDER BY trade_date"
            
            df = pd.read_sql(query, self.conn, params=params)
            
            if df.empty:
                return pd.DataFrame()
            
            df['trade_date'] = pd.to_datetime(df['trade_date'])
            return df
        except Error as e:
            print(f"获取股票日K线数据失败: {str(e)}")
            traceback.print_exc()
            return pd.DataFrame()
    
    def save_stock_info(self, stock_code, info_dict):
        """保存股票信息到数据库"""
        try:
            today = datetime.now().strftime("%Y-%m-%d")
            info_json = json.dumps(info_dict, ensure_ascii=False)
            
            c = self.conn.cursor()
            c.execute(
                """INSERT OR REPLACE INTO stock_info 
                (code, info_json, update_date) 
                VALUES (?, ?, ?)""",
                (stock_code, info_json, today)
            )
            
            self.conn.commit()
            return True
        except Error as e:
            print(f"保存股票信息失败: {str(e)}")
            traceback.print_exc()
            return False
    
    def get_stock_info(self, stock_code):
        """从数据库获取股票信息"""
        try:
            query = "SELECT info_json FROM stock_info WHERE code = ?"
            c = self.conn.cursor()
            c.execute(query, (stock_code,))
            result = c.fetchone()
            
            if result is None:
                return {}
            
            return json.loads(result[0])
        except Error as e:
            print(f"获取股票信息失败: {str(e)}")
            traceback.print_exc()
            return {}
    
    def close(self):
        """关闭数据库连接"""
        if self.conn:
            self.conn.close()

class StockDataManager:
    """股票数据管理类"""
    def __init__(self):
        self.db = DatabaseManager()
        self.stock_list = pd.DataFrame()
    
    def get_stock_list(self, refresh=False):
        """获取沪深A股股票列表"""
        try:
            if refresh:
                print("正在获取股票列表...")
                stock_list = ak.stock_info_a_code_name()
                stock_list = stock_list.rename(columns={
                    'code': '代码',
                    'name': '名称'
                })
                
                # 保存到数据库
                if not self.db.save_stock_list(stock_list):
                    print("保存股票列表到数据库失败")
                
                self.stock_list = stock_list
            else:
                # 从数据库获取
                self.stock_list = self.db.get_stock_list()
                if self.stock_list.empty:
                    # 数据库中没有数据,重新获取
                    return self.get_stock_list(refresh=True)
            
            return self.stock_list
        except Exception as e:
            print(f"获取股票列表失败: {str(e)}")
            traceback.print_exc()
            return pd.DataFrame()
    
    def get_stock_daily(self, stock_code, adjust="qfq"):
        """获取股票历史日K线数据"""
        try:
            print(f"正在获取股票 {stock_code} 的日K线数据...")
            
            # 首先尝试从数据库获取
            df = self.db.get_stock_daily(stock_code, adjust=adjust)
            
            if df.empty:
                # 数据库中没有数据,从API获取
                df = ak.stock_zh_a_hist(
                    symbol=stock_code, 
                    period="daily", 
                    start_date="19900101", 
                    end_date=datetime.now().strftime("%Y%m%d"),
                    adjust=adjust
                )
                
                if df.empty:
                    return pd.DataFrame()
                
                # 重命名列
                df = df.rename(columns={
                    "日期": "trade_date",
                    "开盘": "open",
                    "收盘": "close",
                    "最高": "high",
                    "最低": "low",
                    "成交量": "volume",
                    "成交额": "amount",
                    "涨跌幅": "pct_chg"
                })
                
                df['trade_date'] = pd.to_datetime(df['trade_date'])
                
                # 保存到数据库
                if not self.db.save_stock_daily(stock_code, df, adjust):
                    print("保存股票日K线数据到数据库失败")
            
            return df.sort_values('trade_date')
        
        except Exception as e:
            print(f"获取股票 {stock_code} 日K线数据失败: {str(e)}")
            traceback.print_exc()
            return pd.DataFrame()
    
    def get_stock_info(self, stock_code):
        """获取股票基本信息"""
        try:
            print(f"正在获取股票 {stock_code} 的基本信息...")
            
            # 首先尝试从数据库获取
            info_dict = self.db.get_stock_info(stock_code)
            
            if not info_dict:
                # 数据库中没有数据,从API获取
                info_df = ak.stock_individual_info_em(symbol=stock_code)
                
                if info_df.empty:
                    return {}
                
                # 转换为字典
                info_dict = {}
                for _, row in info_df.iterrows():
                    info_dict[row['item']] = row['value']
                
                # 添加股票代码和名称
                info_dict['股票代码'] = stock_code
                if '股票简称' in info_dict:
                    info_dict['股票名称'] = info_dict['股票简称']
                
                # 保存到数据库
                if not self.db.save_stock_info(stock_code, info_dict):
                    print("保存股票信息到数据库失败")
            
            return info_dict
        
        except Exception as e:
            print(f"获取股票 {stock_code} 基本信息失败: {str(e)}")
            traceback.print_exc()
            return {}

class StockBacktestApp:
    """主应用程序"""
    def __init__(self, root):
        self.root = root
        self.root.title("股票回测系统 v5.0")
        self.data_manager = StockDataManager()
        self.current_data = None
        self.current_stock = {'code': '', 'name': ''}
        self.backtest_results = None
        
        # 设置窗口大小
        self.root.geometry("1200x800")
        
        # 创建主界面
        self._setup_ui()
    
    def _setup_ui(self):
        """初始化用户界面"""
        # 主框架
        main_frame = ttk.Frame(self.root)
        main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        
        # 左侧控制面板
        control_panel = ttk.Frame(main_frame, width=300)
        control_panel.pack(side=tk.LEFT, fill=tk.Y)
        
        # 右侧内容区域
        content_panel = ttk.Frame(main_frame)
        content_panel.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
        
        # 左侧控制面板内容
        self._setup_control_panel(control_panel)
        
        # 右侧内容区域
        self._setup_content_panel(content_panel)
        
        # 状态栏
        self.status_var = tk.StringVar(value="就绪")
        status_bar = ttk.Label(self.root, textvariable=self.status_var, relief=tk.SUNKEN, anchor=tk.W)
        status_bar.pack(side=tk.BOTTOM, fill=tk.X)
    
    def _setup_control_panel(self, parent):
        """设置控制面板"""
        # 股票选择部分
        stock_frame = ttk.LabelFrame(parent, text="股票选择")
        stock_frame.pack(fill=tk.X, padx=5, pady=5)
        
        # 显示当前选择的股票
        self.current_stock_label = ttk.Label(
            stock_frame, 
            text="当前股票: 未选择",
            font=('Microsoft YaHei', 10, 'bold'),
            anchor=tk.CENTER
        )
        self.current_stock_label.pack(fill=tk.X, padx=5, pady=5)
        
        # 选择股票按钮
        ttk.Button(
            stock_frame, 
            text="选择股票", 
            command=self._show_stock_selector
        ).pack(fill=tk.X, padx=5, pady=5)
        
        # 显示股票信息按钮
        ttk.Button(
            stock_frame,
            text="查看股票详情",
            command=self._show_stock_info
        ).pack(fill=tk.X, padx=5, pady=5)
        
        # 日期范围设置
        date_frame = ttk.LabelFrame(parent, text="日期范围")
        date_frame.pack(fill=tk.X, padx=5, pady=5)
        
        ttk.Label(date_frame, text="开始日期:").pack(anchor=tk.W, padx=5)
        self.start_entry = ttk.Entry(date_frame)
        self.start_entry.pack(fill=tk.X, padx=5, pady=(0,5))
        self.start_entry.insert(0, (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d"))
        
        ttk.Label(date_frame, text="结束日期:").pack(anchor=tk.W, padx=5)
        self.end_entry = ttk.Entry(date_frame)
        self.end_entry.pack(fill=tk.X, padx=5, pady=(0,5))
        self.end_entry.insert(0, datetime.now().strftime("%Y-%m-%d"))
        
        ttk.Button(
            date_frame, 
            text="获取数据", 
            command=self._fetch_data
        ).pack(fill=tk.X, padx=5, pady=5)
        
        # 策略设置
        strategy_frame = ttk.LabelFrame(parent, text="策略设置")
        strategy_frame.pack(fill=tk.X, padx=5, pady=5)
        
        self.strategy_var = tk.StringVar(value="均线交叉")
        ttk.Radiobutton(
            strategy_frame, 
            text="均线交叉", 
            variable=self.strategy_var, 
            value="均线交叉"
        ).pack(anchor=tk.W, padx=5, pady=2)
        
        params_frame = ttk.Frame(strategy_frame)
        params_frame.pack(fill=tk.X, padx=5, pady=5)
        
        ttk.Label(params_frame, text="短期均线:").grid(row=0, column=0, sticky=tk.W, padx=5)
        self.short_ma = ttk.Entry(params_frame, width=8)
        self.short_ma.grid(row=0, column=1, sticky=tk.E, padx=5)
        self.short_ma.insert(0, "5")
        
        ttk.Label(params_frame, text="长期均线:").grid(row=1, column=0, sticky=tk.W, padx=5)
        self.long_ma = ttk.Entry(params_frame, width=8)
        self.long_ma.grid(row=1, column=1, sticky=tk.E, padx=5)
        self.long_ma.insert(0, "20")
        
        ttk.Label(params_frame, text="初始资金:").grid(row=2, column=0, sticky=tk.W, padx=5)
        self.capital = ttk.Entry(params_frame, width=8)
        self.capital.grid(row=2, column=1, sticky=tk.E, padx=5)
        self.capital.insert(0, "100000")
        
        ttk.Button(
            strategy_frame, 
            text="运行回测", 
            command=self._run_backtest
        ).pack(fill=tk.X, padx=5, pady=5)
        
        # 数据管理
        manage_frame = ttk.LabelFrame(parent, text="数据管理")
        manage_frame.pack(fill=tk.X, padx=5, pady=5)
        
        ttk.Button(
            manage_frame, 
            text="保存数据", 
            command=self._save_data
        ).pack(fill=tk.X, padx=5, pady=2)
        
        ttk.Button(
            manage_frame, 
            text="加载数据", 
            command=self._load_data
        ).pack(fill=tk.X, padx=5, pady=2)
        
        ttk.Button(
            manage_frame, 
            text="打开数据目录", 
            command=self._open_data_dir
        ).pack(fill=tk.X, padx=5, pady=2)
    
    def _setup_content_panel(self, parent):
        """设置内容面板"""
        self.notebook = ttk.Notebook(parent)
        self.notebook.pack(fill=tk.BOTH, expand=True)
        
        # 数据预览标签页
        self._setup_data_tab()
        
        # 图表分析标签页
        self._setup_chart_tab()
        
        # 回测结果标签页
        self._setup_backtest_tab()
        
        # 指标分析标签页
        self._setup_metrics_tab()
    
    def _setup_data_tab(self):
        """设置数据预览标签页"""
        tab = ttk.Frame(self.notebook)
        self.notebook.add(tab, text="数据预览")
        
        # 创建表格
        columns = ("日期", "开盘", "最高", "最低", "收盘", "成交量", "涨跌幅")
        self.data_tree = ttk.Treeview(tab, columns=columns, show="headings")
        
        for col in columns:
            self.data_tree.heading(col, text=col)
            self.data_tree.column(col, width=100, anchor=tk.CENTER)
        
        # 添加滚动条
        scrollbar = ttk.Scrollbar(tab, orient=tk.VERTICAL, command=self.data_tree.yview)
        self.data_tree.configure(yscroll=scrollbar.set)
        scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
        self.data_tree.pack(fill=tk.BOTH, expand=True)
    
    def _setup_chart_tab(self):
        """设置图表分析标签页"""
        tab = ttk.Frame(self.notebook)
        self.notebook.add(tab, text="图表分析")
        
        # 图表类型选择
        chart_type_frame = ttk.Frame(tab)
        chart_type_frame.pack(fill=tk.X, padx=5, pady=5)
        
        self.chart_type = tk.StringVar(value="kline")
        ttk.Radiobutton(
            chart_type_frame, 
            text="K线图", 
            variable=self.chart_type, 
            value="kline",
            command=self._plot_chart
        ).pack(side=tk.LEFT, padx=5)
        
        ttk.Radiobutton(
            chart_type_frame, 
            text="均线图", 
            variable=self.chart_type, 
            value="ma",
            command=self._plot_chart
        ).pack(side=tk.LEFT, padx=5)
        
        ttk.Radiobutton(
            chart_type_frame, 
            text="成交量", 
            variable=self.chart_type, 
            value="volume",
            command=self._plot_chart
        ).pack(side=tk.LEFT, padx=5)
        
        # 图表容器
        self.chart_frame = ttk.Frame(tab)
        self.chart_frame.pack(fill=tk.BOTH, expand=True)
    
    def _setup_backtest_tab(self):
        """设置回测结果标签页"""
        tab = ttk.Frame(self.notebook)
        self.notebook.add(tab, text="回测结果")
        
        self.backtest_frame = ttk.Frame(tab)
        self.backtest_frame.pack(fill=tk.BOTH, expand=True)
    
    def _setup_metrics_tab(self):
        """设置指标分析标签页"""
        tab = ttk.Frame(self.notebook)
        self.notebook.add(tab, text="回测指标")
        
        # 创建指标表格
        columns = ("指标", "数值")
        self.metrics_tree = ttk.Treeview(tab, columns=columns, show="headings")
        
        for col in columns:
            self.metrics_tree.heading(col, text=col)
            self.metrics_tree.column(col, width=200, anchor=tk.CENTER)
        
        self.metrics_tree.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
    
    def _show_stock_selector(self):
        """显示股票选择对话框"""
        dialog = tk.Toplevel(self.root)
        dialog.title("股票选择器")
        dialog.geometry("800x600")
        
        # 主框架
        main_frame = ttk.Frame(dialog)
        main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        
        # 搜索框区域
        search_frame = ttk.Frame(main_frame)
        search_frame.pack(fill=tk.X, pady=5)
        
        ttk.Label(search_frame, text="🔍 搜索:").pack(side=tk.LEFT)
        self.search_var = tk.StringVar()
        search_entry = ttk.Entry(search_frame, textvariable=self.search_var)
        search_entry.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=5)
        search_entry.bind("<KeyRelease>", self._filter_stock_list)
        
        # 股票列表区域
        list_frame = ttk.LabelFrame(main_frame, text="股票列表")
        list_frame.pack(fill=tk.BOTH, expand=True)
        
        columns = ("代码", "名称")
        self.stock_tree = ttk.Treeview(list_frame, columns=columns, show="headings", height=20)
        
        for col in columns:
            self.stock_tree.heading(col, text=col)
            self.stock_tree.column(col, width=100, anchor=tk.W)
        
        scrollbar = ttk.Scrollbar(list_frame, orient=tk.VERTICAL, command=self.stock_tree.yview)
        self.stock_tree.configure(yscroll=scrollbar.set)
        scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
        self.stock_tree.pack(fill=tk.BOTH, expand=True)
        
        # 按钮区域
        button_frame = ttk.Frame(main_frame)
        button_frame.pack(fill=tk.X, pady=5)
        
        ttk.Button(
            button_frame,
            text="选择",
            command=lambda: self._select_stock(dialog)
        ).pack(side=tk.LEFT, padx=5)
        
        ttk.Button(
            button_frame,
            text="刷新列表",
            command=self._load_stock_list
        ).pack(side=tk.LEFT, padx=5)
        
        ttk.Button(
            button_frame,
            text="取消",
            command=dialog.destroy
        ).pack(side=tk.RIGHT, padx=5)
        
        # 绑定事件
        self.stock_tree.bind("<Double-1>", lambda e: self._select_stock(dialog))
        
        # 加载股票列表
        self._load_stock_list()
    
    def _show_stock_info(self):
        """显示股票详细信息"""
        if not self.current_stock['code']:
            messagebox.showwarning("警告", "请先选择股票")
            return
        
        dialog = tk.Toplevel(self.root)
        dialog.title(f"{self.current_stock['code']} {self.current_stock['name']} - 股票详情")
        dialog.geometry("600x500")
        
        # 获取股票信息
        info = self.data_manager.get_stock_info(self.current_stock['code'])
        
        # 创建文本区域
        text_area = scrolledtext.ScrolledText(
            dialog,
            wrap=tk.WORD,
            width=70,
            height=30,
            font=('Microsoft YaHei', 10)
        )
        text_area.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        
        # 格式化显示股票信息
        text_area.insert(tk.END, f"股票代码: {self.current_stock['code']}\n")
        text_area.insert(tk.END, f"股票名称: {self.current_stock['name']}\n")
        text_area.insert(tk.END, "-"*50 + "\n")
        
        # 添加重要信息
        important_fields = [
            '公司全称', '英文名称', '上市日期', '发行价格', 
            '注册资本', '行业分类', '主营业务',
            '市盈率', '市净率', '每股收益',
            '总股本', '流通股本', '更新时间'
        ]
        
        for field in important_fields:
            if field in info:
                text_area.insert(tk.END, f"{field}: {info[field]}\n")
        
        text_area.config(state=tk.DISABLED)
        
        # 关闭按钮
        ttk.Button(
            dialog,
            text="关闭",
            command=dialog.destroy
        ).pack(pady=5)
    
    def _load_stock_list(self):
        """加载股票列表"""
        self.status_var.set("正在获取股票列表...")
        self.root.update()
        
        try:
            stock_list = self.data_manager.get_stock_list(refresh=True)
            
            # 清空现有列表
            for item in self.stock_tree.get_children():
                self.stock_tree.delete(item)
            
            # 添加股票数据
            for _, row in stock_list.iterrows():
                self.stock_tree.insert("", tk.END, values=(
                    row['代码'],
                    row['名称']
                ))
            
            self.status_var.set("股票列表加载完成")
        
        except Exception as e:
            messagebox.showerror("错误", f"获取股票列表失败: {str(e)}")
            self.status_var.set(f"错误: {str(e)}")
    
    def _filter_stock_list(self, event=None):
        """过滤股票列表"""
        keyword = self.search_var.get().lower()
        
        for item in self.stock_tree.get_children():
            values = self.stock_tree.item(item, 'values')
            if any(keyword in str(v).lower() for v in values):
                self.stock_tree.item(item, tags=('match',))
                self.stock_tree.tag_configure('match', foreground='black')
            else:
                self.stock_tree.item(item, tags=('no_match',))
                self.stock_tree.tag_configure('no_match', foreground='gray')
    
    def _select_stock(self, dialog):
        """选择股票"""
        selected = self.stock_tree.focus()
        if selected:
            values = self.stock_tree.item(selected, 'values')
            self.current_stock = {'code': values[0], 'name': values[1]}
            self.current_stock_label.config(text=f"当前股票: {values[0]} {values[1]}")
            dialog.destroy()
    
    def _fetch_data(self):
        """获取股票数据"""
        if not self.current_stock['code']:
            messagebox.showwarning("警告", "请先选择股票")
            return
        
        code = self.current_stock['code']
        start = self.start_entry.get()
        end = self.end_entry.get()
        
        self.status_var.set(f"正在获取 {code} 的数据...")
        self.root.update()
        
        try:
            # 从数据库获取数据
            df = self.data_manager.get_stock_daily(code)
            
            if df.empty:
                messagebox.showwarning("警告", "未获取到数据")
                return
            
            mask = (df['trade_date'] >= pd.to_datetime(start)) & (df['trade_date'] <= pd.to_datetime(end))
            self.current_data = df[mask].sort_values('trade_date')
            
            info = self.data_manager.get_stock_info(code)
            self.current_stock = {
                'code': code,
                'name': info.get('股票名称', code)
            }
            self.current_stock_label.config(text=f"当前股票: {code} {self.current_stock['name']}")
            
            self._update_data_table()
            self._plot_chart()
            
            self.status_var.set(f"成功获取 {code} 数据 ({len(self.current_data)} 条记录)")
            self.notebook.select(0)
        
        except Exception as e:
            messagebox.showerror("错误", f"获取数据失败: {str(e)}")
            self.status_var.set(f"错误: {str(e)}")
    
    def _update_data_table(self):
        """更新数据表格"""
        for item in self.data_tree.get_children():
            self.data_tree.delete(item)
        
        if self.current_data is None:
            return
        
        for _, row in self.current_data.iterrows():
            self.data_tree.insert("", tk.END, values=(
                row['trade_date'].strftime("%Y-%m-%d"),
                f"{row['open']:.2f}",
                f"{row['high']:.2f}",
                f"{row['low']:.2f}",
                f"{row['close']:.2f}",
                f"{row['volume']:,.0f}",
                f"{row['pct_chg']:.2f}%"
            ))
    
    def _plot_chart(self):
        """绘制图表"""
        if self.current_data is None:
            return
        
        for widget in self.chart_frame.winfo_children():
            widget.destroy()
        
        df = self.current_data.set_index('trade_date')
        chart_type = self.chart_type.get()
        
        fig = Figure(figsize=(10, 6), dpi=100)
        ax = fig.add_subplot(111)
        
        if chart_type == "kline":
            mpf.plot(df, type='candle', ax=ax, volume=False, style='yahoo')
            ax.set_title(f"{self.current_stock['code']} {self.current_stock['name']} K线图")
        elif chart_type == "ma":
            ax.plot(df.index, df['close'], label='收盘价')
            ax.plot(df.index, df['close'].rolling(5).mean(), label='5日均线')
            ax.plot(df.index, df['close'].rolling(20).mean(), label='20日均线')
            ax.set_title(f"{self.current_stock['code']} {self.current_stock['name']} 均线图")
            ax.legend()
        elif chart_type == "volume":
            ax.bar(df.index, df['volume'], width=0.6, color='blue')
            ax.set_title(f"{self.current_stock['code']} {self.current_stock['name']} 成交量")
        
        ax.grid(True)
        
        canvas = FigureCanvasTkAgg(fig, master=self.chart_frame)
        canvas.draw()
        canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
    
    def _run_backtest(self):
        """运行回测"""
        if self.current_data is None:
            messagebox.showwarning("警告", "请先获取股票数据")
            return
        
        try:
            self.status_var.set("正在运行回测...")
            self.root.update()
            
            # 创建策略
            strategy = MovingAverageCrossStrategy(
                self.current_data, 
                {
                    'short_window': int(self.short_ma.get()),
                    'long_window': int(self.long_ma.get())
                }
            )
            
            # 运行回测
            backtester = Backtester(
                self.current_data, 
                strategy, 
                initial_capital=float(self.capital.get())
            )
            self.backtest_results = backtester
            
            # 显示结果
            self._plot_backtest_results()
            self._show_metrics()
            
            self.status_var.set("回测完成")
            self.notebook.select(2)
        
        except Exception as e:
            messagebox.showerror("错误", f"回测失败: {str(e)}")
            self.status_var.set(f"错误: {str(e)}")
    
    def _plot_backtest_results(self):
        """绘制回测结果图表"""
        if self.backtest_results is None:
            return
        
        for widget in self.backtest_frame.winfo_children():
            widget.destroy()
        
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True, gridspec_kw={'height_ratios': [3, 1]})
        
        ax1.plot(self.backtest_results.results.index, self.backtest_results.results['total'])
        ax1.set_title('资产曲线')
        ax1.set_ylabel('总资产')
        ax1.grid(True)
        
        ax2.fill_between(
            self.backtest_results.results.index,
            self.backtest_results.results['drawdown'],
            0,
            color='red',
            alpha=0.3
        )
        ax2.set_title('回撤')
        ax2.set_ylabel('回撤比例')
        ax2.grid(True)
        
        plt.tight_layout()
        
        canvas = FigureCanvasTkAgg(fig, master=self.backtest_frame)
        canvas.draw()
        canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
    
    def _show_metrics(self):
        """显示回测指标"""
        if self.backtest_results is None:
            return
        
        for item in self.metrics_tree.get_children():
            self.metrics_tree.delete(item)
        
        metrics = self.backtest_results.metrics
        self.metrics_tree.insert("", tk.END, values=("总收益率", f"{metrics['total_return']*100:.2f}%"))
        self.metrics_tree.insert("", tk.END, values=("年化收益率", f"{metrics['annual_return']*100:.2f}%"))
        self.metrics_tree.insert("", tk.END, values=("最大回撤", f"{metrics['max_drawdown']:.2%}"))
        self.metrics_tree.insert("", tk.END, values=("夏普比率", f"{metrics['sharpe_ratio']:.2f}"))
        self.metrics_tree.insert("", tk.END, values=("初始资金", f"{float(self.capital.get()):,.2f}"))
        self.metrics_tree.insert("", tk.END, values=("最终资金", f"{self.backtest_results.results['total'].iloc[-1]:,.2f}"))
    
    def _save_data(self):
        """保存数据到数据库"""
        if self.current_data is None:
            messagebox.showwarning("警告", "没有可保存的数据")
            return
        
        try:
            code = self.current_stock['code']
            df = self.current_data.copy()
            
            # 保存到数据库
            if self.data_manager.db.save_stock_daily(code, df):
                messagebox.showinfo("成功", "数据已保存到数据库")
                self.status_var.set("数据已保存到数据库")
            else:
                messagebox.showerror("错误", "保存数据到数据库失败")
                self.status_var.set("保存数据到数据库失败")
        except Exception as e:
            messagebox.showerror("错误", f"保存失败: {str(e)}")
            self.status_var.set(f"错误: {str(e)}")
    
    def _load_data(self):
        """从数据库加载数据"""
        if not self.current_stock['code']:
            messagebox.showwarning("警告", "请先选择股票")
            return
        
        try:
            code = self.current_stock['code']
            start = self.start_entry.get()
            end = self.end_entry.get()
            
            self.status_var.set(f"正在从数据库加载 {code} 的数据...")
            self.root.update()
            
            # 从数据库获取数据
            df = self.data_manager.db.get_stock_daily(
                code,
                start_date=pd.to_datetime(start),
                end_date=pd.to_datetime(end)
            )
            
            if df.empty:
                messagebox.showwarning("警告", "数据库中未找到该股票的数据")
                return
            
            self.current_data = df.sort_values('trade_date')
            
            info = self.data_manager.get_stock_info(code)
            self.current_stock = {
                'code': code,
                'name': info.get('股票名称', code)
            }
            self.current_stock_label.config(text=f"当前股票: {code} {self.current_stock['name']}")
            
            self._update_data_table()
            self._plot_chart()
            
            self.status_var.set(f"成功从数据库加载 {code} 数据 ({len(self.current_data)} 条记录)")
            self.notebook.select(0)
        
        except Exception as e:
            messagebox.showerror("错误", f"加载失败: {str(e)}")
            self.status_var.set(f"错误: {str(e)}")
    
    def _open_data_dir(self):
        """打开数据目录"""
        try:
            os.startfile(DATA_DIR)
            self.status_var.set(f"已打开数据目录: {DATA_DIR}")
        except:
            messagebox.showerror("错误", f"无法打开目录: {DATA_DIR}")
            self.status_var.set(f"错误: 无法打开目录 {DATA_DIR}")

class MovingAverageCrossStrategy:
    """均线交叉策略"""
    def __init__(self, stock_data, params=None):
        self.stock_data = stock_data
        self.params = params or {}
        self.signals = pd.DataFrame(index=stock_data.index)
        self.signals['signal'] = 0  # 0 不操作, 1 买入, -1 卖出
    
    def calculate_indicators(self):
        price = self.stock_data['close']
        self.signals['short_mavg'] = price.rolling(self.params['short_window']).mean()
        self.signals['long_mavg'] = price.rolling(self.params['long_window']).mean()
    
    def generate_signals(self):
        self.signals['signal'] = np.where(
            self.signals['short_mavg'] > self.signals['long_mavg'], 1, 
            np.where(self.signals['short_mavg'] < self.signals['long_mavg'], -1, 0)
        )
        self.signals['signal'] = self.signals['signal'].diff().fillna(0).replace([2,-2], 0)
    
    def get_signals(self):
        if self.signals['signal'].sum() == 0:
            self.calculate_indicators()
            self.generate_signals()
        return self.signals

class Backtester:
    """回测器"""
    def __init__(self, stock_data, strategy, initial_capital=100000, commission=0.0003):
        self.stock_data = stock_data
        self.strategy = strategy
        self.initial_capital = initial_capital
        self.commission = commission
        self.results = self._run_backtest()
        self.metrics = self._calculate_metrics()
    
    def _run_backtest(self):
        signals = self.strategy.get_signals()
        portfolio = pd.DataFrame(index=signals.index)
        portfolio['cash'] = self.initial_capital
        portfolio['shares'] = 0
        
        for i in range(1, len(portfolio)):
            signal = signals['signal'].iloc[i]
            price = self.stock_data['close'].iloc[i]
            
            if signal == 1:  # 买入
                available = portfolio['cash'].iloc[i-1] * (1 - self.commission)
                shares = int(available / price)
                if shares > 0:
                    portfolio['cash'].iloc[i] = portfolio['cash'].iloc[i-1] - shares * price * (1 + self.commission)
                    portfolio['shares'].iloc[i] = portfolio['shares'].iloc[i-1] + shares
                else:
                    portfolio.iloc[i] = portfolio.iloc[i-1]
            elif signal == -1:  # 卖出
                shares = portfolio['shares'].iloc[i-1]
                if shares > 0:
                    portfolio['cash'].iloc[i] = portfolio['cash'].iloc[i-1] + shares * price * (1 - self.commission)
                    portfolio['shares'].iloc[i] = 0
                else:
                    portfolio.iloc[i] = portfolio.iloc[i-1]
            else:
                portfolio.iloc[i] = portfolio.iloc[i-1]
        
        portfolio['total'] = portfolio['cash'] + portfolio['shares'] * self.stock_data['close']
        portfolio['return'] = portfolio['total'].pct_change()
        return portfolio
    
    def _calculate_metrics(self):
        """计算回测指标"""
        total_days = len(self.results)
        trading_days_per_year = 252
        
        # 计算累积收益率
        total_return = (self.results['total'].iloc[-1] / self.initial_capital) - 1
        
        # 计算年化收益率
        annual_return = (1 + total_return) ** (trading_days_per_year / total_days) - 1
        
        # 计算最大回撤
        self.results['cummax'] = self.results['total'].cummax()
        self.results['drawdown'] = (self.results['total'] / self.results['cummax']) - 1
        max_drawdown = self.results['drawdown'].min()
        
        # 计算夏普比率
        risk_free_rate = 0.03  # 假设无风险利率为3%
        daily_risk_free = (1 + risk_free_rate) ** (1/trading_days_per_year) - 1
        excess_return = self.results['return'] - daily_risk_free
        sharpe_ratio = np.sqrt(trading_days_per_year) * excess_return.mean() / excess_return.std()
        
        return {
            'total_return': total_return,
            'annual_return': annual_return,
            'max_drawdown': max_drawdown,
            'sharpe_ratio': sharpe_ratio
        }

if __name__ == "__main__":
    root = tk.Tk()
    app = StockBacktestApp(root)
    root.mainloop()

Logo

专业量化交易与投资者大本营

更多推荐