Compare commits
44 Commits
cbc6a7493d
...
xgboost
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
65f30a4020 | ||
|
|
be331ed631 | ||
|
|
6c5dcc1183 | ||
|
|
02e5db2a36 | ||
|
|
a877f14e65 | ||
|
|
082a2835b6 | ||
|
|
ada6150413 | ||
|
|
ced64825bd | ||
|
|
2f98463df8 | ||
|
|
2a52ffde9a | ||
|
|
a22914731f | ||
|
|
81e4b640a7 | ||
|
|
2dba88b620 | ||
|
|
de67b27e37 | ||
|
|
1284549106 | ||
|
|
5f03524d6a | ||
|
|
74c8048ed5 | ||
|
|
2fd73085b8 | ||
|
|
806697116d | ||
|
|
14905017c8 | ||
|
|
ec1a86e098 | ||
|
|
0a919f825e | ||
|
|
c2886a2aab | ||
|
|
10cc047975 | ||
|
|
955a340d02 | ||
|
|
07b9824b69 | ||
|
|
369b3c1daf | ||
|
|
08c871e05a | ||
|
|
837c505828 | ||
|
|
1cdfe3973a | ||
|
|
8ff86339d6 | ||
|
|
7f788a4d4e | ||
|
|
0eb7fc77f9 | ||
|
|
170751db0e | ||
|
|
f7f0fc6dd5 | ||
|
|
e4ded694b1 | ||
|
|
fa12bcb61a | ||
|
|
125d4f7d52 | ||
|
|
ec8b1a7cf2 | ||
|
|
7c4db08b1b | ||
|
|
f316571a3c | ||
|
|
c7732881c5 | ||
|
|
b0ffedc6af | ||
|
|
e9bfcd03eb |
61
.cursor/rules/always-global.mdc
Normal file
61
.cursor/rules/always-global.mdc
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
---
|
||||||
|
description: Global development standards and AI interaction principles
|
||||||
|
globs:
|
||||||
|
alwaysApply: true
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Always Apply - Global Development Standards
|
||||||
|
|
||||||
|
## AI Interaction Principles
|
||||||
|
|
||||||
|
### Step-by-Step Development
|
||||||
|
- **NEVER** generate large blocks of code without explanation
|
||||||
|
- **ALWAYS** ask "provide your plan in a concise bullet list and wait for my confirmation before proceeding"
|
||||||
|
- Break complex tasks into smaller, manageable pieces (≤250 lines per file, ≤50 lines per function)
|
||||||
|
- Explain your reasoning step-by-step before writing code
|
||||||
|
- Wait for explicit approval before moving to the next sub-task
|
||||||
|
|
||||||
|
### Context Awareness
|
||||||
|
- **ALWAYS** reference existing code patterns and data structures before suggesting new approaches
|
||||||
|
- Ask about existing conventions before implementing new functionality
|
||||||
|
- Preserve established architectural decisions unless explicitly asked to change them
|
||||||
|
- Maintain consistency with existing naming conventions and code style
|
||||||
|
|
||||||
|
## Code Quality Standards
|
||||||
|
|
||||||
|
### File and Function Limits
|
||||||
|
- **Maximum file size**: 250 lines
|
||||||
|
- **Maximum function size**: 50 lines
|
||||||
|
- **Maximum complexity**: If a function does more than one main thing, break it down
|
||||||
|
- **Naming**: Use clear, descriptive names that explain purpose
|
||||||
|
|
||||||
|
### Documentation Requirements
|
||||||
|
- **Every public function** must have a docstring explaining purpose, parameters, and return value
|
||||||
|
- **Every class** must have a class-level docstring
|
||||||
|
- **Complex logic** must have inline comments explaining the "why", not just the "what"
|
||||||
|
- **API endpoints** must be documented with request/response examples
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
- **ALWAYS** include proper error handling for external dependencies
|
||||||
|
- **NEVER** use bare except clauses
|
||||||
|
- Provide meaningful error messages that help with debugging
|
||||||
|
- Log errors appropriately for the application context
|
||||||
|
|
||||||
|
## Security and Best Practices
|
||||||
|
- **NEVER** hardcode credentials, API keys, or sensitive data
|
||||||
|
- **ALWAYS** validate user inputs
|
||||||
|
- Use parameterized queries for database operations
|
||||||
|
- Follow the principle of least privilege
|
||||||
|
- Implement proper authentication and authorization
|
||||||
|
|
||||||
|
## Testing Requirements
|
||||||
|
- **Every implementation** should have corresponding unit tests
|
||||||
|
- **Every API endpoint** should have integration tests
|
||||||
|
- Test files should be placed alongside the code they test
|
||||||
|
- Use descriptive test names that explain what is being tested
|
||||||
|
|
||||||
|
## Response Format
|
||||||
|
- Be concise and avoid unnecessary repetition
|
||||||
|
- Focus on actionable information
|
||||||
|
- Provide examples when explaining complex concepts
|
||||||
|
- Ask clarifying questions when requirements are ambiguous
|
||||||
237
.cursor/rules/architecture.mdc
Normal file
237
.cursor/rules/architecture.mdc
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
---
|
||||||
|
description: Modular design principles and architecture guidelines for scalable development
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Architecture and Modular Design
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Maintain a clean, modular architecture that scales effectively and prevents the complexity issues that arise in AI-assisted development.
|
||||||
|
|
||||||
|
## Core Architecture Principles
|
||||||
|
|
||||||
|
### 1. Modular Design
|
||||||
|
- **Single Responsibility**: Each module has one clear purpose
|
||||||
|
- **Loose Coupling**: Modules depend on interfaces, not implementations
|
||||||
|
- **High Cohesion**: Related functionality is grouped together
|
||||||
|
- **Clear Boundaries**: Module interfaces are well-defined and stable
|
||||||
|
|
||||||
|
### 2. Size Constraints
|
||||||
|
- **Files**: Maximum 250 lines per file
|
||||||
|
- **Functions**: Maximum 50 lines per function
|
||||||
|
- **Classes**: Maximum 300 lines per class
|
||||||
|
- **Modules**: Maximum 10 public functions/classes per module
|
||||||
|
|
||||||
|
### 3. Dependency Management
|
||||||
|
- **Layer Dependencies**: Higher layers depend on lower layers only
|
||||||
|
- **No Circular Dependencies**: Modules cannot depend on each other cyclically
|
||||||
|
- **Interface Segregation**: Depend on specific interfaces, not broad ones
|
||||||
|
- **Dependency Injection**: Pass dependencies rather than creating them internally
|
||||||
|
|
||||||
|
## Modular Architecture Patterns
|
||||||
|
|
||||||
|
### Layer Structure
|
||||||
|
```
|
||||||
|
src/
|
||||||
|
├── presentation/ # UI, API endpoints, CLI interfaces
|
||||||
|
├── application/ # Business logic, use cases, workflows
|
||||||
|
├── domain/ # Core business entities and rules
|
||||||
|
├── infrastructure/ # Database, external APIs, file systems
|
||||||
|
└── shared/ # Common utilities, constants, types
|
||||||
|
```
|
||||||
|
|
||||||
|
### Module Organization
|
||||||
|
```
|
||||||
|
module_name/
|
||||||
|
├── __init__.py # Public interface exports
|
||||||
|
├── core.py # Main module logic
|
||||||
|
├── types.py # Type definitions and interfaces
|
||||||
|
├── utils.py # Module-specific utilities
|
||||||
|
├── tests/ # Module tests
|
||||||
|
└── README.md # Module documentation
|
||||||
|
```
|
||||||
|
|
||||||
|
## Design Patterns for AI Development
|
||||||
|
|
||||||
|
### 1. Repository Pattern
|
||||||
|
Separate data access from business logic:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Domain interface
|
||||||
|
class UserRepository:
|
||||||
|
def get_by_id(self, user_id: str) -> User: ...
|
||||||
|
def save(self, user: User) -> None: ...
|
||||||
|
|
||||||
|
# Infrastructure implementation
|
||||||
|
class SqlUserRepository(UserRepository):
|
||||||
|
def get_by_id(self, user_id: str) -> User:
|
||||||
|
# Database-specific implementation
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Service Pattern
|
||||||
|
Encapsulate business logic in focused services:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class UserService:
|
||||||
|
def __init__(self, user_repo: UserRepository):
|
||||||
|
self._user_repo = user_repo
|
||||||
|
|
||||||
|
def create_user(self, data: UserData) -> User:
|
||||||
|
# Validation and business logic
|
||||||
|
# Single responsibility: user creation
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Factory Pattern
|
||||||
|
Create complex objects with clear interfaces:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class DatabaseFactory:
|
||||||
|
@staticmethod
|
||||||
|
def create_connection(config: DatabaseConfig) -> Connection:
|
||||||
|
# Handle different database types
|
||||||
|
# Encapsulate connection complexity
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture Decision Guidelines
|
||||||
|
|
||||||
|
### When to Create New Modules
|
||||||
|
Create a new module when:
|
||||||
|
- **Functionality** exceeds size constraints (250 lines)
|
||||||
|
- **Responsibility** is distinct from existing modules
|
||||||
|
- **Dependencies** would create circular references
|
||||||
|
- **Reusability** would benefit other parts of the system
|
||||||
|
- **Testing** requires isolated test environments
|
||||||
|
|
||||||
|
### When to Split Existing Modules
|
||||||
|
Split modules when:
|
||||||
|
- **File size** exceeds 250 lines
|
||||||
|
- **Multiple responsibilities** are evident
|
||||||
|
- **Testing** becomes difficult due to complexity
|
||||||
|
- **Dependencies** become too numerous
|
||||||
|
- **Change frequency** differs significantly between parts
|
||||||
|
|
||||||
|
### Module Interface Design
|
||||||
|
```python
|
||||||
|
# Good: Clear, focused interface
|
||||||
|
class PaymentProcessor:
|
||||||
|
def process_payment(self, amount: Money, method: PaymentMethod) -> PaymentResult:
|
||||||
|
"""Process a single payment transaction."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Bad: Unfocused, kitchen-sink interface
|
||||||
|
class PaymentManager:
|
||||||
|
def process_payment(self, ...): pass
|
||||||
|
def validate_card(self, ...): pass
|
||||||
|
def send_receipt(self, ...): pass
|
||||||
|
def update_inventory(self, ...): pass # Wrong responsibility!
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture Validation
|
||||||
|
|
||||||
|
### Architecture Review Checklist
|
||||||
|
- [ ] **Dependencies flow in one direction** (no cycles)
|
||||||
|
- [ ] **Layers are respected** (presentation doesn't call infrastructure directly)
|
||||||
|
- [ ] **Modules have single responsibility**
|
||||||
|
- [ ] **Interfaces are stable** and well-defined
|
||||||
|
- [ ] **Size constraints** are maintained
|
||||||
|
- [ ] **Testing** is straightforward for each module
|
||||||
|
|
||||||
|
### Red Flags
|
||||||
|
- **God Objects**: Classes/modules that do too many things
|
||||||
|
- **Circular Dependencies**: Modules that depend on each other
|
||||||
|
- **Deep Inheritance**: More than 3 levels of inheritance
|
||||||
|
- **Large Interfaces**: Interfaces with more than 7 methods
|
||||||
|
- **Tight Coupling**: Modules that know too much about each other's internals
|
||||||
|
|
||||||
|
## Refactoring Guidelines
|
||||||
|
|
||||||
|
### When to Refactor
|
||||||
|
- Module exceeds size constraints
|
||||||
|
- Code duplication across modules
|
||||||
|
- Difficult to test individual components
|
||||||
|
- New features require changing multiple unrelated modules
|
||||||
|
- Performance bottlenecks due to poor separation
|
||||||
|
|
||||||
|
### Refactoring Process
|
||||||
|
1. **Identify** the specific architectural problem
|
||||||
|
2. **Design** the target architecture
|
||||||
|
3. **Create tests** to verify current behavior
|
||||||
|
4. **Implement changes** incrementally
|
||||||
|
5. **Validate** that tests still pass
|
||||||
|
6. **Update documentation** to reflect changes
|
||||||
|
|
||||||
|
### Safe Refactoring Practices
|
||||||
|
- **One change at a time**: Don't mix refactoring with new features
|
||||||
|
- **Tests first**: Ensure comprehensive test coverage before refactoring
|
||||||
|
- **Incremental changes**: Small steps with verification at each stage
|
||||||
|
- **Backward compatibility**: Maintain existing interfaces during transition
|
||||||
|
- **Documentation updates**: Keep architecture documentation current
|
||||||
|
|
||||||
|
## Architecture Documentation
|
||||||
|
|
||||||
|
### Architecture Decision Records (ADRs)
|
||||||
|
Document significant decisions in `./docs/decisions/`:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
# ADR-003: Service Layer Architecture
|
||||||
|
|
||||||
|
## Status
|
||||||
|
Accepted
|
||||||
|
|
||||||
|
## Context
|
||||||
|
As the application grows, business logic is scattered across controllers and models.
|
||||||
|
|
||||||
|
## Decision
|
||||||
|
Implement a service layer to encapsulate business logic.
|
||||||
|
|
||||||
|
## Consequences
|
||||||
|
**Positive:**
|
||||||
|
- Clear separation of concerns
|
||||||
|
- Easier testing of business logic
|
||||||
|
- Better reusability across different interfaces
|
||||||
|
|
||||||
|
**Negative:**
|
||||||
|
- Additional abstraction layer
|
||||||
|
- More files to maintain
|
||||||
|
```
|
||||||
|
|
||||||
|
### Module Documentation Template
|
||||||
|
```markdown
|
||||||
|
# Module: [Name]
|
||||||
|
|
||||||
|
## Purpose
|
||||||
|
What this module does and why it exists.
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
- **Imports from**: List of modules this depends on
|
||||||
|
- **Used by**: List of modules that depend on this one
|
||||||
|
- **External**: Third-party dependencies
|
||||||
|
|
||||||
|
## Public Interface
|
||||||
|
```python
|
||||||
|
# Key functions and classes exposed by this module
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture Notes
|
||||||
|
- Design patterns used
|
||||||
|
- Important architectural decisions
|
||||||
|
- Known limitations or constraints
|
||||||
|
```
|
||||||
|
|
||||||
|
## Migration Strategies
|
||||||
|
|
||||||
|
### Legacy Code Integration
|
||||||
|
- **Strangler Fig Pattern**: Gradually replace old code with new modules
|
||||||
|
- **Adapter Pattern**: Create interfaces to integrate old and new code
|
||||||
|
- **Facade Pattern**: Simplify complex legacy interfaces
|
||||||
|
|
||||||
|
### Gradual Modernization
|
||||||
|
1. **Identify boundaries** in existing code
|
||||||
|
2. **Extract modules** one at a time
|
||||||
|
3. **Create interfaces** for each extracted module
|
||||||
|
4. **Test thoroughly** at each step
|
||||||
|
5. **Update documentation** continuously
|
||||||
123
.cursor/rules/code-review.mdc
Normal file
123
.cursor/rules/code-review.mdc
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
---
|
||||||
|
description: AI-generated code review checklist and quality assurance guidelines
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Code Review and Quality Assurance
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Establish systematic review processes for AI-generated code to maintain quality, security, and maintainability standards.
|
||||||
|
|
||||||
|
## AI Code Review Checklist
|
||||||
|
|
||||||
|
### Pre-Implementation Review
|
||||||
|
Before accepting any AI-generated code:
|
||||||
|
|
||||||
|
1. **Understand the Code**
|
||||||
|
- [ ] Can you explain what the code does in your own words?
|
||||||
|
- [ ] Do you understand each function and its purpose?
|
||||||
|
- [ ] Are there any "magic" values or unexplained logic?
|
||||||
|
- [ ] Does the code solve the actual problem stated?
|
||||||
|
|
||||||
|
2. **Architecture Alignment**
|
||||||
|
- [ ] Does the code follow established project patterns?
|
||||||
|
- [ ] Is it consistent with existing data structures?
|
||||||
|
- [ ] Does it integrate cleanly with existing components?
|
||||||
|
- [ ] Are new dependencies justified and necessary?
|
||||||
|
|
||||||
|
3. **Code Quality**
|
||||||
|
- [ ] Are functions smaller than 50 lines?
|
||||||
|
- [ ] Are files smaller than 250 lines?
|
||||||
|
- [ ] Are variable and function names descriptive?
|
||||||
|
- [ ] Is the code DRY (Don't Repeat Yourself)?
|
||||||
|
|
||||||
|
### Security Review
|
||||||
|
- [ ] **Input Validation**: All user inputs are validated and sanitized
|
||||||
|
- [ ] **Authentication**: Proper authentication checks are in place
|
||||||
|
- [ ] **Authorization**: Access controls are implemented correctly
|
||||||
|
- [ ] **Data Protection**: Sensitive data is handled securely
|
||||||
|
- [ ] **SQL Injection**: Database queries use parameterized statements
|
||||||
|
- [ ] **XSS Prevention**: Output is properly escaped
|
||||||
|
- [ ] **Error Handling**: Errors don't leak sensitive information
|
||||||
|
|
||||||
|
### Integration Review
|
||||||
|
- [ ] **Existing Functionality**: New code doesn't break existing features
|
||||||
|
- [ ] **Data Consistency**: Database changes maintain referential integrity
|
||||||
|
- [ ] **API Compatibility**: Changes don't break existing API contracts
|
||||||
|
- [ ] **Performance Impact**: New code doesn't introduce performance bottlenecks
|
||||||
|
- [ ] **Testing Coverage**: Appropriate tests are included
|
||||||
|
|
||||||
|
## Review Process
|
||||||
|
|
||||||
|
### Step 1: Initial Code Analysis
|
||||||
|
1. **Read through the entire generated code** before running it
|
||||||
|
2. **Identify patterns** that don't match existing codebase
|
||||||
|
3. **Check dependencies** - are new packages really needed?
|
||||||
|
4. **Verify logic flow** - does the algorithm make sense?
|
||||||
|
|
||||||
|
### Step 2: Security and Error Handling Review
|
||||||
|
1. **Trace data flow** from input to output
|
||||||
|
2. **Identify potential failure points** and verify error handling
|
||||||
|
3. **Check for security vulnerabilities** using the security checklist
|
||||||
|
4. **Verify proper logging** and monitoring implementation
|
||||||
|
|
||||||
|
### Step 3: Integration Testing
|
||||||
|
1. **Test with existing code** to ensure compatibility
|
||||||
|
2. **Run existing test suite** to verify no regressions
|
||||||
|
3. **Test edge cases** and error conditions
|
||||||
|
4. **Verify performance** under realistic conditions
|
||||||
|
|
||||||
|
## Common AI Code Issues to Watch For
|
||||||
|
|
||||||
|
### Overcomplication Patterns
|
||||||
|
- **Unnecessary abstractions**: AI creating complex patterns for simple tasks
|
||||||
|
- **Over-engineering**: Solutions that are more complex than needed
|
||||||
|
- **Redundant code**: AI recreating existing functionality
|
||||||
|
- **Inappropriate design patterns**: Using patterns that don't fit the use case
|
||||||
|
|
||||||
|
### Context Loss Indicators
|
||||||
|
- **Inconsistent naming**: Different conventions from existing code
|
||||||
|
- **Wrong data structures**: Using different patterns than established
|
||||||
|
- **Ignored existing functions**: Reimplementing existing functionality
|
||||||
|
- **Architectural misalignment**: Code that doesn't fit the overall design
|
||||||
|
|
||||||
|
### Technical Debt Indicators
|
||||||
|
- **Magic numbers**: Hardcoded values without explanation
|
||||||
|
- **Poor error messages**: Generic or unhelpful error handling
|
||||||
|
- **Missing documentation**: Code without adequate comments
|
||||||
|
- **Tight coupling**: Components that are too interdependent
|
||||||
|
|
||||||
|
## Quality Gates
|
||||||
|
|
||||||
|
### Mandatory Reviews
|
||||||
|
All AI-generated code must pass these gates before acceptance:
|
||||||
|
|
||||||
|
1. **Security Review**: No security vulnerabilities detected
|
||||||
|
2. **Integration Review**: Integrates cleanly with existing code
|
||||||
|
3. **Performance Review**: Meets performance requirements
|
||||||
|
4. **Maintainability Review**: Code can be easily modified by team members
|
||||||
|
5. **Documentation Review**: Adequate documentation is provided
|
||||||
|
|
||||||
|
### Acceptance Criteria
|
||||||
|
- [ ] Code is understandable by any team member
|
||||||
|
- [ ] Integration requires minimal changes to existing code
|
||||||
|
- [ ] Security review passes all checks
|
||||||
|
- [ ] Performance meets established benchmarks
|
||||||
|
- [ ] Documentation is complete and accurate
|
||||||
|
|
||||||
|
## Rejection Criteria
|
||||||
|
Reject AI-generated code if:
|
||||||
|
- Security vulnerabilities are present
|
||||||
|
- Code is too complex for the problem being solved
|
||||||
|
- Integration requires major refactoring of existing code
|
||||||
|
- Code duplicates existing functionality without justification
|
||||||
|
- Documentation is missing or inadequate
|
||||||
|
|
||||||
|
## Review Documentation
|
||||||
|
For each review, document:
|
||||||
|
- Issues found and how they were resolved
|
||||||
|
- Performance impact assessment
|
||||||
|
- Security concerns and mitigations
|
||||||
|
- Integration challenges and solutions
|
||||||
|
- Recommendations for future similar tasks
|
||||||
93
.cursor/rules/context-management.mdc
Normal file
93
.cursor/rules/context-management.mdc
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
---
|
||||||
|
description: Context management for maintaining codebase awareness and preventing context drift
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Context Management
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Maintain comprehensive project context to prevent context drift and ensure AI-generated code integrates seamlessly with existing codebase patterns and architecture.
|
||||||
|
|
||||||
|
## Context Documentation Requirements
|
||||||
|
|
||||||
|
### PRD.md file documentation
|
||||||
|
1. **Project Overview**
|
||||||
|
- Business objectives and goals
|
||||||
|
- Target users and use cases
|
||||||
|
- Key success metrics
|
||||||
|
|
||||||
|
### CONTEXT.md File Structure
|
||||||
|
Every project must maintain a `CONTEXT.md` file in the root directory with:
|
||||||
|
|
||||||
|
1. **Architecture Overview**
|
||||||
|
- High-level system architecture
|
||||||
|
- Key design patterns used
|
||||||
|
- Database schema overview
|
||||||
|
- API structure and conventions
|
||||||
|
|
||||||
|
2. **Technology Stack**
|
||||||
|
- Programming languages and versions
|
||||||
|
- Frameworks and libraries
|
||||||
|
- Database systems
|
||||||
|
- Development and deployment tools
|
||||||
|
|
||||||
|
3. **Coding Conventions**
|
||||||
|
- Naming conventions
|
||||||
|
- File organization patterns
|
||||||
|
- Code structure preferences
|
||||||
|
- Import/export patterns
|
||||||
|
|
||||||
|
4. **Current Implementation Status**
|
||||||
|
- Completed features
|
||||||
|
- Work in progress
|
||||||
|
- Known technical debt
|
||||||
|
- Planned improvements
|
||||||
|
|
||||||
|
## Context Maintenance Protocol
|
||||||
|
|
||||||
|
### Before Every Coding Session
|
||||||
|
1. **Review CONTEXT.md and PRD.md** to understand current project state
|
||||||
|
2. **Scan recent changes** in git history to understand latest patterns
|
||||||
|
3. **Identify existing patterns** for similar functionality before implementing new features
|
||||||
|
4. **Ask for clarification** if existing patterns are unclear or conflicting
|
||||||
|
|
||||||
|
### During Development
|
||||||
|
1. **Reference existing code** when explaining implementation approaches
|
||||||
|
2. **Maintain consistency** with established patterns and conventions
|
||||||
|
3. **Update CONTEXT.md** when making architectural decisions
|
||||||
|
4. **Document deviations** from established patterns with reasoning
|
||||||
|
|
||||||
|
### Context Preservation Strategies
|
||||||
|
- **Incremental development**: Build on existing patterns rather than creating new ones
|
||||||
|
- **Pattern consistency**: Use established data structures and function signatures
|
||||||
|
- **Integration awareness**: Consider how new code affects existing functionality
|
||||||
|
- **Dependency management**: Understand existing dependencies before adding new ones
|
||||||
|
|
||||||
|
## Context Prompting Best Practices
|
||||||
|
|
||||||
|
### Effective Context Sharing
|
||||||
|
- Include relevant sections of CONTEXT.md in prompts for complex tasks
|
||||||
|
- Reference specific existing files when asking for similar functionality
|
||||||
|
- Provide examples of existing patterns when requesting new implementations
|
||||||
|
- Share recent git commit messages to understand latest changes
|
||||||
|
|
||||||
|
### Context Window Optimization
|
||||||
|
- Prioritize most relevant context for current task
|
||||||
|
- Use @filename references to include specific files
|
||||||
|
- Break large contexts into focused, task-specific chunks
|
||||||
|
- Update context references as project evolves
|
||||||
|
|
||||||
|
## Red Flags - Context Loss Indicators
|
||||||
|
- AI suggests patterns that conflict with existing code
|
||||||
|
- New implementations ignore established conventions
|
||||||
|
- Proposed solutions don't integrate with existing architecture
|
||||||
|
- Code suggestions require significant refactoring of existing functionality
|
||||||
|
|
||||||
|
## Recovery Protocol
|
||||||
|
When context loss is detected:
|
||||||
|
1. **Stop development** and review CONTEXT.md
|
||||||
|
2. **Analyze existing codebase** for established patterns
|
||||||
|
3. **Update context documentation** with missing information
|
||||||
|
4. **Restart task** with proper context provided
|
||||||
|
5. **Test integration** with existing code before proceeding
|
||||||
67
.cursor/rules/create-prd.mdc
Normal file
67
.cursor/rules/create-prd.mdc
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
---
|
||||||
|
description: Creating PRD for a project or specific task/function
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
---
|
||||||
|
description: Creating PRD for a project or specific task/function
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
# Rule: Generating a Product Requirements Document (PRD)
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
To guide an AI assistant in creating a detailed Product Requirements Document (PRD) in Markdown format, based on an initial user prompt. The PRD should be clear, actionable, and suitable for a junior developer to understand and implement the feature.
|
||||||
|
|
||||||
|
## Process
|
||||||
|
|
||||||
|
1. **Receive Initial Prompt:** The user provides a brief description or request for a new feature or functionality.
|
||||||
|
2. **Ask Clarifying Questions:** Before writing the PRD, the AI *must* ask clarifying questions to gather sufficient detail. The goal is to understand the "what" and "why" of the feature, not necessarily the "how" (which the developer will figure out).
|
||||||
|
3. **Generate PRD:** Based on the initial prompt and the user's answers to the clarifying questions, generate a PRD using the structure outlined below.
|
||||||
|
4. **Save PRD:** Save the generated document as `prd-[feature-name].md` inside the `/tasks` directory.
|
||||||
|
|
||||||
|
## Clarifying Questions (Examples)
|
||||||
|
|
||||||
|
The AI should adapt its questions based on the prompt, but here are some common areas to explore:
|
||||||
|
|
||||||
|
* **Problem/Goal:** "What problem does this feature solve for the user?" or "What is the main goal we want to achieve with this feature?"
|
||||||
|
* **Target User:** "Who is the primary user of this feature?"
|
||||||
|
* **Core Functionality:** "Can you describe the key actions a user should be able to perform with this feature?"
|
||||||
|
* **User Stories:** "Could you provide a few user stories? (e.g., As a [type of user], I want to [perform an action] so that [benefit].)"
|
||||||
|
* **Acceptance Criteria:** "How will we know when this feature is successfully implemented? What are the key success criteria?"
|
||||||
|
* **Scope/Boundaries:** "Are there any specific things this feature *should not* do (non-goals)?"
|
||||||
|
* **Data Requirements:** "What kind of data does this feature need to display or manipulate?"
|
||||||
|
* **Design/UI:** "Are there any existing design mockups or UI guidelines to follow?" or "Can you describe the desired look and feel?"
|
||||||
|
* **Edge Cases:** "Are there any potential edge cases or error conditions we should consider?"
|
||||||
|
|
||||||
|
## PRD Structure
|
||||||
|
|
||||||
|
The generated PRD should include the following sections:
|
||||||
|
|
||||||
|
1. **Introduction/Overview:** Briefly describe the feature and the problem it solves. State the goal.
|
||||||
|
2. **Goals:** List the specific, measurable objectives for this feature.
|
||||||
|
3. **User Stories:** Detail the user narratives describing feature usage and benefits.
|
||||||
|
4. **Functional Requirements:** List the specific functionalities the feature must have. Use clear, concise language (e.g., "The system must allow users to upload a profile picture."). Number these requirements.
|
||||||
|
5. **Non-Goals (Out of Scope):** Clearly state what this feature will *not* include to manage scope.
|
||||||
|
6. **Design Considerations (Optional):** Link to mockups, describe UI/UX requirements, or mention relevant components/styles if applicable.
|
||||||
|
7. **Technical Considerations (Optional):** Mention any known technical constraints, dependencies, or suggestions (e.g., "Should integrate with the existing Auth module").
|
||||||
|
8. **Success Metrics:** How will the success of this feature be measured? (e.g., "Increase user engagement by 10%", "Reduce support tickets related to X").
|
||||||
|
9. **Open Questions:** List any remaining questions or areas needing further clarification.
|
||||||
|
|
||||||
|
## Target Audience
|
||||||
|
|
||||||
|
Assume the primary reader of the PRD is a **junior developer**. Therefore, requirements should be explicit, unambiguous, and avoid jargon where possible. Provide enough detail for them to understand the feature's purpose and core logic.
|
||||||
|
|
||||||
|
## Output
|
||||||
|
|
||||||
|
* **Format:** Markdown (`.md`)
|
||||||
|
* **Location:** `/tasks/`
|
||||||
|
* **Filename:** `prd-[feature-name].md`
|
||||||
|
|
||||||
|
## Final instructions
|
||||||
|
|
||||||
|
1. Do NOT start implmenting the PRD
|
||||||
|
2. Make sure to ask the user clarifying questions
|
||||||
|
|
||||||
|
3. Take the user's answers to the clarifying questions and improve the PRD
|
||||||
244
.cursor/rules/documentation.mdc
Normal file
244
.cursor/rules/documentation.mdc
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
---
|
||||||
|
description: Documentation standards for code, architecture, and development decisions
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Documentation Standards
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Maintain comprehensive, up-to-date documentation that supports development, onboarding, and long-term maintenance of the codebase.
|
||||||
|
|
||||||
|
## Documentation Hierarchy
|
||||||
|
|
||||||
|
### 1. Project Level Documentation (in ./docs/)
|
||||||
|
- **README.md**: Project overview, setup instructions, basic usage
|
||||||
|
- **CONTEXT.md**: Current project state, architecture decisions, patterns
|
||||||
|
- **CHANGELOG.md**: Version history and significant changes
|
||||||
|
- **CONTRIBUTING.md**: Development guidelines and processes
|
||||||
|
- **API.md**: API endpoints, request/response formats, authentication
|
||||||
|
|
||||||
|
### 2. Module Level Documentation (in ./docs/modules/)
|
||||||
|
- **[module-name].md**: Purpose, public interfaces, usage examples
|
||||||
|
- **dependencies.md**: External dependencies and their purposes
|
||||||
|
- **architecture.md**: Module relationships and data flow
|
||||||
|
|
||||||
|
### 3. Code Level Documentation
|
||||||
|
- **Docstrings**: Function and class documentation
|
||||||
|
- **Inline comments**: Complex logic explanations
|
||||||
|
- **Type hints**: Clear parameter and return types
|
||||||
|
- **README files**: Directory-specific instructions
|
||||||
|
|
||||||
|
## Documentation Standards
|
||||||
|
|
||||||
|
### Code Documentation
|
||||||
|
```python
|
||||||
|
def process_user_data(user_id: str, data: dict) -> UserResult:
|
||||||
|
"""
|
||||||
|
Process and validate user data before storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Unique identifier for the user
|
||||||
|
data: Dictionary containing user information to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserResult: Processed user data with validation status
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: When user data fails validation
|
||||||
|
DatabaseError: When storage operation fails
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> result = process_user_data("123", {"name": "John", "email": "john@example.com"})
|
||||||
|
>>> print(result.status)
|
||||||
|
'valid'
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
### API Documentation Format
|
||||||
|
```markdown
|
||||||
|
### POST /api/users
|
||||||
|
|
||||||
|
Create a new user account.
|
||||||
|
|
||||||
|
**Request:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "string (required)",
|
||||||
|
"email": "string (required, valid email)",
|
||||||
|
"age": "number (optional, min: 13)"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response (201):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "uuid",
|
||||||
|
"name": "string",
|
||||||
|
"email": "string",
|
||||||
|
"created_at": "iso_datetime"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Errors:**
|
||||||
|
- 400: Invalid input data
|
||||||
|
- 409: Email already exists
|
||||||
|
```
|
||||||
|
|
||||||
|
### Architecture Decision Records (ADRs)
|
||||||
|
Document significant architecture decisions in `./docs/decisions/`:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
# ADR-001: Database Choice - PostgreSQL
|
||||||
|
|
||||||
|
## Status
|
||||||
|
Accepted
|
||||||
|
|
||||||
|
## Context
|
||||||
|
We need to choose a database for storing user data and application state.
|
||||||
|
|
||||||
|
## Decision
|
||||||
|
We will use PostgreSQL as our primary database.
|
||||||
|
|
||||||
|
## Consequences
|
||||||
|
**Positive:**
|
||||||
|
- ACID compliance ensures data integrity
|
||||||
|
- Rich query capabilities with SQL
|
||||||
|
- Good performance for our expected load
|
||||||
|
|
||||||
|
**Negative:**
|
||||||
|
- More complex setup than simpler alternatives
|
||||||
|
- Requires SQL knowledge from team members
|
||||||
|
|
||||||
|
## Alternatives Considered
|
||||||
|
- MongoDB: Rejected due to consistency requirements
|
||||||
|
- SQLite: Rejected due to scalability needs
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation Maintenance
|
||||||
|
|
||||||
|
### When to Update Documentation
|
||||||
|
|
||||||
|
#### Always Update:
|
||||||
|
- **API changes**: Any modification to public interfaces
|
||||||
|
- **Architecture changes**: New patterns, data structures, or workflows
|
||||||
|
- **Configuration changes**: Environment variables, deployment settings
|
||||||
|
- **Dependencies**: Adding, removing, or upgrading packages
|
||||||
|
- **Business logic changes**: Core functionality modifications
|
||||||
|
|
||||||
|
#### Update Weekly:
|
||||||
|
- **CONTEXT.md**: Current development status and priorities
|
||||||
|
- **Known issues**: Bug reports and workarounds
|
||||||
|
- **Performance notes**: Bottlenecks and optimization opportunities
|
||||||
|
|
||||||
|
#### Update per Release:
|
||||||
|
- **CHANGELOG.md**: User-facing changes and improvements
|
||||||
|
- **Version documentation**: Breaking changes and migration guides
|
||||||
|
- **Examples and tutorials**: Keep sample code current
|
||||||
|
|
||||||
|
### Documentation Quality Checklist
|
||||||
|
|
||||||
|
#### Completeness
|
||||||
|
- [ ] Purpose and scope clearly explained
|
||||||
|
- [ ] All public interfaces documented
|
||||||
|
- [ ] Examples provided for complex usage
|
||||||
|
- [ ] Error conditions and handling described
|
||||||
|
- [ ] Dependencies and requirements listed
|
||||||
|
|
||||||
|
#### Accuracy
|
||||||
|
- [ ] Code examples are tested and working
|
||||||
|
- [ ] Links point to correct locations
|
||||||
|
- [ ] Version numbers are current
|
||||||
|
- [ ] Screenshots reflect current UI
|
||||||
|
|
||||||
|
#### Clarity
|
||||||
|
- [ ] Written for the intended audience
|
||||||
|
- [ ] Technical jargon is explained
|
||||||
|
- [ ] Step-by-step instructions are clear
|
||||||
|
- [ ] Visual aids used where helpful
|
||||||
|
|
||||||
|
## Documentation Automation
|
||||||
|
|
||||||
|
### Auto-Generated Documentation
|
||||||
|
- **API docs**: Generate from code annotations
|
||||||
|
- **Type documentation**: Extract from type hints
|
||||||
|
- **Module dependencies**: Auto-update from imports
|
||||||
|
- **Test coverage**: Include coverage reports
|
||||||
|
|
||||||
|
### Documentation Testing
|
||||||
|
```python
|
||||||
|
# Test that code examples in documentation work
|
||||||
|
def test_documentation_examples():
|
||||||
|
"""Verify code examples in docs actually work."""
|
||||||
|
# Test examples from README.md
|
||||||
|
# Test API examples from docs/API.md
|
||||||
|
# Test configuration examples
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation Templates
|
||||||
|
|
||||||
|
### New Module Documentation Template
|
||||||
|
```markdown
|
||||||
|
# Module: [Name]
|
||||||
|
|
||||||
|
## Purpose
|
||||||
|
Brief description of what this module does and why it exists.
|
||||||
|
|
||||||
|
## Public Interface
|
||||||
|
### Functions
|
||||||
|
- `function_name(params)`: Description and example
|
||||||
|
|
||||||
|
### Classes
|
||||||
|
- `ClassName`: Purpose and basic usage
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
```python
|
||||||
|
# Basic usage example
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
- Internal: List of internal modules this depends on
|
||||||
|
- External: List of external packages required
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
How to run tests for this module.
|
||||||
|
|
||||||
|
## Known Issues
|
||||||
|
Current limitations or bugs.
|
||||||
|
```
|
||||||
|
|
||||||
|
### API Endpoint Template
|
||||||
|
```markdown
|
||||||
|
### [METHOD] /api/endpoint
|
||||||
|
|
||||||
|
Brief description of what this endpoint does.
|
||||||
|
|
||||||
|
**Authentication:** Required/Optional
|
||||||
|
**Rate Limiting:** X requests per minute
|
||||||
|
|
||||||
|
**Request:**
|
||||||
|
- Headers required
|
||||||
|
- Body schema
|
||||||
|
- Query parameters
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
- Success response format
|
||||||
|
- Error response format
|
||||||
|
- Status codes
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
Working request/response example
|
||||||
|
```
|
||||||
|
|
||||||
|
## Review and Maintenance Process
|
||||||
|
|
||||||
|
### Documentation Review
|
||||||
|
- Include documentation updates in code reviews
|
||||||
|
- Verify examples still work with code changes
|
||||||
|
- Check for broken links and outdated information
|
||||||
|
- Ensure consistency with current implementation
|
||||||
|
|
||||||
|
### Regular Audits
|
||||||
|
- Monthly review of documentation accuracy
|
||||||
|
- Quarterly assessment of documentation completeness
|
||||||
|
- Annual review of documentation structure and organization
|
||||||
207
.cursor/rules/enhanced-task-list.mdc
Normal file
207
.cursor/rules/enhanced-task-list.mdc
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
---
|
||||||
|
description: Enhanced task list management with quality gates and iterative workflow integration
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Enhanced Task List Management
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Manage task lists with integrated quality gates and iterative workflow to prevent context loss and ensure sustainable development.
|
||||||
|
|
||||||
|
## Task Implementation Protocol
|
||||||
|
|
||||||
|
### Pre-Implementation Check
|
||||||
|
Before starting any sub-task:
|
||||||
|
- [ ] **Context Review**: Have you reviewed CONTEXT.md and relevant documentation?
|
||||||
|
- [ ] **Pattern Identification**: Do you understand existing patterns to follow?
|
||||||
|
- [ ] **Integration Planning**: Do you know how this will integrate with existing code?
|
||||||
|
- [ ] **Size Validation**: Is this task small enough (≤50 lines, ≤250 lines per file)?
|
||||||
|
|
||||||
|
### Implementation Process
|
||||||
|
1. **One sub-task at a time**: Do **NOT** start the next sub‑task until you ask the user for permission and they say "yes" or "y"
|
||||||
|
2. **Step-by-step execution**:
|
||||||
|
- Plan the approach in bullet points
|
||||||
|
- Wait for approval
|
||||||
|
- Implement the specific sub-task
|
||||||
|
- Test the implementation
|
||||||
|
- Update documentation if needed
|
||||||
|
3. **Quality validation**: Run through the code review checklist before marking complete
|
||||||
|
|
||||||
|
### Completion Protocol
|
||||||
|
When you finish a **sub‑task**:
|
||||||
|
1. **Immediate marking**: Change `[ ]` to `[x]`
|
||||||
|
2. **Quality check**: Verify the implementation meets quality standards
|
||||||
|
3. **Integration test**: Ensure new code works with existing functionality
|
||||||
|
4. **Documentation update**: Update relevant files if needed
|
||||||
|
5. **Parent task check**: If **all** subtasks underneath a parent task are now `[x]`, also mark the **parent task** as completed
|
||||||
|
6. **Stop and wait**: Get user approval before proceeding to next sub-task
|
||||||
|
|
||||||
|
## Enhanced Task List Structure
|
||||||
|
|
||||||
|
### Task File Header
|
||||||
|
```markdown
|
||||||
|
# Task List: [Feature Name]
|
||||||
|
|
||||||
|
**Source PRD**: `prd-[feature-name].md`
|
||||||
|
**Status**: In Progress / Complete / Blocked
|
||||||
|
**Context Last Updated**: [Date]
|
||||||
|
**Architecture Review**: Required / Complete / N/A
|
||||||
|
|
||||||
|
## Quick Links
|
||||||
|
- [Context Documentation](./CONTEXT.md)
|
||||||
|
- [Architecture Guidelines](./docs/architecture.md)
|
||||||
|
- [Related Files](#relevant-files)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Task Format with Quality Gates
|
||||||
|
```markdown
|
||||||
|
- [ ] 1.0 Parent Task Title
|
||||||
|
- **Quality Gate**: Architecture review required
|
||||||
|
- **Dependencies**: List any dependencies
|
||||||
|
- [ ] 1.1 [Sub-task description 1.1]
|
||||||
|
- **Size estimate**: [Small/Medium/Large]
|
||||||
|
- **Pattern reference**: [Reference to existing pattern]
|
||||||
|
- **Test requirements**: [Unit/Integration/Both]
|
||||||
|
- [ ] 1.2 [Sub-task description 1.2]
|
||||||
|
- **Integration points**: [List affected components]
|
||||||
|
- **Risk level**: [Low/Medium/High]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Relevant Files Management
|
||||||
|
|
||||||
|
### Enhanced File Tracking
|
||||||
|
```markdown
|
||||||
|
## Relevant Files
|
||||||
|
|
||||||
|
### Implementation Files
|
||||||
|
- `path/to/file1.ts` - Brief description of purpose and role
|
||||||
|
- **Status**: Created / Modified / Needs Review
|
||||||
|
- **Last Modified**: [Date]
|
||||||
|
- **Review Status**: Pending / Approved / Needs Changes
|
||||||
|
|
||||||
|
### Test Files
|
||||||
|
- `path/to/file1.test.ts` - Unit tests for file1.ts
|
||||||
|
- **Coverage**: [Percentage or status]
|
||||||
|
- **Last Run**: [Date and result]
|
||||||
|
|
||||||
|
### Documentation Files
|
||||||
|
- `docs/module-name.md` - Module documentation
|
||||||
|
- **Status**: Up to date / Needs update / Missing
|
||||||
|
- **Last Updated**: [Date]
|
||||||
|
|
||||||
|
### Configuration Files
|
||||||
|
- `config/setting.json` - Configuration changes
|
||||||
|
- **Environment**: [Dev/Staging/Prod affected]
|
||||||
|
- **Backup**: [Location of backup]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Task List Maintenance
|
||||||
|
|
||||||
|
### During Development
|
||||||
|
1. **Regular updates**: Update task status after each significant change
|
||||||
|
2. **File tracking**: Add new files as they are created or modified
|
||||||
|
3. **Dependency tracking**: Note when new dependencies between tasks emerge
|
||||||
|
4. **Risk assessment**: Flag tasks that become more complex than anticipated
|
||||||
|
|
||||||
|
### Quality Checkpoints
|
||||||
|
At 25%, 50%, 75%, and 100% completion:
|
||||||
|
- [ ] **Architecture alignment**: Code follows established patterns
|
||||||
|
- [ ] **Performance impact**: No significant performance degradation
|
||||||
|
- [ ] **Security review**: No security vulnerabilities introduced
|
||||||
|
- [ ] **Documentation current**: All changes are documented
|
||||||
|
|
||||||
|
### Weekly Review Process
|
||||||
|
1. **Completion assessment**: What percentage of tasks are actually complete?
|
||||||
|
2. **Quality assessment**: Are completed tasks meeting quality standards?
|
||||||
|
3. **Process assessment**: Is the iterative workflow being followed?
|
||||||
|
4. **Risk assessment**: Are there emerging risks or blockers?
|
||||||
|
|
||||||
|
## Task Status Indicators
|
||||||
|
|
||||||
|
### Status Levels
|
||||||
|
- `[ ]` **Not Started**: Task not yet begun
|
||||||
|
- `[~]` **In Progress**: Currently being worked on
|
||||||
|
- `[?]` **Blocked**: Waiting for dependencies or decisions
|
||||||
|
- `[!]` **Needs Review**: Implementation complete but needs quality review
|
||||||
|
- `[x]` **Complete**: Finished and quality approved
|
||||||
|
|
||||||
|
### Quality Indicators
|
||||||
|
- ✅ **Quality Approved**: Passed all quality gates
|
||||||
|
- ⚠️ **Quality Concerns**: Has issues but functional
|
||||||
|
- ❌ **Quality Failed**: Needs rework before approval
|
||||||
|
- 🔄 **Under Review**: Currently being reviewed
|
||||||
|
|
||||||
|
### Integration Status
|
||||||
|
- 🔗 **Integrated**: Successfully integrated with existing code
|
||||||
|
- 🔧 **Integration Issues**: Problems with existing code integration
|
||||||
|
- ⏳ **Integration Pending**: Ready for integration testing
|
||||||
|
|
||||||
|
## Emergency Procedures
|
||||||
|
|
||||||
|
### When Tasks Become Too Complex
|
||||||
|
If a sub-task grows beyond expected scope:
|
||||||
|
1. **Stop implementation** immediately
|
||||||
|
2. **Document current state** and what was discovered
|
||||||
|
3. **Break down** the task into smaller pieces
|
||||||
|
4. **Update task list** with new sub-tasks
|
||||||
|
5. **Get approval** for the new breakdown before proceeding
|
||||||
|
|
||||||
|
### When Context is Lost
|
||||||
|
If AI seems to lose track of project patterns:
|
||||||
|
1. **Pause development**
|
||||||
|
2. **Review CONTEXT.md** and recent changes
|
||||||
|
3. **Update context documentation** with current state
|
||||||
|
4. **Restart** with explicit pattern references
|
||||||
|
5. **Reduce task size** until context is re-established
|
||||||
|
|
||||||
|
### When Quality Gates Fail
|
||||||
|
If implementation doesn't meet quality standards:
|
||||||
|
1. **Mark task** with `[!]` status
|
||||||
|
2. **Document specific issues** found
|
||||||
|
3. **Create remediation tasks** if needed
|
||||||
|
4. **Don't proceed** until quality issues are resolved
|
||||||
|
|
||||||
|
## AI Instructions Integration
|
||||||
|
|
||||||
|
### Context Awareness Commands
|
||||||
|
```markdown
|
||||||
|
**Before starting any task, run these checks:**
|
||||||
|
1. @CONTEXT.md - Review current project state
|
||||||
|
2. @architecture.md - Understand design principles
|
||||||
|
3. @code-review.md - Know quality standards
|
||||||
|
4. Look at existing similar code for patterns
|
||||||
|
```
|
||||||
|
|
||||||
|
### Quality Validation Commands
|
||||||
|
```markdown
|
||||||
|
**After completing any sub-task:**
|
||||||
|
1. Run code review checklist
|
||||||
|
2. Test integration with existing code
|
||||||
|
3. Update documentation if needed
|
||||||
|
4. Mark task complete only after quality approval
|
||||||
|
```
|
||||||
|
|
||||||
|
### Workflow Commands
|
||||||
|
```markdown
|
||||||
|
**For each development session:**
|
||||||
|
1. Review incomplete tasks and their status
|
||||||
|
2. Identify next logical sub-task to work on
|
||||||
|
3. Check dependencies and blockers
|
||||||
|
4. Follow iterative workflow process
|
||||||
|
5. Update task list with progress and findings
|
||||||
|
```
|
||||||
|
|
||||||
|
## Success Metrics
|
||||||
|
|
||||||
|
### Daily Success Indicators
|
||||||
|
- Tasks are completed according to quality standards
|
||||||
|
- No sub-tasks are started without completing previous ones
|
||||||
|
- File tracking remains accurate and current
|
||||||
|
- Integration issues are caught early
|
||||||
|
|
||||||
|
### Weekly Success Indicators
|
||||||
|
- Overall task completion rate is sustainable
|
||||||
|
- Quality issues are decreasing over time
|
||||||
|
- Context loss incidents are rare
|
||||||
|
- Team confidence in codebase remains high
|
||||||
70
.cursor/rules/generate-tasks.mdc
Normal file
70
.cursor/rules/generate-tasks.mdc
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
---
|
||||||
|
description: Generate a task list or TODO for a user requirement or implementation.
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
---
|
||||||
|
description:
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
# Rule: Generating a Task List from a PRD
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
To guide an AI assistant in creating a detailed, step-by-step task list in Markdown format based on an existing Product Requirements Document (PRD). The task list should guide a developer through implementation.
|
||||||
|
|
||||||
|
## Output
|
||||||
|
|
||||||
|
- **Format:** Markdown (`.md`)
|
||||||
|
- **Location:** `/tasks/`
|
||||||
|
- **Filename:** `tasks-[prd-file-name].md` (e.g., `tasks-prd-user-profile-editing.md`)
|
||||||
|
|
||||||
|
## Process
|
||||||
|
|
||||||
|
1. **Receive PRD Reference:** The user points the AI to a specific PRD file
|
||||||
|
2. **Analyze PRD:** The AI reads and analyzes the functional requirements, user stories, and other sections of the specified PRD.
|
||||||
|
3. **Phase 1: Generate Parent Tasks:** Based on the PRD analysis, create the file and generate the main, high-level tasks required to implement the feature. Use your judgement on how many high-level tasks to use. It's likely to be about 5. Present these tasks to the user in the specified format (without sub-tasks yet). Inform the user: "I have generated the high-level tasks based on the PRD. Ready to generate the sub-tasks? Respond with 'Go' to proceed."
|
||||||
|
4. **Wait for Confirmation:** Pause and wait for the user to respond with "Go".
|
||||||
|
5. **Phase 2: Generate Sub-Tasks:** Once the user confirms, break down each parent task into smaller, actionable sub-tasks necessary to complete the parent task. Ensure sub-tasks logically follow from the parent task and cover the implementation details implied by the PRD.
|
||||||
|
6. **Identify Relevant Files:** Based on the tasks and PRD, identify potential files that will need to be created or modified. List these under the `Relevant Files` section, including corresponding test files if applicable.
|
||||||
|
7. **Generate Final Output:** Combine the parent tasks, sub-tasks, relevant files, and notes into the final Markdown structure.
|
||||||
|
8. **Save Task List:** Save the generated document in the `/tasks/` directory with the filename `tasks-[prd-file-name].md`, where `[prd-file-name]` matches the base name of the input PRD file (e.g., if the input was `prd-user-profile-editing.md`, the output is `tasks-prd-user-profile-editing.md`).
|
||||||
|
|
||||||
|
## Output Format
|
||||||
|
|
||||||
|
The generated task list _must_ follow this structure:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Relevant Files
|
||||||
|
|
||||||
|
- `path/to/potential/file1.ts` - Brief description of why this file is relevant (e.g., Contains the main component for this feature).
|
||||||
|
- `path/to/file1.test.ts` - Unit tests for `file1.ts`.
|
||||||
|
- `path/to/another/file.tsx` - Brief description (e.g., API route handler for data submission).
|
||||||
|
- `path/to/another/file.test.tsx` - Unit tests for `another/file.tsx`.
|
||||||
|
- `lib/utils/helpers.ts` - Brief description (e.g., Utility functions needed for calculations).
|
||||||
|
- `lib/utils/helpers.test.ts` - Unit tests for `helpers.ts`.
|
||||||
|
|
||||||
|
### Notes
|
||||||
|
|
||||||
|
- Unit tests should typically be placed alongside the code files they are testing (e.g., `MyComponent.tsx` and `MyComponent.test.tsx` in the same directory).
|
||||||
|
- Use `npx jest [optional/path/to/test/file]` to run tests. Running without a path executes all tests found by the Jest configuration.
|
||||||
|
|
||||||
|
## Tasks
|
||||||
|
|
||||||
|
- [ ] 1.0 Parent Task Title
|
||||||
|
- [ ] 1.1 [Sub-task description 1.1]
|
||||||
|
- [ ] 1.2 [Sub-task description 1.2]
|
||||||
|
- [ ] 2.0 Parent Task Title
|
||||||
|
- [ ] 2.1 [Sub-task description 2.1]
|
||||||
|
- [ ] 3.0 Parent Task Title (may not require sub-tasks if purely structural or configuration)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Interaction Model
|
||||||
|
|
||||||
|
The process explicitly requires a pause after generating parent tasks to get user confirmation ("Go") before proceeding to generate the detailed sub-tasks. This ensures the high-level plan aligns with user expectations before diving into details.
|
||||||
|
|
||||||
|
## Target Audience
|
||||||
|
|
||||||
|
|
||||||
|
Assume the primary reader of the task list is a **junior developer** who will implement the feature.
|
||||||
236
.cursor/rules/iterative-workflow.mdc
Normal file
236
.cursor/rules/iterative-workflow.mdc
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
---
|
||||||
|
description: Iterative development workflow for AI-assisted coding
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Iterative Development Workflow
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Establish a structured, iterative development process that prevents the chaos and complexity that can arise from uncontrolled AI-assisted development.
|
||||||
|
|
||||||
|
## Development Phases
|
||||||
|
|
||||||
|
### Phase 1: Planning and Design
|
||||||
|
**Before writing any code:**
|
||||||
|
|
||||||
|
1. **Understand the Requirement**
|
||||||
|
- Break down the task into specific, measurable objectives
|
||||||
|
- Identify existing code patterns that should be followed
|
||||||
|
- List dependencies and integration points
|
||||||
|
- Define acceptance criteria
|
||||||
|
|
||||||
|
2. **Design Review**
|
||||||
|
- Propose approach in bullet points
|
||||||
|
- Wait for explicit approval before proceeding
|
||||||
|
- Consider how the solution fits existing architecture
|
||||||
|
- Identify potential risks and mitigation strategies
|
||||||
|
|
||||||
|
### Phase 2: Incremental Implementation
|
||||||
|
**One small piece at a time:**
|
||||||
|
|
||||||
|
1. **Micro-Tasks** (≤ 50 lines each)
|
||||||
|
- Implement one function or small class at a time
|
||||||
|
- Test immediately after implementation
|
||||||
|
- Ensure integration with existing code
|
||||||
|
- Document decisions and patterns used
|
||||||
|
|
||||||
|
2. **Validation Checkpoints**
|
||||||
|
- After each micro-task, verify it works correctly
|
||||||
|
- Check that it follows established patterns
|
||||||
|
- Confirm it integrates cleanly with existing code
|
||||||
|
- Get approval before moving to next micro-task
|
||||||
|
|
||||||
|
### Phase 3: Integration and Testing
|
||||||
|
**Ensuring system coherence:**
|
||||||
|
|
||||||
|
1. **Integration Testing**
|
||||||
|
- Test new code with existing functionality
|
||||||
|
- Verify no regressions in existing features
|
||||||
|
- Check performance impact
|
||||||
|
- Validate error handling
|
||||||
|
|
||||||
|
2. **Documentation Update**
|
||||||
|
- Update relevant documentation
|
||||||
|
- Record any new patterns or decisions
|
||||||
|
- Update context files if architecture changed
|
||||||
|
|
||||||
|
## Iterative Prompting Strategy
|
||||||
|
|
||||||
|
### Step 1: Context Setting
|
||||||
|
```
|
||||||
|
Before implementing [feature], help me understand:
|
||||||
|
1. What existing patterns should I follow?
|
||||||
|
2. What existing functions/classes are relevant?
|
||||||
|
3. How should this integrate with [specific existing component]?
|
||||||
|
4. What are the potential architectural impacts?
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Plan Creation
|
||||||
|
```
|
||||||
|
Based on the context, create a detailed plan for implementing [feature]:
|
||||||
|
1. Break it into micro-tasks (≤50 lines each)
|
||||||
|
2. Identify dependencies and order of implementation
|
||||||
|
3. Specify integration points with existing code
|
||||||
|
4. List potential risks and mitigation strategies
|
||||||
|
|
||||||
|
Wait for my approval before implementing.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Incremental Implementation
|
||||||
|
```
|
||||||
|
Implement only the first micro-task: [specific task]
|
||||||
|
- Use existing patterns from [reference file/function]
|
||||||
|
- Keep it under 50 lines
|
||||||
|
- Include error handling
|
||||||
|
- Add appropriate tests
|
||||||
|
- Explain your implementation choices
|
||||||
|
|
||||||
|
Stop after this task and wait for approval.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quality Gates
|
||||||
|
|
||||||
|
### Before Each Implementation
|
||||||
|
- [ ] **Purpose is clear**: Can explain what this piece does and why
|
||||||
|
- [ ] **Pattern is established**: Following existing code patterns
|
||||||
|
- [ ] **Size is manageable**: Implementation is small enough to understand completely
|
||||||
|
- [ ] **Integration is planned**: Know how it connects to existing code
|
||||||
|
|
||||||
|
### After Each Implementation
|
||||||
|
- [ ] **Code is understood**: Can explain every line of implemented code
|
||||||
|
- [ ] **Tests pass**: All existing and new tests are passing
|
||||||
|
- [ ] **Integration works**: New code works with existing functionality
|
||||||
|
- [ ] **Documentation updated**: Changes are reflected in relevant documentation
|
||||||
|
|
||||||
|
### Before Moving to Next Task
|
||||||
|
- [ ] **Current task complete**: All acceptance criteria met
|
||||||
|
- [ ] **No regressions**: Existing functionality still works
|
||||||
|
- [ ] **Clean state**: No temporary code or debugging artifacts
|
||||||
|
- [ ] **Approval received**: Explicit go-ahead for next task
|
||||||
|
- [ ] **Documentaion updated**: If relevant changes to module was made.
|
||||||
|
|
||||||
|
## Anti-Patterns to Avoid
|
||||||
|
|
||||||
|
### Large Block Implementation
|
||||||
|
**Don't:**
|
||||||
|
```
|
||||||
|
Implement the entire user management system with authentication,
|
||||||
|
CRUD operations, and email notifications.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Do:**
|
||||||
|
```
|
||||||
|
First, implement just the User model with basic fields.
|
||||||
|
Stop there and let me review before continuing.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Context Loss
|
||||||
|
**Don't:**
|
||||||
|
```
|
||||||
|
Create a new authentication system.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Do:**
|
||||||
|
```
|
||||||
|
Looking at the existing auth patterns in auth.py, implement
|
||||||
|
password validation following the same structure as the
|
||||||
|
existing email validation function.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Over-Engineering
|
||||||
|
**Don't:**
|
||||||
|
```
|
||||||
|
Build a flexible, extensible user management framework that
|
||||||
|
can handle any future requirements.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Do:**
|
||||||
|
```
|
||||||
|
Implement user creation functionality that matches the existing
|
||||||
|
pattern in customer.py, focusing only on the current requirements.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Progress Tracking
|
||||||
|
|
||||||
|
### Task Status Indicators
|
||||||
|
- 🔄 **In Planning**: Requirements gathering and design
|
||||||
|
- ⏳ **In Progress**: Currently implementing
|
||||||
|
- ✅ **Complete**: Implemented, tested, and integrated
|
||||||
|
- 🚫 **Blocked**: Waiting for decisions or dependencies
|
||||||
|
- 🔧 **Needs Refactor**: Working but needs improvement
|
||||||
|
|
||||||
|
### Weekly Review Process
|
||||||
|
1. **Progress Assessment**
|
||||||
|
- What was completed this week?
|
||||||
|
- What challenges were encountered?
|
||||||
|
- How well did the iterative process work?
|
||||||
|
|
||||||
|
2. **Process Adjustment**
|
||||||
|
- Were task sizes appropriate?
|
||||||
|
- Did context management work effectively?
|
||||||
|
- What improvements can be made?
|
||||||
|
|
||||||
|
3. **Architecture Review**
|
||||||
|
- Is the code remaining maintainable?
|
||||||
|
- Are patterns staying consistent?
|
||||||
|
- Is technical debt accumulating?
|
||||||
|
|
||||||
|
## Emergency Procedures
|
||||||
|
|
||||||
|
### When Things Go Wrong
|
||||||
|
If development becomes chaotic or problematic:
|
||||||
|
|
||||||
|
1. **Stop Development**
|
||||||
|
- Don't continue adding to the problem
|
||||||
|
- Take time to assess the situation
|
||||||
|
- Don't rush to "fix" with more AI-generated code
|
||||||
|
|
||||||
|
2. **Assess the Situation**
|
||||||
|
- What specific problems exist?
|
||||||
|
- How far has the code diverged from established patterns?
|
||||||
|
- What parts are still working correctly?
|
||||||
|
|
||||||
|
3. **Recovery Process**
|
||||||
|
- Roll back to last known good state
|
||||||
|
- Update context documentation with lessons learned
|
||||||
|
- Restart with smaller, more focused tasks
|
||||||
|
- Get explicit approval for each step of recovery
|
||||||
|
|
||||||
|
### Context Recovery
|
||||||
|
When AI seems to lose track of project patterns:
|
||||||
|
|
||||||
|
1. **Context Refresh**
|
||||||
|
- Review and update CONTEXT.md
|
||||||
|
- Include examples of current code patterns
|
||||||
|
- Clarify architectural decisions
|
||||||
|
|
||||||
|
2. **Pattern Re-establishment**
|
||||||
|
- Show AI examples of existing, working code
|
||||||
|
- Explicitly state patterns to follow
|
||||||
|
- Start with very small, pattern-matching tasks
|
||||||
|
|
||||||
|
3. **Gradual Re-engagement**
|
||||||
|
- Begin with simple, low-risk tasks
|
||||||
|
- Verify pattern adherence at each step
|
||||||
|
- Gradually increase task complexity as consistency returns
|
||||||
|
|
||||||
|
## Success Metrics
|
||||||
|
|
||||||
|
### Short-term (Daily)
|
||||||
|
- Code is understandable and well-integrated
|
||||||
|
- No major regressions introduced
|
||||||
|
- Development velocity feels sustainable
|
||||||
|
- Team confidence in codebase remains high
|
||||||
|
|
||||||
|
### Medium-term (Weekly)
|
||||||
|
- Technical debt is not accumulating
|
||||||
|
- New features integrate cleanly
|
||||||
|
- Development patterns remain consistent
|
||||||
|
- Documentation stays current
|
||||||
|
|
||||||
|
### Long-term (Monthly)
|
||||||
|
- Codebase remains maintainable as it grows
|
||||||
|
- New team members can understand and contribute
|
||||||
|
- AI assistance enhances rather than hinders development
|
||||||
|
- Architecture remains clean and purposeful
|
||||||
24
.cursor/rules/project.mdc
Normal file
24
.cursor/rules/project.mdc
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
---
|
||||||
|
description:
|
||||||
|
globs:
|
||||||
|
alwaysApply: true
|
||||||
|
---
|
||||||
|
# Rule: Project specific rules
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Unify the project structure and interraction with tools and console
|
||||||
|
|
||||||
|
### System tools
|
||||||
|
- **ALWAYS** use UV for package management
|
||||||
|
- **ALWAYS** use windows PowerShell command for terminal
|
||||||
|
|
||||||
|
### Coding patterns
|
||||||
|
- **ALWYAS** check the arguments and methods before use to avoid errors with whron parameters or names
|
||||||
|
- If in doubt, check [CONTEXT.md](mdc:CONTEXT.md) file and [architecture.md](mdc:docs/architecture.md)
|
||||||
|
- **PREFER** ORM pattern for databases with SQLAclhemy.
|
||||||
|
- **DO NOT USE** emoji in code and comments
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
- Use UV for test in format *uv run pytest [filename]*
|
||||||
|
|
||||||
|
|
||||||
237
.cursor/rules/refactoring.mdc
Normal file
237
.cursor/rules/refactoring.mdc
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
---
|
||||||
|
description: Code refactoring and technical debt management for AI-assisted development
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Code Refactoring and Technical Debt Management
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Guide AI in systematic code refactoring to improve maintainability, reduce complexity, and prevent technical debt accumulation in AI-assisted development projects.
|
||||||
|
|
||||||
|
## When to Apply This Rule
|
||||||
|
- Code complexity has increased beyond manageable levels
|
||||||
|
- Duplicate code patterns are detected
|
||||||
|
- Performance issues are identified
|
||||||
|
- New features are difficult to integrate
|
||||||
|
- Code review reveals maintainability concerns
|
||||||
|
- Weekly technical debt assessment indicates refactoring needs
|
||||||
|
|
||||||
|
## Pre-Refactoring Assessment
|
||||||
|
|
||||||
|
Before starting any refactoring, the AI MUST:
|
||||||
|
|
||||||
|
1. **Context Analysis:**
|
||||||
|
- Review existing `CONTEXT.md` for architectural decisions
|
||||||
|
- Analyze current code patterns and conventions
|
||||||
|
- Identify all files that will be affected (search the codebase for use)
|
||||||
|
- Check for existing tests that verify current behavior
|
||||||
|
|
||||||
|
2. **Scope Definition:**
|
||||||
|
- Clearly define what will and will not be changed
|
||||||
|
- Identify the specific refactoring pattern to apply
|
||||||
|
- Estimate the blast radius of changes
|
||||||
|
- Plan rollback strategy if needed
|
||||||
|
|
||||||
|
3. **Documentation Review:**
|
||||||
|
- Check `./docs/` for relevant module documentation
|
||||||
|
- Review any existing architectural diagrams
|
||||||
|
- Identify dependencies and integration points
|
||||||
|
- Note any known constraints or limitations
|
||||||
|
|
||||||
|
## Refactoring Process
|
||||||
|
|
||||||
|
### Phase 1: Planning and Safety
|
||||||
|
1. **Create Refactoring Plan:**
|
||||||
|
- Document the current state and desired end state
|
||||||
|
- Break refactoring into small, atomic steps
|
||||||
|
- Identify tests that must pass throughout the process
|
||||||
|
- Plan verification steps for each change
|
||||||
|
|
||||||
|
2. **Establish Safety Net:**
|
||||||
|
- Ensure comprehensive test coverage exists
|
||||||
|
- If tests are missing, create them BEFORE refactoring
|
||||||
|
- Document current behavior that must be preserved
|
||||||
|
- Create backup of current implementation approach
|
||||||
|
|
||||||
|
3. **Get Approval:**
|
||||||
|
- Present the refactoring plan to the user
|
||||||
|
- Wait for explicit "Go" or "Proceed" confirmation
|
||||||
|
- Do NOT start refactoring without approval
|
||||||
|
|
||||||
|
### Phase 2: Incremental Implementation
|
||||||
|
4. **One Change at a Time:**
|
||||||
|
- Implement ONE refactoring step per iteration
|
||||||
|
- Run tests after each step to ensure nothing breaks
|
||||||
|
- Update documentation if interfaces change
|
||||||
|
- Mark progress in the refactoring plan
|
||||||
|
|
||||||
|
5. **Verification Protocol:**
|
||||||
|
- Run all relevant tests after each change
|
||||||
|
- Verify functionality works as expected
|
||||||
|
- Check performance hasn't degraded
|
||||||
|
- Ensure no new linting or type errors
|
||||||
|
|
||||||
|
6. **User Checkpoint:**
|
||||||
|
- After each significant step, pause for user review
|
||||||
|
- Present what was changed and current status
|
||||||
|
- Wait for approval before continuing
|
||||||
|
- Address any concerns before proceeding
|
||||||
|
|
||||||
|
### Phase 3: Completion and Documentation
|
||||||
|
7. **Final Verification:**
|
||||||
|
- Run full test suite to ensure nothing is broken
|
||||||
|
- Verify all original functionality is preserved
|
||||||
|
- Check that new code follows project conventions
|
||||||
|
- Confirm performance is maintained or improved
|
||||||
|
|
||||||
|
8. **Documentation Update:**
|
||||||
|
- Update `CONTEXT.md` with new patterns/decisions
|
||||||
|
- Update module documentation in `./docs/`
|
||||||
|
- Document any new conventions established
|
||||||
|
- Note lessons learned for future refactoring
|
||||||
|
|
||||||
|
## Common Refactoring Patterns
|
||||||
|
|
||||||
|
### Extract Method/Function
|
||||||
|
```
|
||||||
|
WHEN: Functions/methods exceed 50 lines or have multiple responsibilities
|
||||||
|
HOW:
|
||||||
|
1. Identify logical groupings within the function
|
||||||
|
2. Extract each group into a well-named helper function
|
||||||
|
3. Ensure each function has a single responsibility
|
||||||
|
4. Verify tests still pass
|
||||||
|
```
|
||||||
|
|
||||||
|
### Extract Module/Class
|
||||||
|
```
|
||||||
|
WHEN: Files exceed 250 lines or handle multiple concerns
|
||||||
|
HOW:
|
||||||
|
1. Identify cohesive functionality groups
|
||||||
|
2. Create new files for each group
|
||||||
|
3. Move related functions/classes together
|
||||||
|
4. Update imports and dependencies
|
||||||
|
5. Verify module boundaries are clean
|
||||||
|
```
|
||||||
|
|
||||||
|
### Eliminate Duplication
|
||||||
|
```
|
||||||
|
WHEN: Similar code appears in multiple places
|
||||||
|
HOW:
|
||||||
|
1. Identify the common pattern or functionality
|
||||||
|
2. Extract to a shared utility function or module
|
||||||
|
3. Update all usage sites to use the shared code
|
||||||
|
4. Ensure the abstraction is not over-engineered
|
||||||
|
```
|
||||||
|
|
||||||
|
### Improve Data Structures
|
||||||
|
```
|
||||||
|
WHEN: Complex nested objects or unclear data flow
|
||||||
|
HOW:
|
||||||
|
1. Define clear interfaces/types for data structures
|
||||||
|
2. Create transformation functions between different representations
|
||||||
|
3. Ensure data flow is unidirectional where possible
|
||||||
|
4. Add validation at boundaries
|
||||||
|
```
|
||||||
|
|
||||||
|
### Reduce Coupling
|
||||||
|
```
|
||||||
|
WHEN: Modules are tightly interconnected
|
||||||
|
HOW:
|
||||||
|
1. Identify dependencies between modules
|
||||||
|
2. Extract interfaces for external dependencies
|
||||||
|
3. Use dependency injection where appropriate
|
||||||
|
4. Ensure modules can be tested in isolation
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quality Gates
|
||||||
|
|
||||||
|
Every refactoring must pass these gates:
|
||||||
|
|
||||||
|
### Technical Quality
|
||||||
|
- [ ] All existing tests pass
|
||||||
|
- [ ] No new linting errors introduced
|
||||||
|
- [ ] Code follows established project conventions
|
||||||
|
- [ ] No performance regression detected
|
||||||
|
- [ ] File sizes remain under 250 lines
|
||||||
|
- [ ] Function sizes remain under 50 lines
|
||||||
|
|
||||||
|
### Maintainability
|
||||||
|
- [ ] Code is more readable than before
|
||||||
|
- [ ] Duplicated code has been reduced
|
||||||
|
- [ ] Module responsibilities are clearer
|
||||||
|
- [ ] Dependencies are explicit and minimal
|
||||||
|
- [ ] Error handling is consistent
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
- [ ] Public interfaces are documented
|
||||||
|
- [ ] Complex logic has explanatory comments
|
||||||
|
- [ ] Architectural decisions are recorded
|
||||||
|
- [ ] Examples are provided where helpful
|
||||||
|
|
||||||
|
## AI Instructions for Refactoring
|
||||||
|
|
||||||
|
1. **Always ask for permission** before starting any refactoring work
|
||||||
|
2. **Start with tests** - ensure comprehensive coverage before changing code
|
||||||
|
3. **Work incrementally** - make small changes and verify each step
|
||||||
|
4. **Preserve behavior** - functionality must remain exactly the same
|
||||||
|
5. **Update documentation** - keep all docs current with changes
|
||||||
|
6. **Follow conventions** - maintain consistency with existing codebase
|
||||||
|
7. **Stop and ask** if any step fails or produces unexpected results
|
||||||
|
8. **Explain changes** - clearly communicate what was changed and why
|
||||||
|
|
||||||
|
## Anti-Patterns to Avoid
|
||||||
|
|
||||||
|
### Over-Engineering
|
||||||
|
- Don't create abstractions for code that isn't duplicated
|
||||||
|
- Avoid complex inheritance hierarchies
|
||||||
|
- Don't optimize prematurely
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
- Never change public APIs without explicit approval
|
||||||
|
- Don't remove functionality, even if it seems unused
|
||||||
|
- Avoid changing behavior "while we're here"
|
||||||
|
|
||||||
|
### Scope Creep
|
||||||
|
- Stick to the defined refactoring scope
|
||||||
|
- Don't add new features during refactoring
|
||||||
|
- Resist the urge to "improve" unrelated code
|
||||||
|
|
||||||
|
## Success Metrics
|
||||||
|
|
||||||
|
Track these metrics to ensure refactoring effectiveness:
|
||||||
|
|
||||||
|
### Code Quality
|
||||||
|
- Reduced cyclomatic complexity
|
||||||
|
- Lower code duplication percentage
|
||||||
|
- Improved test coverage
|
||||||
|
- Fewer linting violations
|
||||||
|
|
||||||
|
### Developer Experience
|
||||||
|
- Faster time to understand code
|
||||||
|
- Easier integration of new features
|
||||||
|
- Reduced bug introduction rate
|
||||||
|
- Higher developer confidence in changes
|
||||||
|
|
||||||
|
### Maintainability
|
||||||
|
- Clearer module boundaries
|
||||||
|
- More predictable behavior
|
||||||
|
- Easier debugging and troubleshooting
|
||||||
|
- Better performance characteristics
|
||||||
|
|
||||||
|
## Output Files
|
||||||
|
|
||||||
|
When refactoring is complete, update:
|
||||||
|
- `refactoring-log-[date].md` - Document what was changed and why
|
||||||
|
- `CONTEXT.md` - Update with new patterns and decisions
|
||||||
|
- `./docs/` - Update relevant module documentation
|
||||||
|
- Task lists - Mark refactoring tasks as complete
|
||||||
|
|
||||||
|
## Final Verification
|
||||||
|
|
||||||
|
Before marking refactoring complete:
|
||||||
|
1. Run full test suite and verify all tests pass
|
||||||
|
2. Check that code follows all project conventions
|
||||||
|
3. Verify documentation is up to date
|
||||||
|
4. Confirm user is satisfied with the results
|
||||||
|
5. Record lessons learned for future refactoring efforts
|
||||||
44
.cursor/rules/task-list.mdc
Normal file
44
.cursor/rules/task-list.mdc
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
---
|
||||||
|
description: TODO list task implementation
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
---
|
||||||
|
description:
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
# Task List Management
|
||||||
|
|
||||||
|
Guidelines for managing task lists in markdown files to track progress on completing a PRD
|
||||||
|
|
||||||
|
## Task Implementation
|
||||||
|
- **One sub-task at a time:** Do **NOT** start the next sub‑task until you ask the user for permission and they say “yes” or "y"
|
||||||
|
- **Completion protocol:**
|
||||||
|
1. When you finish a **sub‑task**, immediately mark it as completed by changing `[ ]` to `[x]`.
|
||||||
|
2. If **all** subtasks underneath a parent task are now `[x]`, also mark the **parent task** as completed.
|
||||||
|
- Stop after each sub‑task and wait for the user’s go‑ahead.
|
||||||
|
|
||||||
|
## Task List Maintenance
|
||||||
|
|
||||||
|
1. **Update the task list as you work:**
|
||||||
|
- Mark tasks and subtasks as completed (`[x]`) per the protocol above.
|
||||||
|
- Add new tasks as they emerge.
|
||||||
|
|
||||||
|
2. **Maintain the “Relevant Files” section:**
|
||||||
|
- List every file created or modified.
|
||||||
|
- Give each file a one‑line description of its purpose.
|
||||||
|
|
||||||
|
## AI Instructions
|
||||||
|
|
||||||
|
When working with task lists, the AI must:
|
||||||
|
|
||||||
|
1. Regularly update the task list file after finishing any significant work.
|
||||||
|
2. Follow the completion protocol:
|
||||||
|
- Mark each finished **sub‑task** `[x]`.
|
||||||
|
- Mark the **parent task** `[x]` once **all** its subtasks are `[x]`.
|
||||||
|
3. Add newly discovered tasks.
|
||||||
|
4. Keep “Relevant Files” accurate and up to date.
|
||||||
|
5. Before starting work, check which sub‑task is next.
|
||||||
|
|
||||||
|
6. After implementing a sub‑task, update the file and then pause for user approval.
|
||||||
351
.gitignore
vendored
351
.gitignore
vendored
@@ -1,170 +1,181 @@
|
|||||||
# ---> Python
|
# ---> Python
|
||||||
# Byte-compiled / optimized / DLL files
|
/data/*.db
|
||||||
__pycache__/
|
/credentials/*.json
|
||||||
*.py[cod]
|
*.csv
|
||||||
*$py.class
|
*.png
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
# C extensions
|
__pycache__/
|
||||||
*.so
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
# Distribution / packaging
|
/data/*.npy
|
||||||
.Python
|
|
||||||
build/
|
# C extensions
|
||||||
develop-eggs/
|
*.so
|
||||||
dist/
|
|
||||||
downloads/
|
# Distribution / packaging
|
||||||
eggs/
|
.Python
|
||||||
.eggs/
|
build/
|
||||||
lib/
|
develop-eggs/
|
||||||
lib64/
|
dist/
|
||||||
parts/
|
downloads/
|
||||||
sdist/
|
eggs/
|
||||||
var/
|
.eggs/
|
||||||
wheels/
|
lib/
|
||||||
share/python-wheels/
|
lib64/
|
||||||
*.egg-info/
|
parts/
|
||||||
.installed.cfg
|
sdist/
|
||||||
*.egg
|
var/
|
||||||
MANIFEST
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
# PyInstaller
|
*.egg-info/
|
||||||
# Usually these files are written by a python script from a template
|
.installed.cfg
|
||||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
*.egg
|
||||||
*.manifest
|
MANIFEST
|
||||||
*.spec
|
|
||||||
|
# PyInstaller
|
||||||
# Installer logs
|
# Usually these files are written by a python script from a template
|
||||||
pip-log.txt
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
pip-delete-this-directory.txt
|
*.manifest
|
||||||
|
*.spec
|
||||||
# Unit test / coverage reports
|
|
||||||
htmlcov/
|
# Installer logs
|
||||||
.tox/
|
pip-log.txt
|
||||||
.nox/
|
pip-delete-this-directory.txt
|
||||||
.coverage
|
|
||||||
.coverage.*
|
# Unit test / coverage reports
|
||||||
.cache
|
htmlcov/
|
||||||
nosetests.xml
|
.tox/
|
||||||
coverage.xml
|
.nox/
|
||||||
*.cover
|
.coverage
|
||||||
*.py,cover
|
.coverage.*
|
||||||
.hypothesis/
|
.cache
|
||||||
.pytest_cache/
|
nosetests.xml
|
||||||
cover/
|
coverage.xml
|
||||||
|
*.cover
|
||||||
# Translations
|
*.py,cover
|
||||||
*.mo
|
.hypothesis/
|
||||||
*.pot
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
# Django stuff:
|
|
||||||
*.log
|
# Translations
|
||||||
local_settings.py
|
*.mo
|
||||||
db.sqlite3
|
*.pot
|
||||||
db.sqlite3-journal
|
|
||||||
|
# Django stuff:
|
||||||
# Flask stuff:
|
*.log
|
||||||
instance/
|
local_settings.py
|
||||||
.webassets-cache
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
# Scrapy stuff:
|
|
||||||
.scrapy
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
# Sphinx documentation
|
.webassets-cache
|
||||||
docs/_build/
|
|
||||||
|
# Scrapy stuff:
|
||||||
# PyBuilder
|
.scrapy
|
||||||
.pybuilder/
|
|
||||||
target/
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
# Jupyter Notebook
|
|
||||||
.ipynb_checkpoints
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
# IPython
|
target/
|
||||||
profile_default/
|
|
||||||
ipython_config.py
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
# pyenv
|
|
||||||
# For a library or package, you might want to ignore these files since the code is
|
# IPython
|
||||||
# intended to run in multiple environments; otherwise, check them in:
|
profile_default/
|
||||||
# .python-version
|
ipython_config.py
|
||||||
|
|
||||||
# pipenv
|
# pyenv
|
||||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
# .python-version
|
||||||
# install all needed dependencies.
|
|
||||||
#Pipfile.lock
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
# UV
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
# install all needed dependencies.
|
||||||
# commonly ignored for libraries.
|
#Pipfile.lock
|
||||||
#uv.lock
|
|
||||||
|
# UV
|
||||||
# poetry
|
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
# commonly ignored for libraries.
|
||||||
# commonly ignored for libraries.
|
#uv.lock
|
||||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
|
||||||
#poetry.lock
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
# pdm
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
# commonly ignored for libraries.
|
||||||
#pdm.lock
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
#poetry.lock
|
||||||
# in version control.
|
|
||||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
# pdm
|
||||||
.pdm.toml
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
.pdm-python
|
#pdm.lock
|
||||||
.pdm-build/
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||||
__pypackages__/
|
.pdm.toml
|
||||||
|
.pdm-python
|
||||||
# Celery stuff
|
.pdm-build/
|
||||||
celerybeat-schedule
|
|
||||||
celerybeat.pid
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
# SageMath parsed files
|
|
||||||
*.sage.py
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
# Environments
|
celerybeat.pid
|
||||||
.env
|
|
||||||
.venv
|
# SageMath parsed files
|
||||||
env/
|
*.sage.py
|
||||||
venv/
|
|
||||||
ENV/
|
# Environments
|
||||||
env.bak/
|
.env
|
||||||
venv.bak/
|
.venv
|
||||||
|
env/
|
||||||
# Spyder project settings
|
venv/
|
||||||
.spyderproject
|
ENV/
|
||||||
.spyproject
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
# Rope project settings
|
|
||||||
.ropeproject
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
# mkdocs documentation
|
.spyproject
|
||||||
/site
|
|
||||||
|
# Rope project settings
|
||||||
# mypy
|
.ropeproject
|
||||||
.mypy_cache/
|
|
||||||
.dmypy.json
|
# mkdocs documentation
|
||||||
dmypy.json
|
/site
|
||||||
|
|
||||||
# Pyre type checker
|
# mypy
|
||||||
.pyre/
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
# pytype static type analyzer
|
dmypy.json
|
||||||
.pytype/
|
|
||||||
|
# Pyre type checker
|
||||||
# Cython debug symbols
|
.pyre/
|
||||||
cython_debug/
|
|
||||||
|
# pytype static type analyzer
|
||||||
# PyCharm
|
.pytype/
|
||||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
|
||||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
# Cython debug symbols
|
||||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
cython_debug/
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
|
||||||
#.idea/
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
||||||
|
An introduction to trading cycles.pdf
|
||||||
|
An introduction to trading cycles.txt
|
||||||
|
README.md
|
||||||
|
.vscode/launch.json
|
||||||
|
data/btcusd_1-day_data.csv
|
||||||
|
data/btcusd_1-min_data.csv
|
||||||
|
|||||||
44
Dockerfile
Normal file
44
Dockerfile
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# Use the base image with CUDA and PyTorch
|
||||||
|
FROM kom4cr0/cuda11.7-pytorch1.13-mamba1.1.1:1.1.1
|
||||||
|
|
||||||
|
# Install NVIDIA Container Toolkit (necessary for GPU support)
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
nvidia-container-runtime \
|
||||||
|
python3 \
|
||||||
|
python3-pip \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install necessary dependencies and configure NVIDIA repository
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
curl \
|
||||||
|
gnupg \
|
||||||
|
lsb-release \
|
||||||
|
sudo \
|
||||||
|
&& curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
|
||||||
|
&& curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
|
||||||
|
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
|
||||||
|
tee /etc/apt/sources.list.d/nvidia-container-toolkit.list \
|
||||||
|
&& sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list \
|
||||||
|
&& apt-get update
|
||||||
|
|
||||||
|
# Install NVIDIA Container Toolkit
|
||||||
|
RUN apt-get install -y nvidia-container-toolkit
|
||||||
|
|
||||||
|
# Set the environment variables for CUDA
|
||||||
|
ENV PATH=/usr/local/cuda-11.7/bin:$PATH
|
||||||
|
ENV LD_LIBRARY_PATH=/usr/local/cuda-11.7/lib64:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
|
# Set the runtime for GPU (requires NVIDIA runtime to be installed on the host machine)
|
||||||
|
ENV NVIDIA_VISIBLE_DEVICES=all
|
||||||
|
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
||||||
|
|
||||||
|
# Set working directory to /projects
|
||||||
|
WORKDIR /project
|
||||||
|
|
||||||
|
# Install necessary Python dependencies
|
||||||
|
# Uncomment and modify the next lines as per your project requirements
|
||||||
|
COPY requirements.txt requirements.txt
|
||||||
|
RUN pip3 install -r requirements.txt
|
||||||
|
|
||||||
|
# Run your Python script
|
||||||
|
CMD ["python3", "main.py"]
|
||||||
513
README.md
513
README.md
@@ -1 +1,512 @@
|
|||||||
# Cycles
|
# Cycles - Cryptocurrency Trading Strategy Backtesting Framework
|
||||||
|
|
||||||
|
A comprehensive Python framework for backtesting cryptocurrency trading strategies using technical indicators, with advanced features like machine learning price prediction to eliminate lookahead bias.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Overview](#overview)
|
||||||
|
- [Features](#features)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Project Structure](#project-structure)
|
||||||
|
- [Core Modules](#core-modules)
|
||||||
|
- [Configuration](#configuration)
|
||||||
|
- [Usage Examples](#usage-examples)
|
||||||
|
- [API Documentation](#api-documentation)
|
||||||
|
- [Testing](#testing)
|
||||||
|
- [Contributing](#contributing)
|
||||||
|
- [License](#license)
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Cycles is a sophisticated backtesting framework designed specifically for cryptocurrency trading strategies. It provides robust tools for:
|
||||||
|
|
||||||
|
- **Strategy Backtesting**: Test trading strategies across multiple timeframes with comprehensive metrics
|
||||||
|
- **Technical Analysis**: Built-in indicators including SuperTrend, RSI, Bollinger Bands, and more
|
||||||
|
- **Machine Learning Integration**: Eliminate lookahead bias using XGBoost price prediction
|
||||||
|
- **Multi-timeframe Analysis**: Support for various timeframes from 1-minute to daily data
|
||||||
|
- **Performance Analytics**: Detailed reporting with profit ratios, drawdowns, win rates, and fee calculations
|
||||||
|
|
||||||
|
### Key Goals
|
||||||
|
|
||||||
|
1. **Realistic Trading Simulation**: Eliminate common backtesting pitfalls like lookahead bias
|
||||||
|
2. **Modular Architecture**: Easy to extend with new indicators and strategies
|
||||||
|
3. **Performance Optimization**: Parallel processing for efficient large-scale backtesting
|
||||||
|
4. **Comprehensive Analysis**: Rich reporting and visualization capabilities
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
### 🚀 Core Features
|
||||||
|
|
||||||
|
- **Multi-Strategy Backtesting**: Test multiple trading strategies simultaneously
|
||||||
|
- **Advanced Stop Loss Management**: Precise stop-loss execution using 1-minute data
|
||||||
|
- **Fee Integration**: Realistic trading fee calculations (OKX exchange fees)
|
||||||
|
- **Parallel Processing**: Efficient multi-core backtesting execution
|
||||||
|
- **Rich Analytics**: Comprehensive performance metrics and reporting
|
||||||
|
|
||||||
|
### 📊 Technical Indicators
|
||||||
|
|
||||||
|
- **SuperTrend**: Multi-parameter SuperTrend indicator with meta-trend analysis
|
||||||
|
- **RSI**: Relative Strength Index with customizable periods
|
||||||
|
- **Bollinger Bands**: Configurable period and standard deviation multipliers
|
||||||
|
- **Extensible Framework**: Easy to add new technical indicators
|
||||||
|
|
||||||
|
### 🤖 Machine Learning
|
||||||
|
|
||||||
|
- **Price Prediction**: XGBoost-based closing price prediction
|
||||||
|
- **Lookahead Bias Elimination**: Realistic trading simulations
|
||||||
|
- **Feature Engineering**: Advanced technical feature extraction
|
||||||
|
- **Model Persistence**: Save and load trained models
|
||||||
|
|
||||||
|
### 📈 Data Management
|
||||||
|
|
||||||
|
- **Multiple Data Sources**: Support for various cryptocurrency exchanges
|
||||||
|
- **Flexible Timeframes**: 1-minute to daily data aggregation
|
||||||
|
- **Efficient Storage**: Optimized data loading and caching
|
||||||
|
- **Google Sheets Integration**: External data source connectivity
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Python 3.10 or higher
|
||||||
|
- UV package manager (recommended)
|
||||||
|
- Git
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
1. **Clone the repository**:
|
||||||
|
```bash
|
||||||
|
git clone <repository-url>
|
||||||
|
cd Cycles
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Install dependencies**:
|
||||||
|
```bash
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Activate virtual environment**:
|
||||||
|
```bash
|
||||||
|
source .venv/bin/activate # Linux/Mac
|
||||||
|
# or
|
||||||
|
.venv\Scripts\activate # Windows
|
||||||
|
```
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
1. **Prepare your configuration file** (`config.json`):
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"start_date": "2023-01-01",
|
||||||
|
"stop_date": "2023-12-31",
|
||||||
|
"initial_usd": 10000,
|
||||||
|
"timeframes": ["5T", "15T", "1H", "4H"],
|
||||||
|
"stop_loss_pcts": [0.02, 0.05, 0.10]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Run a backtest**:
|
||||||
|
```bash
|
||||||
|
uv run python main.py --config config.json
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **View results**:
|
||||||
|
Results will be saved in timestamped CSV files with comprehensive metrics.
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
Cycles/
|
||||||
|
├── cycles/ # Core library modules
|
||||||
|
│ ├── Analysis/ # Technical analysis indicators
|
||||||
|
│ │ ├── boillinger_band.py
|
||||||
|
│ │ ├── rsi.py
|
||||||
|
│ │ └── __init__.py
|
||||||
|
│ ├── utils/ # Utility modules
|
||||||
|
│ │ ├── storage.py # Data storage and management
|
||||||
|
│ │ ├── system.py # System utilities
|
||||||
|
│ │ ├── data_utils.py # Data processing utilities
|
||||||
|
│ │ └── gsheets.py # Google Sheets integration
|
||||||
|
│ ├── backtest.py # Core backtesting engine
|
||||||
|
│ ├── supertrend.py # SuperTrend indicator implementation
|
||||||
|
│ ├── charts.py # Visualization utilities
|
||||||
|
│ ├── market_fees.py # Trading fee calculations
|
||||||
|
│ └── __init__.py
|
||||||
|
├── docs/ # Documentation
|
||||||
|
│ ├── analysis.md # Analysis module documentation
|
||||||
|
│ ├── utils_storage.md # Storage utilities documentation
|
||||||
|
│ └── utils_system.md # System utilities documentation
|
||||||
|
├── data/ # Data directory (not in repo)
|
||||||
|
├── results/ # Backtest results (not in repo)
|
||||||
|
├── xgboost/ # Machine learning components
|
||||||
|
├── OHLCVPredictor/ # Price prediction module
|
||||||
|
├── main.py # Main execution script
|
||||||
|
├── test_bbrsi.py # Example strategy test
|
||||||
|
├── pyproject.toml # Project configuration
|
||||||
|
├── requirements.txt # Dependencies
|
||||||
|
├── uv.lock # UV lock file
|
||||||
|
└── README.md # This file
|
||||||
|
```
|
||||||
|
|
||||||
|
## Core Modules
|
||||||
|
|
||||||
|
### Backtest Engine (`cycles/backtest.py`)
|
||||||
|
|
||||||
|
The heart of the framework, providing comprehensive backtesting capabilities:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from cycles.backtest import Backtest
|
||||||
|
|
||||||
|
results = Backtest.run(
|
||||||
|
min1_df=minute_data,
|
||||||
|
df=timeframe_data,
|
||||||
|
initial_usd=10000,
|
||||||
|
stop_loss_pct=0.05,
|
||||||
|
debug=False
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Features**:
|
||||||
|
- Meta-SuperTrend strategy implementation
|
||||||
|
- Precise stop-loss execution using 1-minute data
|
||||||
|
- Comprehensive trade logging and statistics
|
||||||
|
- Fee-aware profit calculations
|
||||||
|
|
||||||
|
### Technical Analysis (`cycles/Analysis/`)
|
||||||
|
|
||||||
|
Modular technical indicator implementations:
|
||||||
|
|
||||||
|
#### RSI (Relative Strength Index)
|
||||||
|
```python
|
||||||
|
from cycles.Analysis.rsi import RSI
|
||||||
|
|
||||||
|
rsi_calculator = RSI(period=14)
|
||||||
|
data_with_rsi = rsi_calculator.calculate(df, price_column='close')
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Bollinger Bands
|
||||||
|
```python
|
||||||
|
from cycles.Analysis.boillinger_band import BollingerBands
|
||||||
|
|
||||||
|
bb = BollingerBands(period=20, std_dev_multiplier=2.0)
|
||||||
|
data_with_bb = bb.calculate(df)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Data Management (`cycles/utils/storage.py`)
|
||||||
|
|
||||||
|
Efficient data loading, processing, and result storage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from cycles.utils.storage import Storage
|
||||||
|
|
||||||
|
storage = Storage(data_dir='./data', logging=logging)
|
||||||
|
data = storage.load_data('btcusd_1-min_data.csv', '2023-01-01', '2023-12-31')
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Backtest Configuration
|
||||||
|
|
||||||
|
Create a `config.json` file with the following structure:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"start_date": "2023-01-01",
|
||||||
|
"stop_date": "2023-12-31",
|
||||||
|
"initial_usd": 10000,
|
||||||
|
"timeframes": [
|
||||||
|
"1T", // 1 minute
|
||||||
|
"5T", // 5 minutes
|
||||||
|
"15T", // 15 minutes
|
||||||
|
"1H", // 1 hour
|
||||||
|
"4H", // 4 hours
|
||||||
|
"1D" // 1 day
|
||||||
|
],
|
||||||
|
"stop_loss_pcts": [0.02, 0.05, 0.10, 0.15]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
Set the following environment variables for enhanced functionality:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Google Sheets integration (optional)
|
||||||
|
export GOOGLE_SHEETS_CREDENTIALS_PATH="/path/to/credentials.json"
|
||||||
|
|
||||||
|
# Data directory (optional, defaults to ./data)
|
||||||
|
export DATA_DIR="/path/to/data"
|
||||||
|
|
||||||
|
# Results directory (optional, defaults to ./results)
|
||||||
|
export RESULTS_DIR="/path/to/results"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### Basic Backtest
|
||||||
|
|
||||||
|
```python
|
||||||
|
import json
|
||||||
|
from cycles.utils.storage import Storage
|
||||||
|
from cycles.backtest import Backtest
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
with open('config.json', 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
# Initialize storage
|
||||||
|
storage = Storage(data_dir='./data')
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
data_1min = storage.load_data(
|
||||||
|
'btcusd_1-min_data.csv',
|
||||||
|
config['start_date'],
|
||||||
|
config['stop_date']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run backtest
|
||||||
|
results = Backtest.run(
|
||||||
|
min1_df=data_1min,
|
||||||
|
df=data_1min, # Same data for 1-minute strategy
|
||||||
|
initial_usd=config['initial_usd'],
|
||||||
|
stop_loss_pct=0.05,
|
||||||
|
debug=True
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Final USD: {results['final_usd']:.2f}")
|
||||||
|
print(f"Number of trades: {results['n_trades']}")
|
||||||
|
print(f"Win rate: {results['win_rate']:.2%}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-Timeframe Analysis
|
||||||
|
|
||||||
|
```python
|
||||||
|
from main import process
|
||||||
|
|
||||||
|
# Define timeframes to test
|
||||||
|
timeframes = ['5T', '15T', '1H', '4H']
|
||||||
|
stop_loss_pcts = [0.02, 0.05, 0.10]
|
||||||
|
|
||||||
|
# Create tasks for parallel processing
|
||||||
|
tasks = [
|
||||||
|
(timeframe, data_1min, stop_loss_pct, 10000)
|
||||||
|
for timeframe in timeframes
|
||||||
|
for stop_loss_pct in stop_loss_pcts
|
||||||
|
]
|
||||||
|
|
||||||
|
# Process each task
|
||||||
|
for task in tasks:
|
||||||
|
results, trades = process(task, debug=False)
|
||||||
|
print(f"Timeframe: {task[0]}, Stop Loss: {task[2]:.1%}")
|
||||||
|
for result in results:
|
||||||
|
print(f" Final USD: {result['final_usd']:.2f}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Strategy Development
|
||||||
|
|
||||||
|
```python
|
||||||
|
from cycles.Analysis.rsi import RSI
|
||||||
|
from cycles.Analysis.boillinger_band import BollingerBands
|
||||||
|
|
||||||
|
def custom_strategy(df):
|
||||||
|
"""Example custom trading strategy using RSI and Bollinger Bands"""
|
||||||
|
|
||||||
|
# Calculate indicators
|
||||||
|
rsi = RSI(period=14)
|
||||||
|
bb = BollingerBands(period=20, std_dev_multiplier=2.0)
|
||||||
|
|
||||||
|
df_with_rsi = rsi.calculate(df.copy())
|
||||||
|
df_with_bb = bb.calculate(df_with_rsi)
|
||||||
|
|
||||||
|
# Define signals
|
||||||
|
buy_signals = (
|
||||||
|
(df_with_bb['close'] < df_with_bb['LowerBand']) &
|
||||||
|
(df_with_bb['RSI'] < 30)
|
||||||
|
)
|
||||||
|
|
||||||
|
sell_signals = (
|
||||||
|
(df_with_bb['close'] > df_with_bb['UpperBand']) &
|
||||||
|
(df_with_bb['RSI'] > 70)
|
||||||
|
)
|
||||||
|
|
||||||
|
return buy_signals, sell_signals
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Documentation
|
||||||
|
|
||||||
|
### Core Classes
|
||||||
|
|
||||||
|
#### `Backtest`
|
||||||
|
Main backtesting engine with static methods for strategy execution.
|
||||||
|
|
||||||
|
**Methods**:
|
||||||
|
- `run(min1_df, df, initial_usd, stop_loss_pct, debug=False)`: Execute backtest
|
||||||
|
- `check_stop_loss(...)`: Check stop-loss conditions using 1-minute data
|
||||||
|
- `handle_entry(...)`: Process trade entry logic
|
||||||
|
- `handle_exit(...)`: Process trade exit logic
|
||||||
|
|
||||||
|
#### `Storage`
|
||||||
|
Data management and persistence utilities.
|
||||||
|
|
||||||
|
**Methods**:
|
||||||
|
- `load_data(filename, start_date, stop_date)`: Load and filter historical data
|
||||||
|
- `save_data(df, filename)`: Save processed data
|
||||||
|
- `write_backtest_results(...)`: Save backtest results to CSV
|
||||||
|
|
||||||
|
#### `SystemUtils`
|
||||||
|
System optimization and resource management.
|
||||||
|
|
||||||
|
**Methods**:
|
||||||
|
- `get_optimal_workers()`: Determine optimal number of parallel workers
|
||||||
|
- `get_memory_usage()`: Monitor memory consumption
|
||||||
|
|
||||||
|
### Configuration Parameters
|
||||||
|
|
||||||
|
| Parameter | Type | Description | Default |
|
||||||
|
|-----------|------|-------------|---------|
|
||||||
|
| `start_date` | string | Backtest start date (YYYY-MM-DD) | Required |
|
||||||
|
| `stop_date` | string | Backtest end date (YYYY-MM-DD) | Required |
|
||||||
|
| `initial_usd` | float | Starting capital in USD | Required |
|
||||||
|
| `timeframes` | array | List of timeframes to test | Required |
|
||||||
|
| `stop_loss_pcts` | array | Stop-loss percentages to test | Required |
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all tests
|
||||||
|
uv run pytest
|
||||||
|
|
||||||
|
# Run specific test file
|
||||||
|
uv run pytest test_bbrsi.py
|
||||||
|
|
||||||
|
# Run with verbose output
|
||||||
|
uv run pytest -v
|
||||||
|
|
||||||
|
# Run with coverage
|
||||||
|
uv run pytest --cov=cycles
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test Structure
|
||||||
|
|
||||||
|
- `test_bbrsi.py`: Example strategy testing with RSI and Bollinger Bands
|
||||||
|
- Unit tests for individual modules (add as needed)
|
||||||
|
- Integration tests for complete workflows
|
||||||
|
|
||||||
|
### Example Test
|
||||||
|
|
||||||
|
```python
|
||||||
|
# test_bbrsi.py demonstrates strategy testing
|
||||||
|
from cycles.Analysis.rsi import RSI
|
||||||
|
from cycles.Analysis.boillinger_band import BollingerBands
|
||||||
|
|
||||||
|
def test_strategy_signals():
|
||||||
|
# Load test data
|
||||||
|
storage = Storage()
|
||||||
|
data = storage.load_data('test_data.csv', '2023-01-01', '2023-02-01')
|
||||||
|
|
||||||
|
# Calculate indicators
|
||||||
|
rsi = RSI(period=14)
|
||||||
|
bb = BollingerBands(period=20)
|
||||||
|
|
||||||
|
data_with_indicators = bb.calculate(rsi.calculate(data))
|
||||||
|
|
||||||
|
# Test signal generation
|
||||||
|
assert 'RSI' in data_with_indicators.columns
|
||||||
|
assert 'UpperBand' in data_with_indicators.columns
|
||||||
|
assert 'LowerBand' in data_with_indicators.columns
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
### Development Setup
|
||||||
|
|
||||||
|
1. Fork the repository
|
||||||
|
2. Create a feature branch: `git checkout -b feature/new-indicator`
|
||||||
|
3. Install development dependencies: `uv sync --dev`
|
||||||
|
4. Make your changes following the coding standards
|
||||||
|
5. Add tests for new functionality
|
||||||
|
6. Run tests: `uv run pytest`
|
||||||
|
7. Submit a pull request
|
||||||
|
|
||||||
|
### Coding Standards
|
||||||
|
|
||||||
|
- **Maximum file size**: 250 lines
|
||||||
|
- **Maximum function size**: 50 lines
|
||||||
|
- **Documentation**: All public functions must have docstrings
|
||||||
|
- **Type hints**: Use type hints for all function parameters and returns
|
||||||
|
- **Error handling**: Include proper error handling and meaningful error messages
|
||||||
|
- **No emoji**: Avoid emoji in code and comments
|
||||||
|
|
||||||
|
### Adding New Indicators
|
||||||
|
|
||||||
|
1. Create a new file in `cycles/Analysis/`
|
||||||
|
2. Follow the existing pattern (see `rsi.py` or `boillinger_band.py`)
|
||||||
|
3. Include comprehensive docstrings and type hints
|
||||||
|
4. Add tests for the new indicator
|
||||||
|
5. Update documentation
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
### Optimization Tips
|
||||||
|
|
||||||
|
1. **Parallel Processing**: Use the built-in parallel processing for multiple timeframes
|
||||||
|
2. **Data Caching**: Cache frequently used calculations
|
||||||
|
3. **Memory Management**: Monitor memory usage for large datasets
|
||||||
|
4. **Efficient Data Types**: Use appropriate pandas data types
|
||||||
|
|
||||||
|
### Benchmarks
|
||||||
|
|
||||||
|
Typical performance on modern hardware:
|
||||||
|
- **1-minute data**: ~1M candles processed in 2-3 minutes
|
||||||
|
- **Multiple timeframes**: 4 timeframes × 4 stop-loss values in 5-10 minutes
|
||||||
|
- **Memory usage**: ~2-4GB for 1 year of 1-minute BTC data
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **Memory errors with large datasets**:
|
||||||
|
- Reduce date range or use data chunking
|
||||||
|
- Increase system RAM or use swap space
|
||||||
|
|
||||||
|
2. **Slow performance**:
|
||||||
|
- Enable parallel processing
|
||||||
|
- Reduce number of timeframes/stop-loss values
|
||||||
|
- Use SSD storage for data files
|
||||||
|
|
||||||
|
3. **Missing data errors**:
|
||||||
|
- Verify data file format and column names
|
||||||
|
- Check date range availability in data
|
||||||
|
- Ensure proper data cleaning
|
||||||
|
|
||||||
|
### Debug Mode
|
||||||
|
|
||||||
|
Enable debug mode for detailed logging:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Set debug=True for detailed output
|
||||||
|
results = Backtest.run(..., debug=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This project is licensed under the MIT License. See the LICENSE file for details.
|
||||||
|
|
||||||
|
## Changelog
|
||||||
|
|
||||||
|
### Version 0.1.0 (Current)
|
||||||
|
- Initial release
|
||||||
|
- Core backtesting framework
|
||||||
|
- SuperTrend strategy implementation
|
||||||
|
- Technical indicators (RSI, Bollinger Bands)
|
||||||
|
- Multi-timeframe analysis
|
||||||
|
- Machine learning price prediction
|
||||||
|
- Parallel processing support
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
For more detailed documentation, see the `docs/` directory or visit our [documentation website](link-to-docs).
|
||||||
|
|
||||||
|
**Support**: For questions or issues, please create an issue on GitHub or contact the development team.
|
||||||
462
backtest_runner.py
Normal file
462
backtest_runner.py
Normal file
@@ -0,0 +1,462 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import concurrent.futures
|
||||||
|
import logging
|
||||||
|
from typing import List, Tuple, Dict, Any, Optional
|
||||||
|
|
||||||
|
from cycles.utils.storage import Storage
|
||||||
|
from cycles.utils.system import SystemUtils
|
||||||
|
from cycles.utils.progress_manager import ProgressManager
|
||||||
|
from result_processor import ResultProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def _process_single_task_static(task: Tuple[str, str, pd.DataFrame, float, float], progress_callback=None) -> Tuple[List[Dict], List[Dict]]:
|
||||||
|
"""
|
||||||
|
Static version of _process_single_task for use with ProcessPoolExecutor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Tuple of (task_id, timeframe, data_1min, stop_loss_pct, initial_usd)
|
||||||
|
progress_callback: Optional progress callback function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (results, trades)
|
||||||
|
"""
|
||||||
|
task_id, timeframe, data_1min, stop_loss_pct, initial_usd = task
|
||||||
|
|
||||||
|
try:
|
||||||
|
if timeframe == "1T" or timeframe == "1min":
|
||||||
|
df = data_1min.copy()
|
||||||
|
else:
|
||||||
|
df = _resample_data_static(data_1min, timeframe)
|
||||||
|
|
||||||
|
# Create required components for processing
|
||||||
|
from cycles.utils.storage import Storage
|
||||||
|
from result_processor import ResultProcessor
|
||||||
|
|
||||||
|
# Create storage with default paths (for subprocess)
|
||||||
|
storage = Storage()
|
||||||
|
result_processor = ResultProcessor(storage)
|
||||||
|
|
||||||
|
results, trades = result_processor.process_timeframe_results(
|
||||||
|
data_1min,
|
||||||
|
df,
|
||||||
|
[stop_loss_pct],
|
||||||
|
timeframe,
|
||||||
|
initial_usd,
|
||||||
|
progress_callback=progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
return results, trades
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to process {timeframe} with stop loss {stop_loss_pct}: {e}"
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
|
||||||
|
def _resample_data_static(data_1min: pd.DataFrame, timeframe: str) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Static function to resample 1-minute data to specified timeframe
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_1min: 1-minute data DataFrame
|
||||||
|
timeframe: Target timeframe string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resampled DataFrame
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agg_dict = {
|
||||||
|
'open': 'first',
|
||||||
|
'high': 'max',
|
||||||
|
'low': 'min',
|
||||||
|
'close': 'last',
|
||||||
|
'volume': 'sum'
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'predicted_close_price' in data_1min.columns:
|
||||||
|
agg_dict['predicted_close_price'] = 'last'
|
||||||
|
|
||||||
|
resampled = data_1min.resample(timeframe).agg(agg_dict).dropna()
|
||||||
|
|
||||||
|
return resampled.reset_index()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to resample data to {timeframe}: {e}"
|
||||||
|
raise ValueError(error_msg) from e
|
||||||
|
|
||||||
|
|
||||||
|
class BacktestRunner:
|
||||||
|
"""Handles the execution of backtests across multiple timeframes and parameters"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage: Storage,
|
||||||
|
system_utils: SystemUtils,
|
||||||
|
result_processor: ResultProcessor,
|
||||||
|
logging_instance: Optional[logging.Logger] = None,
|
||||||
|
show_progress: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize backtest runner
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage: Storage instance for data operations
|
||||||
|
system_utils: System utilities for resource management
|
||||||
|
result_processor: Result processor for handling outputs
|
||||||
|
logging_instance: Optional logging instance
|
||||||
|
show_progress: Whether to show visual progress bars
|
||||||
|
"""
|
||||||
|
self.storage = storage
|
||||||
|
self.system_utils = system_utils
|
||||||
|
self.result_processor = result_processor
|
||||||
|
self.logging = logging_instance
|
||||||
|
self.show_progress = show_progress
|
||||||
|
self.progress_manager = ProgressManager() if show_progress else None
|
||||||
|
|
||||||
|
def run_backtests(
|
||||||
|
self,
|
||||||
|
data_1min: pd.DataFrame,
|
||||||
|
timeframes: List[str],
|
||||||
|
stop_loss_pcts: List[float],
|
||||||
|
initial_usd: float,
|
||||||
|
debug: bool = False
|
||||||
|
) -> Tuple[List[Dict], List[Dict]]:
|
||||||
|
"""
|
||||||
|
Run backtests across all timeframe and stop loss combinations
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_1min: 1-minute data DataFrame
|
||||||
|
timeframes: List of timeframe strings (e.g., ['1D', '6h'])
|
||||||
|
stop_loss_pcts: List of stop loss percentages
|
||||||
|
initial_usd: Initial USD amount
|
||||||
|
debug: Whether to enable debug mode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (all_results, all_trades)
|
||||||
|
"""
|
||||||
|
# Create tasks for all combinations
|
||||||
|
tasks = self._create_tasks(timeframes, stop_loss_pcts, data_1min, initial_usd)
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Starting {len(tasks)} backtest tasks")
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
return self._run_sequential(tasks)
|
||||||
|
else:
|
||||||
|
return self._run_parallel(tasks)
|
||||||
|
|
||||||
|
def _create_tasks(
|
||||||
|
self,
|
||||||
|
timeframes: List[str],
|
||||||
|
stop_loss_pcts: List[float],
|
||||||
|
data_1min: pd.DataFrame,
|
||||||
|
initial_usd: float
|
||||||
|
) -> List[Tuple]:
|
||||||
|
"""Create task tuples for processing"""
|
||||||
|
tasks = []
|
||||||
|
for timeframe in timeframes:
|
||||||
|
for stop_loss_pct in stop_loss_pcts:
|
||||||
|
task_id = f"{timeframe}_{stop_loss_pct}"
|
||||||
|
task = (task_id, timeframe, data_1min, stop_loss_pct, initial_usd)
|
||||||
|
tasks.append(task)
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _run_sequential(self, tasks: List[Tuple]) -> Tuple[List[Dict], List[Dict]]:
|
||||||
|
"""Run tasks sequentially (for debug mode)"""
|
||||||
|
# Initialize progress tracking if enabled
|
||||||
|
if self.progress_manager:
|
||||||
|
for task in tasks:
|
||||||
|
task_id, timeframe, data_1min, stop_loss_pct, initial_usd = task
|
||||||
|
|
||||||
|
# Calculate actual DataFrame size that will be processed
|
||||||
|
if timeframe == "1T" or timeframe == "1min":
|
||||||
|
actual_df_size = len(data_1min)
|
||||||
|
else:
|
||||||
|
# Get the actual resampled DataFrame size
|
||||||
|
temp_df = self._resample_data(data_1min, timeframe)
|
||||||
|
actual_df_size = len(temp_df)
|
||||||
|
|
||||||
|
task_name = f"{timeframe} SL:{stop_loss_pct:.0%}"
|
||||||
|
self.progress_manager.start_task(task_id, task_name, actual_df_size)
|
||||||
|
|
||||||
|
self.progress_manager.start_display()
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
all_trades = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
for task in tasks:
|
||||||
|
try:
|
||||||
|
# Get progress callback for this task if available
|
||||||
|
progress_callback = None
|
||||||
|
if self.progress_manager:
|
||||||
|
progress_callback = self.progress_manager.get_task_progress_callback(task[0])
|
||||||
|
|
||||||
|
results, trades = self._process_single_task(task, progress_callback)
|
||||||
|
|
||||||
|
if results:
|
||||||
|
all_results.extend(results)
|
||||||
|
if trades:
|
||||||
|
all_trades.extend(trades)
|
||||||
|
|
||||||
|
# Mark task as completed
|
||||||
|
if self.progress_manager:
|
||||||
|
self.progress_manager.complete_task(task[0])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error processing task {task[1]} with stop loss {task[3]}: {e}"
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
finally:
|
||||||
|
# Stop progress display
|
||||||
|
if self.progress_manager:
|
||||||
|
self.progress_manager.stop_display()
|
||||||
|
|
||||||
|
return all_results, all_trades
|
||||||
|
|
||||||
|
def _run_parallel(self, tasks: List[Tuple]) -> Tuple[List[Dict], List[Dict]]:
|
||||||
|
"""Run tasks in parallel using ProcessPoolExecutor"""
|
||||||
|
workers = self.system_utils.get_optimal_workers()
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Running {len(tasks)} tasks with {workers} workers")
|
||||||
|
|
||||||
|
# OPTIMIZATION: Disable progress manager for parallel execution to reduce overhead
|
||||||
|
# Progress tracking adds significant overhead in multiprocessing
|
||||||
|
if self.progress_manager and self.logging:
|
||||||
|
self.logging.info("Progress tracking disabled for parallel execution (performance optimization)")
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
all_trades = []
|
||||||
|
completed_tasks = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
with concurrent.futures.ProcessPoolExecutor(max_workers=workers) as executor:
|
||||||
|
future_to_task = {
|
||||||
|
executor.submit(_process_single_task_static, task): task
|
||||||
|
for task in tasks
|
||||||
|
}
|
||||||
|
|
||||||
|
for future in concurrent.futures.as_completed(future_to_task):
|
||||||
|
task = future_to_task[future]
|
||||||
|
try:
|
||||||
|
results, trades = future.result()
|
||||||
|
if results:
|
||||||
|
all_results.extend(results)
|
||||||
|
if trades:
|
||||||
|
all_trades.extend(trades)
|
||||||
|
|
||||||
|
completed_tasks += 1
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Completed task {task[0]} ({completed_tasks}/{len(tasks)})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Task {task[1]} with stop loss {task[3]} failed: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Parallel execution failed: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Stop progress display
|
||||||
|
if self.progress_manager:
|
||||||
|
self.progress_manager.stop_display()
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"All {len(tasks)} tasks completed successfully")
|
||||||
|
|
||||||
|
return all_results, all_trades
|
||||||
|
|
||||||
|
def _process_single_task(
|
||||||
|
self,
|
||||||
|
task: Tuple[str, str, pd.DataFrame, float, float],
|
||||||
|
progress_callback=None
|
||||||
|
) -> Tuple[List[Dict], List[Dict]]:
|
||||||
|
"""
|
||||||
|
Process a single backtest task
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Tuple of (task_id, timeframe, data_1min, stop_loss_pct, initial_usd)
|
||||||
|
progress_callback: Optional progress callback function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (results, trades)
|
||||||
|
"""
|
||||||
|
task_id, timeframe, data_1min, stop_loss_pct, initial_usd = task
|
||||||
|
|
||||||
|
try:
|
||||||
|
if timeframe == "1T" or timeframe == "1min":
|
||||||
|
df = data_1min.copy()
|
||||||
|
else:
|
||||||
|
df = self._resample_data(data_1min, timeframe)
|
||||||
|
|
||||||
|
results, trades = self.result_processor.process_timeframe_results(
|
||||||
|
data_1min,
|
||||||
|
df,
|
||||||
|
[stop_loss_pct],
|
||||||
|
timeframe,
|
||||||
|
initial_usd,
|
||||||
|
progress_callback=progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# OPTIMIZATION: Skip individual trade file saving during parallel execution
|
||||||
|
# Trade files will be saved in batch at the end
|
||||||
|
# if trades:
|
||||||
|
# self.result_processor.save_trade_file(trades, timeframe, stop_loss_pct)
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Completed task {task_id}: {len(results)} results, {len(trades)} trades")
|
||||||
|
|
||||||
|
return results, trades
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to process {timeframe} with stop loss {stop_loss_pct}: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
def _resample_data(self, data_1min: pd.DataFrame, timeframe: str) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Resample 1-minute data to specified timeframe
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_1min: 1-minute data DataFrame
|
||||||
|
timeframe: Target timeframe string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resampled DataFrame
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agg_dict = {
|
||||||
|
'open': 'first',
|
||||||
|
'high': 'max',
|
||||||
|
'low': 'min',
|
||||||
|
'close': 'last',
|
||||||
|
'volume': 'sum'
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'predicted_close_price' in data_1min.columns:
|
||||||
|
agg_dict['predicted_close_price'] = 'last'
|
||||||
|
|
||||||
|
resampled = data_1min.resample(timeframe).agg(agg_dict).dropna()
|
||||||
|
|
||||||
|
return resampled.reset_index()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to resample data to {timeframe}: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise ValueError(error_msg) from e
|
||||||
|
|
||||||
|
def _get_timeframe_factor(self, timeframe: str) -> int:
|
||||||
|
"""
|
||||||
|
Get the factor by which data is reduced when resampling to timeframe
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeframe: Target timeframe string (e.g., '1h', '4h', '1D')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Factor for estimating data size after resampling
|
||||||
|
"""
|
||||||
|
timeframe_factors = {
|
||||||
|
'1T': 1, '1min': 1,
|
||||||
|
'5T': 5, '5min': 5,
|
||||||
|
'15T': 15, '15min': 15,
|
||||||
|
'30T': 30, '30min': 30,
|
||||||
|
'1h': 60, '1H': 60,
|
||||||
|
'2h': 120, '2H': 120,
|
||||||
|
'4h': 240, '4H': 240,
|
||||||
|
'6h': 360, '6H': 360,
|
||||||
|
'8h': 480, '8H': 480,
|
||||||
|
'12h': 720, '12H': 720,
|
||||||
|
'1D': 1440, '1d': 1440,
|
||||||
|
'2D': 2880, '2d': 2880,
|
||||||
|
'3D': 4320, '3d': 4320,
|
||||||
|
'1W': 10080, '1w': 10080
|
||||||
|
}
|
||||||
|
return timeframe_factors.get(timeframe, 60) # Default to 1 hour if unknown
|
||||||
|
|
||||||
|
def load_data(self, filename: str, start_date: str, stop_date: str) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Load and validate data for backtesting
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Name of data file
|
||||||
|
start_date: Start date string
|
||||||
|
stop_date: Stop date string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded and validated DataFrame
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If data is empty or invalid
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = self.storage.load_data(filename, start_date, stop_date)
|
||||||
|
|
||||||
|
if data.empty:
|
||||||
|
raise ValueError(f"No data loaded for period {start_date} to {stop_date}")
|
||||||
|
|
||||||
|
required_columns = ['open', 'high', 'low', 'close', 'volume']
|
||||||
|
|
||||||
|
if 'predicted_close_price' in data.columns:
|
||||||
|
required_columns.append('predicted_close_price')
|
||||||
|
|
||||||
|
missing_columns = [col for col in required_columns if col not in data.columns]
|
||||||
|
|
||||||
|
if missing_columns:
|
||||||
|
raise ValueError(f"Missing required columns: {missing_columns}")
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Loaded {len(data)} rows of data from {filename}")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to load data from {filename}: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
def validate_inputs(
|
||||||
|
self,
|
||||||
|
timeframes: List[str],
|
||||||
|
stop_loss_pcts: List[float],
|
||||||
|
initial_usd: float
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Validate backtest input parameters
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeframes: List of timeframe strings
|
||||||
|
stop_loss_pcts: List of stop loss percentages
|
||||||
|
initial_usd: Initial USD amount
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any input is invalid
|
||||||
|
"""
|
||||||
|
if not timeframes:
|
||||||
|
raise ValueError("At least one timeframe must be specified")
|
||||||
|
|
||||||
|
if not stop_loss_pcts:
|
||||||
|
raise ValueError("At least one stop loss percentage must be specified")
|
||||||
|
|
||||||
|
for pct in stop_loss_pcts:
|
||||||
|
if not 0 < pct < 1:
|
||||||
|
raise ValueError(f"Stop loss percentage must be between 0 and 1, got: {pct}")
|
||||||
|
|
||||||
|
if initial_usd <= 0:
|
||||||
|
raise ValueError("Initial USD must be positive")
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info("Input validation completed successfully")
|
||||||
175
config_manager.py
Normal file
175
config_manager.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
import json
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigManager:
|
||||||
|
"""Manages configuration loading, validation, and default values for backtest operations"""
|
||||||
|
|
||||||
|
DEFAULT_CONFIG = {
|
||||||
|
"start_date": "2025-05-01",
|
||||||
|
"stop_date": datetime.datetime.today().strftime('%Y-%m-%d'),
|
||||||
|
"initial_usd": 10000,
|
||||||
|
"timeframes": ["1D", "6h", "3h", "1h", "30m", "15m", "5m", "1m"],
|
||||||
|
"stop_loss_pcts": [0.01, 0.02, 0.03, 0.05],
|
||||||
|
"data_dir": "../data",
|
||||||
|
"results_dir": "results"
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, logging_instance: Optional[logging.Logger] = None):
|
||||||
|
"""
|
||||||
|
Initialize configuration manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logging_instance: Optional logging instance for output
|
||||||
|
"""
|
||||||
|
self.logging = logging_instance
|
||||||
|
self.config = {}
|
||||||
|
|
||||||
|
def load_config(self, config_path: Optional[str] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Load configuration from file or interactive input
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to JSON config file, if None prompts for interactive input
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing validated configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If config file doesn't exist
|
||||||
|
json.JSONDecodeError: If config file has invalid JSON
|
||||||
|
ValueError: If configuration values are invalid
|
||||||
|
"""
|
||||||
|
if config_path:
|
||||||
|
self.config = self._load_from_file(config_path)
|
||||||
|
else:
|
||||||
|
self.config = self._load_interactive()
|
||||||
|
|
||||||
|
self._validate_config()
|
||||||
|
return self.config
|
||||||
|
|
||||||
|
def _load_from_file(self, config_path: str) -> Dict[str, Any]:
|
||||||
|
"""Load configuration from JSON file"""
|
||||||
|
try:
|
||||||
|
config_file = Path(config_path)
|
||||||
|
if not config_file.exists():
|
||||||
|
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
||||||
|
|
||||||
|
with open(config_file, 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Configuration loaded from {config_path}")
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
error_msg = f"Invalid JSON in configuration file {config_path}: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise json.JSONDecodeError(error_msg, e.doc, e.pos)
|
||||||
|
|
||||||
|
def _load_interactive(self) -> Dict[str, Any]:
|
||||||
|
"""Load configuration through interactive prompts"""
|
||||||
|
print("No config file provided. Please enter the following values (press Enter to use default):")
|
||||||
|
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
# Start date
|
||||||
|
start_date = input(f"Start date [{self.DEFAULT_CONFIG['start_date']}]: ") or self.DEFAULT_CONFIG['start_date']
|
||||||
|
config['start_date'] = start_date
|
||||||
|
|
||||||
|
# Stop date
|
||||||
|
stop_date = input(f"Stop date [{self.DEFAULT_CONFIG['stop_date']}]: ") or self.DEFAULT_CONFIG['stop_date']
|
||||||
|
config['stop_date'] = stop_date
|
||||||
|
|
||||||
|
# Initial USD
|
||||||
|
initial_usd_str = input(f"Initial USD [{self.DEFAULT_CONFIG['initial_usd']}]: ") or str(self.DEFAULT_CONFIG['initial_usd'])
|
||||||
|
try:
|
||||||
|
config['initial_usd'] = float(initial_usd_str)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Invalid initial USD value: {initial_usd_str}")
|
||||||
|
|
||||||
|
# Timeframes
|
||||||
|
timeframes_str = input(f"Timeframes (comma separated) [{', '.join(self.DEFAULT_CONFIG['timeframes'])}]: ") or ','.join(self.DEFAULT_CONFIG['timeframes'])
|
||||||
|
config['timeframes'] = [tf.strip() for tf in timeframes_str.split(',') if tf.strip()]
|
||||||
|
|
||||||
|
# Stop loss percentages
|
||||||
|
stop_loss_pcts_str = input(f"Stop loss pcts (comma separated) [{', '.join(str(x) for x in self.DEFAULT_CONFIG['stop_loss_pcts'])}]: ") or ','.join(str(x) for x in self.DEFAULT_CONFIG['stop_loss_pcts'])
|
||||||
|
try:
|
||||||
|
config['stop_loss_pcts'] = [float(x.strip()) for x in stop_loss_pcts_str.split(',') if x.strip()]
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Invalid stop loss percentages: {stop_loss_pcts_str}")
|
||||||
|
|
||||||
|
# Add default directories
|
||||||
|
config['data_dir'] = self.DEFAULT_CONFIG['data_dir']
|
||||||
|
config['results_dir'] = self.DEFAULT_CONFIG['results_dir']
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def _validate_config(self) -> None:
|
||||||
|
"""
|
||||||
|
Validate configuration values
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any configuration value is invalid
|
||||||
|
"""
|
||||||
|
# Validate initial USD
|
||||||
|
if self.config.get('initial_usd', 0) <= 0:
|
||||||
|
raise ValueError("Initial USD must be positive")
|
||||||
|
|
||||||
|
# Validate stop loss percentages
|
||||||
|
stop_loss_pcts = self.config.get('stop_loss_pcts', [])
|
||||||
|
for pct in stop_loss_pcts:
|
||||||
|
if not 0 < pct < 1:
|
||||||
|
raise ValueError(f"Stop loss percentage must be between 0 and 1, got: {pct}")
|
||||||
|
|
||||||
|
# Validate dates
|
||||||
|
try:
|
||||||
|
datetime.datetime.strptime(self.config['start_date'], '%Y-%m-%d')
|
||||||
|
datetime.datetime.strptime(self.config['stop_date'], '%Y-%m-%d')
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"Invalid date format (should be YYYY-MM-DD): {e}")
|
||||||
|
|
||||||
|
# Validate timeframes
|
||||||
|
timeframes = self.config.get('timeframes', [])
|
||||||
|
if not timeframes:
|
||||||
|
raise ValueError("At least one timeframe must be specified")
|
||||||
|
|
||||||
|
# Validate directories exist or can be created
|
||||||
|
for dir_key in ['data_dir', 'results_dir']:
|
||||||
|
dir_path = Path(self.config.get(dir_key, ''))
|
||||||
|
try:
|
||||||
|
dir_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Cannot create directory {dir_path}: {e}")
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info("Configuration validation completed successfully")
|
||||||
|
|
||||||
|
def get_config(self) -> Dict[str, Any]:
|
||||||
|
"""Return the current configuration"""
|
||||||
|
return self.config.copy()
|
||||||
|
|
||||||
|
def save_config(self, output_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Save current configuration to file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path: Path where to save the configuration
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
json.dump(self.config, f, indent=2)
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Configuration saved to {output_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to save configuration to {output_path}: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise
|
||||||
10
configs/flat_2021_2024_config.json
Normal file
10
configs/flat_2021_2024_config.json
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"start_date": "2021-11-01",
|
||||||
|
"stop_date": "2024-04-01",
|
||||||
|
"initial_usd": 10000,
|
||||||
|
"timeframes": ["1min", "2min", "3min", "4min", "5min", "10min", "15min", "30min", "1h", "2h", "4h", "6h", "8h", "12h", "1d"],
|
||||||
|
"stop_loss_pcts": [0.01, 0.02, 0.03, 0.04, 0.05, 0.1],
|
||||||
|
"data_dir": "../data",
|
||||||
|
"results_dir": "../results",
|
||||||
|
"debug": 0
|
||||||
|
}
|
||||||
10
configs/full_config.json
Normal file
10
configs/full_config.json
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"start_date": "2020-01-01",
|
||||||
|
"stop_date": "2025-07-08",
|
||||||
|
"initial_usd": 10000,
|
||||||
|
"timeframes": ["1h", "4h", "15ME", "5ME", "1ME"],
|
||||||
|
"stop_loss_pcts": [0.01, 0.02, 0.03, 0.05],
|
||||||
|
"data_dir": "../data",
|
||||||
|
"results_dir": "../results",
|
||||||
|
"debug": 1
|
||||||
|
}
|
||||||
10
configs/sample_config.json
Normal file
10
configs/sample_config.json
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"start_date": "2023-01-01",
|
||||||
|
"stop_date": "2025-01-15",
|
||||||
|
"initial_usd": 10000,
|
||||||
|
"timeframes": ["4h"],
|
||||||
|
"stop_loss_pcts": [0.05],
|
||||||
|
"data_dir": "../data",
|
||||||
|
"results_dir": "../results",
|
||||||
|
"debug": 0
|
||||||
|
}
|
||||||
@@ -1,248 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from scipy.signal import argrelextrema
|
|
||||||
|
|
||||||
class CycleDetector:
|
|
||||||
def __init__(self, data, timeframe='daily'):
|
|
||||||
"""
|
|
||||||
Initialize the CycleDetector with price data.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- data: DataFrame with at least 'date' or 'datetime' and 'close' columns
|
|
||||||
- timeframe: 'daily', 'weekly', or 'monthly'
|
|
||||||
"""
|
|
||||||
self.data = data.copy()
|
|
||||||
self.timeframe = timeframe
|
|
||||||
|
|
||||||
# Ensure we have a consistent date column name
|
|
||||||
if 'datetime' in self.data.columns and 'date' not in self.data.columns:
|
|
||||||
self.data.rename(columns={'datetime': 'date'}, inplace=True)
|
|
||||||
|
|
||||||
# Convert data to specified timeframe if needed
|
|
||||||
if timeframe == 'weekly' and 'date' in self.data.columns:
|
|
||||||
self.data = self._convert_data(self.data, 'W')
|
|
||||||
elif timeframe == 'monthly' and 'date' in self.data.columns:
|
|
||||||
self.data = self._convert_data(self.data, 'M')
|
|
||||||
|
|
||||||
# Add columns for local minima and maxima detection
|
|
||||||
self._add_swing_points()
|
|
||||||
|
|
||||||
def _convert_data(self, data, timeframe):
|
|
||||||
"""Convert daily data to 'timeframe' timeframe."""
|
|
||||||
data['date'] = pd.to_datetime(data['date'])
|
|
||||||
data.set_index('date', inplace=True)
|
|
||||||
weekly = data.resample(timeframe).agg({
|
|
||||||
'open': 'first',
|
|
||||||
'high': 'max',
|
|
||||||
'low': 'min',
|
|
||||||
'close': 'last',
|
|
||||||
'volume': 'sum'
|
|
||||||
})
|
|
||||||
return weekly.reset_index()
|
|
||||||
|
|
||||||
|
|
||||||
def _add_swing_points(self, window=5):
|
|
||||||
"""
|
|
||||||
Identify swing points (local minima and maxima).
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- window: The window size for local minima/maxima detection
|
|
||||||
"""
|
|
||||||
# Set the index to make calculations easier
|
|
||||||
if 'date' in self.data.columns:
|
|
||||||
self.data.set_index('date', inplace=True)
|
|
||||||
|
|
||||||
# Detect local minima (swing lows)
|
|
||||||
min_idx = argrelextrema(self.data['low'].values, np.less, order=window)[0]
|
|
||||||
self.data['swing_low'] = False
|
|
||||||
self.data.iloc[min_idx, self.data.columns.get_loc('swing_low')] = True
|
|
||||||
|
|
||||||
# Detect local maxima (swing highs)
|
|
||||||
max_idx = argrelextrema(self.data['high'].values, np.greater, order=window)[0]
|
|
||||||
self.data['swing_high'] = False
|
|
||||||
self.data.iloc[max_idx, self.data.columns.get_loc('swing_high')] = True
|
|
||||||
|
|
||||||
# Reset index
|
|
||||||
self.data.reset_index(inplace=True)
|
|
||||||
|
|
||||||
def find_cycle_lows(self):
|
|
||||||
"""Find all swing lows which represent cycle lows."""
|
|
||||||
swing_low_dates = self.data[self.data['swing_low']]['date'].values
|
|
||||||
return swing_low_dates
|
|
||||||
|
|
||||||
def calculate_cycle_lengths(self):
|
|
||||||
"""Calculate the lengths of each cycle between consecutive lows."""
|
|
||||||
swing_low_indices = np.where(self.data['swing_low'])[0]
|
|
||||||
cycle_lengths = np.diff(swing_low_indices)
|
|
||||||
return cycle_lengths
|
|
||||||
|
|
||||||
def get_average_cycle_length(self):
|
|
||||||
"""Calculate the average cycle length."""
|
|
||||||
cycle_lengths = self.calculate_cycle_lengths()
|
|
||||||
if len(cycle_lengths) > 0:
|
|
||||||
return np.mean(cycle_lengths)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_cycle_window(self, tolerance=0.10):
|
|
||||||
"""
|
|
||||||
Get the cycle window with the specified tolerance.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- tolerance: The tolerance as a percentage (default: 10%)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- tuple: (min_cycle_length, avg_cycle_length, max_cycle_length)
|
|
||||||
"""
|
|
||||||
avg_length = self.get_average_cycle_length()
|
|
||||||
if avg_length is not None:
|
|
||||||
min_length = avg_length * (1 - tolerance)
|
|
||||||
max_length = avg_length * (1 + tolerance)
|
|
||||||
return (min_length, avg_length, max_length)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def detect_two_drives_pattern(self, lookback=10):
|
|
||||||
"""
|
|
||||||
Detect 2-drives pattern: a swing low, counter trend bounce, and a lower low.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- lookback: Number of periods to look back
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- list: Indices where 2-drives patterns are detected
|
|
||||||
"""
|
|
||||||
patterns = []
|
|
||||||
|
|
||||||
for i in range(lookback, len(self.data) - 1):
|
|
||||||
if not self.data.iloc[i]['swing_low']:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get the segment of data to check for pattern
|
|
||||||
segment = self.data.iloc[i-lookback:i+1]
|
|
||||||
swing_lows = segment[segment['swing_low']]['low'].values
|
|
||||||
|
|
||||||
if len(swing_lows) >= 2 and swing_lows[-1] < swing_lows[-2]:
|
|
||||||
# Check if there was a bounce between the two lows
|
|
||||||
between_lows = segment.iloc[-len(swing_lows):-1]
|
|
||||||
if len(between_lows) > 0 and max(between_lows['high']) > swing_lows[-2]:
|
|
||||||
patterns.append(i)
|
|
||||||
|
|
||||||
return patterns
|
|
||||||
|
|
||||||
def detect_v_shaped_lows(self, window=5, threshold=0.02):
|
|
||||||
"""
|
|
||||||
Detect V-shaped cycle lows (sharp decline followed by sharp rise).
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- window: Window to look for sharp price changes
|
|
||||||
- threshold: Percentage change threshold to consider 'sharp'
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- list: Indices where V-shaped patterns are detected
|
|
||||||
"""
|
|
||||||
patterns = []
|
|
||||||
|
|
||||||
# Find all swing lows
|
|
||||||
swing_low_indices = np.where(self.data['swing_low'])[0]
|
|
||||||
|
|
||||||
for idx in swing_low_indices:
|
|
||||||
# Need enough data points before and after
|
|
||||||
if idx < window or idx + window >= len(self.data):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get the low price at this swing low
|
|
||||||
low_price = self.data.iloc[idx]['low']
|
|
||||||
|
|
||||||
# Check for sharp decline before low (at least window bars before)
|
|
||||||
before_segment = self.data.iloc[max(0, idx-window):idx]
|
|
||||||
if len(before_segment) > 0:
|
|
||||||
max_before = before_segment['high'].max()
|
|
||||||
decline = (max_before - low_price) / max_before
|
|
||||||
|
|
||||||
# Check for sharp rise after low (at least window bars after)
|
|
||||||
after_segment = self.data.iloc[idx+1:min(len(self.data), idx+window+1)]
|
|
||||||
if len(after_segment) > 0:
|
|
||||||
max_after = after_segment['high'].max()
|
|
||||||
rise = (max_after - low_price) / low_price
|
|
||||||
|
|
||||||
# Both decline and rise must exceed threshold to be considered V-shaped
|
|
||||||
if decline > threshold and rise > threshold:
|
|
||||||
patterns.append(idx)
|
|
||||||
|
|
||||||
return patterns
|
|
||||||
|
|
||||||
def plot_cycles(self, pattern_detection=None, title_suffix=''):
|
|
||||||
"""
|
|
||||||
Plot the price data with cycle lows and detected patterns.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- pattern_detection: 'two_drives', 'v_shape', or None
|
|
||||||
- title_suffix: Optional suffix for the plot title
|
|
||||||
"""
|
|
||||||
plt.figure(figsize=(14, 7))
|
|
||||||
|
|
||||||
# Determine the date column name (could be 'date' or 'datetime')
|
|
||||||
date_col = 'date' if 'date' in self.data.columns else 'datetime'
|
|
||||||
|
|
||||||
# Plot price data
|
|
||||||
plt.plot(self.data[date_col], self.data['close'], label='Close Price')
|
|
||||||
|
|
||||||
# Calculate a consistent vertical position for indicators based on price range
|
|
||||||
price_range = self.data['close'].max() - self.data['close'].min()
|
|
||||||
indicator_offset = price_range * 0.01 # 1% of price range
|
|
||||||
|
|
||||||
# Plot cycle lows (now at a fixed offset below the low price)
|
|
||||||
swing_lows = self.data[self.data['swing_low']]
|
|
||||||
plt.scatter(swing_lows[date_col], swing_lows['low'] - indicator_offset,
|
|
||||||
color='green', marker='^', s=100, label='Cycle Lows')
|
|
||||||
|
|
||||||
# Plot specific patterns if requested
|
|
||||||
if 'two_drives' in pattern_detection:
|
|
||||||
pattern_indices = self.detect_two_drives_pattern()
|
|
||||||
if pattern_indices:
|
|
||||||
patterns = self.data.iloc[pattern_indices]
|
|
||||||
plt.scatter(patterns[date_col], patterns['low'] - indicator_offset * 2,
|
|
||||||
color='red', marker='o', s=150, label='Two Drives Pattern')
|
|
||||||
|
|
||||||
elif 'v_shape' in pattern_detection:
|
|
||||||
pattern_indices = self.detect_v_shaped_lows()
|
|
||||||
if pattern_indices:
|
|
||||||
patterns = self.data.iloc[pattern_indices]
|
|
||||||
plt.scatter(patterns[date_col], patterns['low'] - indicator_offset * 2,
|
|
||||||
color='purple', marker='o', s=150, label='V-Shape Pattern')
|
|
||||||
|
|
||||||
# Add cycle lengths and averages
|
|
||||||
cycle_lengths = self.calculate_cycle_lengths()
|
|
||||||
avg_cycle = self.get_average_cycle_length()
|
|
||||||
cycle_window = self.get_cycle_window()
|
|
||||||
|
|
||||||
window_text = ""
|
|
||||||
if cycle_window:
|
|
||||||
window_text = f"Tolerance Window: [{cycle_window[0]:.2f} - {cycle_window[2]:.2f}]"
|
|
||||||
|
|
||||||
plt.title(f"Detected Cycles - {self.timeframe.capitalize()} Timeframe {title_suffix}\n"
|
|
||||||
f"Average Cycle Length: {avg_cycle:.2f} periods, {window_text}")
|
|
||||||
|
|
||||||
plt.legend()
|
|
||||||
plt.grid(True)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
# Usage example:
|
|
||||||
# 1. Load your data
|
|
||||||
# data = pd.read_csv('your_price_data.csv')
|
|
||||||
|
|
||||||
# 2. Create cycle detector instances for different timeframes
|
|
||||||
# weekly_detector = CycleDetector(data, timeframe='weekly')
|
|
||||||
# daily_detector = CycleDetector(data, timeframe='daily')
|
|
||||||
|
|
||||||
# 3. Analyze cycles
|
|
||||||
# weekly_cycle_length = weekly_detector.get_average_cycle_length()
|
|
||||||
# daily_cycle_length = daily_detector.get_average_cycle_length()
|
|
||||||
|
|
||||||
# 4. Detect patterns
|
|
||||||
# two_drives = weekly_detector.detect_two_drives_pattern()
|
|
||||||
# v_shapes = daily_detector.detect_v_shaped_lows()
|
|
||||||
|
|
||||||
# 5. Visualize
|
|
||||||
# weekly_detector.plot_cycles(pattern_detection='two_drives')
|
|
||||||
# daily_detector.plot_cycles(pattern_detection='v_shape')
|
|
||||||
0
cycles/Analysis/__init__.py
Normal file
0
cycles/Analysis/__init__.py
Normal file
50
cycles/Analysis/boillinger_band.py
Normal file
50
cycles/Analysis/boillinger_band.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
class BollingerBands:
|
||||||
|
"""
|
||||||
|
Calculates Bollinger Bands for given financial data.
|
||||||
|
"""
|
||||||
|
def __init__(self, period: int = 20, std_dev_multiplier: float = 2.0):
|
||||||
|
"""
|
||||||
|
Initializes the BollingerBands calculator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
period (int): The period for the moving average and standard deviation.
|
||||||
|
std_dev_multiplier (float): The number of standard deviations for the upper and lower bands.
|
||||||
|
"""
|
||||||
|
if period <= 0:
|
||||||
|
raise ValueError("Period must be a positive integer.")
|
||||||
|
if std_dev_multiplier <= 0:
|
||||||
|
raise ValueError("Standard deviation multiplier must be positive.")
|
||||||
|
|
||||||
|
self.period = period
|
||||||
|
self.std_dev_multiplier = std_dev_multiplier
|
||||||
|
|
||||||
|
def calculate(self, data_df: pd.DataFrame, price_column: str = 'close') -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Calculates Bollinger Bands and adds them to the DataFrame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_df (pd.DataFrame): DataFrame with price data. Must include the price_column.
|
||||||
|
price_column (str): The name of the column containing the price data (e.g., 'close').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: The original DataFrame with added columns:
|
||||||
|
'SMA' (Simple Moving Average),
|
||||||
|
'UpperBand',
|
||||||
|
'LowerBand'.
|
||||||
|
"""
|
||||||
|
if price_column not in data_df.columns:
|
||||||
|
raise ValueError(f"Price column '{price_column}' not found in DataFrame.")
|
||||||
|
|
||||||
|
# Calculate SMA
|
||||||
|
data_df['SMA'] = data_df[price_column].rolling(window=self.period).mean()
|
||||||
|
|
||||||
|
# Calculate Standard Deviation
|
||||||
|
std_dev = data_df[price_column].rolling(window=self.period).std()
|
||||||
|
|
||||||
|
# Calculate Upper and Lower Bands
|
||||||
|
data_df['UpperBand'] = data_df['SMA'] + (self.std_dev_multiplier * std_dev)
|
||||||
|
data_df['LowerBand'] = data_df['SMA'] - (self.std_dev_multiplier * std_dev)
|
||||||
|
|
||||||
|
return data_df
|
||||||
109
cycles/Analysis/rsi.py
Normal file
109
cycles/Analysis/rsi.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class RSI:
|
||||||
|
"""
|
||||||
|
A class to calculate the Relative Strength Index (RSI).
|
||||||
|
"""
|
||||||
|
def __init__(self, period: int = 14):
|
||||||
|
"""
|
||||||
|
Initializes the RSI calculator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
period (int): The period for RSI calculation. Default is 14.
|
||||||
|
Must be a positive integer.
|
||||||
|
"""
|
||||||
|
if not isinstance(period, int) or period <= 0:
|
||||||
|
raise ValueError("Period must be a positive integer.")
|
||||||
|
self.period = period
|
||||||
|
|
||||||
|
def calculate(self, data_df: pd.DataFrame, price_column: str = 'close') -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Calculates the RSI and adds it as a column to the input DataFrame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_df (pd.DataFrame): DataFrame with historical price data.
|
||||||
|
Must contain the 'price_column'.
|
||||||
|
price_column (str): The name of the column containing price data.
|
||||||
|
Default is 'close'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: The input DataFrame with an added 'RSI' column.
|
||||||
|
Returns the original DataFrame with no 'RSI' column
|
||||||
|
if the period is larger than the number of data points.
|
||||||
|
"""
|
||||||
|
if price_column not in data_df.columns:
|
||||||
|
raise ValueError(f"Price column '{price_column}' not found in DataFrame.")
|
||||||
|
|
||||||
|
if len(data_df) < self.period:
|
||||||
|
print(f"Warning: Data length ({len(data_df)}) is less than RSI period ({self.period}). RSI will not be calculated.")
|
||||||
|
return data_df.copy()
|
||||||
|
|
||||||
|
df = data_df.copy()
|
||||||
|
delta = df[price_column].diff(1)
|
||||||
|
|
||||||
|
gain = delta.where(delta > 0, 0)
|
||||||
|
loss = -delta.where(delta < 0, 0) # Ensure loss is positive
|
||||||
|
|
||||||
|
# Calculate initial average gain and loss (SMA)
|
||||||
|
avg_gain = gain.rolling(window=self.period, min_periods=self.period).mean().iloc[self.period -1:self.period]
|
||||||
|
avg_loss = loss.rolling(window=self.period, min_periods=self.period).mean().iloc[self.period -1:self.period]
|
||||||
|
|
||||||
|
|
||||||
|
# Calculate subsequent average gains and losses (EMA-like)
|
||||||
|
# Pre-allocate lists for gains and losses to avoid repeated appending to Series
|
||||||
|
gains = [0.0] * len(df)
|
||||||
|
losses = [0.0] * len(df)
|
||||||
|
|
||||||
|
if not avg_gain.empty:
|
||||||
|
gains[self.period -1] = avg_gain.iloc[0]
|
||||||
|
if not avg_loss.empty:
|
||||||
|
losses[self.period -1] = avg_loss.iloc[0]
|
||||||
|
|
||||||
|
|
||||||
|
for i in range(self.period, len(df)):
|
||||||
|
gains[i] = ((gains[i-1] * (self.period - 1)) + gain.iloc[i]) / self.period
|
||||||
|
losses[i] = ((losses[i-1] * (self.period - 1)) + loss.iloc[i]) / self.period
|
||||||
|
|
||||||
|
df['avg_gain'] = pd.Series(gains, index=df.index)
|
||||||
|
df['avg_loss'] = pd.Series(losses, index=df.index)
|
||||||
|
|
||||||
|
# Calculate RS
|
||||||
|
# Handle division by zero: if avg_loss is 0, RS is undefined or infinite.
|
||||||
|
# If avg_loss is 0 and avg_gain is also 0, RSI is conventionally 50.
|
||||||
|
# If avg_loss is 0 and avg_gain > 0, RSI is conventionally 100.
|
||||||
|
rs = df['avg_gain'] / df['avg_loss']
|
||||||
|
|
||||||
|
# Calculate RSI
|
||||||
|
# RSI = 100 - (100 / (1 + RS))
|
||||||
|
# If avg_loss is 0:
|
||||||
|
# If avg_gain > 0, RS -> inf, RSI -> 100
|
||||||
|
# If avg_gain == 0, RS -> NaN (0/0), RSI -> 50 (conventionally, or could be 0 or 100 depending on interpretation)
|
||||||
|
# We will use a common convention where RSI is 100 if avg_loss is 0 and avg_gain > 0,
|
||||||
|
# and RSI is 0 if avg_loss is 0 and avg_gain is 0 (or 50, let's use 0 to indicate no strength if both are 0).
|
||||||
|
# However, to avoid NaN from 0/0, it's better to calculate RSI directly with conditions.
|
||||||
|
|
||||||
|
rsi_values = []
|
||||||
|
for i in range(len(df)):
|
||||||
|
avg_g = df['avg_gain'].iloc[i]
|
||||||
|
avg_l = df['avg_loss'].iloc[i]
|
||||||
|
|
||||||
|
if i < self.period -1 : # Not enough data for initial SMA
|
||||||
|
rsi_values.append(np.nan)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if avg_l == 0:
|
||||||
|
if avg_g == 0:
|
||||||
|
rsi_values.append(50) # Or 0, or np.nan depending on how you want to treat this. 50 implies neutrality.
|
||||||
|
else:
|
||||||
|
rsi_values.append(100) # Max strength
|
||||||
|
else:
|
||||||
|
rs_val = avg_g / avg_l
|
||||||
|
rsi_values.append(100 - (100 / (1 + rs_val)))
|
||||||
|
|
||||||
|
df['RSI'] = pd.Series(rsi_values, index=df.index)
|
||||||
|
|
||||||
|
# Remove intermediate columns if desired, or keep them for debugging
|
||||||
|
# df.drop(columns=['avg_gain', 'avg_loss'], inplace=True)
|
||||||
|
|
||||||
|
return df
|
||||||
0
cycles/__init__.py
Normal file
0
cycles/__init__.py
Normal file
332
cycles/backtest.py
Normal file
332
cycles/backtest.py
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
|
||||||
|
from cycles.supertrend import Supertrends
|
||||||
|
from cycles.market_fees import MarketFees
|
||||||
|
|
||||||
|
class Backtest:
|
||||||
|
@staticmethod
|
||||||
|
def run(min1_df, df, initial_usd, stop_loss_pct, progress_callback=None, verbose=False):
|
||||||
|
"""
|
||||||
|
Backtest a simple strategy using the meta supertrend (all three supertrends agree).
|
||||||
|
Buys when meta supertrend is positive, sells when negative, applies a percentage stop loss.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- min1_df: pandas DataFrame, 1-minute timeframe data for more accurate stop loss checking (optional)
|
||||||
|
- df: pandas DataFrame, main timeframe data for signals
|
||||||
|
- initial_usd: float, starting USD amount
|
||||||
|
- stop_loss_pct: float, stop loss as a fraction (e.g. 0.05 for 5%)
|
||||||
|
- progress_callback: callable, optional callback function to report progress (current_step)
|
||||||
|
- verbose: bool, enable debug logging for stop loss checks
|
||||||
|
"""
|
||||||
|
_df = df.copy().reset_index()
|
||||||
|
|
||||||
|
# Ensure we have a timestamp column regardless of original index name
|
||||||
|
if 'timestamp' not in _df.columns:
|
||||||
|
# If reset_index() created a column with the original index name, rename it
|
||||||
|
if len(_df.columns) > 0 and _df.columns[0] not in ['open', 'high', 'low', 'close', 'volume', 'predicted_close_price']:
|
||||||
|
_df = _df.rename(columns={_df.columns[0]: 'timestamp'})
|
||||||
|
else:
|
||||||
|
raise ValueError("Unable to identify timestamp column in DataFrame")
|
||||||
|
|
||||||
|
_df['timestamp'] = pd.to_datetime(_df['timestamp'])
|
||||||
|
|
||||||
|
supertrends = Supertrends(_df, verbose=False, close_column='predicted_close_price')
|
||||||
|
|
||||||
|
supertrend_results_list = supertrends.calculate_supertrend_indicators()
|
||||||
|
trends = [st['results']['trend'] for st in supertrend_results_list]
|
||||||
|
trends_arr = np.stack(trends, axis=1)
|
||||||
|
meta_trend = np.where((trends_arr[:,0] == trends_arr[:,1]) & (trends_arr[:,1] == trends_arr[:,2]),
|
||||||
|
trends_arr[:,0], 0)
|
||||||
|
# Shift meta_trend by one to avoid lookahead bias
|
||||||
|
meta_trend_signal = np.roll(meta_trend, 1)
|
||||||
|
meta_trend_signal[0] = 0 # or np.nan, but 0 means 'no signal' for first bar
|
||||||
|
|
||||||
|
position = 0 # 0 = no position, 1 = long
|
||||||
|
entry_price = 0
|
||||||
|
usd = initial_usd
|
||||||
|
coin = 0
|
||||||
|
trade_log = []
|
||||||
|
max_balance = initial_usd
|
||||||
|
drawdowns = []
|
||||||
|
trades = []
|
||||||
|
entry_time = None
|
||||||
|
stop_loss_count = 0 # Track number of stop losses
|
||||||
|
|
||||||
|
# Ensure min1_df has proper DatetimeIndex
|
||||||
|
if min1_df is not None and not min1_df.empty:
|
||||||
|
min1_df.index = pd.to_datetime(min1_df.index)
|
||||||
|
|
||||||
|
for i in range(1, len(_df)):
|
||||||
|
# Report progress if callback is provided
|
||||||
|
if progress_callback:
|
||||||
|
# Update more frequently for better responsiveness
|
||||||
|
update_frequency = max(1, len(_df) // 50) # Update every 2% of dataset (50 updates total)
|
||||||
|
if i % update_frequency == 0 or i == len(_df) - 1: # Always update on last iteration
|
||||||
|
if verbose: # Only print in verbose mode to avoid spam
|
||||||
|
print(f"DEBUG: Progress callback called with i={i}, total={len(_df)-1}")
|
||||||
|
progress_callback(i)
|
||||||
|
|
||||||
|
price_open = _df['open'].iloc[i]
|
||||||
|
price_close = _df['close'].iloc[i]
|
||||||
|
date = _df['timestamp'].iloc[i]
|
||||||
|
prev_mt = meta_trend_signal[i-1]
|
||||||
|
curr_mt = meta_trend_signal[i]
|
||||||
|
|
||||||
|
# Check stop loss if in position
|
||||||
|
if position == 1:
|
||||||
|
stop_loss_result = Backtest.check_stop_loss(
|
||||||
|
min1_df,
|
||||||
|
entry_time,
|
||||||
|
date,
|
||||||
|
entry_price,
|
||||||
|
stop_loss_pct,
|
||||||
|
coin,
|
||||||
|
verbose=verbose
|
||||||
|
)
|
||||||
|
if stop_loss_result is not None:
|
||||||
|
trade_log_entry, position, coin, entry_price, usd = stop_loss_result
|
||||||
|
trade_log.append(trade_log_entry)
|
||||||
|
stop_loss_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Entry: only if not in position and signal changes to 1
|
||||||
|
if position == 0 and prev_mt != 1 and curr_mt == 1:
|
||||||
|
entry_result = Backtest.handle_entry(usd, price_open, date)
|
||||||
|
coin, entry_price, entry_time, usd, position, trade_log_entry = entry_result
|
||||||
|
trade_log.append(trade_log_entry)
|
||||||
|
|
||||||
|
# Exit: only if in position and signal changes from 1 to -1
|
||||||
|
elif position == 1 and prev_mt == 1 and curr_mt == -1:
|
||||||
|
exit_result = Backtest.handle_exit(coin, price_open, entry_price, entry_time, date)
|
||||||
|
usd, coin, position, entry_price, trade_log_entry = exit_result
|
||||||
|
trade_log.append(trade_log_entry)
|
||||||
|
|
||||||
|
# Track drawdown
|
||||||
|
balance = usd if position == 0 else coin * price_close
|
||||||
|
if balance > max_balance:
|
||||||
|
max_balance = balance
|
||||||
|
drawdown = (max_balance - balance) / max_balance
|
||||||
|
drawdowns.append(drawdown)
|
||||||
|
|
||||||
|
# Report completion if callback is provided
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(len(_df) - 1)
|
||||||
|
|
||||||
|
# If still in position at end, sell at last close
|
||||||
|
if position == 1:
|
||||||
|
exit_result = Backtest.handle_exit(coin, _df['close'].iloc[-1], entry_price, entry_time, _df['timestamp'].iloc[-1])
|
||||||
|
usd, coin, position, entry_price, trade_log_entry = exit_result
|
||||||
|
trade_log.append(trade_log_entry)
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
final_balance = usd
|
||||||
|
n_trades = len(trade_log)
|
||||||
|
wins = [1 for t in trade_log if t['exit'] is not None and t['exit'] > t['entry']]
|
||||||
|
win_rate = len(wins) / n_trades if n_trades > 0 else 0
|
||||||
|
max_drawdown = max(drawdowns) if drawdowns else 0
|
||||||
|
avg_trade = np.mean([t['exit']/t['entry']-1 for t in trade_log if t['exit'] is not None]) if trade_log else 0
|
||||||
|
|
||||||
|
trades = []
|
||||||
|
total_fees_usd = 0.0
|
||||||
|
for trade in trade_log:
|
||||||
|
if trade['exit'] is not None:
|
||||||
|
profit_pct = (trade['exit'] - trade['entry']) / trade['entry']
|
||||||
|
else:
|
||||||
|
profit_pct = 0.0
|
||||||
|
|
||||||
|
# Validate fee_usd field
|
||||||
|
if 'fee_usd' not in trade:
|
||||||
|
raise ValueError(f"Trade missing required field 'fee_usd': {trade}")
|
||||||
|
fee_usd = trade['fee_usd']
|
||||||
|
if fee_usd is None:
|
||||||
|
raise ValueError(f"Trade fee_usd is None: {trade}")
|
||||||
|
|
||||||
|
# Validate trade type field
|
||||||
|
if 'type' not in trade:
|
||||||
|
raise ValueError(f"Trade missing required field 'type': {trade}")
|
||||||
|
trade_type = trade['type']
|
||||||
|
if trade_type is None:
|
||||||
|
raise ValueError(f"Trade type is None: {trade}")
|
||||||
|
|
||||||
|
trades.append({
|
||||||
|
'entry_time': trade['entry_time'],
|
||||||
|
'exit_time': trade['exit_time'],
|
||||||
|
'entry': trade['entry'],
|
||||||
|
'exit': trade['exit'],
|
||||||
|
'profit_pct': profit_pct,
|
||||||
|
'type': trade_type,
|
||||||
|
'fee_usd': fee_usd
|
||||||
|
})
|
||||||
|
total_fees_usd += fee_usd
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"initial_usd": initial_usd,
|
||||||
|
"final_usd": final_balance,
|
||||||
|
"n_trades": n_trades,
|
||||||
|
"n_stop_loss": stop_loss_count, # Add stop loss count
|
||||||
|
"win_rate": win_rate,
|
||||||
|
"max_drawdown": max_drawdown,
|
||||||
|
"avg_trade": avg_trade,
|
||||||
|
"trade_log": trade_log,
|
||||||
|
"trades": trades,
|
||||||
|
"total_fees_usd": total_fees_usd,
|
||||||
|
}
|
||||||
|
if n_trades > 0:
|
||||||
|
results["first_trade"] = {
|
||||||
|
"entry_time": trade_log[0]['entry_time'],
|
||||||
|
"entry": trade_log[0]['entry']
|
||||||
|
}
|
||||||
|
results["last_trade"] = {
|
||||||
|
"exit_time": trade_log[-1]['exit_time'],
|
||||||
|
"exit": trade_log[-1]['exit']
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_stop_loss(min1_df, entry_time, current_time, entry_price, stop_loss_pct, coin, verbose=False):
|
||||||
|
"""
|
||||||
|
Check if stop loss should be triggered based on 1-minute data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min1_df: 1-minute DataFrame with DatetimeIndex
|
||||||
|
entry_time: Entry timestamp
|
||||||
|
current_time: Current timestamp
|
||||||
|
entry_price: Entry price
|
||||||
|
stop_loss_pct: Stop loss percentage (e.g. 0.05 for 5%)
|
||||||
|
coin: Current coin position
|
||||||
|
verbose: Enable debug logging
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (trade_log_entry, position, coin, entry_price, usd) if stop loss triggered, None otherwise
|
||||||
|
"""
|
||||||
|
if min1_df is None or min1_df.empty:
|
||||||
|
if verbose:
|
||||||
|
print("Warning: No 1-minute data available for stop loss checking")
|
||||||
|
return None
|
||||||
|
|
||||||
|
stop_price = entry_price * (1 - stop_loss_pct)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Ensure min1_df has a DatetimeIndex
|
||||||
|
if not isinstance(min1_df.index, pd.DatetimeIndex):
|
||||||
|
if verbose:
|
||||||
|
print("Warning: min1_df does not have DatetimeIndex")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert entry_time and current_time to pandas Timestamps for comparison
|
||||||
|
entry_ts = pd.to_datetime(entry_time)
|
||||||
|
current_ts = pd.to_datetime(current_time)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Checking stop loss from {entry_ts} to {current_ts}, stop_price: {stop_price:.2f}")
|
||||||
|
|
||||||
|
# Handle edge case where entry and current time are the same (1-minute timeframe)
|
||||||
|
if entry_ts == current_ts:
|
||||||
|
if verbose:
|
||||||
|
print("Entry and current time are the same, no range to check")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find the range of 1-minute data to check (exclusive of entry time, inclusive of current time)
|
||||||
|
# We start from the candle AFTER entry to avoid checking the entry candle itself
|
||||||
|
start_check_time = entry_ts + pd.Timedelta(minutes=1)
|
||||||
|
|
||||||
|
# Get the slice of data to check for stop loss
|
||||||
|
mask = (min1_df.index > entry_ts) & (min1_df.index <= current_ts)
|
||||||
|
min1_slice = min1_df.loc[mask]
|
||||||
|
|
||||||
|
if len(min1_slice) == 0:
|
||||||
|
if verbose:
|
||||||
|
print(f"No 1-minute data found between {start_check_time} and {current_ts}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Checking {len(min1_slice)} candles for stop loss")
|
||||||
|
|
||||||
|
# Check if any low price in the slice hits the stop loss
|
||||||
|
stop_triggered = (min1_slice['low'] <= stop_price).any()
|
||||||
|
|
||||||
|
if stop_triggered:
|
||||||
|
# Find the exact candle where stop loss was triggered
|
||||||
|
stop_candle = min1_slice[min1_slice['low'] <= stop_price].iloc[0]
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Stop loss triggered at {stop_candle.name}, low: {stop_candle['low']:.2f}")
|
||||||
|
|
||||||
|
# More realistic fill: if open < stop, fill at open, else at stop
|
||||||
|
if stop_candle['open'] < stop_price:
|
||||||
|
sell_price = stop_candle['open']
|
||||||
|
if verbose:
|
||||||
|
print(f"Filled at open price: {sell_price:.2f}")
|
||||||
|
else:
|
||||||
|
sell_price = stop_price
|
||||||
|
if verbose:
|
||||||
|
print(f"Filled at stop price: {sell_price:.2f}")
|
||||||
|
|
||||||
|
btc_to_sell = coin
|
||||||
|
usd_gross = btc_to_sell * sell_price
|
||||||
|
exit_fee = MarketFees.calculate_okx_taker_maker_fee(usd_gross, is_maker=False)
|
||||||
|
usd_after_stop = usd_gross - exit_fee
|
||||||
|
|
||||||
|
trade_log_entry = {
|
||||||
|
'type': 'STOP',
|
||||||
|
'entry': entry_price,
|
||||||
|
'exit': sell_price,
|
||||||
|
'entry_time': entry_time,
|
||||||
|
'exit_time': stop_candle.name,
|
||||||
|
'fee_usd': exit_fee
|
||||||
|
}
|
||||||
|
# After stop loss, reset position and entry, return USD balance
|
||||||
|
return trade_log_entry, 0, 0, 0, usd_after_stop
|
||||||
|
elif verbose:
|
||||||
|
print(f"No stop loss triggered, min low in range: {min1_slice['low'].min():.2f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# In case of any error, don't trigger stop loss but log the issue
|
||||||
|
error_msg = f"Warning: Stop loss check failed: {e}"
|
||||||
|
print(error_msg)
|
||||||
|
if verbose:
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def handle_entry(usd, price_open, date):
|
||||||
|
entry_fee = MarketFees.calculate_okx_taker_maker_fee(usd, is_maker=False)
|
||||||
|
usd_after_fee = usd - entry_fee
|
||||||
|
coin = usd_after_fee / price_open
|
||||||
|
entry_price = price_open
|
||||||
|
entry_time = date
|
||||||
|
usd = 0
|
||||||
|
position = 1
|
||||||
|
trade_log_entry = {
|
||||||
|
'type': 'BUY',
|
||||||
|
'entry': entry_price,
|
||||||
|
'exit': None,
|
||||||
|
'entry_time': entry_time,
|
||||||
|
'exit_time': None,
|
||||||
|
'fee_usd': entry_fee
|
||||||
|
}
|
||||||
|
return coin, entry_price, entry_time, usd, position, trade_log_entry
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def handle_exit(coin, price_open, entry_price, entry_time, date):
|
||||||
|
btc_to_sell = coin
|
||||||
|
usd_gross = btc_to_sell * price_open
|
||||||
|
exit_fee = MarketFees.calculate_okx_taker_maker_fee(usd_gross, is_maker=False)
|
||||||
|
usd = usd_gross - exit_fee
|
||||||
|
trade_log_entry = {
|
||||||
|
'type': 'SELL',
|
||||||
|
'entry': entry_price,
|
||||||
|
'exit': price_open,
|
||||||
|
'entry_time': entry_time,
|
||||||
|
'exit_time': date,
|
||||||
|
'fee_usd': exit_fee
|
||||||
|
}
|
||||||
|
coin = 0
|
||||||
|
position = 0
|
||||||
|
entry_price = 0
|
||||||
|
return usd, coin, position, entry_price, trade_log_entry
|
||||||
86
cycles/charts.py
Normal file
86
cycles/charts.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import os
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
class BacktestCharts:
|
||||||
|
def __init__(self, charts_dir="charts"):
|
||||||
|
self.charts_dir = charts_dir
|
||||||
|
os.makedirs(self.charts_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def plot_profit_ratio_vs_stop_loss(self, results, filename="profit_ratio_vs_stop_loss.png"):
|
||||||
|
"""
|
||||||
|
Plots profit ratio vs stop loss percentage for each timeframe.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- results: list of dicts, each with keys: 'timeframe', 'stop_loss_pct', 'profit_ratio'
|
||||||
|
- filename: output filename (will be saved in charts_dir)
|
||||||
|
"""
|
||||||
|
# Organize data by timeframe
|
||||||
|
from collections import defaultdict
|
||||||
|
data = defaultdict(lambda: {"stop_loss_pct": [], "profit_ratio": []})
|
||||||
|
for row in results:
|
||||||
|
tf = row["timeframe"]
|
||||||
|
data[tf]["stop_loss_pct"].append(row["stop_loss_pct"])
|
||||||
|
data[tf]["profit_ratio"].append(row["profit_ratio"])
|
||||||
|
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
for tf, vals in data.items():
|
||||||
|
# Sort by stop_loss_pct for smooth lines
|
||||||
|
sorted_pairs = sorted(zip(vals["stop_loss_pct"], vals["profit_ratio"]))
|
||||||
|
stop_loss, profit_ratio = zip(*sorted_pairs)
|
||||||
|
plt.plot(
|
||||||
|
[s * 100 for s in stop_loss], # Convert to percent
|
||||||
|
profit_ratio,
|
||||||
|
marker="o",
|
||||||
|
label=tf
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.xlabel("Stop Loss (%)")
|
||||||
|
plt.ylabel("Profit Ratio")
|
||||||
|
plt.title("Profit Ratio vs Stop Loss (%) per Timeframe")
|
||||||
|
plt.legend(title="Timeframe")
|
||||||
|
plt.grid(True, linestyle="--", alpha=0.5)
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
output_path = os.path.join(self.charts_dir, filename)
|
||||||
|
plt.savefig(output_path)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def plot_average_trade_vs_stop_loss(self, results, filename="average_trade_vs_stop_loss.png"):
|
||||||
|
"""
|
||||||
|
Plots average trade vs stop loss percentage for each timeframe.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- results: list of dicts, each with keys: 'timeframe', 'stop_loss_pct', 'average_trade'
|
||||||
|
- filename: output filename (will be saved in charts_dir)
|
||||||
|
"""
|
||||||
|
from collections import defaultdict
|
||||||
|
data = defaultdict(lambda: {"stop_loss_pct": [], "average_trade": []})
|
||||||
|
for row in results:
|
||||||
|
tf = row["timeframe"]
|
||||||
|
if "average_trade" not in row:
|
||||||
|
continue # Skip rows without average_trade
|
||||||
|
data[tf]["stop_loss_pct"].append(row["stop_loss_pct"])
|
||||||
|
data[tf]["average_trade"].append(row["average_trade"])
|
||||||
|
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
for tf, vals in data.items():
|
||||||
|
# Sort by stop_loss_pct for smooth lines
|
||||||
|
sorted_pairs = sorted(zip(vals["stop_loss_pct"], vals["average_trade"]))
|
||||||
|
stop_loss, average_trade = zip(*sorted_pairs)
|
||||||
|
plt.plot(
|
||||||
|
[s * 100 for s in stop_loss], # Convert to percent
|
||||||
|
average_trade,
|
||||||
|
marker="o",
|
||||||
|
label=tf
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.xlabel("Stop Loss (%)")
|
||||||
|
plt.ylabel("Average Trade")
|
||||||
|
plt.title("Average Trade vs Stop Loss (%) per Timeframe")
|
||||||
|
plt.legend(title="Timeframe")
|
||||||
|
plt.grid(True, linestyle="--", alpha=0.5)
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
output_path = os.path.join(self.charts_dir, filename)
|
||||||
|
plt.savefig(output_path)
|
||||||
|
plt.close()
|
||||||
7
cycles/market_fees.py
Normal file
7
cycles/market_fees.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
class MarketFees:
|
||||||
|
@staticmethod
|
||||||
|
def calculate_okx_taker_maker_fee(amount, is_maker=True):
|
||||||
|
fee_rate = 0.0008 if is_maker else 0.0010
|
||||||
|
return amount * fee_rate
|
||||||
215
cycles/supertrend.py
Normal file
215
cycles/supertrend.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
@lru_cache(maxsize=32)
|
||||||
|
def cached_supertrend_calculation(period, multiplier, data_tuple):
|
||||||
|
high = np.array(data_tuple[0])
|
||||||
|
low = np.array(data_tuple[1])
|
||||||
|
close = np.array(data_tuple[2])
|
||||||
|
tr = np.zeros_like(close)
|
||||||
|
tr[0] = high[0] - low[0]
|
||||||
|
hc_range = np.abs(high[1:] - close[:-1])
|
||||||
|
lc_range = np.abs(low[1:] - close[:-1])
|
||||||
|
hl_range = high[1:] - low[1:]
|
||||||
|
tr[1:] = np.maximum.reduce([hl_range, hc_range, lc_range])
|
||||||
|
atr = np.zeros_like(tr)
|
||||||
|
atr[0] = tr[0]
|
||||||
|
multiplier_ema = 2.0 / (period + 1)
|
||||||
|
for i in range(1, len(tr)):
|
||||||
|
atr[i] = (tr[i] * multiplier_ema) + (atr[i-1] * (1 - multiplier_ema))
|
||||||
|
upper_band = np.zeros_like(close)
|
||||||
|
lower_band = np.zeros_like(close)
|
||||||
|
for i in range(len(close)):
|
||||||
|
hl_avg = (high[i] + low[i]) / 2
|
||||||
|
upper_band[i] = hl_avg + (multiplier * atr[i])
|
||||||
|
lower_band[i] = hl_avg - (multiplier * atr[i])
|
||||||
|
final_upper = np.zeros_like(close)
|
||||||
|
final_lower = np.zeros_like(close)
|
||||||
|
supertrend = np.zeros_like(close)
|
||||||
|
trend = np.zeros_like(close)
|
||||||
|
final_upper[0] = upper_band[0]
|
||||||
|
final_lower[0] = lower_band[0]
|
||||||
|
if close[0] <= upper_band[0]:
|
||||||
|
supertrend[0] = upper_band[0]
|
||||||
|
trend[0] = -1
|
||||||
|
else:
|
||||||
|
supertrend[0] = lower_band[0]
|
||||||
|
trend[0] = 1
|
||||||
|
for i in range(1, len(close)):
|
||||||
|
if (upper_band[i] < final_upper[i-1]) or (close[i-1] > final_upper[i-1]):
|
||||||
|
final_upper[i] = upper_band[i]
|
||||||
|
else:
|
||||||
|
final_upper[i] = final_upper[i-1]
|
||||||
|
if (lower_band[i] > final_lower[i-1]) or (close[i-1] < final_lower[i-1]):
|
||||||
|
final_lower[i] = lower_band[i]
|
||||||
|
else:
|
||||||
|
final_lower[i] = final_lower[i-1]
|
||||||
|
if supertrend[i-1] == final_upper[i-1] and close[i] <= final_upper[i]:
|
||||||
|
supertrend[i] = final_upper[i]
|
||||||
|
trend[i] = -1
|
||||||
|
elif supertrend[i-1] == final_upper[i-1] and close[i] > final_upper[i]:
|
||||||
|
supertrend[i] = final_lower[i]
|
||||||
|
trend[i] = 1
|
||||||
|
elif supertrend[i-1] == final_lower[i-1] and close[i] >= final_lower[i]:
|
||||||
|
supertrend[i] = final_lower[i]
|
||||||
|
trend[i] = 1
|
||||||
|
elif supertrend[i-1] == final_lower[i-1] and close[i] < final_lower[i]:
|
||||||
|
supertrend[i] = final_upper[i]
|
||||||
|
trend[i] = -1
|
||||||
|
return {
|
||||||
|
'supertrend': supertrend,
|
||||||
|
'trend': trend,
|
||||||
|
'upper_band': final_upper,
|
||||||
|
'lower_band': final_lower
|
||||||
|
}
|
||||||
|
|
||||||
|
def calculate_supertrend_external(data, period, multiplier, close_column='close'):
|
||||||
|
"""
|
||||||
|
External function to calculate SuperTrend with configurable close column
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- data: DataFrame with OHLC data
|
||||||
|
- period: int, period for ATR calculation
|
||||||
|
- multiplier: float, multiplier for ATR
|
||||||
|
- close_column: str, name of the column to use as close price (default: 'close')
|
||||||
|
"""
|
||||||
|
high_tuple = tuple(data['high'])
|
||||||
|
low_tuple = tuple(data['low'])
|
||||||
|
close_tuple = tuple(data[close_column])
|
||||||
|
return cached_supertrend_calculation(period, multiplier, (high_tuple, low_tuple, close_tuple))
|
||||||
|
|
||||||
|
class Supertrends:
|
||||||
|
def __init__(self, data, close_column='close', verbose=False, display=False):
|
||||||
|
"""
|
||||||
|
Initialize Supertrends calculator
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- data: pandas DataFrame with OHLC data or list of prices
|
||||||
|
- close_column: str, name of the column to use as close price (default: 'close')
|
||||||
|
- verbose: bool, enable verbose logging
|
||||||
|
- display: bool, display mode (currently unused)
|
||||||
|
"""
|
||||||
|
self.close_column = close_column
|
||||||
|
self.data = data
|
||||||
|
self.verbose = verbose
|
||||||
|
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
self.logger = logging.getLogger('TrendDetectorSimple')
|
||||||
|
|
||||||
|
if not isinstance(self.data, pd.DataFrame):
|
||||||
|
if isinstance(self.data, list):
|
||||||
|
self.data = pd.DataFrame({self.close_column: self.data})
|
||||||
|
else:
|
||||||
|
raise ValueError("Data must be a pandas DataFrame or a list")
|
||||||
|
|
||||||
|
# Validate that required columns exist
|
||||||
|
required_columns = ['high', 'low', self.close_column]
|
||||||
|
missing_columns = [col for col in required_columns if col not in self.data.columns]
|
||||||
|
if missing_columns:
|
||||||
|
raise ValueError(f"Missing required columns: {missing_columns}")
|
||||||
|
|
||||||
|
def calculate_tr(self):
|
||||||
|
"""Calculate True Range using the configured close column"""
|
||||||
|
df = self.data.copy()
|
||||||
|
high = df['high'].values
|
||||||
|
low = df['low'].values
|
||||||
|
close = df[self.close_column].values
|
||||||
|
tr = np.zeros_like(close)
|
||||||
|
tr[0] = high[0] - low[0]
|
||||||
|
for i in range(1, len(close)):
|
||||||
|
hl_range = high[i] - low[i]
|
||||||
|
hc_range = abs(high[i] - close[i-1])
|
||||||
|
lc_range = abs(low[i] - close[i-1])
|
||||||
|
tr[i] = max(hl_range, hc_range, lc_range)
|
||||||
|
return tr
|
||||||
|
|
||||||
|
def calculate_atr(self, period=14):
|
||||||
|
"""Calculate Average True Range"""
|
||||||
|
tr = self.calculate_tr()
|
||||||
|
atr = np.zeros_like(tr)
|
||||||
|
atr[0] = tr[0]
|
||||||
|
multiplier = 2.0 / (period + 1)
|
||||||
|
for i in range(1, len(tr)):
|
||||||
|
atr[i] = (tr[i] * multiplier) + (atr[i-1] * (1 - multiplier))
|
||||||
|
return atr
|
||||||
|
|
||||||
|
def calculate_supertrend(self, period=10, multiplier=3.0):
|
||||||
|
"""
|
||||||
|
Calculate SuperTrend indicator for the price data using the configured close column.
|
||||||
|
SuperTrend is a trend-following indicator that uses ATR to determine the trend direction.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- period: int, the period for the ATR calculation (default: 10)
|
||||||
|
- multiplier: float, the multiplier for the ATR (default: 3.0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Dictionary containing SuperTrend values, trend direction, and upper/lower bands
|
||||||
|
"""
|
||||||
|
df = self.data.copy()
|
||||||
|
high = df['high'].values
|
||||||
|
low = df['low'].values
|
||||||
|
close = df[self.close_column].values
|
||||||
|
atr = self.calculate_atr(period)
|
||||||
|
upper_band = np.zeros_like(close)
|
||||||
|
lower_band = np.zeros_like(close)
|
||||||
|
for i in range(len(close)):
|
||||||
|
hl_avg = (high[i] + low[i]) / 2
|
||||||
|
upper_band[i] = hl_avg + (multiplier * atr[i])
|
||||||
|
lower_band[i] = hl_avg - (multiplier * atr[i])
|
||||||
|
final_upper = np.zeros_like(close)
|
||||||
|
final_lower = np.zeros_like(close)
|
||||||
|
supertrend = np.zeros_like(close)
|
||||||
|
trend = np.zeros_like(close)
|
||||||
|
final_upper[0] = upper_band[0]
|
||||||
|
final_lower[0] = lower_band[0]
|
||||||
|
if close[0] <= upper_band[0]:
|
||||||
|
supertrend[0] = upper_band[0]
|
||||||
|
trend[0] = -1
|
||||||
|
else:
|
||||||
|
supertrend[0] = lower_band[0]
|
||||||
|
trend[0] = 1
|
||||||
|
for i in range(1, len(close)):
|
||||||
|
if (upper_band[i] < final_upper[i-1]) or (close[i-1] > final_upper[i-1]):
|
||||||
|
final_upper[i] = upper_band[i]
|
||||||
|
else:
|
||||||
|
final_upper[i] = final_upper[i-1]
|
||||||
|
if (lower_band[i] > final_lower[i-1]) or (close[i-1] < final_lower[i-1]):
|
||||||
|
final_lower[i] = lower_band[i]
|
||||||
|
else:
|
||||||
|
final_lower[i] = final_lower[i-1]
|
||||||
|
if supertrend[i-1] == final_upper[i-1] and close[i] <= final_upper[i]:
|
||||||
|
supertrend[i] = final_upper[i]
|
||||||
|
trend[i] = -1
|
||||||
|
elif supertrend[i-1] == final_upper[i-1] and close[i] > final_upper[i]:
|
||||||
|
supertrend[i] = final_lower[i]
|
||||||
|
trend[i] = 1
|
||||||
|
elif supertrend[i-1] == final_lower[i-1] and close[i] >= final_lower[i]:
|
||||||
|
supertrend[i] = final_lower[i]
|
||||||
|
trend[i] = 1
|
||||||
|
elif supertrend[i-1] == final_lower[i-1] and close[i] < final_lower[i]:
|
||||||
|
supertrend[i] = final_upper[i]
|
||||||
|
trend[i] = -1
|
||||||
|
supertrend_results = {
|
||||||
|
'supertrend': supertrend,
|
||||||
|
'trend': trend,
|
||||||
|
'upper_band': final_upper,
|
||||||
|
'lower_band': final_lower
|
||||||
|
}
|
||||||
|
return supertrend_results
|
||||||
|
|
||||||
|
def calculate_supertrend_indicators(self):
|
||||||
|
supertrend_params = [
|
||||||
|
{"period": 12, "multiplier": 3.0},
|
||||||
|
{"period": 10, "multiplier": 1.0},
|
||||||
|
{"period": 11, "multiplier": 2.0}
|
||||||
|
]
|
||||||
|
results = []
|
||||||
|
for p in supertrend_params:
|
||||||
|
result = self.calculate_supertrend(period=p["period"], multiplier=p["multiplier"])
|
||||||
|
results.append({
|
||||||
|
"results": result,
|
||||||
|
"params": p
|
||||||
|
})
|
||||||
|
return results
|
||||||
0
cycles/utils/__init__.py
Normal file
0
cycles/utils/__init__.py
Normal file
152
cycles/utils/data_loader.py
Normal file
152
cycles/utils/data_loader.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import pandas as pd
|
||||||
|
from typing import Union, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .storage_utils import (
|
||||||
|
_parse_timestamp_column,
|
||||||
|
_filter_by_date_range,
|
||||||
|
_normalize_column_names,
|
||||||
|
TimestampParsingError,
|
||||||
|
DataLoadingError
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DataLoader:
|
||||||
|
"""Handles loading and preprocessing of data from various file formats"""
|
||||||
|
|
||||||
|
def __init__(self, data_dir: str, logging_instance: Optional[logging.Logger] = None):
|
||||||
|
"""Initialize data loader
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: Directory containing data files
|
||||||
|
logging_instance: Optional logging instance
|
||||||
|
"""
|
||||||
|
self.data_dir = data_dir
|
||||||
|
self.logging = logging_instance
|
||||||
|
|
||||||
|
def load_data(self, file_path: str, start_date: Union[str, pd.Timestamp],
|
||||||
|
stop_date: Union[str, pd.Timestamp]) -> pd.DataFrame:
|
||||||
|
"""Load data with optimized dtypes and filtering, supporting CSV and JSON input
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: path to the data file
|
||||||
|
start_date: start date (string or datetime-like)
|
||||||
|
stop_date: stop date (string or datetime-like)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pandas DataFrame with timestamp index
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DataLoadingError: If data loading fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Convert string dates to pandas datetime objects for proper comparison
|
||||||
|
start_date = pd.to_datetime(start_date)
|
||||||
|
stop_date = pd.to_datetime(stop_date)
|
||||||
|
|
||||||
|
# Determine file type
|
||||||
|
_, ext = os.path.splitext(file_path)
|
||||||
|
ext = ext.lower()
|
||||||
|
|
||||||
|
if ext == ".json":
|
||||||
|
return self._load_json_data(file_path, start_date, stop_date)
|
||||||
|
else:
|
||||||
|
return self._load_csv_data(file_path, start_date, stop_date)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error loading data from {file_path}: {e}"
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
# Return an empty DataFrame with a DatetimeIndex
|
||||||
|
return pd.DataFrame(index=pd.to_datetime([]))
|
||||||
|
|
||||||
|
def _load_json_data(self, file_path: str, start_date: pd.Timestamp,
|
||||||
|
stop_date: pd.Timestamp) -> pd.DataFrame:
|
||||||
|
"""Load and process JSON data file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to JSON file
|
||||||
|
start_date: Start date for filtering
|
||||||
|
stop_date: Stop date for filtering
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed DataFrame with timestamp index
|
||||||
|
"""
|
||||||
|
with open(os.path.join(self.data_dir, file_path), 'r') as f:
|
||||||
|
raw = json.load(f)
|
||||||
|
|
||||||
|
data = pd.DataFrame(raw["Data"])
|
||||||
|
data = _normalize_column_names(data)
|
||||||
|
|
||||||
|
# Convert timestamp to datetime
|
||||||
|
data["timestamp"] = pd.to_datetime(data["timestamp"], unit="s")
|
||||||
|
|
||||||
|
# Filter by date range
|
||||||
|
data = _filter_by_date_range(data, "timestamp", start_date, stop_date)
|
||||||
|
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.info(f"Data loaded from {file_path} for date range {start_date} to {stop_date}")
|
||||||
|
|
||||||
|
return data.set_index("timestamp")
|
||||||
|
|
||||||
|
def _load_csv_data(self, file_path: str, start_date: pd.Timestamp,
|
||||||
|
stop_date: pd.Timestamp) -> pd.DataFrame:
|
||||||
|
"""Load and process CSV data file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to CSV file
|
||||||
|
start_date: Start date for filtering
|
||||||
|
stop_date: Stop date for filtering
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed DataFrame with timestamp index
|
||||||
|
"""
|
||||||
|
# Define optimized dtypes
|
||||||
|
dtypes = {
|
||||||
|
'Open': 'float32',
|
||||||
|
'High': 'float32',
|
||||||
|
'Low': 'float32',
|
||||||
|
'Close': 'float32',
|
||||||
|
'Volume': 'float32'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Read data with original capitalized column names
|
||||||
|
data = pd.read_csv(os.path.join(self.data_dir, file_path), dtype=dtypes)
|
||||||
|
|
||||||
|
return self._process_csv_timestamps(data, start_date, stop_date, file_path)
|
||||||
|
|
||||||
|
def _process_csv_timestamps(self, data: pd.DataFrame, start_date: pd.Timestamp,
|
||||||
|
stop_date: pd.Timestamp, file_path: str) -> pd.DataFrame:
|
||||||
|
"""Process timestamps in CSV data and filter by date range
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame with CSV data
|
||||||
|
start_date: Start date for filtering
|
||||||
|
stop_date: Stop date for filtering
|
||||||
|
file_path: Original file path for logging
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed DataFrame with timestamp index
|
||||||
|
"""
|
||||||
|
if 'Timestamp' in data.columns:
|
||||||
|
data = _parse_timestamp_column(data, 'Timestamp')
|
||||||
|
data = _filter_by_date_range(data, 'Timestamp', start_date, stop_date)
|
||||||
|
data = _normalize_column_names(data)
|
||||||
|
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.info(f"Data loaded from {file_path} for date range {start_date} to {stop_date}")
|
||||||
|
|
||||||
|
return data.set_index('timestamp')
|
||||||
|
else:
|
||||||
|
# Attempt to use the first column if 'Timestamp' is not present
|
||||||
|
data.rename(columns={data.columns[0]: 'timestamp'}, inplace=True)
|
||||||
|
data = _parse_timestamp_column(data, 'timestamp')
|
||||||
|
data = _filter_by_date_range(data, 'timestamp', start_date, stop_date)
|
||||||
|
data = _normalize_column_names(data)
|
||||||
|
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.info(f"Data loaded from {file_path} (using first column as timestamp) for date range {start_date} to {stop_date}")
|
||||||
|
|
||||||
|
return data.set_index('timestamp')
|
||||||
106
cycles/utils/data_saver.py
Normal file
106
cycles/utils/data_saver.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
from typing import Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .storage_utils import DataSavingError
|
||||||
|
|
||||||
|
|
||||||
|
class DataSaver:
|
||||||
|
"""Handles saving data to various file formats"""
|
||||||
|
|
||||||
|
def __init__(self, data_dir: str, logging_instance: Optional[logging.Logger] = None):
|
||||||
|
"""Initialize data saver
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: Directory for saving data files
|
||||||
|
logging_instance: Optional logging instance
|
||||||
|
"""
|
||||||
|
self.data_dir = data_dir
|
||||||
|
self.logging = logging_instance
|
||||||
|
|
||||||
|
def save_data(self, data: pd.DataFrame, file_path: str) -> None:
|
||||||
|
"""Save processed data to a CSV file.
|
||||||
|
If the DataFrame has a DatetimeIndex, it's converted to float Unix timestamps
|
||||||
|
(seconds since epoch) before saving. The index is saved as a column named 'timestamp'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame to save
|
||||||
|
file_path: path to the data file relative to the data_dir
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DataSavingError: If saving fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data_to_save = data.copy()
|
||||||
|
data_to_save = self._prepare_data_for_saving(data_to_save)
|
||||||
|
|
||||||
|
# Save to CSV, ensuring the 'timestamp' column (if created) is written
|
||||||
|
full_path = os.path.join(self.data_dir, file_path)
|
||||||
|
data_to_save.to_csv(full_path, index=False)
|
||||||
|
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.info(f"Data saved to {full_path} with Unix timestamp column.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to save data to {file_path}: {e}"
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise DataSavingError(error_msg) from e
|
||||||
|
|
||||||
|
def _prepare_data_for_saving(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Prepare DataFrame for saving by handling different index types
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame to prepare
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame ready for saving
|
||||||
|
"""
|
||||||
|
if isinstance(data.index, pd.DatetimeIndex):
|
||||||
|
return self._convert_datetime_index_to_timestamp(data)
|
||||||
|
elif pd.api.types.is_numeric_dtype(data.index.dtype):
|
||||||
|
return self._convert_numeric_index_to_timestamp(data)
|
||||||
|
else:
|
||||||
|
# For other index types, save with the current index
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _convert_datetime_index_to_timestamp(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Convert DatetimeIndex to Unix timestamp column
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame with DatetimeIndex
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with timestamp column
|
||||||
|
"""
|
||||||
|
# Convert DatetimeIndex to Unix timestamp (float seconds since epoch)
|
||||||
|
data['timestamp'] = data.index.astype('int64') / 1e9
|
||||||
|
data.reset_index(drop=True, inplace=True)
|
||||||
|
|
||||||
|
# Ensure 'timestamp' is the first column if other columns exist
|
||||||
|
if 'timestamp' in data.columns and len(data.columns) > 1:
|
||||||
|
cols = ['timestamp'] + [col for col in data.columns if col != 'timestamp']
|
||||||
|
data = data[cols]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _convert_numeric_index_to_timestamp(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Convert numeric index to timestamp column
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame with numeric index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with timestamp column
|
||||||
|
"""
|
||||||
|
# If index is already numeric (e.g. float Unix timestamps from a previous save/load cycle)
|
||||||
|
data['timestamp'] = data.index
|
||||||
|
data.reset_index(drop=True, inplace=True)
|
||||||
|
|
||||||
|
# Ensure 'timestamp' is the first column if other columns exist
|
||||||
|
if 'timestamp' in data.columns and len(data.columns) > 1:
|
||||||
|
cols = ['timestamp'] + [col for col in data.columns if col != 'timestamp']
|
||||||
|
data = data[cols]
|
||||||
|
|
||||||
|
return data
|
||||||
60
cycles/utils/data_utils.py
Normal file
60
cycles/utils/data_utils.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
def aggregate_to_daily(data_df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Aggregates time-series financial data to daily OHLCV format.
|
||||||
|
|
||||||
|
The input DataFrame is expected to have a DatetimeIndex.
|
||||||
|
'open' will be the first 'open' price of the day.
|
||||||
|
'close' will be the last 'close' price of the day.
|
||||||
|
'high' will be the maximum 'high' price of the day.
|
||||||
|
'low' will be the minimum 'low' price of the day.
|
||||||
|
'volume' (if present) will be the sum of volumes for the day.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_df (pd.DataFrame): DataFrame with a DatetimeIndex and columns
|
||||||
|
like 'open', 'high', 'low', 'close', and optionally 'volume'.
|
||||||
|
Column names are expected to be lowercase.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: DataFrame aggregated to daily OHLCV data.
|
||||||
|
The index will be a DatetimeIndex with the time set to noon (12:00:00) for each day.
|
||||||
|
Returns an empty DataFrame if no relevant OHLCV columns are found.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the input DataFrame does not have a DatetimeIndex.
|
||||||
|
"""
|
||||||
|
if not isinstance(data_df.index, pd.DatetimeIndex):
|
||||||
|
raise ValueError("Input DataFrame must have a DatetimeIndex.")
|
||||||
|
|
||||||
|
agg_rules = {}
|
||||||
|
|
||||||
|
# Define aggregation rules based on available columns
|
||||||
|
if 'open' in data_df.columns:
|
||||||
|
agg_rules['open'] = 'first'
|
||||||
|
if 'high' in data_df.columns:
|
||||||
|
agg_rules['high'] = 'max'
|
||||||
|
if 'low' in data_df.columns:
|
||||||
|
agg_rules['low'] = 'min'
|
||||||
|
if 'close' in data_df.columns:
|
||||||
|
agg_rules['close'] = 'last'
|
||||||
|
if 'volume' in data_df.columns:
|
||||||
|
agg_rules['volume'] = 'sum'
|
||||||
|
|
||||||
|
if not agg_rules:
|
||||||
|
# Log a warning or raise an error if no relevant columns are found
|
||||||
|
# For now, returning an empty DataFrame with a message might be suitable for some cases
|
||||||
|
print("Warning: No standard OHLCV columns (open, high, low, close, volume) found for daily aggregation.")
|
||||||
|
return pd.DataFrame(index=pd.to_datetime([])) # Return empty DF with datetime index
|
||||||
|
|
||||||
|
# Resample to daily frequency and apply aggregation rules
|
||||||
|
daily_data = data_df.resample('D').agg(agg_rules)
|
||||||
|
|
||||||
|
# Adjust timestamps to noon if data exists
|
||||||
|
if not daily_data.empty and isinstance(daily_data.index, pd.DatetimeIndex):
|
||||||
|
daily_data.index = daily_data.index + pd.Timedelta(hours=12)
|
||||||
|
|
||||||
|
# Remove rows where all values are NaN (these are days with no trades in the original data)
|
||||||
|
daily_data.dropna(how='all', inplace=True)
|
||||||
|
|
||||||
|
return daily_data
|
||||||
233
cycles/utils/progress_manager.py
Normal file
233
cycles/utils/progress_manager.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Progress Manager for tracking multiple parallel backtest tasks
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import sys
|
||||||
|
from typing import Dict, Optional, Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskProgress:
|
||||||
|
"""Represents progress information for a single task"""
|
||||||
|
task_id: str
|
||||||
|
name: str
|
||||||
|
current: int
|
||||||
|
total: int
|
||||||
|
start_time: float
|
||||||
|
last_update: float
|
||||||
|
|
||||||
|
@property
|
||||||
|
def percentage(self) -> float:
|
||||||
|
"""Calculate completion percentage"""
|
||||||
|
if self.total == 0:
|
||||||
|
return 0.0
|
||||||
|
return (self.current / self.total) * 100
|
||||||
|
|
||||||
|
@property
|
||||||
|
def elapsed_time(self) -> float:
|
||||||
|
"""Calculate elapsed time in seconds"""
|
||||||
|
return time.time() - self.start_time
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eta(self) -> Optional[float]:
|
||||||
|
"""Estimate time to completion in seconds"""
|
||||||
|
if self.current == 0 or self.percentage >= 100:
|
||||||
|
return None
|
||||||
|
|
||||||
|
elapsed = self.elapsed_time
|
||||||
|
rate = self.current / elapsed
|
||||||
|
remaining = self.total - self.current
|
||||||
|
return remaining / rate if rate > 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressManager:
|
||||||
|
"""Manages progress tracking for multiple parallel tasks"""
|
||||||
|
|
||||||
|
def __init__(self, update_interval: float = 1.0, display_width: int = 50):
|
||||||
|
"""
|
||||||
|
Initialize progress manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
update_interval: How often to update display (seconds)
|
||||||
|
display_width: Width of progress bar in characters
|
||||||
|
"""
|
||||||
|
self.tasks: Dict[str, TaskProgress] = {}
|
||||||
|
self.update_interval = update_interval
|
||||||
|
self.display_width = display_width
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.display_thread: Optional[threading.Thread] = None
|
||||||
|
self.running = False
|
||||||
|
self.last_display_height = 0
|
||||||
|
|
||||||
|
def start_task(self, task_id: str, name: str, total: int) -> None:
|
||||||
|
"""
|
||||||
|
Start tracking a new task
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Unique identifier for the task
|
||||||
|
name: Human-readable name for the task
|
||||||
|
total: Total number of steps in the task
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
self.tasks[task_id] = TaskProgress(
|
||||||
|
task_id=task_id,
|
||||||
|
name=name,
|
||||||
|
current=0,
|
||||||
|
total=total,
|
||||||
|
start_time=time.time(),
|
||||||
|
last_update=time.time()
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_progress(self, task_id: str, current: int) -> None:
|
||||||
|
"""
|
||||||
|
Update progress for a specific task
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task identifier
|
||||||
|
current: Current progress value
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
if task_id in self.tasks:
|
||||||
|
self.tasks[task_id].current = current
|
||||||
|
self.tasks[task_id].last_update = time.time()
|
||||||
|
|
||||||
|
def complete_task(self, task_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Mark a task as completed
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task identifier
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
if task_id in self.tasks:
|
||||||
|
task = self.tasks[task_id]
|
||||||
|
task.current = task.total
|
||||||
|
task.last_update = time.time()
|
||||||
|
|
||||||
|
def start_display(self) -> None:
|
||||||
|
"""Start the progress display thread"""
|
||||||
|
if not self.running:
|
||||||
|
self.running = True
|
||||||
|
self.display_thread = threading.Thread(target=self._display_loop, daemon=True)
|
||||||
|
self.display_thread.start()
|
||||||
|
|
||||||
|
def stop_display(self) -> None:
|
||||||
|
"""Stop the progress display thread"""
|
||||||
|
self.running = False
|
||||||
|
if self.display_thread:
|
||||||
|
self.display_thread.join(timeout=1.0)
|
||||||
|
self._clear_display()
|
||||||
|
|
||||||
|
def _display_loop(self) -> None:
|
||||||
|
"""Main loop for updating the progress display"""
|
||||||
|
while self.running:
|
||||||
|
self._update_display()
|
||||||
|
time.sleep(self.update_interval)
|
||||||
|
|
||||||
|
def _update_display(self) -> None:
|
||||||
|
"""Update the console display with current progress"""
|
||||||
|
with self.lock:
|
||||||
|
if not self.tasks:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Clear previous display
|
||||||
|
self._clear_display()
|
||||||
|
|
||||||
|
# Build display lines
|
||||||
|
lines = []
|
||||||
|
for task in sorted(self.tasks.values(), key=lambda t: t.task_id):
|
||||||
|
line = self._format_progress_line(task)
|
||||||
|
lines.append(line)
|
||||||
|
|
||||||
|
# Print all lines
|
||||||
|
for line in lines:
|
||||||
|
print(line, flush=True)
|
||||||
|
|
||||||
|
self.last_display_height = len(lines)
|
||||||
|
|
||||||
|
def _clear_display(self) -> None:
|
||||||
|
"""Clear the previous progress display"""
|
||||||
|
if self.last_display_height > 0:
|
||||||
|
# Move cursor up and clear lines
|
||||||
|
for _ in range(self.last_display_height):
|
||||||
|
sys.stdout.write('\033[F') # Move cursor up one line
|
||||||
|
sys.stdout.write('\033[K') # Clear line
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
def _format_progress_line(self, task: TaskProgress) -> str:
|
||||||
|
"""
|
||||||
|
Format a single progress line for display
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: TaskProgress instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted progress string
|
||||||
|
"""
|
||||||
|
# Progress bar
|
||||||
|
filled_width = int(task.percentage / 100 * self.display_width)
|
||||||
|
bar = '█' * filled_width + '░' * (self.display_width - filled_width)
|
||||||
|
|
||||||
|
# Time information
|
||||||
|
elapsed_str = self._format_time(task.elapsed_time)
|
||||||
|
eta_str = self._format_time(task.eta) if task.eta else "N/A"
|
||||||
|
|
||||||
|
# Format line
|
||||||
|
line = (f"{task.name:<25} │{bar}│ "
|
||||||
|
f"{task.percentage:5.1f}% "
|
||||||
|
f"({task.current:,}/{task.total:,}) "
|
||||||
|
f"⏱ {elapsed_str} ETA: {eta_str}")
|
||||||
|
|
||||||
|
return line
|
||||||
|
|
||||||
|
def _format_time(self, seconds: float) -> str:
|
||||||
|
"""
|
||||||
|
Format time duration for display
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seconds: Time in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted time string
|
||||||
|
"""
|
||||||
|
if seconds < 60:
|
||||||
|
return f"{seconds:.0f}s"
|
||||||
|
elif seconds < 3600:
|
||||||
|
minutes = seconds / 60
|
||||||
|
return f"{minutes:.1f}m"
|
||||||
|
else:
|
||||||
|
hours = seconds / 3600
|
||||||
|
return f"{hours:.1f}h"
|
||||||
|
|
||||||
|
def get_task_progress_callback(self, task_id: str) -> Callable[[int], None]:
|
||||||
|
"""
|
||||||
|
Get a progress callback function for a specific task
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callback function that updates progress for this task
|
||||||
|
"""
|
||||||
|
def callback(current: int) -> None:
|
||||||
|
self.update_progress(task_id, current)
|
||||||
|
|
||||||
|
return callback
|
||||||
|
|
||||||
|
def all_tasks_completed(self) -> bool:
|
||||||
|
"""Check if all tasks are completed"""
|
||||||
|
with self.lock:
|
||||||
|
return all(task.current >= task.total for task in self.tasks.values())
|
||||||
|
|
||||||
|
def get_summary(self) -> str:
|
||||||
|
"""Get a summary of all tasks"""
|
||||||
|
with self.lock:
|
||||||
|
total_tasks = len(self.tasks)
|
||||||
|
completed_tasks = sum(1 for task in self.tasks.values()
|
||||||
|
if task.current >= task.total)
|
||||||
|
|
||||||
|
return f"Tasks: {completed_tasks}/{total_tasks} completed"
|
||||||
179
cycles/utils/result_formatter.py
Normal file
179
cycles/utils/result_formatter.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
import os
|
||||||
|
import csv
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from collections import defaultdict
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .storage_utils import DataSavingError
|
||||||
|
|
||||||
|
|
||||||
|
class ResultFormatter:
|
||||||
|
"""Handles formatting and writing of backtest results to CSV files"""
|
||||||
|
|
||||||
|
def __init__(self, results_dir: str, logging_instance: Optional[logging.Logger] = None):
|
||||||
|
"""Initialize result formatter
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results_dir: Directory for saving result files
|
||||||
|
logging_instance: Optional logging instance
|
||||||
|
"""
|
||||||
|
self.results_dir = results_dir
|
||||||
|
self.logging = logging_instance
|
||||||
|
|
||||||
|
def format_row(self, row: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
"""Format a row for a combined results CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
row: Dictionary containing row data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with formatted values
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"timeframe": row["timeframe"],
|
||||||
|
"stop_loss_pct": f"{row['stop_loss_pct']*100:.2f}%",
|
||||||
|
"n_trades": row["n_trades"],
|
||||||
|
"n_stop_loss": row["n_stop_loss"],
|
||||||
|
"win_rate": f"{row['win_rate']*100:.2f}%",
|
||||||
|
"max_drawdown": f"{row['max_drawdown']*100:.2f}%",
|
||||||
|
"avg_trade": f"{row['avg_trade']*100:.2f}%",
|
||||||
|
"profit_ratio": f"{row['profit_ratio']*100:.2f}%",
|
||||||
|
"final_usd": f"{row['final_usd']:.2f}",
|
||||||
|
"total_fees_usd": f"{row['total_fees_usd']:.2f}",
|
||||||
|
}
|
||||||
|
|
||||||
|
def write_results_chunk(self, filename: str, fieldnames: List[str],
|
||||||
|
rows: List[Dict], write_header: bool = False,
|
||||||
|
initial_usd: Optional[float] = None) -> None:
|
||||||
|
"""Write a chunk of results to a CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: filename to write to
|
||||||
|
fieldnames: list of fieldnames
|
||||||
|
rows: list of rows
|
||||||
|
write_header: whether to write the header
|
||||||
|
initial_usd: initial USD value for header comment
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DataSavingError: If writing fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
mode = 'w' if write_header else 'a'
|
||||||
|
|
||||||
|
with open(filename, mode, newline="") as csvfile:
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
|
if write_header:
|
||||||
|
if initial_usd is not None:
|
||||||
|
csvfile.write(f"# initial_usd: {initial_usd}\n")
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
# Only keep keys that are in fieldnames
|
||||||
|
filtered_row = {k: v for k, v in row.items() if k in fieldnames}
|
||||||
|
writer.writerow(filtered_row)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to write results chunk to {filename}: {e}"
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise DataSavingError(error_msg) from e
|
||||||
|
|
||||||
|
def write_backtest_results(self, filename: str, fieldnames: List[str],
|
||||||
|
rows: List[Dict], metadata_lines: Optional[List[str]] = None) -> str:
|
||||||
|
"""Write combined backtest results to a CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: filename to write to
|
||||||
|
fieldnames: list of fieldnames
|
||||||
|
rows: list of result dictionaries
|
||||||
|
metadata_lines: optional list of strings to write as header comments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Full path to the written file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DataSavingError: If writing fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
fname = os.path.join(self.results_dir, filename)
|
||||||
|
with open(fname, "w", newline="") as csvfile:
|
||||||
|
if metadata_lines:
|
||||||
|
for line in metadata_lines:
|
||||||
|
csvfile.write(f"{line}\n")
|
||||||
|
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter='\t')
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
writer.writerow(self.format_row(row))
|
||||||
|
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.info(f"Combined results written to {fname}")
|
||||||
|
|
||||||
|
return fname
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to write backtest results to {filename}: {e}"
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise DataSavingError(error_msg) from e
|
||||||
|
|
||||||
|
def write_trades(self, all_trade_rows: List[Dict], trades_fieldnames: List[str]) -> None:
|
||||||
|
"""Write trades to separate CSV files grouped by timeframe and stop loss
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_trade_rows: list of trade dictionaries
|
||||||
|
trades_fieldnames: list of trade fieldnames
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DataSavingError: If writing fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
trades_by_combo = self._group_trades_by_combination(all_trade_rows)
|
||||||
|
|
||||||
|
for (tf, sl), trades in trades_by_combo.items():
|
||||||
|
self._write_single_trade_file(tf, sl, trades, trades_fieldnames)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to write trades: {e}"
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise DataSavingError(error_msg) from e
|
||||||
|
|
||||||
|
def _group_trades_by_combination(self, all_trade_rows: List[Dict]) -> Dict:
|
||||||
|
"""Group trades by timeframe and stop loss combination
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_trade_rows: List of trade dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary grouped by (timeframe, stop_loss_pct) tuples
|
||||||
|
"""
|
||||||
|
trades_by_combo = defaultdict(list)
|
||||||
|
for trade in all_trade_rows:
|
||||||
|
tf = trade.get("timeframe")
|
||||||
|
sl = trade.get("stop_loss_pct")
|
||||||
|
trades_by_combo[(tf, sl)].append(trade)
|
||||||
|
return trades_by_combo
|
||||||
|
|
||||||
|
def _write_single_trade_file(self, timeframe: str, stop_loss_pct: float,
|
||||||
|
trades: List[Dict], trades_fieldnames: List[str]) -> None:
|
||||||
|
"""Write trades for a single timeframe/stop-loss combination
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeframe: Timeframe identifier
|
||||||
|
stop_loss_pct: Stop loss percentage
|
||||||
|
trades: List of trades for this combination
|
||||||
|
trades_fieldnames: List of field names for trades
|
||||||
|
"""
|
||||||
|
sl_percent = int(round(stop_loss_pct * 100))
|
||||||
|
trades_filename = os.path.join(self.results_dir, f"trades_{timeframe}_ST{sl_percent}pct.csv")
|
||||||
|
|
||||||
|
with open(trades_filename, "w", newline="") as csvfile:
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=trades_fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
for trade in trades:
|
||||||
|
writer.writerow({k: trade.get(k, "") for k in trades_fieldnames})
|
||||||
|
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.info(f"Trades written to {trades_filename}")
|
||||||
123
cycles/utils/storage.py
Normal file
123
cycles/utils/storage.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
from typing import Optional, Union, Dict, Any, List
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .data_loader import DataLoader
|
||||||
|
from .data_saver import DataSaver
|
||||||
|
from .result_formatter import ResultFormatter
|
||||||
|
from .storage_utils import DataLoadingError, DataSavingError
|
||||||
|
|
||||||
|
RESULTS_DIR = "../results"
|
||||||
|
DATA_DIR = "../data"
|
||||||
|
|
||||||
|
|
||||||
|
class Storage:
|
||||||
|
"""Unified storage interface for data and results operations
|
||||||
|
|
||||||
|
Acts as a coordinator for DataLoader, DataSaver, and ResultFormatter components,
|
||||||
|
maintaining backward compatibility while providing a clean separation of concerns.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, logging=None, results_dir=RESULTS_DIR, data_dir=DATA_DIR):
|
||||||
|
"""Initialize storage with component instances
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logging: Optional logging instance
|
||||||
|
results_dir: Directory for results files
|
||||||
|
data_dir: Directory for data files
|
||||||
|
"""
|
||||||
|
self.results_dir = results_dir
|
||||||
|
self.data_dir = data_dir
|
||||||
|
self.logging = logging
|
||||||
|
|
||||||
|
# Create directories if they don't exist
|
||||||
|
os.makedirs(self.results_dir, exist_ok=True)
|
||||||
|
os.makedirs(self.data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Initialize component instances
|
||||||
|
self.data_loader = DataLoader(data_dir, logging)
|
||||||
|
self.data_saver = DataSaver(data_dir, logging)
|
||||||
|
self.result_formatter = ResultFormatter(results_dir, logging)
|
||||||
|
|
||||||
|
def load_data(self, file_path: str, start_date: Union[str, pd.Timestamp],
|
||||||
|
stop_date: Union[str, pd.Timestamp]) -> pd.DataFrame:
|
||||||
|
"""Load data with optimized dtypes and filtering, supporting CSV and JSON input
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: path to the data file
|
||||||
|
start_date: start date (string or datetime-like)
|
||||||
|
stop_date: stop date (string or datetime-like)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pandas DataFrame with timestamp index
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DataLoadingError: If data loading fails
|
||||||
|
"""
|
||||||
|
return self.data_loader.load_data(file_path, start_date, stop_date)
|
||||||
|
|
||||||
|
def save_data(self, data: pd.DataFrame, file_path: str) -> None:
|
||||||
|
"""Save processed data to a CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame to save
|
||||||
|
file_path: path to the data file relative to the data_dir
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DataSavingError: If saving fails
|
||||||
|
"""
|
||||||
|
self.data_saver.save_data(data, file_path)
|
||||||
|
|
||||||
|
def format_row(self, row: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
"""Format a row for a combined results CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
row: Dictionary containing row data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with formatted values
|
||||||
|
"""
|
||||||
|
return self.result_formatter.format_row(row)
|
||||||
|
|
||||||
|
def write_results_chunk(self, filename: str, fieldnames: List[str],
|
||||||
|
rows: List[Dict], write_header: bool = False,
|
||||||
|
initial_usd: Optional[float] = None) -> None:
|
||||||
|
"""Write a chunk of results to a CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: filename to write to
|
||||||
|
fieldnames: list of fieldnames
|
||||||
|
rows: list of rows
|
||||||
|
write_header: whether to write the header
|
||||||
|
initial_usd: initial USD value for header comment
|
||||||
|
"""
|
||||||
|
self.result_formatter.write_results_chunk(
|
||||||
|
filename, fieldnames, rows, write_header, initial_usd
|
||||||
|
)
|
||||||
|
|
||||||
|
def write_backtest_results(self, filename: str, fieldnames: List[str],
|
||||||
|
rows: List[Dict], metadata_lines: Optional[List[str]] = None) -> str:
|
||||||
|
"""Write combined backtest results to a CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: filename to write to
|
||||||
|
fieldnames: list of fieldnames
|
||||||
|
rows: list of result dictionaries
|
||||||
|
metadata_lines: optional list of strings to write as header comments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Full path to the written file
|
||||||
|
"""
|
||||||
|
return self.result_formatter.write_backtest_results(
|
||||||
|
filename, fieldnames, rows, metadata_lines
|
||||||
|
)
|
||||||
|
|
||||||
|
def write_trades(self, all_trade_rows: List[Dict], trades_fieldnames: List[str]) -> None:
|
||||||
|
"""Write trades to separate CSV files grouped by timeframe and stop loss
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_trade_rows: list of trade dictionaries
|
||||||
|
trades_fieldnames: list of trade fieldnames
|
||||||
|
"""
|
||||||
|
self.result_formatter.write_trades(all_trade_rows, trades_fieldnames)
|
||||||
73
cycles/utils/storage_utils.py
Normal file
73
cycles/utils/storage_utils.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
class TimestampParsingError(Exception):
|
||||||
|
"""Custom exception for timestamp parsing errors"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DataLoadingError(Exception):
|
||||||
|
"""Custom exception for data loading errors"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DataSavingError(Exception):
|
||||||
|
"""Custom exception for data saving errors"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_timestamp_column(data: pd.DataFrame, column_name: str) -> pd.DataFrame:
|
||||||
|
"""Parse timestamp column handling both Unix timestamps and datetime strings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame containing the timestamp column
|
||||||
|
column_name: Name of the timestamp column
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with parsed timestamp column
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimestampParsingError: If timestamp parsing fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
sample_timestamp = str(data[column_name].iloc[0])
|
||||||
|
try:
|
||||||
|
# Check if it's a Unix timestamp (numeric)
|
||||||
|
float(sample_timestamp)
|
||||||
|
# It's a Unix timestamp, convert using unit='s'
|
||||||
|
data[column_name] = pd.to_datetime(data[column_name], unit='s')
|
||||||
|
except ValueError:
|
||||||
|
# It's already in datetime string format, convert without unit
|
||||||
|
data[column_name] = pd.to_datetime(data[column_name])
|
||||||
|
return data
|
||||||
|
except Exception as e:
|
||||||
|
raise TimestampParsingError(f"Failed to parse timestamp column '{column_name}': {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_by_date_range(data: pd.DataFrame, timestamp_col: str,
|
||||||
|
start_date: pd.Timestamp, stop_date: pd.Timestamp) -> pd.DataFrame:
|
||||||
|
"""Filter DataFrame by date range
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame to filter
|
||||||
|
timestamp_col: Name of timestamp column
|
||||||
|
start_date: Start date for filtering
|
||||||
|
stop_date: Stop date for filtering
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered DataFrame
|
||||||
|
"""
|
||||||
|
return data[(data[timestamp_col] >= start_date) & (data[timestamp_col] <= stop_date)]
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_column_names(data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Convert all column names to lowercase
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame to normalize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with lowercase column names
|
||||||
|
"""
|
||||||
|
data.columns = data.columns.str.lower()
|
||||||
|
return data
|
||||||
21
cycles/utils/system.py
Normal file
21
cycles/utils/system.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import os
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
class SystemUtils:
|
||||||
|
|
||||||
|
def __init__(self, logging=None):
|
||||||
|
self.logging = logging
|
||||||
|
|
||||||
|
def get_optimal_workers(self):
|
||||||
|
"""Determine optimal number of worker processes based on system resources"""
|
||||||
|
cpu_count = os.cpu_count() or 4
|
||||||
|
memory_gb = psutil.virtual_memory().total / (1024**3)
|
||||||
|
|
||||||
|
# OPTIMIZATION: More aggressive worker allocation for better performance
|
||||||
|
workers_by_memory = max(1, int(memory_gb / 2)) # 2GB per worker
|
||||||
|
workers_by_cpu = max(1, int(cpu_count * 0.8)) # Use 80% of CPU cores
|
||||||
|
optimal_workers = min(workers_by_cpu, workers_by_memory, 8) # Cap at 8 workers
|
||||||
|
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.info(f"Using {optimal_workers} workers for processing (CPU-based: {workers_by_cpu}, Memory-based: {workers_by_memory})")
|
||||||
|
return optimal_workers
|
||||||
78
docs/analysis.md
Normal file
78
docs/analysis.md
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
# Analysis Module
|
||||||
|
|
||||||
|
This document provides an overview of the `Analysis` module and its components, which are typically used for technical analysis of financial market data.
|
||||||
|
|
||||||
|
## Modules
|
||||||
|
|
||||||
|
The `Analysis` module includes classes for calculating common technical indicators:
|
||||||
|
|
||||||
|
- **Relative Strength Index (RSI)**: Implemented in `cycles/Analysis/rsi.py`.
|
||||||
|
- **Bollinger Bands**: Implemented in `cycles/Analysis/boillinger_band.py`.
|
||||||
|
|
||||||
|
## Class: `RSI`
|
||||||
|
|
||||||
|
Found in `cycles/Analysis/rsi.py`.
|
||||||
|
|
||||||
|
Calculates the Relative Strength Index.
|
||||||
|
### Mathematical Model
|
||||||
|
1. **Average Gain (AvgU)** and **Average Loss (AvgD)** over 14 periods:
|
||||||
|
$$
|
||||||
|
\text{AvgU} = \frac{\sum \text{Upward Price Changes}}{14}, \quad \text{AvgD} = \frac{\sum \text{Downward Price Changes}}{14}
|
||||||
|
$$
|
||||||
|
2. **Relative Strength (RS)**:
|
||||||
|
$$
|
||||||
|
RS = \frac{\text{AvgU}}{\text{AvgD}}
|
||||||
|
$$
|
||||||
|
3. **RSI**:
|
||||||
|
$$
|
||||||
|
RSI = 100 - \frac{100}{1 + RS}
|
||||||
|
$$
|
||||||
|
|
||||||
|
### `__init__(self, period: int = 14)`
|
||||||
|
|
||||||
|
- **Description**: Initializes the RSI calculator.
|
||||||
|
- **Parameters**:
|
||||||
|
- `period` (int, optional): The period for RSI calculation. Defaults to 14. Must be a positive integer.
|
||||||
|
|
||||||
|
### `calculate(self, data_df: pd.DataFrame, price_column: str = 'close') -> pd.DataFrame`
|
||||||
|
|
||||||
|
- **Description**: Calculates the RSI and adds it as an 'RSI' column to the input DataFrame. Handles cases where data length is less than the period by returning the original DataFrame with a warning.
|
||||||
|
- **Parameters**:
|
||||||
|
- `data_df` (pd.DataFrame): DataFrame with historical price data. Must contain the `price_column`.
|
||||||
|
- `price_column` (str, optional): The name of the column containing price data. Defaults to 'close'.
|
||||||
|
- **Returns**: `pd.DataFrame` - The input DataFrame with an added 'RSI' column (containing `np.nan` for initial periods where RSI cannot be calculated). Returns a copy of the original DataFrame if the period is larger than the number of data points.
|
||||||
|
|
||||||
|
## Class: `BollingerBands`
|
||||||
|
|
||||||
|
Found in `cycles/Analysis/boillinger_band.py`.
|
||||||
|
|
||||||
|
## **Bollinger Bands**
|
||||||
|
### Mathematical Model
|
||||||
|
1. **Middle Band**: 20-day Simple Moving Average (SMA)
|
||||||
|
$$
|
||||||
|
\text{Middle Band} = \frac{1}{20} \sum_{i=1}^{20} \text{Close}_{t-i}
|
||||||
|
$$
|
||||||
|
2. **Upper Band**: Middle Band + 2 × 20-day Standard Deviation (σ)
|
||||||
|
$$
|
||||||
|
\text{Upper Band} = \text{Middle Band} + 2 \times \sigma_{20}
|
||||||
|
$$
|
||||||
|
3. **Lower Band**: Middle Band − 2 × 20-day Standard Deviation (σ)
|
||||||
|
$$
|
||||||
|
\text{Lower Band} = \text{Middle Band} - 2 \times \sigma_{20}
|
||||||
|
$$
|
||||||
|
|
||||||
|
|
||||||
|
### `__init__(self, period: int = 20, std_dev_multiplier: float = 2.0)`
|
||||||
|
|
||||||
|
- **Description**: Initializes the BollingerBands calculator.
|
||||||
|
- **Parameters**:
|
||||||
|
- `period` (int, optional): The period for the moving average and standard deviation. Defaults to 20. Must be a positive integer.
|
||||||
|
- `std_dev_multiplier` (float, optional): The number of standard deviations for the upper and lower bands. Defaults to 2.0. Must be positive.
|
||||||
|
|
||||||
|
### `calculate(self, data_df: pd.DataFrame, price_column: str = 'close') -> pd.DataFrame`
|
||||||
|
|
||||||
|
- **Description**: Calculates Bollinger Bands and adds 'SMA' (Simple Moving Average), 'UpperBand', and 'LowerBand' columns to the DataFrame.
|
||||||
|
- **Parameters**:
|
||||||
|
- `data_df` (pd.DataFrame): DataFrame with price data. Must include the `price_column`.
|
||||||
|
- `price_column` (str, optional): The name of the column containing the price data (e.g., 'close'). Defaults to 'close'.
|
||||||
|
- **Returns**: `pd.DataFrame` - The original DataFrame with added columns: 'SMA', 'UpperBand', 'LowerBand'.
|
||||||
207
docs/utils_storage.md
Normal file
207
docs/utils_storage.md
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
# Storage Utilities
|
||||||
|
|
||||||
|
This document describes the refactored storage utilities found in `cycles/utils/` that provide modular, maintainable data and results management.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The storage utilities have been refactored into a modular architecture with clear separation of concerns:
|
||||||
|
|
||||||
|
- **`Storage`** - Main coordinator class providing unified interface (backward compatible)
|
||||||
|
- **`DataLoader`** - Handles loading data from various file formats
|
||||||
|
- **`DataSaver`** - Manages saving data with proper format handling
|
||||||
|
- **`ResultFormatter`** - Formats and writes backtest results to CSV files
|
||||||
|
- **`storage_utils`** - Shared utilities and custom exceptions
|
||||||
|
|
||||||
|
This design improves maintainability, testability, and follows the single responsibility principle.
|
||||||
|
|
||||||
|
## Constants
|
||||||
|
|
||||||
|
- `RESULTS_DIR`: Default directory for storing results (default: "../results")
|
||||||
|
- `DATA_DIR`: Default directory for storing input data (default: "../data")
|
||||||
|
|
||||||
|
## Main Classes
|
||||||
|
|
||||||
|
### `Storage` (Coordinator Class)
|
||||||
|
|
||||||
|
The main interface that coordinates all storage operations while maintaining backward compatibility.
|
||||||
|
|
||||||
|
#### `__init__(self, logging=None, results_dir=RESULTS_DIR, data_dir=DATA_DIR)`
|
||||||
|
|
||||||
|
**Description**: Initializes the Storage coordinator with component instances.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `logging` (optional): A logging instance for outputting information
|
||||||
|
- `results_dir` (str, optional): Path to the directory for storing results
|
||||||
|
- `data_dir` (str, optional): Path to the directory for storing data
|
||||||
|
|
||||||
|
**Creates**: Component instances for DataLoader, DataSaver, and ResultFormatter
|
||||||
|
|
||||||
|
#### `load_data(self, file_path: str, start_date: Union[str, pd.Timestamp], stop_date: Union[str, pd.Timestamp]) -> pd.DataFrame`
|
||||||
|
|
||||||
|
**Description**: Loads data with optimized dtypes and filtering, supporting CSV and JSON input.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `file_path` (str): Path to the data file (relative to `data_dir`)
|
||||||
|
- `start_date` (datetime-like): The start date for filtering data
|
||||||
|
- `stop_date` (datetime-like): The end date for filtering data
|
||||||
|
|
||||||
|
**Returns**: `pandas.DataFrame` with timestamp index
|
||||||
|
|
||||||
|
**Raises**: `DataLoadingError` if loading fails
|
||||||
|
|
||||||
|
#### `save_data(self, data: pd.DataFrame, file_path: str) -> None`
|
||||||
|
|
||||||
|
**Description**: Saves processed data to a CSV file with proper timestamp handling.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `data` (pd.DataFrame): The DataFrame to save
|
||||||
|
- `file_path` (str): Path to the data file (relative to `data_dir`)
|
||||||
|
|
||||||
|
**Raises**: `DataSavingError` if saving fails
|
||||||
|
|
||||||
|
#### `format_row(self, row: Dict[str, Any]) -> Dict[str, str]`
|
||||||
|
|
||||||
|
**Description**: Formats a dictionary row for output to results CSV files.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `row` (dict): The row of data to format
|
||||||
|
|
||||||
|
**Returns**: `dict` with formatted values (percentages, currency, etc.)
|
||||||
|
|
||||||
|
#### `write_results_chunk(self, filename: str, fieldnames: List[str], rows: List[Dict], write_header: bool = False, initial_usd: Optional[float] = None) -> None`
|
||||||
|
|
||||||
|
**Description**: Writes a chunk of results to a CSV file with optional header.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `filename` (str): The name of the file to write to
|
||||||
|
- `fieldnames` (list): CSV header/column names
|
||||||
|
- `rows` (list): List of dictionaries representing rows
|
||||||
|
- `write_header` (bool, optional): Whether to write the header
|
||||||
|
- `initial_usd` (float, optional): Initial USD value for header comment
|
||||||
|
|
||||||
|
#### `write_backtest_results(self, filename: str, fieldnames: List[str], rows: List[Dict], metadata_lines: Optional[List[str]] = None) -> str`
|
||||||
|
|
||||||
|
**Description**: Writes combined backtest results to a CSV file with metadata.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `filename` (str): Name of the file to write to (relative to `results_dir`)
|
||||||
|
- `fieldnames` (list): CSV header/column names
|
||||||
|
- `rows` (list): List of result dictionaries
|
||||||
|
- `metadata_lines` (list, optional): Header comment lines
|
||||||
|
|
||||||
|
**Returns**: Full path to the written file
|
||||||
|
|
||||||
|
#### `write_trades(self, all_trade_rows: List[Dict], trades_fieldnames: List[str]) -> None`
|
||||||
|
|
||||||
|
**Description**: Writes trade data to separate CSV files grouped by timeframe and stop-loss.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `all_trade_rows` (list): List of trade dictionaries
|
||||||
|
- `trades_fieldnames` (list): CSV header for trade files
|
||||||
|
|
||||||
|
**Files Created**: `trades_{timeframe}_ST{sl_percent}pct.csv` in `results_dir`
|
||||||
|
|
||||||
|
### `DataLoader`
|
||||||
|
|
||||||
|
Handles loading and preprocessing of data from various file formats.
|
||||||
|
|
||||||
|
#### Key Features:
|
||||||
|
- Supports CSV and JSON formats
|
||||||
|
- Optimized pandas dtypes for financial data
|
||||||
|
- Intelligent timestamp parsing (Unix timestamps and datetime strings)
|
||||||
|
- Date range filtering
|
||||||
|
- Column name normalization (lowercase)
|
||||||
|
- Comprehensive error handling
|
||||||
|
|
||||||
|
#### Methods:
|
||||||
|
- `load_data()` - Main loading interface
|
||||||
|
- `_load_json_data()` - JSON-specific loading logic
|
||||||
|
- `_load_csv_data()` - CSV-specific loading logic
|
||||||
|
- `_process_csv_timestamps()` - Timestamp parsing for CSV data
|
||||||
|
|
||||||
|
### `DataSaver`
|
||||||
|
|
||||||
|
Manages saving data with proper format handling and index conversion.
|
||||||
|
|
||||||
|
#### Key Features:
|
||||||
|
- Converts DatetimeIndex to Unix timestamps for CSV compatibility
|
||||||
|
- Handles numeric indexes appropriately
|
||||||
|
- Ensures 'timestamp' column is first in output
|
||||||
|
- Comprehensive error handling and logging
|
||||||
|
|
||||||
|
#### Methods:
|
||||||
|
- `save_data()` - Main saving interface
|
||||||
|
- `_prepare_data_for_saving()` - Data preparation logic
|
||||||
|
- `_convert_datetime_index_to_timestamp()` - DatetimeIndex conversion
|
||||||
|
- `_convert_numeric_index_to_timestamp()` - Numeric index conversion
|
||||||
|
|
||||||
|
### `ResultFormatter`
|
||||||
|
|
||||||
|
Handles formatting and writing of backtest results to CSV files.
|
||||||
|
|
||||||
|
#### Key Features:
|
||||||
|
- Consistent formatting for percentages and currency
|
||||||
|
- Grouped trade file writing by timeframe/stop-loss
|
||||||
|
- Metadata header support
|
||||||
|
- Tab-delimited output for results
|
||||||
|
- Error handling for all write operations
|
||||||
|
|
||||||
|
#### Methods:
|
||||||
|
- `format_row()` - Format individual result rows
|
||||||
|
- `write_results_chunk()` - Write result chunks with headers
|
||||||
|
- `write_backtest_results()` - Write combined results with metadata
|
||||||
|
- `write_trades()` - Write grouped trade files
|
||||||
|
|
||||||
|
## Utility Functions and Exceptions
|
||||||
|
|
||||||
|
### Custom Exceptions
|
||||||
|
|
||||||
|
- **`TimestampParsingError`** - Raised when timestamp parsing fails
|
||||||
|
- **`DataLoadingError`** - Raised when data loading operations fail
|
||||||
|
- **`DataSavingError`** - Raised when data saving operations fail
|
||||||
|
|
||||||
|
### Utility Functions
|
||||||
|
|
||||||
|
- **`_parse_timestamp_column()`** - Parse timestamp columns with format detection
|
||||||
|
- **`_filter_by_date_range()`** - Filter DataFrames by date range
|
||||||
|
- **`_normalize_column_names()`** - Convert column names to lowercase
|
||||||
|
|
||||||
|
## Architecture Benefits
|
||||||
|
|
||||||
|
### Separation of Concerns
|
||||||
|
- Each class has a single, well-defined responsibility
|
||||||
|
- Data loading, saving, and result formatting are cleanly separated
|
||||||
|
- Shared utilities are extracted to prevent code duplication
|
||||||
|
|
||||||
|
### Maintainability
|
||||||
|
- All files are under 250 lines (quality gate)
|
||||||
|
- All methods are under 50 lines (quality gate)
|
||||||
|
- Clear interfaces and comprehensive documentation
|
||||||
|
- Type hints for better IDE support and clarity
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
- Custom exceptions for different error types
|
||||||
|
- Consistent error logging patterns
|
||||||
|
- Graceful degradation (empty DataFrames on load failure)
|
||||||
|
|
||||||
|
### Backward Compatibility
|
||||||
|
- Storage class maintains exact same public interface
|
||||||
|
- All existing code continues to work unchanged
|
||||||
|
- Component classes are available for advanced usage
|
||||||
|
|
||||||
|
## Migration Notes
|
||||||
|
|
||||||
|
The refactoring maintains full backward compatibility. Existing code using `Storage` will continue to work unchanged. For new code, consider using the component classes directly for more focused functionality:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Existing pattern (still works)
|
||||||
|
from cycles.utils.storage import Storage
|
||||||
|
storage = Storage(logging=logger)
|
||||||
|
data = storage.load_data('file.csv', start, end)
|
||||||
|
|
||||||
|
# New pattern for focused usage
|
||||||
|
from cycles.utils.data_loader import DataLoader
|
||||||
|
loader = DataLoader(data_dir, logger)
|
||||||
|
data = loader.load_data('file.csv', start, end)
|
||||||
|
```
|
||||||
|
|
||||||
49
docs/utils_system.md
Normal file
49
docs/utils_system.md
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# System Utilities
|
||||||
|
|
||||||
|
This document describes the system utility functions found in `cycles/utils/system.py`.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `system.py` module provides utility functions related to system information and resource management. It currently includes a class `SystemUtils` for determining optimal configurations based on system resources.
|
||||||
|
|
||||||
|
## Classes and Methods
|
||||||
|
|
||||||
|
### `SystemUtils`
|
||||||
|
|
||||||
|
A class to provide system-related utility methods.
|
||||||
|
|
||||||
|
#### `__init__(self, logging=None)`
|
||||||
|
|
||||||
|
- **Description**: Initializes the `SystemUtils` class.
|
||||||
|
- **Parameters**:
|
||||||
|
- `logging` (optional): A logging instance to output information. Defaults to `None`.
|
||||||
|
|
||||||
|
#### `get_optimal_workers(self)`
|
||||||
|
|
||||||
|
- **Description**: Determines the optimal number of worker processes based on available CPU cores and memory.
|
||||||
|
The heuristic aims to use 75% of CPU cores, with a cap based on available memory (assuming each worker might need ~2GB for large datasets). It returns the minimum of the workers calculated by CPU and memory.
|
||||||
|
- **Parameters**: None.
|
||||||
|
- **Returns**: `int` - The recommended number of worker processes.
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
```python
|
||||||
|
from cycles.utils.system import SystemUtils
|
||||||
|
|
||||||
|
# Initialize (optionally with a logger)
|
||||||
|
# import logging
|
||||||
|
# logging.basicConfig(level=logging.INFO)
|
||||||
|
# logger = logging.getLogger(__name__)
|
||||||
|
# sys_utils = SystemUtils(logging=logger)
|
||||||
|
sys_utils = SystemUtils()
|
||||||
|
|
||||||
|
|
||||||
|
optimal_workers = sys_utils.get_optimal_workers()
|
||||||
|
print(f"Optimal number of workers: {optimal_workers}")
|
||||||
|
|
||||||
|
# This value can then be used, for example, when setting up a ThreadPoolExecutor
|
||||||
|
# from concurrent.futures import ThreadPoolExecutor
|
||||||
|
# with ThreadPoolExecutor(max_workers=optimal_workers) as executor:
|
||||||
|
# # ... submit tasks ...
|
||||||
|
# pass
|
||||||
|
```
|
||||||
211
main.py
211
main.py
@@ -1,56 +1,175 @@
|
|||||||
import pandas as pd
|
#!/usr/bin/env python3
|
||||||
from trend_detector_macd import TrendDetectorMACD
|
"""
|
||||||
from trend_detector_simple import TrendDetectorSimple
|
Backtest execution script for cryptocurrency trading strategies
|
||||||
from cycle_detector import CycleDetector
|
Refactored for improved maintainability and error handling
|
||||||
|
"""
|
||||||
|
|
||||||
# Load data from CSV file instead of database
|
import logging
|
||||||
data = pd.read_csv('data/btcusd_1-day_data.csv')
|
import datetime
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
# Convert datetime column to datetime type
|
# Import custom modules
|
||||||
start_date = pd.to_datetime('2025-04-01')
|
from config_manager import ConfigManager
|
||||||
stop_date = pd.to_datetime('2025-05-06')
|
from backtest_runner import BacktestRunner
|
||||||
|
from result_processor import ResultProcessor
|
||||||
daily_data = data[(pd.to_datetime(data['datetime']) >= start_date) &
|
from cycles.utils.storage import Storage
|
||||||
(pd.to_datetime(data['datetime']) < stop_date)]
|
from cycles.utils.system import SystemUtils
|
||||||
print(f"Number of data points: {len(daily_data)}")
|
|
||||||
|
|
||||||
trend_detector = TrendDetectorSimple(daily_data, verbose=True)
|
|
||||||
trends, analysis_results = trend_detector.detect_trends()
|
|
||||||
trend_detector.plot_trends(trends, analysis_results)
|
|
||||||
|
|
||||||
#trend_detector = TrendDetectorMACD(daily_data, True)
|
|
||||||
#trends = trend_detector.detect_trends_MACD_signal()
|
|
||||||
#trend_detector.plot_trends(trends)
|
|
||||||
|
|
||||||
|
|
||||||
# # Cycle detection (new code)
|
def setup_logging() -> logging.Logger:
|
||||||
# print("\n===== CYCLE DETECTION =====")
|
"""Configure and return logging instance"""
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler("backtest.log"),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
# # Daily cycles
|
|
||||||
# daily_detector = CycleDetector(daily_data, timeframe='daily')
|
|
||||||
# daily_avg_cycle = daily_detector.get_average_cycle_length()
|
|
||||||
# daily_window = daily_detector.get_cycle_window()
|
|
||||||
|
|
||||||
# print(f"Daily Timeframe: Average Cycle Length = {daily_avg_cycle:.2f} days")
|
def create_metadata_lines(config: dict, data_df, result_processor: ResultProcessor) -> list:
|
||||||
# if daily_window:
|
"""Create metadata lines for results file"""
|
||||||
# print(f"Daily Cycle Window: {daily_window[0]:.2f} to {daily_window[2]:.2f} days")
|
start_date = config['start_date']
|
||||||
|
stop_date = config['stop_date']
|
||||||
|
initial_usd = config['initial_usd']
|
||||||
|
|
||||||
|
# Get price information
|
||||||
|
start_time, start_price = result_processor.get_price_info(data_df, start_date)
|
||||||
|
stop_time, stop_price = result_processor.get_price_info(data_df, stop_date)
|
||||||
|
|
||||||
|
metadata_lines = [
|
||||||
|
f"Start date\t{start_date}\tPrice\t{start_price or 'N/A'}",
|
||||||
|
f"Stop date\t{stop_date}\tPrice\t{stop_price or 'N/A'}",
|
||||||
|
f"Initial USD\t{initial_usd}"
|
||||||
|
]
|
||||||
|
|
||||||
|
return metadata_lines
|
||||||
|
|
||||||
# weekly_detector = CycleDetector(daily_data, timeframe='weekly')
|
|
||||||
# weekly_avg_cycle = weekly_detector.get_average_cycle_length()
|
|
||||||
# weekly_window = weekly_detector.get_cycle_window()
|
|
||||||
|
|
||||||
# print(f"\nWeekly Timeframe: Average Cycle Length = {weekly_avg_cycle:.2f} weeks")
|
def main():
|
||||||
# if weekly_window:
|
"""Main execution function"""
|
||||||
# print(f"Weekly Cycle Window: {weekly_window[0]:.2f} to {weekly_window[2]:.2f} weeks")
|
logger = setup_logging()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(description="Run backtest with config file.")
|
||||||
|
parser.add_argument("config", type=str, nargs="?", help="Path to config JSON file.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Initialize configuration manager
|
||||||
|
config_manager = ConfigManager(logging_instance=logger)
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
logger.info("Loading configuration...")
|
||||||
|
config = config_manager.load_config(args.config)
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
logger.info("Initializing components...")
|
||||||
|
storage = Storage(
|
||||||
|
data_dir=config['data_dir'],
|
||||||
|
results_dir=config['results_dir'],
|
||||||
|
logging=logger
|
||||||
|
)
|
||||||
|
system_utils = SystemUtils(logging=logger)
|
||||||
|
result_processor = ResultProcessor(storage, logging_instance=logger)
|
||||||
|
|
||||||
|
# OPTIMIZATION: Disable progress for parallel execution to improve performance
|
||||||
|
show_progress = config.get('show_progress', True)
|
||||||
|
debug_mode = config.get('debug', 0) == 1
|
||||||
|
|
||||||
|
# Only show progress in debug (sequential) mode
|
||||||
|
if not debug_mode:
|
||||||
|
show_progress = False
|
||||||
|
logger.info("Progress tracking disabled for parallel execution (performance optimization)")
|
||||||
|
|
||||||
|
runner = BacktestRunner(
|
||||||
|
storage,
|
||||||
|
system_utils,
|
||||||
|
result_processor,
|
||||||
|
logging_instance=logger,
|
||||||
|
show_progress=show_progress
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate inputs
|
||||||
|
logger.info("Validating inputs...")
|
||||||
|
runner.validate_inputs(
|
||||||
|
config['timeframes'],
|
||||||
|
config['stop_loss_pcts'],
|
||||||
|
config['initial_usd']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
logger.info("Loading market data...")
|
||||||
|
# data_filename = 'btcusd_1-min_data.csv'
|
||||||
|
data_filename = 'btcusd_1-min_data_with_price_predictions.csv'
|
||||||
|
data_1min = runner.load_data(
|
||||||
|
data_filename,
|
||||||
|
config['start_date'],
|
||||||
|
config['stop_date']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run backtests
|
||||||
|
logger.info("Starting backtest execution...")
|
||||||
|
|
||||||
|
all_results, all_trades = runner.run_backtests(
|
||||||
|
data_1min,
|
||||||
|
config['timeframes'],
|
||||||
|
config['stop_loss_pcts'],
|
||||||
|
config['initial_usd'],
|
||||||
|
debug=debug_mode
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process and save results
|
||||||
|
logger.info("Processing and saving results...")
|
||||||
|
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M")
|
||||||
|
|
||||||
|
# OPTIMIZATION: Save trade files in batch after parallel execution
|
||||||
|
if all_trades and not debug_mode:
|
||||||
|
logger.info("Saving trade files in batch...")
|
||||||
|
result_processor.save_all_trade_files(all_trades)
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
metadata_lines = create_metadata_lines(config, data_1min, result_processor)
|
||||||
|
|
||||||
|
# Save aggregated results
|
||||||
|
result_file = result_processor.save_backtest_results(
|
||||||
|
all_results,
|
||||||
|
metadata_lines,
|
||||||
|
timestamp
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Backtest completed successfully. Results saved to {result_file}")
|
||||||
|
logger.info(f"Processed {len(all_results)} result combinations")
|
||||||
|
logger.info(f"Generated {len(all_trades)} total trades")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.warning("Backtest interrupted by user")
|
||||||
|
sys.exit(130) # Standard exit code for Ctrl+C
|
||||||
|
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.error(f"File not found: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Invalid configuration or data: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Runtime error during backtest: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
# # Detect patterns
|
|
||||||
# two_drives = daily_detector.detect_two_drives_pattern()
|
|
||||||
# v_shapes = daily_detector.detect_v_shaped_lows()
|
|
||||||
|
|
||||||
# print(f"\nDetected {len(two_drives)} 'Two Drives' patterns in daily data")
|
if __name__ == "__main__":
|
||||||
# print(f"Detected {len(v_shapes)} 'V-Shaped' lows in daily data")
|
main()
|
||||||
|
|
||||||
# # Plot cycles with detected patterns
|
|
||||||
# print("\nPlotting cycles and patterns...")
|
|
||||||
# daily_detector.plot_cycles(pattern_detection=['two_drives', 'v_shape'], title_suffix='(with Two Drives Pattern)')
|
|
||||||
# weekly_detector.plot_cycles(title_suffix='(Weekly View)')
|
|
||||||
19
pyproject.toml
Normal file
19
pyproject.toml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
[project]
|
||||||
|
name = "cycles"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"dash>=3.0.4",
|
||||||
|
"gspread>=6.2.1",
|
||||||
|
"matplotlib>=3.10.3",
|
||||||
|
"numba>=0.61.2",
|
||||||
|
"pandas>=2.2.3",
|
||||||
|
"psutil>=7.0.0",
|
||||||
|
"scikit-learn>=1.6.1",
|
||||||
|
"scipy>=1.15.3",
|
||||||
|
"seaborn>=0.13.2",
|
||||||
|
"ta>=0.11.0",
|
||||||
|
"xgboost>=3.0.2",
|
||||||
|
]
|
||||||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
446
result_processor.py
Normal file
446
result_processor.py
Normal file
@@ -0,0 +1,446 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import csv
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from cycles.utils.storage import Storage
|
||||||
|
|
||||||
|
|
||||||
|
class ResultProcessor:
|
||||||
|
"""Handles processing, aggregation, and saving of backtest results"""
|
||||||
|
|
||||||
|
def __init__(self, storage: Storage, logging_instance: Optional[logging.Logger] = None):
|
||||||
|
"""
|
||||||
|
Initialize result processor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage: Storage instance for file operations
|
||||||
|
logging_instance: Optional logging instance
|
||||||
|
"""
|
||||||
|
self.storage = storage
|
||||||
|
self.logging = logging_instance
|
||||||
|
|
||||||
|
def process_timeframe_results(
|
||||||
|
self,
|
||||||
|
min1_df: pd.DataFrame,
|
||||||
|
df: pd.DataFrame,
|
||||||
|
stop_loss_pcts: List[float],
|
||||||
|
timeframe_name: str,
|
||||||
|
initial_usd: float,
|
||||||
|
progress_callback=None
|
||||||
|
) -> Tuple[List[Dict], List[Dict]]:
|
||||||
|
"""
|
||||||
|
Process results for a single timeframe with multiple stop loss values
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min1_df: 1-minute data DataFrame
|
||||||
|
df: Resampled timeframe DataFrame
|
||||||
|
stop_loss_pcts: List of stop loss percentages to test
|
||||||
|
timeframe_name: Name of the timeframe (e.g., '1D', '6h')
|
||||||
|
initial_usd: Initial USD amount
|
||||||
|
progress_callback: Optional progress callback function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (results_rows, trade_rows)
|
||||||
|
"""
|
||||||
|
from cycles.backtest import Backtest
|
||||||
|
|
||||||
|
df = df.copy().reset_index(drop=True)
|
||||||
|
results_rows = []
|
||||||
|
trade_rows = []
|
||||||
|
|
||||||
|
for stop_loss_pct in stop_loss_pcts:
|
||||||
|
try:
|
||||||
|
results = Backtest.run(
|
||||||
|
min1_df,
|
||||||
|
df,
|
||||||
|
initial_usd=initial_usd,
|
||||||
|
stop_loss_pct=stop_loss_pct,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
verbose=False # Default to False for production runs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
metrics = self._calculate_metrics(results, initial_usd, stop_loss_pct, timeframe_name)
|
||||||
|
results_rows.append(metrics)
|
||||||
|
|
||||||
|
# Process trades
|
||||||
|
if 'trades' not in results:
|
||||||
|
raise ValueError(f"Backtest results missing 'trades' field for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||||
|
trades = self._process_trades(results['trades'], timeframe_name, stop_loss_pct)
|
||||||
|
trade_rows.extend(trades)
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Timeframe: {timeframe_name}, Stop Loss: {stop_loss_pct}, Trades: {results['n_trades']}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error processing {timeframe_name} with stop loss {stop_loss_pct}: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
return results_rows, trade_rows
|
||||||
|
|
||||||
|
def _calculate_metrics(
|
||||||
|
self,
|
||||||
|
results: Dict[str, Any],
|
||||||
|
initial_usd: float,
|
||||||
|
stop_loss_pct: float,
|
||||||
|
timeframe_name: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Calculate performance metrics from backtest results"""
|
||||||
|
if 'trades' not in results:
|
||||||
|
raise ValueError(f"Backtest results missing 'trades' field for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||||
|
trades = results['trades']
|
||||||
|
n_trades = results["n_trades"]
|
||||||
|
|
||||||
|
# Validate that all required fields are present
|
||||||
|
required_fields = ['final_usd', 'max_drawdown', 'total_fees_usd', 'n_trades', 'n_stop_loss', 'win_rate', 'avg_trade']
|
||||||
|
missing_fields = [field for field in required_fields if field not in results]
|
||||||
|
if missing_fields:
|
||||||
|
raise ValueError(f"Backtest results missing required fields: {missing_fields}")
|
||||||
|
|
||||||
|
# Calculate win metrics - validate trade fields
|
||||||
|
winning_trades = []
|
||||||
|
for t in trades:
|
||||||
|
if 'exit' not in t:
|
||||||
|
raise ValueError(f"Trade missing 'exit' field: {t}")
|
||||||
|
if 'entry' not in t:
|
||||||
|
raise ValueError(f"Trade missing 'entry' field: {t}")
|
||||||
|
if t['exit'] is not None and t['exit'] > t['entry']:
|
||||||
|
winning_trades.append(t)
|
||||||
|
n_winning_trades = len(winning_trades)
|
||||||
|
win_rate = n_winning_trades / n_trades if n_trades > 0 else 0
|
||||||
|
|
||||||
|
# Calculate profit metrics
|
||||||
|
total_profit = sum(trade['profit_pct'] for trade in trades if trade['profit_pct'] > 0)
|
||||||
|
total_loss = abs(sum(trade['profit_pct'] for trade in trades if trade['profit_pct'] < 0))
|
||||||
|
avg_trade = sum(trade['profit_pct'] for trade in trades) / n_trades if n_trades > 0 else 0
|
||||||
|
profit_ratio = total_profit / total_loss if total_loss > 0 else (float('inf') if total_profit > 0 else 0)
|
||||||
|
|
||||||
|
# Get values directly from backtest results (no defaults)
|
||||||
|
max_drawdown = results['max_drawdown']
|
||||||
|
final_usd = results['final_usd']
|
||||||
|
total_fees_usd = results['total_fees_usd']
|
||||||
|
n_stop_loss = results['n_stop_loss'] # Get stop loss count directly from backtest
|
||||||
|
|
||||||
|
# Validate no None values
|
||||||
|
if max_drawdown is None:
|
||||||
|
raise ValueError(f"max_drawdown is None for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||||
|
if final_usd is None:
|
||||||
|
raise ValueError(f"final_usd is None for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||||
|
if total_fees_usd is None:
|
||||||
|
raise ValueError(f"total_fees_usd is None for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||||
|
if n_stop_loss is None:
|
||||||
|
raise ValueError(f"n_stop_loss is None for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"timeframe": timeframe_name,
|
||||||
|
"stop_loss_pct": stop_loss_pct,
|
||||||
|
"n_trades": n_trades,
|
||||||
|
"n_stop_loss": n_stop_loss,
|
||||||
|
"win_rate": win_rate,
|
||||||
|
"max_drawdown": max_drawdown,
|
||||||
|
"avg_trade": avg_trade,
|
||||||
|
"total_profit": total_profit,
|
||||||
|
"total_loss": total_loss,
|
||||||
|
"profit_ratio": profit_ratio,
|
||||||
|
"initial_usd": initial_usd,
|
||||||
|
"final_usd": final_usd,
|
||||||
|
"total_fees_usd": total_fees_usd,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _calculate_max_drawdown(self, trades: List[Dict]) -> float:
|
||||||
|
"""Calculate maximum drawdown from trade sequence"""
|
||||||
|
cumulative_profit = 0
|
||||||
|
max_drawdown = 0
|
||||||
|
peak = 0
|
||||||
|
|
||||||
|
for trade in trades:
|
||||||
|
cumulative_profit += trade['profit_pct']
|
||||||
|
if cumulative_profit > peak:
|
||||||
|
peak = cumulative_profit
|
||||||
|
drawdown = peak - cumulative_profit
|
||||||
|
if drawdown > max_drawdown:
|
||||||
|
max_drawdown = drawdown
|
||||||
|
|
||||||
|
return max_drawdown
|
||||||
|
|
||||||
|
def _process_trades(
|
||||||
|
self,
|
||||||
|
trades: List[Dict],
|
||||||
|
timeframe_name: str,
|
||||||
|
stop_loss_pct: float
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""Process individual trades with metadata"""
|
||||||
|
processed_trades = []
|
||||||
|
|
||||||
|
for trade in trades:
|
||||||
|
# Validate all required trade fields
|
||||||
|
required_fields = ["entry_time", "exit_time", "entry", "exit", "profit_pct", "type", "fee_usd"]
|
||||||
|
missing_fields = [field for field in required_fields if field not in trade]
|
||||||
|
if missing_fields:
|
||||||
|
raise ValueError(f"Trade missing required fields: {missing_fields} in trade: {trade}")
|
||||||
|
|
||||||
|
processed_trade = {
|
||||||
|
"timeframe": timeframe_name,
|
||||||
|
"stop_loss_pct": stop_loss_pct,
|
||||||
|
"entry_time": trade["entry_time"],
|
||||||
|
"exit_time": trade["exit_time"],
|
||||||
|
"entry_price": trade["entry"],
|
||||||
|
"exit_price": trade["exit"],
|
||||||
|
"profit_pct": trade["profit_pct"],
|
||||||
|
"type": trade["type"],
|
||||||
|
"fee_usd": trade["fee_usd"],
|
||||||
|
}
|
||||||
|
processed_trades.append(processed_trade)
|
||||||
|
|
||||||
|
return processed_trades
|
||||||
|
|
||||||
|
def _debug_output(self, results: Dict[str, Any]) -> None:
|
||||||
|
"""Output debug information for backtest results"""
|
||||||
|
if 'trades' not in results:
|
||||||
|
raise ValueError("Backtest results missing 'trades' field for debug output")
|
||||||
|
trades = results['trades']
|
||||||
|
|
||||||
|
# Print stop loss trades
|
||||||
|
stop_loss_trades = []
|
||||||
|
for t in trades:
|
||||||
|
if 'type' not in t:
|
||||||
|
raise ValueError(f"Trade missing 'type' field: {t}")
|
||||||
|
if t['type'] == 'STOP':
|
||||||
|
stop_loss_trades.append(t)
|
||||||
|
|
||||||
|
if stop_loss_trades:
|
||||||
|
print("Stop Loss Trades:")
|
||||||
|
for trade in stop_loss_trades:
|
||||||
|
print(trade)
|
||||||
|
|
||||||
|
# Print large loss trades
|
||||||
|
large_loss_trades = []
|
||||||
|
for t in trades:
|
||||||
|
if 'profit_pct' not in t:
|
||||||
|
raise ValueError(f"Trade missing 'profit_pct' field: {t}")
|
||||||
|
if t['profit_pct'] < -0.09:
|
||||||
|
large_loss_trades.append(t)
|
||||||
|
|
||||||
|
if large_loss_trades:
|
||||||
|
print("Large Loss Trades:")
|
||||||
|
for trade in large_loss_trades:
|
||||||
|
print("Large loss trade:", trade)
|
||||||
|
|
||||||
|
def aggregate_results(self, all_results: List[Dict]) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Aggregate results per stop_loss_pct and timeframe
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_results: List of result dictionaries from all timeframes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of aggregated summary rows
|
||||||
|
"""
|
||||||
|
grouped = defaultdict(list)
|
||||||
|
for row in all_results:
|
||||||
|
key = (row['timeframe'], row['stop_loss_pct'])
|
||||||
|
grouped[key].append(row)
|
||||||
|
|
||||||
|
summary_rows = []
|
||||||
|
for (timeframe, stop_loss_pct), rows in grouped.items():
|
||||||
|
summary = self._aggregate_group(rows, timeframe, stop_loss_pct)
|
||||||
|
summary_rows.append(summary)
|
||||||
|
|
||||||
|
return summary_rows
|
||||||
|
|
||||||
|
def _aggregate_group(self, rows: List[Dict], timeframe: str, stop_loss_pct: float) -> Dict:
|
||||||
|
"""Aggregate a group of rows with the same timeframe and stop loss"""
|
||||||
|
if not rows:
|
||||||
|
raise ValueError(f"No rows to aggregate for {timeframe} with {stop_loss_pct} stop loss")
|
||||||
|
|
||||||
|
# Validate all rows have required fields
|
||||||
|
required_fields = ['n_trades', 'n_stop_loss', 'win_rate', 'max_drawdown', 'avg_trade', 'profit_ratio', 'final_usd', 'total_fees_usd', 'initial_usd']
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
missing_fields = [field for field in required_fields if field not in row]
|
||||||
|
if missing_fields:
|
||||||
|
raise ValueError(f"Row {i} missing required fields: {missing_fields}")
|
||||||
|
|
||||||
|
total_trades = sum(r['n_trades'] for r in rows)
|
||||||
|
total_stop_loss = sum(r['n_stop_loss'] for r in rows)
|
||||||
|
|
||||||
|
# Calculate averages (no defaults, expect all values to be present)
|
||||||
|
avg_win_rate = np.mean([r['win_rate'] for r in rows])
|
||||||
|
avg_max_drawdown = np.mean([r['max_drawdown'] for r in rows])
|
||||||
|
avg_avg_trade = np.mean([r['avg_trade'] for r in rows])
|
||||||
|
|
||||||
|
# Handle infinite profit ratios properly
|
||||||
|
finite_profit_ratios = [r['profit_ratio'] for r in rows if not np.isinf(r['profit_ratio'])]
|
||||||
|
avg_profit_ratio = np.mean(finite_profit_ratios) if finite_profit_ratios else 0
|
||||||
|
|
||||||
|
# Calculate final USD and fees (no defaults)
|
||||||
|
final_usd = np.mean([r['final_usd'] for r in rows])
|
||||||
|
total_fees_usd = np.mean([r['total_fees_usd'] for r in rows])
|
||||||
|
initial_usd = rows[0]['initial_usd']
|
||||||
|
|
||||||
|
return {
|
||||||
|
"timeframe": timeframe,
|
||||||
|
"stop_loss_pct": stop_loss_pct,
|
||||||
|
"n_trades": total_trades,
|
||||||
|
"n_stop_loss": total_stop_loss,
|
||||||
|
"win_rate": avg_win_rate,
|
||||||
|
"max_drawdown": avg_max_drawdown,
|
||||||
|
"avg_trade": avg_avg_trade,
|
||||||
|
"profit_ratio": avg_profit_ratio,
|
||||||
|
"initial_usd": initial_usd,
|
||||||
|
"final_usd": final_usd,
|
||||||
|
"total_fees_usd": total_fees_usd,
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_trade_file(self, trades: List[Dict], timeframe: str, stop_loss_pct: float) -> None:
|
||||||
|
"""
|
||||||
|
Save individual trade file with summary header
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trades: List of trades for this combination
|
||||||
|
timeframe: Timeframe name
|
||||||
|
stop_loss_pct: Stop loss percentage
|
||||||
|
"""
|
||||||
|
if not trades:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate filename
|
||||||
|
sl_percent = int(round(stop_loss_pct * 100))
|
||||||
|
trades_filename = os.path.join(self.storage.results_dir, f"trades_{timeframe}_ST{sl_percent}pct.csv")
|
||||||
|
|
||||||
|
# Prepare summary from first trade
|
||||||
|
sample_trade = trades[0]
|
||||||
|
summary_fields = ["timeframe", "stop_loss_pct", "n_trades", "win_rate"]
|
||||||
|
summary_values = [timeframe, stop_loss_pct, len(trades), "calculated_elsewhere"]
|
||||||
|
|
||||||
|
# Write file with header and trades
|
||||||
|
trades_fieldnames = ["entry_time", "exit_time", "entry_price", "exit_price", "profit_pct", "type", "fee_usd"]
|
||||||
|
|
||||||
|
with open(trades_filename, "w", newline="") as f:
|
||||||
|
# Write summary header
|
||||||
|
f.write("\t".join(summary_fields) + "\n")
|
||||||
|
f.write("\t".join(str(v) for v in summary_values) + "\n")
|
||||||
|
|
||||||
|
# Write trades
|
||||||
|
writer = csv.DictWriter(f, fieldnames=trades_fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
for trade in trades:
|
||||||
|
# Validate all required fields are present
|
||||||
|
missing_fields = [k for k in trades_fieldnames if k not in trade]
|
||||||
|
if missing_fields:
|
||||||
|
raise ValueError(f"Trade missing required fields for CSV: {missing_fields} in trade: {trade}")
|
||||||
|
writer.writerow({k: trade[k] for k in trades_fieldnames})
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Trades saved to {trades_filename}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to save trades file for {timeframe}_ST{int(round(stop_loss_pct * 100))}pct: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
def save_backtest_results(
|
||||||
|
self,
|
||||||
|
results: List[Dict],
|
||||||
|
metadata_lines: List[str],
|
||||||
|
timestamp: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Save aggregated backtest results to CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of aggregated result dictionaries
|
||||||
|
metadata_lines: List of metadata strings
|
||||||
|
timestamp: Timestamp for filename
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to saved file
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
filename = f"{timestamp}_backtest.csv"
|
||||||
|
fieldnames = [
|
||||||
|
"timeframe", "stop_loss_pct", "n_trades", "n_stop_loss", "win_rate",
|
||||||
|
"max_drawdown", "avg_trade", "profit_ratio", "final_usd", "total_fees_usd"
|
||||||
|
]
|
||||||
|
|
||||||
|
filepath = self.storage.write_backtest_results(filename, fieldnames, results, metadata_lines)
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Backtest results saved to {filepath}")
|
||||||
|
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to save backtest results: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
def get_price_info(self, data_df: pd.DataFrame, date: str) -> Tuple[Optional[str], Optional[float]]:
|
||||||
|
"""
|
||||||
|
Get nearest price information for a given date
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_df: DataFrame with price data
|
||||||
|
date: Target date string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (nearest_time, price) or (None, None) if no data
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if len(data_df) == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
target_ts = pd.to_datetime(date)
|
||||||
|
nearest_idx = data_df.index.get_indexer([target_ts], method='nearest')[0]
|
||||||
|
nearest_time = data_df.index[nearest_idx]
|
||||||
|
price = data_df.iloc[nearest_idx]['close']
|
||||||
|
|
||||||
|
return str(nearest_time), float(price)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.logging:
|
||||||
|
self.logging.warning(f"Could not get price info for {date}: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def save_all_trade_files(self, all_trades: List[Dict]) -> None:
|
||||||
|
"""
|
||||||
|
Save all trade files in batch after parallel execution completes
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_trades: List of all trades from all tasks
|
||||||
|
"""
|
||||||
|
if not all_trades:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Group trades by timeframe and stop loss
|
||||||
|
trade_groups = {}
|
||||||
|
for trade in all_trades:
|
||||||
|
timeframe = trade.get('timeframe')
|
||||||
|
stop_loss_pct = trade.get('stop_loss_pct')
|
||||||
|
if timeframe and stop_loss_pct is not None:
|
||||||
|
key = (timeframe, stop_loss_pct)
|
||||||
|
if key not in trade_groups:
|
||||||
|
trade_groups[key] = []
|
||||||
|
trade_groups[key].append(trade)
|
||||||
|
|
||||||
|
# Save each group
|
||||||
|
for (timeframe, stop_loss_pct), trades in trade_groups.items():
|
||||||
|
self.save_trade_file(trades, timeframe, stop_loss_pct)
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Saved {len(trade_groups)} trade files in batch")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to save trade files in batch: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
132
test_bbrsi.py
Normal file
132
test_bbrsi.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
import logging
|
||||||
|
import seaborn as sns
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from cycles.utils.storage import Storage
|
||||||
|
from cycles.utils.data_utils import aggregate_to_daily
|
||||||
|
from cycles.Analysis.boillinger_band import BollingerBands
|
||||||
|
from cycles.Analysis.rsi import RSI
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler("backtest.log"),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
config_minute = {
|
||||||
|
"start_date": "2022-01-01",
|
||||||
|
"stop_date": "2023-01-01",
|
||||||
|
"data_file": "btcusd_1-min_data.csv"
|
||||||
|
}
|
||||||
|
|
||||||
|
config_day = {
|
||||||
|
"start_date": "2022-01-01",
|
||||||
|
"stop_date": "2023-01-01",
|
||||||
|
"data_file": "btcusd_1-day_data.csv"
|
||||||
|
}
|
||||||
|
|
||||||
|
IS_DAY = True
|
||||||
|
|
||||||
|
def no_strategy(data_bb, data_with_rsi):
|
||||||
|
buy_condition = pd.Series([False] * len(data_bb), index=data_bb.index)
|
||||||
|
sell_condition = pd.Series([False] * len(data_bb), index=data_bb.index)
|
||||||
|
return buy_condition, sell_condition
|
||||||
|
|
||||||
|
def strategy_1(data_bb, data_with_rsi):
|
||||||
|
# Long trade: price move below lower Bollinger band and RSI go below 25
|
||||||
|
buy_condition = (data_bb['close'] < data_bb['LowerBand']) & (data_bb['RSI'] < 25)
|
||||||
|
# Short only: price move above top Bollinger band and RSI goes over 75
|
||||||
|
sell_condition = (data_bb['close'] > data_bb['UpperBand']) & (data_bb['RSI'] > 75)
|
||||||
|
return buy_condition, sell_condition
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
storage = Storage(logging=logging)
|
||||||
|
|
||||||
|
if IS_DAY:
|
||||||
|
config = config_day
|
||||||
|
else:
|
||||||
|
config = config_minute
|
||||||
|
|
||||||
|
data = storage.load_data(config["data_file"], config["start_date"], config["stop_date"])
|
||||||
|
|
||||||
|
if not IS_DAY:
|
||||||
|
data_daily = aggregate_to_daily(data)
|
||||||
|
storage.save_data(data, "btcusd_1-day_data.csv")
|
||||||
|
df_to_plot = data_daily
|
||||||
|
else:
|
||||||
|
df_to_plot = data
|
||||||
|
|
||||||
|
bb = BollingerBands(period=30, std_dev_multiplier=2.0)
|
||||||
|
data_bb = bb.calculate(df_to_plot.copy())
|
||||||
|
|
||||||
|
rsi_calculator = RSI(period=13)
|
||||||
|
data_with_rsi = rsi_calculator.calculate(df_to_plot.copy(), price_column='close')
|
||||||
|
|
||||||
|
# Combine BB and RSI data into a single DataFrame for signal generation
|
||||||
|
# Ensure indices are aligned; they should be as both are from df_to_plot.copy()
|
||||||
|
if 'RSI' in data_with_rsi.columns:
|
||||||
|
data_bb['RSI'] = data_with_rsi['RSI']
|
||||||
|
else:
|
||||||
|
# If RSI wasn't calculated (e.g., not enough data), create a dummy column with NaNs
|
||||||
|
# to prevent errors later, though signals won't be generated.
|
||||||
|
data_bb['RSI'] = pd.Series(index=data_bb.index, dtype=float)
|
||||||
|
logging.warning("RSI column not found or not calculated. Signals relying on RSI may not be generated.")
|
||||||
|
|
||||||
|
strategy = 1
|
||||||
|
if strategy == 1:
|
||||||
|
buy_condition, sell_condition = strategy_1(data_bb, data_with_rsi)
|
||||||
|
else:
|
||||||
|
buy_condition, sell_condition = no_strategy(data_bb, data_with_rsi)
|
||||||
|
|
||||||
|
buy_signals = data_bb[buy_condition]
|
||||||
|
sell_signals = data_bb[sell_condition]
|
||||||
|
|
||||||
|
# plot the data with seaborn library
|
||||||
|
if df_to_plot is not None and not df_to_plot.empty:
|
||||||
|
# Create a figure with two subplots, sharing the x-axis
|
||||||
|
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 8), sharex=True)
|
||||||
|
|
||||||
|
# Plot 1: Close Price and Bollinger Bands
|
||||||
|
sns.lineplot(x=data_bb.index, y='close', data=data_bb, label='Close Price', ax=ax1)
|
||||||
|
sns.lineplot(x=data_bb.index, y='UpperBand', data=data_bb, label='Upper Band (BB)', ax=ax1)
|
||||||
|
sns.lineplot(x=data_bb.index, y='LowerBand', data=data_bb, label='Lower Band (BB)', ax=ax1)
|
||||||
|
# Plot Buy/Sell signals on Price chart
|
||||||
|
if not buy_signals.empty:
|
||||||
|
ax1.scatter(buy_signals.index, buy_signals['close'], color='green', marker='o', s=20, label='Buy Signal', zorder=5)
|
||||||
|
if not sell_signals.empty:
|
||||||
|
ax1.scatter(sell_signals.index, sell_signals['close'], color='red', marker='o', s=20, label='Sell Signal', zorder=5)
|
||||||
|
ax1.set_title('Price and Bollinger Bands with Signals')
|
||||||
|
ax1.set_ylabel('Price')
|
||||||
|
ax1.legend()
|
||||||
|
ax1.grid(True)
|
||||||
|
|
||||||
|
# Plot 2: RSI
|
||||||
|
if 'RSI' in data_bb.columns: # Check data_bb now as it should contain RSI
|
||||||
|
sns.lineplot(x=data_bb.index, y='RSI', data=data_bb, label='RSI (14)', ax=ax2, color='purple')
|
||||||
|
ax2.axhline(75, color='red', linestyle='--', linewidth=0.8, label='Overbought (75)')
|
||||||
|
ax2.axhline(25, color='green', linestyle='--', linewidth=0.8, label='Oversold (25)')
|
||||||
|
# Plot Buy/Sell signals on RSI chart
|
||||||
|
if not buy_signals.empty:
|
||||||
|
ax2.scatter(buy_signals.index, buy_signals['RSI'], color='green', marker='o', s=20, label='Buy Signal (RSI)', zorder=5)
|
||||||
|
if not sell_signals.empty:
|
||||||
|
ax2.scatter(sell_signals.index, sell_signals['RSI'], color='red', marker='o', s=20, label='Sell Signal (RSI)', zorder=5)
|
||||||
|
ax2.set_title('Relative Strength Index (RSI) with Signals')
|
||||||
|
ax2.set_ylabel('RSI Value')
|
||||||
|
ax2.set_ylim(0, 100) # RSI is typically bounded between 0 and 100
|
||||||
|
ax2.legend()
|
||||||
|
ax2.grid(True)
|
||||||
|
else:
|
||||||
|
logging.info("RSI data not available for plotting.")
|
||||||
|
|
||||||
|
plt.xlabel('Date') # Common X-axis label
|
||||||
|
fig.tight_layout() # Adjust layout to prevent overlapping titles/labels
|
||||||
|
plt.show()
|
||||||
|
else:
|
||||||
|
logging.info("No data to plot.")
|
||||||
|
|
||||||
@@ -1,259 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import ta
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import matplotlib.dates as mdates
|
|
||||||
import logging
|
|
||||||
import mplfinance as mpf
|
|
||||||
from matplotlib.patches import Rectangle
|
|
||||||
|
|
||||||
class TrendDetectorMACD:
|
|
||||||
def __init__(self, data, verbose=False):
|
|
||||||
self.data = data
|
|
||||||
self.verbose = verbose
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
|
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s')
|
|
||||||
self.logger = logging.getLogger('TrendDetector')
|
|
||||||
|
|
||||||
# Convert data to pandas DataFrame if it's not already
|
|
||||||
if not isinstance(self.data, pd.DataFrame):
|
|
||||||
if isinstance(self.data, list):
|
|
||||||
self.logger.info("Converting list to DataFrame")
|
|
||||||
self.data = pd.DataFrame({'close': self.data})
|
|
||||||
else:
|
|
||||||
self.logger.error("Invalid data format provided")
|
|
||||||
raise ValueError("Data must be a pandas DataFrame or a list")
|
|
||||||
|
|
||||||
self.logger.info(f"Initialized TrendDetector with {len(self.data)} data points")
|
|
||||||
|
|
||||||
def detect_trends_MACD_signal(self):
|
|
||||||
self.logger.info("Starting trend detection")
|
|
||||||
if len(self.data) < 3:
|
|
||||||
self.logger.warning("Not enough data points for trend detection")
|
|
||||||
return {"error": "Not enough data points for trend detection"}
|
|
||||||
|
|
||||||
# Create a copy of the DataFrame to avoid modifying the original
|
|
||||||
df = self.data.copy()
|
|
||||||
self.logger.info("Created copy of input data")
|
|
||||||
|
|
||||||
# If 'close' column doesn't exist, try to use a relevant column
|
|
||||||
if 'close' not in df.columns and len(df.columns) > 0:
|
|
||||||
self.logger.info(f"'close' column not found, using {df.columns[0]} instead")
|
|
||||||
df['close'] = df[df.columns[0]] # Use the first column as 'close'
|
|
||||||
|
|
||||||
# Add trend indicators
|
|
||||||
self.logger.info("Calculating MACD indicators")
|
|
||||||
# Moving Average Convergence Divergence (MACD)
|
|
||||||
df['macd'] = ta.trend.macd(df['close'])
|
|
||||||
df['macd_signal'] = ta.trend.macd_signal(df['close'])
|
|
||||||
df['macd_diff'] = ta.trend.macd_diff(df['close'])
|
|
||||||
|
|
||||||
# Directional Movement Index (DMI)
|
|
||||||
if all(col in df.columns for col in ['high', 'low', 'close']):
|
|
||||||
self.logger.info("Calculating ADX indicators")
|
|
||||||
df['adx'] = ta.trend.adx(df['high'], df['low'], df['close'])
|
|
||||||
df['adx_pos'] = ta.trend.adx_pos(df['high'], df['low'], df['close'])
|
|
||||||
df['adx_neg'] = ta.trend.adx_neg(df['high'], df['low'], df['close'])
|
|
||||||
|
|
||||||
# Identify trend changes
|
|
||||||
self.logger.info("Identifying trend changes")
|
|
||||||
df['trend'] = np.where(df['macd'] > df['macd_signal'], 'up', 'down')
|
|
||||||
df['trend_change'] = df['trend'] != df['trend'].shift(1)
|
|
||||||
|
|
||||||
# Generate trend segments
|
|
||||||
self.logger.info("Generating trend segments")
|
|
||||||
trends = []
|
|
||||||
trend_start = 0
|
|
||||||
|
|
||||||
for i in range(1, len(df)):
|
|
||||||
|
|
||||||
if df['trend_change'].iloc[i]:
|
|
||||||
if i > trend_start:
|
|
||||||
trends.append({
|
|
||||||
"type": df['trend'].iloc[i-1],
|
|
||||||
"start_index": trend_start,
|
|
||||||
"end_index": i-1,
|
|
||||||
"start_value": df['close'].iloc[trend_start],
|
|
||||||
"end_value": df['close'].iloc[i-1]
|
|
||||||
})
|
|
||||||
trend_start = i
|
|
||||||
|
|
||||||
# Add the last trend
|
|
||||||
if trend_start < len(df):
|
|
||||||
trends.append({
|
|
||||||
"type": df['trend'].iloc[-1],
|
|
||||||
"start_index": trend_start,
|
|
||||||
"end_index": len(df)-1,
|
|
||||||
"start_value": df['close'].iloc[trend_start],
|
|
||||||
"end_value": df['close'].iloc[-1]
|
|
||||||
})
|
|
||||||
|
|
||||||
self.logger.info(f"Detected {len(trends)} trend segments")
|
|
||||||
return trends
|
|
||||||
|
|
||||||
def get_strongest_trend(self):
|
|
||||||
self.logger.info("Finding strongest trend")
|
|
||||||
trends = self.detect_trends_MACD_signal()
|
|
||||||
if isinstance(trends, dict) and "error" in trends:
|
|
||||||
self.logger.warning(f"Error in trend detection: {trends['error']}")
|
|
||||||
return trends
|
|
||||||
|
|
||||||
if not trends:
|
|
||||||
self.logger.info("No significant trends detected")
|
|
||||||
return {"message": "No significant trends detected"}
|
|
||||||
|
|
||||||
strongest = max(trends, key=lambda x: abs(x["end_value"] - x["start_value"]))
|
|
||||||
self.logger.info(f"Strongest trend: {strongest['type']} from index {strongest['start_index']} to {strongest['end_index']}")
|
|
||||||
return strongest
|
|
||||||
|
|
||||||
def plot_trends(self, trends):
|
|
||||||
"""
|
|
||||||
Plot price data with identified trends highlighted using candlestick charts.
|
|
||||||
"""
|
|
||||||
self.logger.info("Plotting trends with candlesticks")
|
|
||||||
if isinstance(trends, dict) and "error" in trends:
|
|
||||||
self.logger.error(trends["error"])
|
|
||||||
print(trends["error"])
|
|
||||||
return
|
|
||||||
|
|
||||||
if not trends:
|
|
||||||
self.logger.warning("No significant trends detected for plotting")
|
|
||||||
print("No significant trends detected")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Create a figure with 2 subplots that share the x-axis
|
|
||||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), gridspec_kw={'height_ratios': [2, 1]}, sharex=True)
|
|
||||||
self.logger.info("Creating plot figure with shared x-axis")
|
|
||||||
|
|
||||||
# Prepare data for candlestick chart
|
|
||||||
df = self.data.copy()
|
|
||||||
|
|
||||||
# Ensure required columns exist for candlestick
|
|
||||||
required_cols = ['open', 'high', 'low', 'close']
|
|
||||||
if not all(col in df.columns for col in required_cols):
|
|
||||||
self.logger.warning("Missing required columns for candlestick. Defaulting to line chart.")
|
|
||||||
if 'close' in df.columns:
|
|
||||||
ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'],
|
|
||||||
df['close'], color='black', alpha=0.7, linewidth=1, label='Price')
|
|
||||||
else:
|
|
||||||
ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'],
|
|
||||||
df[df.columns[0]], color='black', alpha=0.7, linewidth=1, label='Price')
|
|
||||||
else:
|
|
||||||
# Get x values (dates if available, otherwise indices)
|
|
||||||
if 'datetime' in df.columns:
|
|
||||||
x_label = 'Date'
|
|
||||||
# Format date axis
|
|
||||||
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
|
||||||
ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
|
||||||
fig.autofmt_xdate()
|
|
||||||
self.logger.info("Using datetime for x-axis")
|
|
||||||
|
|
||||||
# For candlestick, ensure datetime is the index
|
|
||||||
if df.index.name != 'datetime':
|
|
||||||
df = df.set_index('datetime')
|
|
||||||
else:
|
|
||||||
x_label = 'Index'
|
|
||||||
self.logger.info("Using index for x-axis")
|
|
||||||
|
|
||||||
# Plot candlestick chart
|
|
||||||
up_color = 'green'
|
|
||||||
down_color = 'red'
|
|
||||||
|
|
||||||
# Draw candlesticks manually
|
|
||||||
width = 0.6
|
|
||||||
for i in range(len(df)):
|
|
||||||
# Get OHLC values for this candle
|
|
||||||
open_val = df['open'].iloc[i]
|
|
||||||
close_val = df['close'].iloc[i]
|
|
||||||
high_val = df['high'].iloc[i]
|
|
||||||
low_val = df['low'].iloc[i]
|
|
||||||
idx = df.index[i]
|
|
||||||
|
|
||||||
# Determine candle color
|
|
||||||
color = up_color if close_val >= open_val else down_color
|
|
||||||
|
|
||||||
# Plot candle body
|
|
||||||
body_height = abs(close_val - open_val)
|
|
||||||
bottom = min(open_val, close_val)
|
|
||||||
rect = Rectangle((i - width/2, bottom), width, body_height, color=color, alpha=0.8)
|
|
||||||
ax1.add_patch(rect)
|
|
||||||
|
|
||||||
# Plot candle wicks
|
|
||||||
ax1.plot([i, i], [low_val, high_val], color='black', linewidth=1)
|
|
||||||
|
|
||||||
# Set appropriate x-axis limits
|
|
||||||
ax1.set_xlim(-0.5, len(df) - 0.5)
|
|
||||||
|
|
||||||
# Highlight each trend with a different color
|
|
||||||
self.logger.info("Highlighting trends on plot")
|
|
||||||
for trend in trends:
|
|
||||||
start_idx = trend['start_index']
|
|
||||||
end_idx = trend['end_index']
|
|
||||||
trend_type = trend['type']
|
|
||||||
|
|
||||||
# Get x-coordinates for trend plotting
|
|
||||||
x_start = start_idx
|
|
||||||
x_end = end_idx
|
|
||||||
|
|
||||||
# Get y-coordinates for trend line
|
|
||||||
if 'close' in df.columns:
|
|
||||||
y_start = df['close'].iloc[start_idx]
|
|
||||||
y_end = df['close'].iloc[end_idx]
|
|
||||||
else:
|
|
||||||
y_start = df[df.columns[0]].iloc[start_idx]
|
|
||||||
y_end = df[df.columns[0]].iloc[end_idx]
|
|
||||||
|
|
||||||
# Choose color based on trend type
|
|
||||||
color = 'green' if trend_type == 'up' else 'red'
|
|
||||||
|
|
||||||
# Plot trend line
|
|
||||||
ax1.plot([x_start, x_end], [y_start, y_end], color=color, linewidth=2,
|
|
||||||
label=f"{trend_type.capitalize()} Trend" if f"{trend_type.capitalize()} Trend" not in ax1.get_legend_handles_labels()[1] else "")
|
|
||||||
|
|
||||||
# Add markers at start and end points
|
|
||||||
ax1.scatter(x_start, y_start, color=color, marker='o', s=50)
|
|
||||||
ax1.scatter(x_end, y_end, color=color, marker='s', s=50)
|
|
||||||
|
|
||||||
# Configure first subplot
|
|
||||||
ax1.set_title('Price with Trends (Candlestick)', fontsize=16)
|
|
||||||
ax1.set_ylabel('Price', fontsize=14)
|
|
||||||
ax1.grid(alpha=0.3)
|
|
||||||
ax1.legend()
|
|
||||||
|
|
||||||
# Create MACD in second subplot
|
|
||||||
self.logger.info("Creating MACD subplot")
|
|
||||||
|
|
||||||
# Calculate MACD indicators if not already present
|
|
||||||
if 'macd' not in df.columns:
|
|
||||||
if 'close' not in df.columns and len(df.columns) > 0:
|
|
||||||
df['close'] = df[df.columns[0]]
|
|
||||||
|
|
||||||
df['macd'] = ta.trend.macd(df['close'])
|
|
||||||
df['macd_signal'] = ta.trend.macd_signal(df['close'])
|
|
||||||
df['macd_diff'] = ta.trend.macd_diff(df['close'])
|
|
||||||
|
|
||||||
# Plot MACD components on second subplot
|
|
||||||
x_indices = np.arange(len(df))
|
|
||||||
ax2.plot(x_indices, df['macd'], label='MACD', color='blue')
|
|
||||||
ax2.plot(x_indices, df['macd_signal'], label='Signal', color='orange')
|
|
||||||
|
|
||||||
# Plot MACD histogram
|
|
||||||
for i in range(len(df)):
|
|
||||||
if df['macd_diff'].iloc[i] >= 0:
|
|
||||||
ax2.bar(i, df['macd_diff'].iloc[i], color='green', alpha=0.5, width=0.8)
|
|
||||||
else:
|
|
||||||
ax2.bar(i, df['macd_diff'].iloc[i], color='red', alpha=0.5, width=0.8)
|
|
||||||
|
|
||||||
ax2.set_title('MACD Indicator', fontsize=16)
|
|
||||||
ax2.set_xlabel(x_label, fontsize=14)
|
|
||||||
ax2.set_ylabel('MACD', fontsize=14)
|
|
||||||
ax2.grid(alpha=0.3)
|
|
||||||
ax2.legend()
|
|
||||||
|
|
||||||
# Enable synchronized zooming
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.subplots_adjust(hspace=0.1)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
return plt
|
|
||||||
@@ -1,205 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import logging
|
|
||||||
from scipy.signal import find_peaks
|
|
||||||
import matplotlib.dates as mdates
|
|
||||||
from scipy import stats
|
|
||||||
|
|
||||||
class TrendDetectorSimple:
|
|
||||||
def __init__(self, data, verbose=False):
|
|
||||||
"""
|
|
||||||
Initialize the TrendDetectorSimple class.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- data: pandas DataFrame containing price data
|
|
||||||
- verbose: boolean, whether to display detailed logging information
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.data = data
|
|
||||||
self.verbose = verbose
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
|
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s')
|
|
||||||
self.logger = logging.getLogger('TrendDetectorSimple')
|
|
||||||
|
|
||||||
# Convert data to pandas DataFrame if it's not already
|
|
||||||
if not isinstance(self.data, pd.DataFrame):
|
|
||||||
if isinstance(self.data, list):
|
|
||||||
self.logger.info("Converting list to DataFrame")
|
|
||||||
self.data = pd.DataFrame({'close': self.data})
|
|
||||||
else:
|
|
||||||
self.logger.error("Invalid data format provided")
|
|
||||||
raise ValueError("Data must be a pandas DataFrame or a list")
|
|
||||||
|
|
||||||
self.logger.info(f"Initialized TrendDetectorSimple with {len(self.data)} data points")
|
|
||||||
|
|
||||||
def detect_trends(self):
|
|
||||||
"""
|
|
||||||
Detect trends by identifying local minima and maxima in the price data
|
|
||||||
using scipy.signal.find_peaks.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- prominence: float, required prominence of peaks (relative to the price range)
|
|
||||||
- width: int, required width of peaks in data points
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- DataFrame with columns for timestamps, prices, and trend indicators
|
|
||||||
"""
|
|
||||||
self.logger.info(f"Detecting trends")
|
|
||||||
|
|
||||||
df = self.data.copy()
|
|
||||||
close_prices = df['close'].values
|
|
||||||
|
|
||||||
max_peaks, _ = find_peaks(close_prices)
|
|
||||||
min_peaks, _ = find_peaks(-close_prices)
|
|
||||||
|
|
||||||
self.logger.info(f"Found {len(min_peaks)} local minima and {len(max_peaks)} local maxima")
|
|
||||||
|
|
||||||
df['is_min'] = False
|
|
||||||
df['is_max'] = False
|
|
||||||
|
|
||||||
for peak in max_peaks:
|
|
||||||
df.at[peak, 'is_max'] = True
|
|
||||||
for peak in min_peaks:
|
|
||||||
df.at[peak, 'is_min'] = True
|
|
||||||
|
|
||||||
result = df[['datetime', 'close', 'is_min', 'is_max']].copy()
|
|
||||||
|
|
||||||
# Perform linear regression on min_peaks and max_peaks
|
|
||||||
self.logger.info("Performing linear regression on min and max peaks")
|
|
||||||
min_prices = df['close'].iloc[min_peaks].values
|
|
||||||
max_prices = df['close'].iloc[max_peaks].values
|
|
||||||
|
|
||||||
# Linear regression for min peaks if we have at least 2 points
|
|
||||||
min_slope, min_intercept, min_r_value, _, _ = stats.linregress(min_peaks, min_prices)
|
|
||||||
# Linear regression for max peaks if we have at least 2 points
|
|
||||||
max_slope, max_intercept, max_r_value, _, _ = stats.linregress(max_peaks, max_prices)
|
|
||||||
|
|
||||||
# Calculate Simple Moving Averages (SMA) for 7 and 15 periods
|
|
||||||
self.logger.info("Calculating SMA-7 and SMA-15")
|
|
||||||
|
|
||||||
# Calculate SMA values and exclude NaN values
|
|
||||||
sma_7 = df['close'].rolling(window=7).mean().dropna().values
|
|
||||||
sma_15 = df['close'].rolling(window=15).mean().dropna().values
|
|
||||||
|
|
||||||
# Add SMA values to regression_results
|
|
||||||
analysis_results = {}
|
|
||||||
analysis_results['linear_regression'] = {
|
|
||||||
'min': {
|
|
||||||
'slope': min_slope,
|
|
||||||
'intercept': min_intercept,
|
|
||||||
'r_squared': min_r_value ** 2
|
|
||||||
},
|
|
||||||
'max': {
|
|
||||||
'slope': max_slope,
|
|
||||||
'intercept': max_intercept,
|
|
||||||
'r_squared': max_r_value ** 2
|
|
||||||
}
|
|
||||||
}
|
|
||||||
analysis_results['sma'] = {
|
|
||||||
'7': sma_7,
|
|
||||||
'15': sma_15
|
|
||||||
}
|
|
||||||
|
|
||||||
self.logger.info(f"Min peaks regression: slope={min_slope:.4f}, intercept={min_intercept:.4f}, r²={min_r_value**2:.4f}")
|
|
||||||
self.logger.info(f"Max peaks regression: slope={max_slope:.4f}, intercept={max_intercept:.4f}, r²={max_r_value**2:.4f}")
|
|
||||||
|
|
||||||
return result, analysis_results
|
|
||||||
|
|
||||||
def plot_trends(self, trend_data, analysis_results):
|
|
||||||
"""
|
|
||||||
Plot the price data with detected trends using a candlestick chart.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- trend_data: DataFrame, the output from detect_trends(). If None, detect_trends() will be called.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- None (displays the plot)
|
|
||||||
"""
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from matplotlib.patches import Rectangle
|
|
||||||
|
|
||||||
# Create the figure and axis
|
|
||||||
fig, ax = plt.subplots(figsize=(12, 8))
|
|
||||||
|
|
||||||
# Create a copy of the data
|
|
||||||
df = self.data.copy()
|
|
||||||
|
|
||||||
# Plot candlestick chart
|
|
||||||
up_color = 'green'
|
|
||||||
down_color = 'red'
|
|
||||||
|
|
||||||
# Draw candlesticks manually
|
|
||||||
width = 0.6
|
|
||||||
x_values = range(len(df))
|
|
||||||
|
|
||||||
for i in range(len(df)):
|
|
||||||
# Get OHLC values for this candle
|
|
||||||
open_val = df['open'].iloc[i]
|
|
||||||
close_val = df['close'].iloc[i]
|
|
||||||
high_val = df['high'].iloc[i]
|
|
||||||
low_val = df['low'].iloc[i]
|
|
||||||
|
|
||||||
# Determine candle color
|
|
||||||
color = up_color if close_val >= open_val else down_color
|
|
||||||
|
|
||||||
# Plot candle body
|
|
||||||
body_height = abs(close_val - open_val)
|
|
||||||
bottom = min(open_val, close_val)
|
|
||||||
rect = Rectangle((i - width/2, bottom), width, body_height, color=color, alpha=0.8)
|
|
||||||
ax.add_patch(rect)
|
|
||||||
|
|
||||||
# Plot candle wicks
|
|
||||||
ax.plot([i, i], [low_val, high_val], color='black', linewidth=1)
|
|
||||||
|
|
||||||
min_indices = trend_data.index[trend_data['is_min'] == True].tolist()
|
|
||||||
if min_indices:
|
|
||||||
min_y = [df['close'].iloc[i] for i in min_indices]
|
|
||||||
ax.scatter(min_indices, min_y, color='darkred', s=200, marker='^', label='Local Minima', zorder=100)
|
|
||||||
|
|
||||||
max_indices = trend_data.index[trend_data['is_max'] == True].tolist()
|
|
||||||
if max_indices:
|
|
||||||
max_y = [df['close'].iloc[i] for i in max_indices]
|
|
||||||
ax.scatter(max_indices, max_y, color='darkgreen', s=200, marker='v', label='Local Maxima', zorder=100)
|
|
||||||
|
|
||||||
if analysis_results:
|
|
||||||
x_vals = np.arange(len(df))
|
|
||||||
# Minima regression line (support)
|
|
||||||
min_slope = analysis_results['linear_regression']['min']['slope']
|
|
||||||
min_intercept = analysis_results['linear_regression']['min']['intercept']
|
|
||||||
min_line = min_slope * x_vals + min_intercept
|
|
||||||
ax.plot(x_vals, min_line, 'g--', linewidth=2, label='Minima Regression')
|
|
||||||
|
|
||||||
# Maxima regression line (resistance)
|
|
||||||
max_slope = analysis_results['linear_regression']['max']['slope']
|
|
||||||
max_intercept = analysis_results['linear_regression']['max']['intercept']
|
|
||||||
max_line = max_slope * x_vals + max_intercept
|
|
||||||
ax.plot(x_vals, max_line, 'r--', linewidth=2, label='Maxima Regression')
|
|
||||||
|
|
||||||
# SMA-7 line
|
|
||||||
sma_7 = analysis_results['sma']['7']
|
|
||||||
ax.plot(x_vals, sma_7, 'y-', linewidth=2, label='SMA-7')
|
|
||||||
|
|
||||||
# SMA-15 line
|
|
||||||
# sma_15 = analysis_results['sma']['15']
|
|
||||||
# valid_idx_15 = ~np.isnan(sma_15)
|
|
||||||
# ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], 'm-', linewidth=2, label='SMA-15')
|
|
||||||
|
|
||||||
# Set title and labels
|
|
||||||
ax.set_title('Price Candlestick Chart with Local Minima and Maxima', fontsize=14)
|
|
||||||
ax.set_xlabel('Date', fontsize=12)
|
|
||||||
ax.set_ylabel('Price', fontsize=12)
|
|
||||||
|
|
||||||
# Set appropriate x-axis limits
|
|
||||||
ax.set_xlim(-0.5, len(df) - 0.5)
|
|
||||||
|
|
||||||
# Add a legend
|
|
||||||
ax.legend(loc='best')
|
|
||||||
|
|
||||||
# Adjust layout
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
# Show the plot
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
39
xgboost/custom_xgboost.py
Normal file
39
xgboost/custom_xgboost.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import xgboost as xgb
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class CustomXGBoostGPU:
|
||||||
|
def __init__(self, X_train, X_test, y_train, y_test):
|
||||||
|
self.X_train = X_train.astype(np.float32)
|
||||||
|
self.X_test = X_test.astype(np.float32)
|
||||||
|
self.y_train = y_train.astype(np.float32)
|
||||||
|
self.y_test = y_test.astype(np.float32)
|
||||||
|
self.model = None
|
||||||
|
self.params = None # Will be set during training
|
||||||
|
|
||||||
|
def train(self, **xgb_params):
|
||||||
|
params = {
|
||||||
|
'tree_method': 'hist',
|
||||||
|
'device': 'cuda',
|
||||||
|
'objective': 'reg:squarederror',
|
||||||
|
'eval_metric': 'rmse',
|
||||||
|
'verbosity': 1,
|
||||||
|
}
|
||||||
|
params.update(xgb_params)
|
||||||
|
self.params = params # Store params for later access
|
||||||
|
dtrain = xgb.DMatrix(self.X_train, label=self.y_train)
|
||||||
|
dtest = xgb.DMatrix(self.X_test, label=self.y_test)
|
||||||
|
evals = [(dtrain, 'train'), (dtest, 'eval')]
|
||||||
|
self.model = xgb.train(params, dtrain, num_boost_round=100, evals=evals, early_stopping_rounds=10)
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def predict(self, X):
|
||||||
|
if self.model is None:
|
||||||
|
raise ValueError('Model not trained yet.')
|
||||||
|
dmatrix = xgb.DMatrix(X.astype(np.float32))
|
||||||
|
return self.model.predict(dmatrix)
|
||||||
|
|
||||||
|
def save_model(self, file_path):
|
||||||
|
"""Save the trained XGBoost model to the specified file path."""
|
||||||
|
if self.model is None:
|
||||||
|
raise ValueError('Model not trained yet.')
|
||||||
|
self.model.save_model(file_path)
|
||||||
806
xgboost/main.py
Normal file
806
xgboost/main.py
Normal file
@@ -0,0 +1,806 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from custom_xgboost import CustomXGBoostGPU
|
||||||
|
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
||||||
|
from plot_results import plot_prediction_error_distribution, plot_direction_transition_heatmap
|
||||||
|
from cycles.supertrend import Supertrends
|
||||||
|
import time
|
||||||
|
from numba import njit
|
||||||
|
import itertools
|
||||||
|
import csv
|
||||||
|
import pandas_ta as ta
|
||||||
|
|
||||||
|
def run_indicator(func, *args):
|
||||||
|
return func(*args)
|
||||||
|
|
||||||
|
def run_indicator_job(job):
|
||||||
|
import time
|
||||||
|
func, *args = job
|
||||||
|
indicator_name = func.__name__
|
||||||
|
start = time.time()
|
||||||
|
result = func(*args)
|
||||||
|
elapsed = time.time() - start
|
||||||
|
print(f'Indicator {indicator_name} computed in {elapsed:.4f} seconds')
|
||||||
|
return result
|
||||||
|
|
||||||
|
def calc_rsi(close):
|
||||||
|
from ta.momentum import RSIIndicator
|
||||||
|
return ('rsi', RSIIndicator(close, window=14).rsi())
|
||||||
|
|
||||||
|
def calc_macd(close):
|
||||||
|
from ta.trend import MACD
|
||||||
|
return ('macd', MACD(close).macd())
|
||||||
|
|
||||||
|
def calc_bollinger(close):
|
||||||
|
from ta.volatility import BollingerBands
|
||||||
|
bb = BollingerBands(close=close, window=20, window_dev=2)
|
||||||
|
return [
|
||||||
|
('bb_bbm', bb.bollinger_mavg()),
|
||||||
|
('bb_bbh', bb.bollinger_hband()),
|
||||||
|
('bb_bbl', bb.bollinger_lband()),
|
||||||
|
('bb_bb_width', bb.bollinger_hband() - bb.bollinger_lband())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_stochastic(high, low, close):
|
||||||
|
from ta.momentum import StochasticOscillator
|
||||||
|
stoch = StochasticOscillator(high=high, low=low, close=close, window=14, smooth_window=3)
|
||||||
|
return [
|
||||||
|
('stoch_k', stoch.stoch()),
|
||||||
|
('stoch_d', stoch.stoch_signal())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_atr(high, low, close):
|
||||||
|
from ta.volatility import AverageTrueRange
|
||||||
|
atr = AverageTrueRange(high=high, low=low, close=close, window=14)
|
||||||
|
return ('atr', atr.average_true_range())
|
||||||
|
|
||||||
|
def calc_cci(high, low, close):
|
||||||
|
from ta.trend import CCIIndicator
|
||||||
|
cci = CCIIndicator(high=high, low=low, close=close, window=20)
|
||||||
|
return ('cci', cci.cci())
|
||||||
|
|
||||||
|
def calc_williamsr(high, low, close):
|
||||||
|
from ta.momentum import WilliamsRIndicator
|
||||||
|
willr = WilliamsRIndicator(high=high, low=low, close=close, lbp=14)
|
||||||
|
return ('williams_r', willr.williams_r())
|
||||||
|
|
||||||
|
def calc_ema(close):
|
||||||
|
from ta.trend import EMAIndicator
|
||||||
|
ema = EMAIndicator(close=close, window=14)
|
||||||
|
return ('ema_14', ema.ema_indicator())
|
||||||
|
|
||||||
|
def calc_obv(close, volume):
|
||||||
|
from ta.volume import OnBalanceVolumeIndicator
|
||||||
|
obv = OnBalanceVolumeIndicator(close=close, volume=volume)
|
||||||
|
return ('obv', obv.on_balance_volume())
|
||||||
|
|
||||||
|
def calc_cmf(high, low, close, volume):
|
||||||
|
from ta.volume import ChaikinMoneyFlowIndicator
|
||||||
|
cmf = ChaikinMoneyFlowIndicator(high=high, low=low, close=close, volume=volume, window=20)
|
||||||
|
return ('cmf', cmf.chaikin_money_flow())
|
||||||
|
|
||||||
|
def calc_sma(close):
|
||||||
|
from ta.trend import SMAIndicator
|
||||||
|
return [
|
||||||
|
('sma_50', SMAIndicator(close, window=50).sma_indicator()),
|
||||||
|
('sma_200', SMAIndicator(close, window=200).sma_indicator())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_roc(close):
|
||||||
|
from ta.momentum import ROCIndicator
|
||||||
|
return ('roc_10', ROCIndicator(close, window=10).roc())
|
||||||
|
|
||||||
|
def calc_momentum(close):
|
||||||
|
return ('momentum_10', close - close.shift(10))
|
||||||
|
|
||||||
|
def calc_psar(high, low, close):
|
||||||
|
# Use the Numba-accelerated fast_psar function for speed
|
||||||
|
psar_values = fast_psar(np.array(high), np.array(low), np.array(close))
|
||||||
|
return [('psar', pd.Series(psar_values, index=close.index))]
|
||||||
|
|
||||||
|
def calc_donchian(high, low, close):
|
||||||
|
from ta.volatility import DonchianChannel
|
||||||
|
donchian = DonchianChannel(high, low, close, window=20)
|
||||||
|
return [
|
||||||
|
('donchian_hband', donchian.donchian_channel_hband()),
|
||||||
|
('donchian_lband', donchian.donchian_channel_lband()),
|
||||||
|
('donchian_mband', donchian.donchian_channel_mband())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_keltner(high, low, close):
|
||||||
|
from ta.volatility import KeltnerChannel
|
||||||
|
keltner = KeltnerChannel(high, low, close, window=20)
|
||||||
|
return [
|
||||||
|
('keltner_hband', keltner.keltner_channel_hband()),
|
||||||
|
('keltner_lband', keltner.keltner_channel_lband()),
|
||||||
|
('keltner_mband', keltner.keltner_channel_mband())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_dpo(close):
|
||||||
|
from ta.trend import DPOIndicator
|
||||||
|
return ('dpo_20', DPOIndicator(close, window=20).dpo())
|
||||||
|
|
||||||
|
def calc_ultimate(high, low, close):
|
||||||
|
from ta.momentum import UltimateOscillator
|
||||||
|
return ('ultimate_osc', UltimateOscillator(high, low, close).ultimate_oscillator())
|
||||||
|
|
||||||
|
def calc_ichimoku(high, low):
|
||||||
|
from ta.trend import IchimokuIndicator
|
||||||
|
ichimoku = IchimokuIndicator(high, low, window1=9, window2=26, window3=52)
|
||||||
|
return [
|
||||||
|
('ichimoku_a', ichimoku.ichimoku_a()),
|
||||||
|
('ichimoku_b', ichimoku.ichimoku_b()),
|
||||||
|
('ichimoku_base_line', ichimoku.ichimoku_base_line()),
|
||||||
|
('ichimoku_conversion_line', ichimoku.ichimoku_conversion_line())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_elder_ray(close, low, high):
|
||||||
|
from ta.trend import EMAIndicator
|
||||||
|
ema = EMAIndicator(close, window=13).ema_indicator()
|
||||||
|
return [
|
||||||
|
('elder_ray_bull', ema - low),
|
||||||
|
('elder_ray_bear', ema - high)
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_daily_return(close):
|
||||||
|
from ta.others import DailyReturnIndicator
|
||||||
|
return ('daily_return', DailyReturnIndicator(close).daily_return())
|
||||||
|
|
||||||
|
@njit
|
||||||
|
def fast_psar(high, low, close, af=0.02, max_af=0.2):
|
||||||
|
length = len(close)
|
||||||
|
psar = np.zeros(length)
|
||||||
|
bull = True
|
||||||
|
af_step = af
|
||||||
|
ep = low[0]
|
||||||
|
psar[0] = low[0]
|
||||||
|
for i in range(1, length):
|
||||||
|
prev_psar = psar[i-1]
|
||||||
|
if bull:
|
||||||
|
psar[i] = prev_psar + af_step * (ep - prev_psar)
|
||||||
|
if low[i] < psar[i]:
|
||||||
|
bull = False
|
||||||
|
psar[i] = ep
|
||||||
|
af_step = af
|
||||||
|
ep = low[i]
|
||||||
|
else:
|
||||||
|
if high[i] > ep:
|
||||||
|
ep = high[i]
|
||||||
|
af_step = min(af_step + af, max_af)
|
||||||
|
else:
|
||||||
|
psar[i] = prev_psar + af_step * (ep - prev_psar)
|
||||||
|
if high[i] > psar[i]:
|
||||||
|
bull = True
|
||||||
|
psar[i] = ep
|
||||||
|
af_step = af
|
||||||
|
ep = high[i]
|
||||||
|
else:
|
||||||
|
if low[i] < ep:
|
||||||
|
ep = low[i]
|
||||||
|
af_step = min(af_step + af, max_af)
|
||||||
|
return psar
|
||||||
|
|
||||||
|
def compute_lag(df, col, lag):
|
||||||
|
return df[col].shift(lag)
|
||||||
|
|
||||||
|
def compute_rolling(df, col, stat, window):
|
||||||
|
if stat == 'mean':
|
||||||
|
return df[col].rolling(window).mean()
|
||||||
|
elif stat == 'std':
|
||||||
|
return df[col].rolling(window).std()
|
||||||
|
elif stat == 'min':
|
||||||
|
return df[col].rolling(window).min()
|
||||||
|
elif stat == 'max':
|
||||||
|
return df[col].rolling(window).max()
|
||||||
|
|
||||||
|
def compute_log_return(df, horizon):
|
||||||
|
return np.log(df['Close'] / df['Close'].shift(horizon))
|
||||||
|
|
||||||
|
def compute_volatility(df, window):
|
||||||
|
return df['log_return'].rolling(window).std()
|
||||||
|
|
||||||
|
def run_feature_job(job, df):
|
||||||
|
feature_name, func, *args = job
|
||||||
|
print(f'Computing feature: {feature_name}')
|
||||||
|
result = func(df, *args)
|
||||||
|
return feature_name, result
|
||||||
|
|
||||||
|
def calc_adx(high, low, close):
|
||||||
|
from ta.trend import ADXIndicator
|
||||||
|
adx = ADXIndicator(high=high, low=low, close=close, window=14)
|
||||||
|
return [
|
||||||
|
('adx', adx.adx()),
|
||||||
|
('adx_pos', adx.adx_pos()),
|
||||||
|
('adx_neg', adx.adx_neg())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_trix(close):
|
||||||
|
from ta.trend import TRIXIndicator
|
||||||
|
trix = TRIXIndicator(close=close, window=15)
|
||||||
|
return ('trix', trix.trix())
|
||||||
|
|
||||||
|
def calc_vortex(high, low, close):
|
||||||
|
from ta.trend import VortexIndicator
|
||||||
|
vortex = VortexIndicator(high=high, low=low, close=close, window=14)
|
||||||
|
return [
|
||||||
|
('vortex_pos', vortex.vortex_indicator_pos()),
|
||||||
|
('vortex_neg', vortex.vortex_indicator_neg())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_kama(close):
|
||||||
|
import pandas_ta as ta
|
||||||
|
kama = ta.kama(close, length=10)
|
||||||
|
return ('kama', kama)
|
||||||
|
|
||||||
|
def calc_force_index(close, volume):
|
||||||
|
from ta.volume import ForceIndexIndicator
|
||||||
|
fi = ForceIndexIndicator(close=close, volume=volume, window=13)
|
||||||
|
return ('force_index', fi.force_index())
|
||||||
|
|
||||||
|
def calc_eom(high, low, volume):
|
||||||
|
from ta.volume import EaseOfMovementIndicator
|
||||||
|
eom = EaseOfMovementIndicator(high=high, low=low, volume=volume, window=14)
|
||||||
|
return ('eom', eom.ease_of_movement())
|
||||||
|
|
||||||
|
def calc_mfi(high, low, close, volume):
|
||||||
|
from ta.volume import MFIIndicator
|
||||||
|
mfi = MFIIndicator(high=high, low=low, close=close, volume=volume, window=14)
|
||||||
|
return ('mfi', mfi.money_flow_index())
|
||||||
|
|
||||||
|
def calc_adi(high, low, close, volume):
|
||||||
|
from ta.volume import AccDistIndexIndicator
|
||||||
|
adi = AccDistIndexIndicator(high=high, low=low, close=close, volume=volume)
|
||||||
|
return ('adi', adi.acc_dist_index())
|
||||||
|
|
||||||
|
def calc_tema(close):
|
||||||
|
import pandas_ta as ta
|
||||||
|
tema = ta.tema(close, length=10)
|
||||||
|
return ('tema', tema)
|
||||||
|
|
||||||
|
def calc_stochrsi(close):
|
||||||
|
from ta.momentum import StochRSIIndicator
|
||||||
|
stochrsi = StochRSIIndicator(close=close, window=14, smooth1=3, smooth2=3)
|
||||||
|
return [
|
||||||
|
('stochrsi', stochrsi.stochrsi()),
|
||||||
|
('stochrsi_k', stochrsi.stochrsi_k()),
|
||||||
|
('stochrsi_d', stochrsi.stochrsi_d())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_awesome_oscillator(high, low):
|
||||||
|
from ta.momentum import AwesomeOscillatorIndicator
|
||||||
|
ao = AwesomeOscillatorIndicator(high=high, low=low, window1=5, window2=34)
|
||||||
|
return ('awesome_osc', ao.awesome_oscillator())
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
IMPUTE_NANS = True # Set to True to impute NaNs, False to drop rows with NaNs
|
||||||
|
csv_path = './data/btcusd_1-min_data.csv'
|
||||||
|
csv_prefix = os.path.splitext(os.path.basename(csv_path))[0]
|
||||||
|
|
||||||
|
print('Reading CSV and filtering data...')
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
df = df[df['Volume'] != 0]
|
||||||
|
|
||||||
|
min_date = '2017-06-01'
|
||||||
|
print('Converting Timestamp and filtering by date...')
|
||||||
|
df['Timestamp'] = pd.to_datetime(df['Timestamp'], unit='s')
|
||||||
|
df = df[df['Timestamp'] >= min_date]
|
||||||
|
|
||||||
|
lags = 3
|
||||||
|
|
||||||
|
print('Calculating log returns as the new target...')
|
||||||
|
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
|
||||||
|
|
||||||
|
ohlcv_cols = ['Open', 'High', 'Low', 'Close', 'Volume']
|
||||||
|
window_sizes = [5, 15, 30] # in minutes, adjust as needed
|
||||||
|
|
||||||
|
features_dict = {}
|
||||||
|
|
||||||
|
print('Starting feature computation...')
|
||||||
|
feature_start_time = time.time()
|
||||||
|
|
||||||
|
# --- Technical Indicator Features: Calculate or Load from Cache ---
|
||||||
|
print('Calculating or loading technical indicator features...')
|
||||||
|
# RSI
|
||||||
|
feature_file = f'./data/{csv_prefix}_rsi.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['rsi'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: rsi')
|
||||||
|
_, values = calc_rsi(df['Close'])
|
||||||
|
features_dict['rsi'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# MACD
|
||||||
|
feature_file = f'./data/{csv_prefix}_macd.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['macd'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: macd')
|
||||||
|
_, values = calc_macd(df['Close'])
|
||||||
|
features_dict['macd'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# ATR
|
||||||
|
feature_file = f'./data/{csv_prefix}_atr.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['atr'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: atr')
|
||||||
|
_, values = calc_atr(df['High'], df['Low'], df['Close'])
|
||||||
|
features_dict['atr'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# CCI
|
||||||
|
feature_file = f'./data/{csv_prefix}_cci.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['cci'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: cci')
|
||||||
|
_, values = calc_cci(df['High'], df['Low'], df['Close'])
|
||||||
|
features_dict['cci'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# Williams %R
|
||||||
|
feature_file = f'./data/{csv_prefix}_williams_r.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['williams_r'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: williams_r')
|
||||||
|
_, values = calc_williamsr(df['High'], df['Low'], df['Close'])
|
||||||
|
features_dict['williams_r'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# EMA 14
|
||||||
|
feature_file = f'./data/{csv_prefix}_ema_14.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['ema_14'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: ema_14')
|
||||||
|
_, values = calc_ema(df['Close'])
|
||||||
|
features_dict['ema_14'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# OBV
|
||||||
|
feature_file = f'./data/{csv_prefix}_obv.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['obv'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: obv')
|
||||||
|
_, values = calc_obv(df['Close'], df['Volume'])
|
||||||
|
features_dict['obv'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# CMF
|
||||||
|
feature_file = f'./data/{csv_prefix}_cmf.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['cmf'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: cmf')
|
||||||
|
_, values = calc_cmf(df['High'], df['Low'], df['Close'], df['Volume'])
|
||||||
|
features_dict['cmf'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# ROC 10
|
||||||
|
feature_file = f'./data/{csv_prefix}_roc_10.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['roc_10'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: roc_10')
|
||||||
|
_, values = calc_roc(df['Close'])
|
||||||
|
features_dict['roc_10'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# DPO 20
|
||||||
|
feature_file = f'./data/{csv_prefix}_dpo_20.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['dpo_20'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: dpo_20')
|
||||||
|
_, values = calc_dpo(df['Close'])
|
||||||
|
features_dict['dpo_20'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# Ultimate Oscillator
|
||||||
|
feature_file = f'./data/{csv_prefix}_ultimate_osc.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['ultimate_osc'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: ultimate_osc')
|
||||||
|
_, values = calc_ultimate(df['High'], df['Low'], df['Close'])
|
||||||
|
features_dict['ultimate_osc'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# Daily Return
|
||||||
|
feature_file = f'./data/{csv_prefix}_daily_return.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['daily_return'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: daily_return')
|
||||||
|
_, values = calc_daily_return(df['Close'])
|
||||||
|
features_dict['daily_return'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# Multi-column indicators
|
||||||
|
# Bollinger Bands
|
||||||
|
print('Calculating multi-column indicator: bollinger')
|
||||||
|
result = calc_bollinger(df['Close'])
|
||||||
|
for subname, values in result:
|
||||||
|
print(f"Adding subfeature: {subname}")
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
if os.path.exists(sub_feature_file):
|
||||||
|
print(f'B Loading cached feature: {sub_feature_file}')
|
||||||
|
arr = np.load(sub_feature_file)
|
||||||
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# Stochastic Oscillator
|
||||||
|
print('Calculating multi-column indicator: stochastic')
|
||||||
|
result = calc_stochastic(df['High'], df['Low'], df['Close'])
|
||||||
|
for subname, values in result:
|
||||||
|
print(f"Adding subfeature: {subname}")
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
if os.path.exists(sub_feature_file):
|
||||||
|
print(f'B Loading cached feature: {sub_feature_file}')
|
||||||
|
arr = np.load(sub_feature_file)
|
||||||
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# SMA
|
||||||
|
print('Calculating multi-column indicator: sma')
|
||||||
|
result = calc_sma(df['Close'])
|
||||||
|
for subname, values in result:
|
||||||
|
print(f"Adding subfeature: {subname}")
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
if os.path.exists(sub_feature_file):
|
||||||
|
print(f'B Loading cached feature: {sub_feature_file}')
|
||||||
|
arr = np.load(sub_feature_file)
|
||||||
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# PSAR
|
||||||
|
print('Calculating multi-column indicator: psar')
|
||||||
|
result = calc_psar(df['High'], df['Low'], df['Close'])
|
||||||
|
for subname, values in result:
|
||||||
|
print(f"Adding subfeature: {subname}")
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
if os.path.exists(sub_feature_file):
|
||||||
|
print(f'B Loading cached feature: {sub_feature_file}')
|
||||||
|
arr = np.load(sub_feature_file)
|
||||||
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# Donchian Channel
|
||||||
|
print('Calculating multi-column indicator: donchian')
|
||||||
|
result = calc_donchian(df['High'], df['Low'], df['Close'])
|
||||||
|
for subname, values in result:
|
||||||
|
print(f"Adding subfeature: {subname}")
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
if os.path.exists(sub_feature_file):
|
||||||
|
print(f'B Loading cached feature: {sub_feature_file}')
|
||||||
|
arr = np.load(sub_feature_file)
|
||||||
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# Keltner Channel
|
||||||
|
print('Calculating multi-column indicator: keltner')
|
||||||
|
result = calc_keltner(df['High'], df['Low'], df['Close'])
|
||||||
|
for subname, values in result:
|
||||||
|
print(f"Adding subfeature: {subname}")
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
if os.path.exists(sub_feature_file):
|
||||||
|
print(f'B Loading cached feature: {sub_feature_file}')
|
||||||
|
arr = np.load(sub_feature_file)
|
||||||
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# Ichimoku
|
||||||
|
print('Calculating multi-column indicator: ichimoku')
|
||||||
|
result = calc_ichimoku(df['High'], df['Low'])
|
||||||
|
for subname, values in result:
|
||||||
|
print(f"Adding subfeature: {subname}")
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
if os.path.exists(sub_feature_file):
|
||||||
|
print(f'B Loading cached feature: {sub_feature_file}')
|
||||||
|
arr = np.load(sub_feature_file)
|
||||||
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# Elder Ray
|
||||||
|
print('Calculating multi-column indicator: elder_ray')
|
||||||
|
result = calc_elder_ray(df['Close'], df['Low'], df['High'])
|
||||||
|
for subname, values in result:
|
||||||
|
print(f"Adding subfeature: {subname}")
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
if os.path.exists(sub_feature_file):
|
||||||
|
print(f'B Loading cached feature: {sub_feature_file}')
|
||||||
|
arr = np.load(sub_feature_file)
|
||||||
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# Prepare lags, rolling stats, log returns, and volatility features sequentially
|
||||||
|
# Lags
|
||||||
|
for col in ohlcv_cols:
|
||||||
|
for lag in range(1, lags + 1):
|
||||||
|
feature_name = f'{col}_lag{lag}'
|
||||||
|
feature_file = f'./data/{csv_prefix}_{feature_name}.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'C Loading cached feature: {feature_file}')
|
||||||
|
features_dict[feature_name] = np.load(feature_file)
|
||||||
|
else:
|
||||||
|
print(f'Computing lag feature: {feature_name}')
|
||||||
|
result = compute_lag(df, col, lag)
|
||||||
|
features_dict[feature_name] = result
|
||||||
|
np.save(feature_file, result.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
# Rolling statistics
|
||||||
|
for col in ohlcv_cols:
|
||||||
|
for window in window_sizes:
|
||||||
|
if (col == 'Open' and window == 5):
|
||||||
|
continue
|
||||||
|
if (col == 'High' and window == 5):
|
||||||
|
continue
|
||||||
|
if (col == 'High' and window == 30):
|
||||||
|
continue
|
||||||
|
if (col == 'Low' and window == 15):
|
||||||
|
continue
|
||||||
|
for stat in ['mean', 'std', 'min', 'max']:
|
||||||
|
feature_name = f'{col}_roll_{stat}_{window}'
|
||||||
|
feature_file = f'./data/{csv_prefix}_{feature_name}.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'D Loading cached feature: {feature_file}')
|
||||||
|
features_dict[feature_name] = np.load(feature_file)
|
||||||
|
else:
|
||||||
|
print(f'Computing rolling stat feature: {feature_name}')
|
||||||
|
result = compute_rolling(df, col, stat, window)
|
||||||
|
features_dict[feature_name] = result
|
||||||
|
np.save(feature_file, result.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
# Log returns for different horizons
|
||||||
|
for horizon in [5, 15, 30]:
|
||||||
|
feature_name = f'log_return_{horizon}'
|
||||||
|
feature_file = f'./data/{csv_prefix}_{feature_name}.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'E Loading cached feature: {feature_file}')
|
||||||
|
features_dict[feature_name] = np.load(feature_file)
|
||||||
|
else:
|
||||||
|
print(f'Computing log return feature: {feature_name}')
|
||||||
|
result = compute_log_return(df, horizon)
|
||||||
|
features_dict[feature_name] = result
|
||||||
|
np.save(feature_file, result.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
# Volatility
|
||||||
|
for window in window_sizes:
|
||||||
|
feature_name = f'volatility_{window}'
|
||||||
|
feature_file = f'./data/{csv_prefix}_{feature_name}.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'F Loading cached feature: {feature_file}')
|
||||||
|
features_dict[feature_name] = np.load(feature_file)
|
||||||
|
else:
|
||||||
|
print(f'Computing volatility feature: {feature_name}')
|
||||||
|
result = compute_volatility(df, window)
|
||||||
|
features_dict[feature_name] = result
|
||||||
|
np.save(feature_file, result.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# --- Additional Technical Indicator Features ---
|
||||||
|
# ADX
|
||||||
|
adx_names = ['adx', 'adx_pos', 'adx_neg']
|
||||||
|
adx_files = [f'./data/{csv_prefix}_{name}.npy' for name in adx_names]
|
||||||
|
if all(os.path.exists(f) for f in adx_files):
|
||||||
|
print('G Loading cached features: ADX')
|
||||||
|
for name, f in zip(adx_names, adx_files):
|
||||||
|
arr = np.load(f)
|
||||||
|
features_dict[name] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating multi-column indicator: adx')
|
||||||
|
result = calc_adx(df['High'], df['Low'], df['Close'])
|
||||||
|
for subname, values in result:
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# Force Index
|
||||||
|
feature_file = f'./data/{csv_prefix}_force_index.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'K Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['force_index'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: force_index')
|
||||||
|
_, values = calc_force_index(df['Close'], df['Volume'])
|
||||||
|
features_dict['force_index'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# Supertrend indicators
|
||||||
|
for period, multiplier in [(12, 3.0), (10, 1.0), (11, 2.0)]:
|
||||||
|
st_name = f'supertrend_{period}_{multiplier}'
|
||||||
|
st_trend_name = f'supertrend_trend_{period}_{multiplier}'
|
||||||
|
st_file = f'./data/{csv_prefix}_{st_name}.npy'
|
||||||
|
st_trend_file = f'./data/{csv_prefix}_{st_trend_name}.npy'
|
||||||
|
if os.path.exists(st_file) and os.path.exists(st_trend_file):
|
||||||
|
print(f'L Loading cached features: {st_file}, {st_trend_file}')
|
||||||
|
features_dict[st_name] = pd.Series(np.load(st_file), index=df.index)
|
||||||
|
features_dict[st_trend_name] = pd.Series(np.load(st_trend_file), index=df.index)
|
||||||
|
else:
|
||||||
|
print(f'Calculating Supertrend indicator: {st_name}')
|
||||||
|
st = ta.supertrend(df['High'], df['Low'], df['Close'], length=period, multiplier=multiplier)
|
||||||
|
features_dict[st_name] = st[f'SUPERT_{period}_{multiplier}']
|
||||||
|
features_dict[st_trend_name] = st[f'SUPERTd_{period}_{multiplier}']
|
||||||
|
np.save(st_file, features_dict[st_name].values)
|
||||||
|
np.save(st_trend_file, features_dict[st_trend_name].values)
|
||||||
|
print(f'Saved features: {st_file}, {st_trend_file}')
|
||||||
|
|
||||||
|
# Concatenate all new features at once
|
||||||
|
print('Concatenating all new features to DataFrame...')
|
||||||
|
features_df = pd.DataFrame(features_dict)
|
||||||
|
print("Columns in features_df:", features_df.columns.tolist())
|
||||||
|
print("All-NaN columns in features_df:", features_df.columns[features_df.isna().all()].tolist())
|
||||||
|
df = pd.concat([df, features_df], axis=1)
|
||||||
|
|
||||||
|
# Print all columns after concatenation
|
||||||
|
print("All columns in df after concat:", df.columns.tolist())
|
||||||
|
|
||||||
|
# Downcast all float columns to save memory
|
||||||
|
print('Downcasting float columns to save memory...')
|
||||||
|
for col in df.columns:
|
||||||
|
try:
|
||||||
|
df[col] = pd.to_numeric(df[col], downcast='float')
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Add time features (exclude 'dayofweek')
|
||||||
|
print('Adding hour feature...')
|
||||||
|
df['Timestamp'] = pd.to_datetime(df['Timestamp'], errors='coerce')
|
||||||
|
df['hour'] = df['Timestamp'].dt.hour
|
||||||
|
|
||||||
|
# Handle NaNs after all feature engineering
|
||||||
|
if IMPUTE_NANS:
|
||||||
|
print('Imputing NaNs after feature engineering (using mean imputation)...')
|
||||||
|
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
||||||
|
for col in numeric_cols:
|
||||||
|
df[col] = df[col].fillna(df[col].mean())
|
||||||
|
# If you want to impute non-numeric columns differently, add logic here
|
||||||
|
else:
|
||||||
|
print('Dropping NaNs after feature engineering...')
|
||||||
|
df = df.dropna().reset_index(drop=True)
|
||||||
|
|
||||||
|
# Exclude 'Timestamp', 'Close', 'log_return', and any future target columns from features
|
||||||
|
print('Selecting feature columns...')
|
||||||
|
exclude_cols = ['Timestamp', 'Close', 'log_return', 'log_return_5', 'log_return_15', 'log_return_30']
|
||||||
|
feature_cols = [col for col in df.columns if col not in exclude_cols]
|
||||||
|
print('Features used for training:', feature_cols)
|
||||||
|
|
||||||
|
# Prepare CSV for results
|
||||||
|
results_csv = './data/leave_one_out_results.csv'
|
||||||
|
if not os.path.exists(results_csv):
|
||||||
|
with open(results_csv, 'w', newline='') as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
writer.writerow(['left_out_feature', 'used_features', 'rmse', 'mae', 'r2', 'mape', 'directional_accuracy'])
|
||||||
|
|
||||||
|
total_features = len(feature_cols)
|
||||||
|
never_leave_out = {'Open', 'High', 'Low', 'Close', 'Volume'}
|
||||||
|
for idx, left_out in enumerate(feature_cols):
|
||||||
|
if left_out in never_leave_out:
|
||||||
|
continue
|
||||||
|
used = [f for f in feature_cols if f != left_out]
|
||||||
|
print(f'\n=== Leave-one-out {idx+1}/{total_features}: left out {left_out} ===')
|
||||||
|
try:
|
||||||
|
# Prepare X and y for this combination
|
||||||
|
X = df[used].values.astype(np.float32)
|
||||||
|
y = df["log_return"].values.astype(np.float32)
|
||||||
|
split_idx = int(len(X) * 0.8)
|
||||||
|
X_train, X_test = X[:split_idx], X[split_idx:]
|
||||||
|
y_train, y_test = y[:split_idx], y[split_idx:]
|
||||||
|
test_timestamps = df['Timestamp'].values[split_idx:]
|
||||||
|
|
||||||
|
model = CustomXGBoostGPU(X_train, X_test, y_train, y_test)
|
||||||
|
booster = model.train()
|
||||||
|
model.save_model(f'./data/xgboost_model_wo_{left_out}.json')
|
||||||
|
|
||||||
|
test_preds = model.predict(X_test)
|
||||||
|
rmse = np.sqrt(mean_squared_error(y_test, test_preds))
|
||||||
|
|
||||||
|
# Reconstruct price series from log returns
|
||||||
|
if 'Close' in df.columns:
|
||||||
|
close_prices = df['Close'].values
|
||||||
|
else:
|
||||||
|
close_prices = pd.read_csv(csv_path)['Close'].values
|
||||||
|
start_price = close_prices[split_idx]
|
||||||
|
actual_prices = [start_price]
|
||||||
|
for r_ in y_test:
|
||||||
|
actual_prices.append(actual_prices[-1] * np.exp(r_))
|
||||||
|
actual_prices = np.array(actual_prices[1:])
|
||||||
|
predicted_prices = [start_price]
|
||||||
|
for r_ in test_preds:
|
||||||
|
predicted_prices.append(predicted_prices[-1] * np.exp(r_))
|
||||||
|
predicted_prices = np.array(predicted_prices[1:])
|
||||||
|
|
||||||
|
mae = mean_absolute_error(actual_prices, predicted_prices)
|
||||||
|
r2 = r2_score(actual_prices, predicted_prices)
|
||||||
|
direction_actual = np.sign(np.diff(actual_prices))
|
||||||
|
direction_pred = np.sign(np.diff(predicted_prices))
|
||||||
|
directional_accuracy = (direction_actual == direction_pred).mean()
|
||||||
|
mape = np.mean(np.abs((actual_prices - predicted_prices) / actual_prices)) * 100
|
||||||
|
|
||||||
|
# Save results to CSV
|
||||||
|
with open(results_csv, 'a', newline='') as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
writer.writerow([left_out, "|".join(used), rmse, mae, r2, mape, directional_accuracy])
|
||||||
|
print(f'Left out {left_out}: RMSE={rmse:.4f}, MAE={mae:.4f}, R2={r2:.4f}, MAPE={mape:.2f}%, DirAcc={directional_accuracy*100:.2f}%')
|
||||||
|
|
||||||
|
# Plotting for this run
|
||||||
|
plot_prefix = f'loo_{left_out}'
|
||||||
|
print('Plotting distribution of absolute prediction errors...')
|
||||||
|
plot_prediction_error_distribution(predicted_prices, actual_prices, prefix=plot_prefix)
|
||||||
|
|
||||||
|
print('Plotting directional accuracy...')
|
||||||
|
plot_direction_transition_heatmap(actual_prices, predicted_prices, prefix=plot_prefix)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Leave-one-out failed for {left_out}: {e}')
|
||||||
|
print(f'All leave-one-out runs completed. Results saved to {results_csv}')
|
||||||
|
sys.exit(0)
|
||||||
318
xgboost/plot_results.py
Normal file
318
xgboost/plot_results.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
import numpy as np
|
||||||
|
import dash
|
||||||
|
from dash import dcc, html
|
||||||
|
import plotly.graph_objs as go
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
def display_actual_vs_predicted(y_test, test_preds, timestamps, n_plot=200):
|
||||||
|
import plotly.offline as pyo
|
||||||
|
n_plot = min(n_plot, len(y_test))
|
||||||
|
plot_indices = timestamps[:n_plot]
|
||||||
|
actual = y_test[:n_plot]
|
||||||
|
predicted = test_preds[:n_plot]
|
||||||
|
|
||||||
|
trace_actual = go.Scatter(x=plot_indices, y=actual, mode='lines', name='Actual')
|
||||||
|
trace_predicted = go.Scatter(x=plot_indices, y=predicted, mode='lines', name='Predicted')
|
||||||
|
data = [trace_actual, trace_predicted]
|
||||||
|
layout = go.Layout(
|
||||||
|
title='Actual vs. Predicted BTC Close Prices (Test Set)',
|
||||||
|
xaxis={'title': 'Timestamp'},
|
||||||
|
yaxis={'title': 'BTC Close Price'},
|
||||||
|
legend={'x': 0, 'y': 1},
|
||||||
|
margin={'l': 40, 'b': 40, 't': 40, 'r': 10},
|
||||||
|
hovermode='closest'
|
||||||
|
)
|
||||||
|
fig = go.Figure(data=data, layout=layout)
|
||||||
|
pyo.plot(fig, auto_open=False)
|
||||||
|
|
||||||
|
def plot_target_distribution(y_train, y_test):
|
||||||
|
import plotly.offline as pyo
|
||||||
|
trace_train = go.Histogram(
|
||||||
|
x=y_train,
|
||||||
|
nbinsx=100,
|
||||||
|
opacity=0.5,
|
||||||
|
name='Train',
|
||||||
|
marker=dict(color='blue')
|
||||||
|
)
|
||||||
|
trace_test = go.Histogram(
|
||||||
|
x=y_test,
|
||||||
|
nbinsx=100,
|
||||||
|
opacity=0.5,
|
||||||
|
name='Test',
|
||||||
|
marker=dict(color='orange')
|
||||||
|
)
|
||||||
|
data = [trace_train, trace_test]
|
||||||
|
layout = go.Layout(
|
||||||
|
title='Distribution of Target Variable (Close Price)',
|
||||||
|
xaxis=dict(title='BTC Close Price'),
|
||||||
|
yaxis=dict(title='Frequency'),
|
||||||
|
barmode='overlay'
|
||||||
|
)
|
||||||
|
fig = go.Figure(data=data, layout=layout)
|
||||||
|
pyo.plot(fig, auto_open=False)
|
||||||
|
|
||||||
|
def plot_predicted_vs_actual_log_returns(y_test, test_preds, timestamps=None, n_plot=200):
|
||||||
|
import plotly.offline as pyo
|
||||||
|
import plotly.graph_objs as go
|
||||||
|
n_plot = min(n_plot, len(y_test))
|
||||||
|
actual = y_test[:n_plot]
|
||||||
|
predicted = test_preds[:n_plot]
|
||||||
|
if timestamps is not None:
|
||||||
|
x_axis = timestamps[:n_plot]
|
||||||
|
x_label = 'Timestamp'
|
||||||
|
else:
|
||||||
|
x_axis = list(range(n_plot))
|
||||||
|
x_label = 'Index'
|
||||||
|
|
||||||
|
# Line plot: Actual vs Predicted over time
|
||||||
|
trace_actual = go.Scatter(x=x_axis, y=actual, mode='lines', name='Actual')
|
||||||
|
trace_predicted = go.Scatter(x=x_axis, y=predicted, mode='lines', name='Predicted')
|
||||||
|
data_line = [trace_actual, trace_predicted]
|
||||||
|
layout_line = go.Layout(
|
||||||
|
title='Actual vs. Predicted Log Returns (Test Set)',
|
||||||
|
xaxis={'title': x_label},
|
||||||
|
yaxis={'title': 'Log Return'},
|
||||||
|
legend={'x': 0, 'y': 1},
|
||||||
|
margin={'l': 40, 'b': 40, 't': 40, 'r': 10},
|
||||||
|
hovermode='closest'
|
||||||
|
)
|
||||||
|
fig_line = go.Figure(data=data_line, layout=layout_line)
|
||||||
|
pyo.plot(fig_line, filename='charts/log_return_line_plot.html', auto_open=False)
|
||||||
|
|
||||||
|
# Scatter plot: Predicted vs Actual
|
||||||
|
trace_scatter = go.Scatter(
|
||||||
|
x=actual,
|
||||||
|
y=predicted,
|
||||||
|
mode='markers',
|
||||||
|
name='Predicted vs Actual',
|
||||||
|
opacity=0.5
|
||||||
|
)
|
||||||
|
# Diagonal reference line
|
||||||
|
min_val = min(np.min(actual), np.min(predicted))
|
||||||
|
max_val = max(np.max(actual), np.max(predicted))
|
||||||
|
trace_diag = go.Scatter(
|
||||||
|
x=[min_val, max_val],
|
||||||
|
y=[min_val, max_val],
|
||||||
|
mode='lines',
|
||||||
|
name='Ideal',
|
||||||
|
line=dict(dash='dash', color='red')
|
||||||
|
)
|
||||||
|
data_scatter = [trace_scatter, trace_diag]
|
||||||
|
layout_scatter = go.Layout(
|
||||||
|
title='Predicted vs Actual Log Returns (Scatter)',
|
||||||
|
xaxis={'title': 'Actual Log Return'},
|
||||||
|
yaxis={'title': 'Predicted Log Return'},
|
||||||
|
showlegend=True,
|
||||||
|
margin={'l': 40, 'b': 40, 't': 40, 'r': 10},
|
||||||
|
hovermode='closest'
|
||||||
|
)
|
||||||
|
fig_scatter = go.Figure(data=data_scatter, layout=layout_scatter)
|
||||||
|
pyo.plot(fig_scatter, filename='charts/log_return_scatter_plot.html', auto_open=False)
|
||||||
|
|
||||||
|
def plot_predicted_vs_actual_prices(actual_prices, predicted_prices, timestamps=None, n_plot=200):
|
||||||
|
import plotly.offline as pyo
|
||||||
|
import plotly.graph_objs as go
|
||||||
|
n_plot = min(n_plot, len(actual_prices))
|
||||||
|
actual = actual_prices[:n_plot]
|
||||||
|
predicted = predicted_prices[:n_plot]
|
||||||
|
if timestamps is not None:
|
||||||
|
x_axis = timestamps[:n_plot]
|
||||||
|
x_label = 'Timestamp'
|
||||||
|
else:
|
||||||
|
x_axis = list(range(n_plot))
|
||||||
|
x_label = 'Index'
|
||||||
|
|
||||||
|
# Line plot: Actual vs Predicted over time
|
||||||
|
trace_actual = go.Scatter(x=x_axis, y=actual, mode='lines', name='Actual Price')
|
||||||
|
trace_predicted = go.Scatter(x=x_axis, y=predicted, mode='lines', name='Predicted Price')
|
||||||
|
data_line = [trace_actual, trace_predicted]
|
||||||
|
layout_line = go.Layout(
|
||||||
|
title='Actual vs. Predicted BTC Prices (Test Set)',
|
||||||
|
xaxis={'title': x_label},
|
||||||
|
yaxis={'title': 'BTC Price'},
|
||||||
|
legend={'x': 0, 'y': 1},
|
||||||
|
margin={'l': 40, 'b': 40, 't': 40, 'r': 10},
|
||||||
|
hovermode='closest'
|
||||||
|
)
|
||||||
|
fig_line = go.Figure(data=data_line, layout=layout_line)
|
||||||
|
pyo.plot(fig_line, filename='charts/price_line_plot.html', auto_open=False)
|
||||||
|
|
||||||
|
# Scatter plot: Predicted vs Actual
|
||||||
|
trace_scatter = go.Scatter(
|
||||||
|
x=actual,
|
||||||
|
y=predicted,
|
||||||
|
mode='markers',
|
||||||
|
name='Predicted vs Actual',
|
||||||
|
opacity=0.5
|
||||||
|
)
|
||||||
|
# Diagonal reference line
|
||||||
|
min_val = min(np.min(actual), np.min(predicted))
|
||||||
|
max_val = max(np.max(actual), np.max(predicted))
|
||||||
|
trace_diag = go.Scatter(
|
||||||
|
x=[min_val, max_val],
|
||||||
|
y=[min_val, max_val],
|
||||||
|
mode='lines',
|
||||||
|
name='Ideal',
|
||||||
|
line=dict(dash='dash', color='red')
|
||||||
|
)
|
||||||
|
data_scatter = [trace_scatter, trace_diag]
|
||||||
|
layout_scatter = go.Layout(
|
||||||
|
title='Predicted vs Actual Prices (Scatter)',
|
||||||
|
xaxis={'title': 'Actual Price'},
|
||||||
|
yaxis={'title': 'Predicted Price'},
|
||||||
|
showlegend=True,
|
||||||
|
margin={'l': 40, 'b': 40, 't': 40, 'r': 10},
|
||||||
|
hovermode='closest'
|
||||||
|
)
|
||||||
|
fig_scatter = go.Figure(data=data_scatter, layout=layout_scatter)
|
||||||
|
pyo.plot(fig_scatter, filename='charts/price_scatter_plot.html', auto_open=False)
|
||||||
|
|
||||||
|
def plot_prediction_error_distribution(predicted_prices, actual_prices, nbins=100, prefix=""):
|
||||||
|
"""
|
||||||
|
Plots the distribution of signed prediction errors between predicted and actual prices,
|
||||||
|
coloring negative errors (under-prediction) and positive errors (over-prediction) differently.
|
||||||
|
"""
|
||||||
|
import plotly.offline as pyo
|
||||||
|
import plotly.graph_objs as go
|
||||||
|
errors = np.array(predicted_prices) - np.array(actual_prices)
|
||||||
|
|
||||||
|
# Separate negative and positive errors
|
||||||
|
neg_errors = errors[errors < 0]
|
||||||
|
pos_errors = errors[errors >= 0]
|
||||||
|
|
||||||
|
# Calculate common bin edges
|
||||||
|
min_error = np.min(errors)
|
||||||
|
max_error = np.max(errors)
|
||||||
|
bin_edges = np.linspace(min_error, max_error, nbins + 1)
|
||||||
|
xbins = dict(start=min_error, end=max_error, size=(max_error - min_error) / nbins)
|
||||||
|
|
||||||
|
trace_neg = go.Histogram(
|
||||||
|
x=neg_errors,
|
||||||
|
opacity=0.75,
|
||||||
|
marker=dict(color='blue'),
|
||||||
|
name='Negative Error (Under-prediction)',
|
||||||
|
xbins=xbins
|
||||||
|
)
|
||||||
|
trace_pos = go.Histogram(
|
||||||
|
x=pos_errors,
|
||||||
|
opacity=0.75,
|
||||||
|
marker=dict(color='orange'),
|
||||||
|
name='Positive Error (Over-prediction)',
|
||||||
|
xbins=xbins
|
||||||
|
)
|
||||||
|
layout = go.Layout(
|
||||||
|
title='Distribution of Prediction Errors (Signed)',
|
||||||
|
xaxis=dict(title='Prediction Error (Predicted - Actual)'),
|
||||||
|
yaxis=dict(title='Frequency'),
|
||||||
|
barmode='overlay',
|
||||||
|
bargap=0.05
|
||||||
|
)
|
||||||
|
fig = go.Figure(data=[trace_neg, trace_pos], layout=layout)
|
||||||
|
filename = f'charts/{prefix}_prediction_error_distribution.html'
|
||||||
|
pyo.plot(fig, filename=filename, auto_open=False)
|
||||||
|
|
||||||
|
def plot_directional_accuracy(actual_prices, predicted_prices, timestamps=None, n_plot=200):
|
||||||
|
"""
|
||||||
|
Plots the directional accuracy of predictions compared to actual price movements.
|
||||||
|
Shows whether the predicted direction matches the actual direction of price movement.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actual_prices: Array of actual price values
|
||||||
|
predicted_prices: Array of predicted price values
|
||||||
|
timestamps: Optional array of timestamps for x-axis
|
||||||
|
n_plot: Number of points to plot (default 200, plots last n_plot points)
|
||||||
|
"""
|
||||||
|
import plotly.graph_objs as go
|
||||||
|
import plotly.offline as pyo
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Calculate price changes
|
||||||
|
actual_changes = np.diff(actual_prices)
|
||||||
|
predicted_changes = np.diff(predicted_prices)
|
||||||
|
|
||||||
|
# Determine if directions match
|
||||||
|
actual_direction = np.sign(actual_changes)
|
||||||
|
predicted_direction = np.sign(predicted_changes)
|
||||||
|
correct_direction = actual_direction == predicted_direction
|
||||||
|
|
||||||
|
# Get last n_plot points
|
||||||
|
actual_changes = actual_changes[-n_plot:]
|
||||||
|
predicted_changes = predicted_changes[-n_plot:]
|
||||||
|
correct_direction = correct_direction[-n_plot:]
|
||||||
|
|
||||||
|
if timestamps is not None:
|
||||||
|
x_values = timestamps[1:] # Skip first since we took diff
|
||||||
|
x_values = x_values[-n_plot:] # Get last n_plot points
|
||||||
|
else:
|
||||||
|
x_values = list(range(len(actual_changes)))
|
||||||
|
|
||||||
|
# Create traces for correct and incorrect predictions
|
||||||
|
correct_trace = go.Scatter(
|
||||||
|
x=np.array(x_values)[correct_direction],
|
||||||
|
y=actual_changes[correct_direction],
|
||||||
|
mode='markers',
|
||||||
|
name='Correct Direction',
|
||||||
|
marker=dict(color='green', size=8)
|
||||||
|
)
|
||||||
|
|
||||||
|
incorrect_trace = go.Scatter(
|
||||||
|
x=np.array(x_values)[~correct_direction],
|
||||||
|
y=actual_changes[~correct_direction],
|
||||||
|
mode='markers',
|
||||||
|
name='Incorrect Direction',
|
||||||
|
marker=dict(color='red', size=8)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate accuracy percentage
|
||||||
|
accuracy = np.mean(correct_direction) * 100
|
||||||
|
|
||||||
|
layout = go.Layout(
|
||||||
|
title=f'Directional Accuracy (Overall: {accuracy:.1f}%)',
|
||||||
|
xaxis=dict(title='Time' if timestamps is not None else 'Sample'),
|
||||||
|
yaxis=dict(title='Price Change'),
|
||||||
|
showlegend=True
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = go.Figure(data=[correct_trace, incorrect_trace], layout=layout)
|
||||||
|
pyo.plot(fig, filename='charts/directional_accuracy.html', auto_open=False)
|
||||||
|
|
||||||
|
def plot_direction_transition_heatmap(actual_prices, predicted_prices, prefix=""):
|
||||||
|
"""
|
||||||
|
Plots a heatmap showing the frequency of each (actual, predicted) direction pair.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import plotly.graph_objs as go
|
||||||
|
import plotly.offline as pyo
|
||||||
|
|
||||||
|
# Calculate directions
|
||||||
|
actual_direction = np.sign(np.diff(actual_prices))
|
||||||
|
predicted_direction = np.sign(np.diff(predicted_prices))
|
||||||
|
|
||||||
|
# Build 3x3 matrix: rows=actual, cols=predicted, values=counts
|
||||||
|
# Map -1 -> 0, 0 -> 1, 1 -> 2 for indexing
|
||||||
|
mapping = {-1: 0, 0: 1, 1: 2}
|
||||||
|
matrix = np.zeros((3, 3), dtype=int)
|
||||||
|
for a, p in zip(actual_direction, predicted_direction):
|
||||||
|
matrix[mapping[a], mapping[p]] += 1
|
||||||
|
|
||||||
|
# Axis labels
|
||||||
|
directions = ['Down (-1)', 'No Change (0)', 'Up (+1)']
|
||||||
|
|
||||||
|
# Plot heatmap
|
||||||
|
heatmap = go.Heatmap(
|
||||||
|
z=matrix,
|
||||||
|
x=directions, # predicted
|
||||||
|
y=directions, # actual
|
||||||
|
colorscale='Viridis',
|
||||||
|
colorbar=dict(title='Count')
|
||||||
|
)
|
||||||
|
layout = go.Layout(
|
||||||
|
title='Direction Prediction Transition Matrix',
|
||||||
|
xaxis=dict(title='Predicted Direction'),
|
||||||
|
yaxis=dict(title='Actual Direction')
|
||||||
|
)
|
||||||
|
fig = go.Figure(data=[heatmap], layout=layout)
|
||||||
|
filename = f'charts/{prefix}_direction_transition_heatmap.html'
|
||||||
|
pyo.plot(fig, filename=filename, auto_open=False)
|
||||||
|
|
||||||
Reference in New Issue
Block a user