Compare commits
68 Commits
f7f0fc6dd5
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 267f040fe8 | |||
|
|
65f30a4020 | ||
|
|
be331ed631 | ||
|
|
6c5dcc1183 | ||
|
|
02e5db2a36 | ||
|
|
a877f14e65 | ||
|
|
082a2835b6 | ||
|
|
ada6150413 | ||
|
|
ced64825bd | ||
|
|
2f98463df8 | ||
|
|
2a52ffde9a | ||
|
|
a22914731f | ||
|
|
81e4b640a7 | ||
|
|
2dba88b620 | ||
|
|
de67b27e37 | ||
|
|
1284549106 | ||
|
|
5f03524d6a | ||
|
|
b013183f67 | ||
|
|
74c8048ed5 | ||
|
|
6916722a43 | ||
|
|
47551d6781 | ||
|
|
d499c5b8d0 | ||
|
|
2418538747 | ||
| 65ae3060de | |||
|
|
b71faa9758 | ||
|
|
c743e81af8 | ||
|
|
969e011d48 | ||
|
|
cb576a9dfc | ||
|
|
ebd8ef3d87 | ||
|
|
1566044fa8 | ||
|
|
3483aaf6d7 | ||
|
|
256ad67742 | ||
|
|
f67b6b8ebd | ||
|
|
9629d3090b | ||
|
|
9b15f9f44f | ||
|
|
5d0b707bc6 | ||
|
|
235098c045 | ||
|
|
4552d7e6b5 | ||
|
|
7af8cdcb32 | ||
|
|
e5c2988d71 | ||
|
|
00873d593f | ||
|
|
3a9dec543c | ||
|
|
934c807246 | ||
|
|
8e220b564c | ||
|
|
1107346594 | ||
|
|
45c853efab | ||
|
|
268bc33bbf | ||
|
|
e286dd881a | ||
|
|
736b278ee2 | ||
|
|
a924328c90 | ||
|
|
f4873c56ff | ||
|
|
2fd73085b8 | ||
|
|
806697116d | ||
|
|
14905017c8 | ||
|
|
ec1a86e098 | ||
|
|
0a919f825e | ||
|
|
c2886a2aab | ||
|
|
10cc047975 | ||
|
|
955a340d02 | ||
|
|
07b9824b69 | ||
|
|
369b3c1daf | ||
|
|
08c871e05a | ||
|
|
837c505828 | ||
|
|
1cdfe3973a | ||
|
|
8ff86339d6 | ||
|
|
7f788a4d4e | ||
|
|
0eb7fc77f9 | ||
|
|
170751db0e |
8
.cursor/project.mdc
Normal file
8
.cursor/project.mdc
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
---
|
||||||
|
description:
|
||||||
|
globs:
|
||||||
|
alwaysApply: true
|
||||||
|
---
|
||||||
|
- use UV for package management
|
||||||
|
- ./docs folder for the documetation and the modules description, update related files if logic changed
|
||||||
|
|
||||||
61
.cursor/rules/always-global.mdc
Normal file
61
.cursor/rules/always-global.mdc
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
---
|
||||||
|
description: Global development standards and AI interaction principles
|
||||||
|
globs:
|
||||||
|
alwaysApply: true
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Always Apply - Global Development Standards
|
||||||
|
|
||||||
|
## AI Interaction Principles
|
||||||
|
|
||||||
|
### Step-by-Step Development
|
||||||
|
- **NEVER** generate large blocks of code without explanation
|
||||||
|
- **ALWAYS** ask "provide your plan in a concise bullet list and wait for my confirmation before proceeding"
|
||||||
|
- Break complex tasks into smaller, manageable pieces (≤250 lines per file, ≤50 lines per function)
|
||||||
|
- Explain your reasoning step-by-step before writing code
|
||||||
|
- Wait for explicit approval before moving to the next sub-task
|
||||||
|
|
||||||
|
### Context Awareness
|
||||||
|
- **ALWAYS** reference existing code patterns and data structures before suggesting new approaches
|
||||||
|
- Ask about existing conventions before implementing new functionality
|
||||||
|
- Preserve established architectural decisions unless explicitly asked to change them
|
||||||
|
- Maintain consistency with existing naming conventions and code style
|
||||||
|
|
||||||
|
## Code Quality Standards
|
||||||
|
|
||||||
|
### File and Function Limits
|
||||||
|
- **Maximum file size**: 250 lines
|
||||||
|
- **Maximum function size**: 50 lines
|
||||||
|
- **Maximum complexity**: If a function does more than one main thing, break it down
|
||||||
|
- **Naming**: Use clear, descriptive names that explain purpose
|
||||||
|
|
||||||
|
### Documentation Requirements
|
||||||
|
- **Every public function** must have a docstring explaining purpose, parameters, and return value
|
||||||
|
- **Every class** must have a class-level docstring
|
||||||
|
- **Complex logic** must have inline comments explaining the "why", not just the "what"
|
||||||
|
- **API endpoints** must be documented with request/response examples
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
- **ALWAYS** include proper error handling for external dependencies
|
||||||
|
- **NEVER** use bare except clauses
|
||||||
|
- Provide meaningful error messages that help with debugging
|
||||||
|
- Log errors appropriately for the application context
|
||||||
|
|
||||||
|
## Security and Best Practices
|
||||||
|
- **NEVER** hardcode credentials, API keys, or sensitive data
|
||||||
|
- **ALWAYS** validate user inputs
|
||||||
|
- Use parameterized queries for database operations
|
||||||
|
- Follow the principle of least privilege
|
||||||
|
- Implement proper authentication and authorization
|
||||||
|
|
||||||
|
## Testing Requirements
|
||||||
|
- **Every implementation** should have corresponding unit tests
|
||||||
|
- **Every API endpoint** should have integration tests
|
||||||
|
- Test files should be placed alongside the code they test
|
||||||
|
- Use descriptive test names that explain what is being tested
|
||||||
|
|
||||||
|
## Response Format
|
||||||
|
- Be concise and avoid unnecessary repetition
|
||||||
|
- Focus on actionable information
|
||||||
|
- Provide examples when explaining complex concepts
|
||||||
|
- Ask clarifying questions when requirements are ambiguous
|
||||||
237
.cursor/rules/architecture.mdc
Normal file
237
.cursor/rules/architecture.mdc
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
---
|
||||||
|
description: Modular design principles and architecture guidelines for scalable development
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Architecture and Modular Design
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Maintain a clean, modular architecture that scales effectively and prevents the complexity issues that arise in AI-assisted development.
|
||||||
|
|
||||||
|
## Core Architecture Principles
|
||||||
|
|
||||||
|
### 1. Modular Design
|
||||||
|
- **Single Responsibility**: Each module has one clear purpose
|
||||||
|
- **Loose Coupling**: Modules depend on interfaces, not implementations
|
||||||
|
- **High Cohesion**: Related functionality is grouped together
|
||||||
|
- **Clear Boundaries**: Module interfaces are well-defined and stable
|
||||||
|
|
||||||
|
### 2. Size Constraints
|
||||||
|
- **Files**: Maximum 250 lines per file
|
||||||
|
- **Functions**: Maximum 50 lines per function
|
||||||
|
- **Classes**: Maximum 300 lines per class
|
||||||
|
- **Modules**: Maximum 10 public functions/classes per module
|
||||||
|
|
||||||
|
### 3. Dependency Management
|
||||||
|
- **Layer Dependencies**: Higher layers depend on lower layers only
|
||||||
|
- **No Circular Dependencies**: Modules cannot depend on each other cyclically
|
||||||
|
- **Interface Segregation**: Depend on specific interfaces, not broad ones
|
||||||
|
- **Dependency Injection**: Pass dependencies rather than creating them internally
|
||||||
|
|
||||||
|
## Modular Architecture Patterns
|
||||||
|
|
||||||
|
### Layer Structure
|
||||||
|
```
|
||||||
|
src/
|
||||||
|
├── presentation/ # UI, API endpoints, CLI interfaces
|
||||||
|
├── application/ # Business logic, use cases, workflows
|
||||||
|
├── domain/ # Core business entities and rules
|
||||||
|
├── infrastructure/ # Database, external APIs, file systems
|
||||||
|
└── shared/ # Common utilities, constants, types
|
||||||
|
```
|
||||||
|
|
||||||
|
### Module Organization
|
||||||
|
```
|
||||||
|
module_name/
|
||||||
|
├── __init__.py # Public interface exports
|
||||||
|
├── core.py # Main module logic
|
||||||
|
├── types.py # Type definitions and interfaces
|
||||||
|
├── utils.py # Module-specific utilities
|
||||||
|
├── tests/ # Module tests
|
||||||
|
└── README.md # Module documentation
|
||||||
|
```
|
||||||
|
|
||||||
|
## Design Patterns for AI Development
|
||||||
|
|
||||||
|
### 1. Repository Pattern
|
||||||
|
Separate data access from business logic:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Domain interface
|
||||||
|
class UserRepository:
|
||||||
|
def get_by_id(self, user_id: str) -> User: ...
|
||||||
|
def save(self, user: User) -> None: ...
|
||||||
|
|
||||||
|
# Infrastructure implementation
|
||||||
|
class SqlUserRepository(UserRepository):
|
||||||
|
def get_by_id(self, user_id: str) -> User:
|
||||||
|
# Database-specific implementation
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Service Pattern
|
||||||
|
Encapsulate business logic in focused services:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class UserService:
|
||||||
|
def __init__(self, user_repo: UserRepository):
|
||||||
|
self._user_repo = user_repo
|
||||||
|
|
||||||
|
def create_user(self, data: UserData) -> User:
|
||||||
|
# Validation and business logic
|
||||||
|
# Single responsibility: user creation
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Factory Pattern
|
||||||
|
Create complex objects with clear interfaces:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class DatabaseFactory:
|
||||||
|
@staticmethod
|
||||||
|
def create_connection(config: DatabaseConfig) -> Connection:
|
||||||
|
# Handle different database types
|
||||||
|
# Encapsulate connection complexity
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture Decision Guidelines
|
||||||
|
|
||||||
|
### When to Create New Modules
|
||||||
|
Create a new module when:
|
||||||
|
- **Functionality** exceeds size constraints (250 lines)
|
||||||
|
- **Responsibility** is distinct from existing modules
|
||||||
|
- **Dependencies** would create circular references
|
||||||
|
- **Reusability** would benefit other parts of the system
|
||||||
|
- **Testing** requires isolated test environments
|
||||||
|
|
||||||
|
### When to Split Existing Modules
|
||||||
|
Split modules when:
|
||||||
|
- **File size** exceeds 250 lines
|
||||||
|
- **Multiple responsibilities** are evident
|
||||||
|
- **Testing** becomes difficult due to complexity
|
||||||
|
- **Dependencies** become too numerous
|
||||||
|
- **Change frequency** differs significantly between parts
|
||||||
|
|
||||||
|
### Module Interface Design
|
||||||
|
```python
|
||||||
|
# Good: Clear, focused interface
|
||||||
|
class PaymentProcessor:
|
||||||
|
def process_payment(self, amount: Money, method: PaymentMethod) -> PaymentResult:
|
||||||
|
"""Process a single payment transaction."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Bad: Unfocused, kitchen-sink interface
|
||||||
|
class PaymentManager:
|
||||||
|
def process_payment(self, ...): pass
|
||||||
|
def validate_card(self, ...): pass
|
||||||
|
def send_receipt(self, ...): pass
|
||||||
|
def update_inventory(self, ...): pass # Wrong responsibility!
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture Validation
|
||||||
|
|
||||||
|
### Architecture Review Checklist
|
||||||
|
- [ ] **Dependencies flow in one direction** (no cycles)
|
||||||
|
- [ ] **Layers are respected** (presentation doesn't call infrastructure directly)
|
||||||
|
- [ ] **Modules have single responsibility**
|
||||||
|
- [ ] **Interfaces are stable** and well-defined
|
||||||
|
- [ ] **Size constraints** are maintained
|
||||||
|
- [ ] **Testing** is straightforward for each module
|
||||||
|
|
||||||
|
### Red Flags
|
||||||
|
- **God Objects**: Classes/modules that do too many things
|
||||||
|
- **Circular Dependencies**: Modules that depend on each other
|
||||||
|
- **Deep Inheritance**: More than 3 levels of inheritance
|
||||||
|
- **Large Interfaces**: Interfaces with more than 7 methods
|
||||||
|
- **Tight Coupling**: Modules that know too much about each other's internals
|
||||||
|
|
||||||
|
## Refactoring Guidelines
|
||||||
|
|
||||||
|
### When to Refactor
|
||||||
|
- Module exceeds size constraints
|
||||||
|
- Code duplication across modules
|
||||||
|
- Difficult to test individual components
|
||||||
|
- New features require changing multiple unrelated modules
|
||||||
|
- Performance bottlenecks due to poor separation
|
||||||
|
|
||||||
|
### Refactoring Process
|
||||||
|
1. **Identify** the specific architectural problem
|
||||||
|
2. **Design** the target architecture
|
||||||
|
3. **Create tests** to verify current behavior
|
||||||
|
4. **Implement changes** incrementally
|
||||||
|
5. **Validate** that tests still pass
|
||||||
|
6. **Update documentation** to reflect changes
|
||||||
|
|
||||||
|
### Safe Refactoring Practices
|
||||||
|
- **One change at a time**: Don't mix refactoring with new features
|
||||||
|
- **Tests first**: Ensure comprehensive test coverage before refactoring
|
||||||
|
- **Incremental changes**: Small steps with verification at each stage
|
||||||
|
- **Backward compatibility**: Maintain existing interfaces during transition
|
||||||
|
- **Documentation updates**: Keep architecture documentation current
|
||||||
|
|
||||||
|
## Architecture Documentation
|
||||||
|
|
||||||
|
### Architecture Decision Records (ADRs)
|
||||||
|
Document significant decisions in `./docs/decisions/`:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
# ADR-003: Service Layer Architecture
|
||||||
|
|
||||||
|
## Status
|
||||||
|
Accepted
|
||||||
|
|
||||||
|
## Context
|
||||||
|
As the application grows, business logic is scattered across controllers and models.
|
||||||
|
|
||||||
|
## Decision
|
||||||
|
Implement a service layer to encapsulate business logic.
|
||||||
|
|
||||||
|
## Consequences
|
||||||
|
**Positive:**
|
||||||
|
- Clear separation of concerns
|
||||||
|
- Easier testing of business logic
|
||||||
|
- Better reusability across different interfaces
|
||||||
|
|
||||||
|
**Negative:**
|
||||||
|
- Additional abstraction layer
|
||||||
|
- More files to maintain
|
||||||
|
```
|
||||||
|
|
||||||
|
### Module Documentation Template
|
||||||
|
```markdown
|
||||||
|
# Module: [Name]
|
||||||
|
|
||||||
|
## Purpose
|
||||||
|
What this module does and why it exists.
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
- **Imports from**: List of modules this depends on
|
||||||
|
- **Used by**: List of modules that depend on this one
|
||||||
|
- **External**: Third-party dependencies
|
||||||
|
|
||||||
|
## Public Interface
|
||||||
|
```python
|
||||||
|
# Key functions and classes exposed by this module
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture Notes
|
||||||
|
- Design patterns used
|
||||||
|
- Important architectural decisions
|
||||||
|
- Known limitations or constraints
|
||||||
|
```
|
||||||
|
|
||||||
|
## Migration Strategies
|
||||||
|
|
||||||
|
### Legacy Code Integration
|
||||||
|
- **Strangler Fig Pattern**: Gradually replace old code with new modules
|
||||||
|
- **Adapter Pattern**: Create interfaces to integrate old and new code
|
||||||
|
- **Facade Pattern**: Simplify complex legacy interfaces
|
||||||
|
|
||||||
|
### Gradual Modernization
|
||||||
|
1. **Identify boundaries** in existing code
|
||||||
|
2. **Extract modules** one at a time
|
||||||
|
3. **Create interfaces** for each extracted module
|
||||||
|
4. **Test thoroughly** at each step
|
||||||
|
5. **Update documentation** continuously
|
||||||
123
.cursor/rules/code-review.mdc
Normal file
123
.cursor/rules/code-review.mdc
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
---
|
||||||
|
description: AI-generated code review checklist and quality assurance guidelines
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Code Review and Quality Assurance
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Establish systematic review processes for AI-generated code to maintain quality, security, and maintainability standards.
|
||||||
|
|
||||||
|
## AI Code Review Checklist
|
||||||
|
|
||||||
|
### Pre-Implementation Review
|
||||||
|
Before accepting any AI-generated code:
|
||||||
|
|
||||||
|
1. **Understand the Code**
|
||||||
|
- [ ] Can you explain what the code does in your own words?
|
||||||
|
- [ ] Do you understand each function and its purpose?
|
||||||
|
- [ ] Are there any "magic" values or unexplained logic?
|
||||||
|
- [ ] Does the code solve the actual problem stated?
|
||||||
|
|
||||||
|
2. **Architecture Alignment**
|
||||||
|
- [ ] Does the code follow established project patterns?
|
||||||
|
- [ ] Is it consistent with existing data structures?
|
||||||
|
- [ ] Does it integrate cleanly with existing components?
|
||||||
|
- [ ] Are new dependencies justified and necessary?
|
||||||
|
|
||||||
|
3. **Code Quality**
|
||||||
|
- [ ] Are functions smaller than 50 lines?
|
||||||
|
- [ ] Are files smaller than 250 lines?
|
||||||
|
- [ ] Are variable and function names descriptive?
|
||||||
|
- [ ] Is the code DRY (Don't Repeat Yourself)?
|
||||||
|
|
||||||
|
### Security Review
|
||||||
|
- [ ] **Input Validation**: All user inputs are validated and sanitized
|
||||||
|
- [ ] **Authentication**: Proper authentication checks are in place
|
||||||
|
- [ ] **Authorization**: Access controls are implemented correctly
|
||||||
|
- [ ] **Data Protection**: Sensitive data is handled securely
|
||||||
|
- [ ] **SQL Injection**: Database queries use parameterized statements
|
||||||
|
- [ ] **XSS Prevention**: Output is properly escaped
|
||||||
|
- [ ] **Error Handling**: Errors don't leak sensitive information
|
||||||
|
|
||||||
|
### Integration Review
|
||||||
|
- [ ] **Existing Functionality**: New code doesn't break existing features
|
||||||
|
- [ ] **Data Consistency**: Database changes maintain referential integrity
|
||||||
|
- [ ] **API Compatibility**: Changes don't break existing API contracts
|
||||||
|
- [ ] **Performance Impact**: New code doesn't introduce performance bottlenecks
|
||||||
|
- [ ] **Testing Coverage**: Appropriate tests are included
|
||||||
|
|
||||||
|
## Review Process
|
||||||
|
|
||||||
|
### Step 1: Initial Code Analysis
|
||||||
|
1. **Read through the entire generated code** before running it
|
||||||
|
2. **Identify patterns** that don't match existing codebase
|
||||||
|
3. **Check dependencies** - are new packages really needed?
|
||||||
|
4. **Verify logic flow** - does the algorithm make sense?
|
||||||
|
|
||||||
|
### Step 2: Security and Error Handling Review
|
||||||
|
1. **Trace data flow** from input to output
|
||||||
|
2. **Identify potential failure points** and verify error handling
|
||||||
|
3. **Check for security vulnerabilities** using the security checklist
|
||||||
|
4. **Verify proper logging** and monitoring implementation
|
||||||
|
|
||||||
|
### Step 3: Integration Testing
|
||||||
|
1. **Test with existing code** to ensure compatibility
|
||||||
|
2. **Run existing test suite** to verify no regressions
|
||||||
|
3. **Test edge cases** and error conditions
|
||||||
|
4. **Verify performance** under realistic conditions
|
||||||
|
|
||||||
|
## Common AI Code Issues to Watch For
|
||||||
|
|
||||||
|
### Overcomplication Patterns
|
||||||
|
- **Unnecessary abstractions**: AI creating complex patterns for simple tasks
|
||||||
|
- **Over-engineering**: Solutions that are more complex than needed
|
||||||
|
- **Redundant code**: AI recreating existing functionality
|
||||||
|
- **Inappropriate design patterns**: Using patterns that don't fit the use case
|
||||||
|
|
||||||
|
### Context Loss Indicators
|
||||||
|
- **Inconsistent naming**: Different conventions from existing code
|
||||||
|
- **Wrong data structures**: Using different patterns than established
|
||||||
|
- **Ignored existing functions**: Reimplementing existing functionality
|
||||||
|
- **Architectural misalignment**: Code that doesn't fit the overall design
|
||||||
|
|
||||||
|
### Technical Debt Indicators
|
||||||
|
- **Magic numbers**: Hardcoded values without explanation
|
||||||
|
- **Poor error messages**: Generic or unhelpful error handling
|
||||||
|
- **Missing documentation**: Code without adequate comments
|
||||||
|
- **Tight coupling**: Components that are too interdependent
|
||||||
|
|
||||||
|
## Quality Gates
|
||||||
|
|
||||||
|
### Mandatory Reviews
|
||||||
|
All AI-generated code must pass these gates before acceptance:
|
||||||
|
|
||||||
|
1. **Security Review**: No security vulnerabilities detected
|
||||||
|
2. **Integration Review**: Integrates cleanly with existing code
|
||||||
|
3. **Performance Review**: Meets performance requirements
|
||||||
|
4. **Maintainability Review**: Code can be easily modified by team members
|
||||||
|
5. **Documentation Review**: Adequate documentation is provided
|
||||||
|
|
||||||
|
### Acceptance Criteria
|
||||||
|
- [ ] Code is understandable by any team member
|
||||||
|
- [ ] Integration requires minimal changes to existing code
|
||||||
|
- [ ] Security review passes all checks
|
||||||
|
- [ ] Performance meets established benchmarks
|
||||||
|
- [ ] Documentation is complete and accurate
|
||||||
|
|
||||||
|
## Rejection Criteria
|
||||||
|
Reject AI-generated code if:
|
||||||
|
- Security vulnerabilities are present
|
||||||
|
- Code is too complex for the problem being solved
|
||||||
|
- Integration requires major refactoring of existing code
|
||||||
|
- Code duplicates existing functionality without justification
|
||||||
|
- Documentation is missing or inadequate
|
||||||
|
|
||||||
|
## Review Documentation
|
||||||
|
For each review, document:
|
||||||
|
- Issues found and how they were resolved
|
||||||
|
- Performance impact assessment
|
||||||
|
- Security concerns and mitigations
|
||||||
|
- Integration challenges and solutions
|
||||||
|
- Recommendations for future similar tasks
|
||||||
93
.cursor/rules/context-management.mdc
Normal file
93
.cursor/rules/context-management.mdc
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
---
|
||||||
|
description: Context management for maintaining codebase awareness and preventing context drift
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Context Management
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Maintain comprehensive project context to prevent context drift and ensure AI-generated code integrates seamlessly with existing codebase patterns and architecture.
|
||||||
|
|
||||||
|
## Context Documentation Requirements
|
||||||
|
|
||||||
|
### PRD.md file documentation
|
||||||
|
1. **Project Overview**
|
||||||
|
- Business objectives and goals
|
||||||
|
- Target users and use cases
|
||||||
|
- Key success metrics
|
||||||
|
|
||||||
|
### CONTEXT.md File Structure
|
||||||
|
Every project must maintain a `CONTEXT.md` file in the root directory with:
|
||||||
|
|
||||||
|
1. **Architecture Overview**
|
||||||
|
- High-level system architecture
|
||||||
|
- Key design patterns used
|
||||||
|
- Database schema overview
|
||||||
|
- API structure and conventions
|
||||||
|
|
||||||
|
2. **Technology Stack**
|
||||||
|
- Programming languages and versions
|
||||||
|
- Frameworks and libraries
|
||||||
|
- Database systems
|
||||||
|
- Development and deployment tools
|
||||||
|
|
||||||
|
3. **Coding Conventions**
|
||||||
|
- Naming conventions
|
||||||
|
- File organization patterns
|
||||||
|
- Code structure preferences
|
||||||
|
- Import/export patterns
|
||||||
|
|
||||||
|
4. **Current Implementation Status**
|
||||||
|
- Completed features
|
||||||
|
- Work in progress
|
||||||
|
- Known technical debt
|
||||||
|
- Planned improvements
|
||||||
|
|
||||||
|
## Context Maintenance Protocol
|
||||||
|
|
||||||
|
### Before Every Coding Session
|
||||||
|
1. **Review CONTEXT.md and PRD.md** to understand current project state
|
||||||
|
2. **Scan recent changes** in git history to understand latest patterns
|
||||||
|
3. **Identify existing patterns** for similar functionality before implementing new features
|
||||||
|
4. **Ask for clarification** if existing patterns are unclear or conflicting
|
||||||
|
|
||||||
|
### During Development
|
||||||
|
1. **Reference existing code** when explaining implementation approaches
|
||||||
|
2. **Maintain consistency** with established patterns and conventions
|
||||||
|
3. **Update CONTEXT.md** when making architectural decisions
|
||||||
|
4. **Document deviations** from established patterns with reasoning
|
||||||
|
|
||||||
|
### Context Preservation Strategies
|
||||||
|
- **Incremental development**: Build on existing patterns rather than creating new ones
|
||||||
|
- **Pattern consistency**: Use established data structures and function signatures
|
||||||
|
- **Integration awareness**: Consider how new code affects existing functionality
|
||||||
|
- **Dependency management**: Understand existing dependencies before adding new ones
|
||||||
|
|
||||||
|
## Context Prompting Best Practices
|
||||||
|
|
||||||
|
### Effective Context Sharing
|
||||||
|
- Include relevant sections of CONTEXT.md in prompts for complex tasks
|
||||||
|
- Reference specific existing files when asking for similar functionality
|
||||||
|
- Provide examples of existing patterns when requesting new implementations
|
||||||
|
- Share recent git commit messages to understand latest changes
|
||||||
|
|
||||||
|
### Context Window Optimization
|
||||||
|
- Prioritize most relevant context for current task
|
||||||
|
- Use @filename references to include specific files
|
||||||
|
- Break large contexts into focused, task-specific chunks
|
||||||
|
- Update context references as project evolves
|
||||||
|
|
||||||
|
## Red Flags - Context Loss Indicators
|
||||||
|
- AI suggests patterns that conflict with existing code
|
||||||
|
- New implementations ignore established conventions
|
||||||
|
- Proposed solutions don't integrate with existing architecture
|
||||||
|
- Code suggestions require significant refactoring of existing functionality
|
||||||
|
|
||||||
|
## Recovery Protocol
|
||||||
|
When context loss is detected:
|
||||||
|
1. **Stop development** and review CONTEXT.md
|
||||||
|
2. **Analyze existing codebase** for established patterns
|
||||||
|
3. **Update context documentation** with missing information
|
||||||
|
4. **Restart task** with proper context provided
|
||||||
|
5. **Test integration** with existing code before proceeding
|
||||||
67
.cursor/rules/create-prd.mdc
Normal file
67
.cursor/rules/create-prd.mdc
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
---
|
||||||
|
description: Creating PRD for a project or specific task/function
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
---
|
||||||
|
description: Creating PRD for a project or specific task/function
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
# Rule: Generating a Product Requirements Document (PRD)
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
To guide an AI assistant in creating a detailed Product Requirements Document (PRD) in Markdown format, based on an initial user prompt. The PRD should be clear, actionable, and suitable for a junior developer to understand and implement the feature.
|
||||||
|
|
||||||
|
## Process
|
||||||
|
|
||||||
|
1. **Receive Initial Prompt:** The user provides a brief description or request for a new feature or functionality.
|
||||||
|
2. **Ask Clarifying Questions:** Before writing the PRD, the AI *must* ask clarifying questions to gather sufficient detail. The goal is to understand the "what" and "why" of the feature, not necessarily the "how" (which the developer will figure out).
|
||||||
|
3. **Generate PRD:** Based on the initial prompt and the user's answers to the clarifying questions, generate a PRD using the structure outlined below.
|
||||||
|
4. **Save PRD:** Save the generated document as `prd-[feature-name].md` inside the `/tasks` directory.
|
||||||
|
|
||||||
|
## Clarifying Questions (Examples)
|
||||||
|
|
||||||
|
The AI should adapt its questions based on the prompt, but here are some common areas to explore:
|
||||||
|
|
||||||
|
* **Problem/Goal:** "What problem does this feature solve for the user?" or "What is the main goal we want to achieve with this feature?"
|
||||||
|
* **Target User:** "Who is the primary user of this feature?"
|
||||||
|
* **Core Functionality:** "Can you describe the key actions a user should be able to perform with this feature?"
|
||||||
|
* **User Stories:** "Could you provide a few user stories? (e.g., As a [type of user], I want to [perform an action] so that [benefit].)"
|
||||||
|
* **Acceptance Criteria:** "How will we know when this feature is successfully implemented? What are the key success criteria?"
|
||||||
|
* **Scope/Boundaries:** "Are there any specific things this feature *should not* do (non-goals)?"
|
||||||
|
* **Data Requirements:** "What kind of data does this feature need to display or manipulate?"
|
||||||
|
* **Design/UI:** "Are there any existing design mockups or UI guidelines to follow?" or "Can you describe the desired look and feel?"
|
||||||
|
* **Edge Cases:** "Are there any potential edge cases or error conditions we should consider?"
|
||||||
|
|
||||||
|
## PRD Structure
|
||||||
|
|
||||||
|
The generated PRD should include the following sections:
|
||||||
|
|
||||||
|
1. **Introduction/Overview:** Briefly describe the feature and the problem it solves. State the goal.
|
||||||
|
2. **Goals:** List the specific, measurable objectives for this feature.
|
||||||
|
3. **User Stories:** Detail the user narratives describing feature usage and benefits.
|
||||||
|
4. **Functional Requirements:** List the specific functionalities the feature must have. Use clear, concise language (e.g., "The system must allow users to upload a profile picture."). Number these requirements.
|
||||||
|
5. **Non-Goals (Out of Scope):** Clearly state what this feature will *not* include to manage scope.
|
||||||
|
6. **Design Considerations (Optional):** Link to mockups, describe UI/UX requirements, or mention relevant components/styles if applicable.
|
||||||
|
7. **Technical Considerations (Optional):** Mention any known technical constraints, dependencies, or suggestions (e.g., "Should integrate with the existing Auth module").
|
||||||
|
8. **Success Metrics:** How will the success of this feature be measured? (e.g., "Increase user engagement by 10%", "Reduce support tickets related to X").
|
||||||
|
9. **Open Questions:** List any remaining questions or areas needing further clarification.
|
||||||
|
|
||||||
|
## Target Audience
|
||||||
|
|
||||||
|
Assume the primary reader of the PRD is a **junior developer**. Therefore, requirements should be explicit, unambiguous, and avoid jargon where possible. Provide enough detail for them to understand the feature's purpose and core logic.
|
||||||
|
|
||||||
|
## Output
|
||||||
|
|
||||||
|
* **Format:** Markdown (`.md`)
|
||||||
|
* **Location:** `/tasks/`
|
||||||
|
* **Filename:** `prd-[feature-name].md`
|
||||||
|
|
||||||
|
## Final instructions
|
||||||
|
|
||||||
|
1. Do NOT start implmenting the PRD
|
||||||
|
2. Make sure to ask the user clarifying questions
|
||||||
|
|
||||||
|
3. Take the user's answers to the clarifying questions and improve the PRD
|
||||||
244
.cursor/rules/documentation.mdc
Normal file
244
.cursor/rules/documentation.mdc
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
---
|
||||||
|
description: Documentation standards for code, architecture, and development decisions
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Documentation Standards
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Maintain comprehensive, up-to-date documentation that supports development, onboarding, and long-term maintenance of the codebase.
|
||||||
|
|
||||||
|
## Documentation Hierarchy
|
||||||
|
|
||||||
|
### 1. Project Level Documentation (in ./docs/)
|
||||||
|
- **README.md**: Project overview, setup instructions, basic usage
|
||||||
|
- **CONTEXT.md**: Current project state, architecture decisions, patterns
|
||||||
|
- **CHANGELOG.md**: Version history and significant changes
|
||||||
|
- **CONTRIBUTING.md**: Development guidelines and processes
|
||||||
|
- **API.md**: API endpoints, request/response formats, authentication
|
||||||
|
|
||||||
|
### 2. Module Level Documentation (in ./docs/modules/)
|
||||||
|
- **[module-name].md**: Purpose, public interfaces, usage examples
|
||||||
|
- **dependencies.md**: External dependencies and their purposes
|
||||||
|
- **architecture.md**: Module relationships and data flow
|
||||||
|
|
||||||
|
### 3. Code Level Documentation
|
||||||
|
- **Docstrings**: Function and class documentation
|
||||||
|
- **Inline comments**: Complex logic explanations
|
||||||
|
- **Type hints**: Clear parameter and return types
|
||||||
|
- **README files**: Directory-specific instructions
|
||||||
|
|
||||||
|
## Documentation Standards
|
||||||
|
|
||||||
|
### Code Documentation
|
||||||
|
```python
|
||||||
|
def process_user_data(user_id: str, data: dict) -> UserResult:
|
||||||
|
"""
|
||||||
|
Process and validate user data before storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Unique identifier for the user
|
||||||
|
data: Dictionary containing user information to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserResult: Processed user data with validation status
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: When user data fails validation
|
||||||
|
DatabaseError: When storage operation fails
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> result = process_user_data("123", {"name": "John", "email": "john@example.com"})
|
||||||
|
>>> print(result.status)
|
||||||
|
'valid'
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
### API Documentation Format
|
||||||
|
```markdown
|
||||||
|
### POST /api/users
|
||||||
|
|
||||||
|
Create a new user account.
|
||||||
|
|
||||||
|
**Request:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "string (required)",
|
||||||
|
"email": "string (required, valid email)",
|
||||||
|
"age": "number (optional, min: 13)"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response (201):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "uuid",
|
||||||
|
"name": "string",
|
||||||
|
"email": "string",
|
||||||
|
"created_at": "iso_datetime"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Errors:**
|
||||||
|
- 400: Invalid input data
|
||||||
|
- 409: Email already exists
|
||||||
|
```
|
||||||
|
|
||||||
|
### Architecture Decision Records (ADRs)
|
||||||
|
Document significant architecture decisions in `./docs/decisions/`:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
# ADR-001: Database Choice - PostgreSQL
|
||||||
|
|
||||||
|
## Status
|
||||||
|
Accepted
|
||||||
|
|
||||||
|
## Context
|
||||||
|
We need to choose a database for storing user data and application state.
|
||||||
|
|
||||||
|
## Decision
|
||||||
|
We will use PostgreSQL as our primary database.
|
||||||
|
|
||||||
|
## Consequences
|
||||||
|
**Positive:**
|
||||||
|
- ACID compliance ensures data integrity
|
||||||
|
- Rich query capabilities with SQL
|
||||||
|
- Good performance for our expected load
|
||||||
|
|
||||||
|
**Negative:**
|
||||||
|
- More complex setup than simpler alternatives
|
||||||
|
- Requires SQL knowledge from team members
|
||||||
|
|
||||||
|
## Alternatives Considered
|
||||||
|
- MongoDB: Rejected due to consistency requirements
|
||||||
|
- SQLite: Rejected due to scalability needs
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation Maintenance
|
||||||
|
|
||||||
|
### When to Update Documentation
|
||||||
|
|
||||||
|
#### Always Update:
|
||||||
|
- **API changes**: Any modification to public interfaces
|
||||||
|
- **Architecture changes**: New patterns, data structures, or workflows
|
||||||
|
- **Configuration changes**: Environment variables, deployment settings
|
||||||
|
- **Dependencies**: Adding, removing, or upgrading packages
|
||||||
|
- **Business logic changes**: Core functionality modifications
|
||||||
|
|
||||||
|
#### Update Weekly:
|
||||||
|
- **CONTEXT.md**: Current development status and priorities
|
||||||
|
- **Known issues**: Bug reports and workarounds
|
||||||
|
- **Performance notes**: Bottlenecks and optimization opportunities
|
||||||
|
|
||||||
|
#### Update per Release:
|
||||||
|
- **CHANGELOG.md**: User-facing changes and improvements
|
||||||
|
- **Version documentation**: Breaking changes and migration guides
|
||||||
|
- **Examples and tutorials**: Keep sample code current
|
||||||
|
|
||||||
|
### Documentation Quality Checklist
|
||||||
|
|
||||||
|
#### Completeness
|
||||||
|
- [ ] Purpose and scope clearly explained
|
||||||
|
- [ ] All public interfaces documented
|
||||||
|
- [ ] Examples provided for complex usage
|
||||||
|
- [ ] Error conditions and handling described
|
||||||
|
- [ ] Dependencies and requirements listed
|
||||||
|
|
||||||
|
#### Accuracy
|
||||||
|
- [ ] Code examples are tested and working
|
||||||
|
- [ ] Links point to correct locations
|
||||||
|
- [ ] Version numbers are current
|
||||||
|
- [ ] Screenshots reflect current UI
|
||||||
|
|
||||||
|
#### Clarity
|
||||||
|
- [ ] Written for the intended audience
|
||||||
|
- [ ] Technical jargon is explained
|
||||||
|
- [ ] Step-by-step instructions are clear
|
||||||
|
- [ ] Visual aids used where helpful
|
||||||
|
|
||||||
|
## Documentation Automation
|
||||||
|
|
||||||
|
### Auto-Generated Documentation
|
||||||
|
- **API docs**: Generate from code annotations
|
||||||
|
- **Type documentation**: Extract from type hints
|
||||||
|
- **Module dependencies**: Auto-update from imports
|
||||||
|
- **Test coverage**: Include coverage reports
|
||||||
|
|
||||||
|
### Documentation Testing
|
||||||
|
```python
|
||||||
|
# Test that code examples in documentation work
|
||||||
|
def test_documentation_examples():
|
||||||
|
"""Verify code examples in docs actually work."""
|
||||||
|
# Test examples from README.md
|
||||||
|
# Test API examples from docs/API.md
|
||||||
|
# Test configuration examples
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation Templates
|
||||||
|
|
||||||
|
### New Module Documentation Template
|
||||||
|
```markdown
|
||||||
|
# Module: [Name]
|
||||||
|
|
||||||
|
## Purpose
|
||||||
|
Brief description of what this module does and why it exists.
|
||||||
|
|
||||||
|
## Public Interface
|
||||||
|
### Functions
|
||||||
|
- `function_name(params)`: Description and example
|
||||||
|
|
||||||
|
### Classes
|
||||||
|
- `ClassName`: Purpose and basic usage
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
```python
|
||||||
|
# Basic usage example
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
- Internal: List of internal modules this depends on
|
||||||
|
- External: List of external packages required
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
How to run tests for this module.
|
||||||
|
|
||||||
|
## Known Issues
|
||||||
|
Current limitations or bugs.
|
||||||
|
```
|
||||||
|
|
||||||
|
### API Endpoint Template
|
||||||
|
```markdown
|
||||||
|
### [METHOD] /api/endpoint
|
||||||
|
|
||||||
|
Brief description of what this endpoint does.
|
||||||
|
|
||||||
|
**Authentication:** Required/Optional
|
||||||
|
**Rate Limiting:** X requests per minute
|
||||||
|
|
||||||
|
**Request:**
|
||||||
|
- Headers required
|
||||||
|
- Body schema
|
||||||
|
- Query parameters
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
- Success response format
|
||||||
|
- Error response format
|
||||||
|
- Status codes
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
Working request/response example
|
||||||
|
```
|
||||||
|
|
||||||
|
## Review and Maintenance Process
|
||||||
|
|
||||||
|
### Documentation Review
|
||||||
|
- Include documentation updates in code reviews
|
||||||
|
- Verify examples still work with code changes
|
||||||
|
- Check for broken links and outdated information
|
||||||
|
- Ensure consistency with current implementation
|
||||||
|
|
||||||
|
### Regular Audits
|
||||||
|
- Monthly review of documentation accuracy
|
||||||
|
- Quarterly assessment of documentation completeness
|
||||||
|
- Annual review of documentation structure and organization
|
||||||
207
.cursor/rules/enhanced-task-list.mdc
Normal file
207
.cursor/rules/enhanced-task-list.mdc
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
---
|
||||||
|
description: Enhanced task list management with quality gates and iterative workflow integration
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Enhanced Task List Management
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Manage task lists with integrated quality gates and iterative workflow to prevent context loss and ensure sustainable development.
|
||||||
|
|
||||||
|
## Task Implementation Protocol
|
||||||
|
|
||||||
|
### Pre-Implementation Check
|
||||||
|
Before starting any sub-task:
|
||||||
|
- [ ] **Context Review**: Have you reviewed CONTEXT.md and relevant documentation?
|
||||||
|
- [ ] **Pattern Identification**: Do you understand existing patterns to follow?
|
||||||
|
- [ ] **Integration Planning**: Do you know how this will integrate with existing code?
|
||||||
|
- [ ] **Size Validation**: Is this task small enough (≤50 lines, ≤250 lines per file)?
|
||||||
|
|
||||||
|
### Implementation Process
|
||||||
|
1. **One sub-task at a time**: Do **NOT** start the next sub‑task until you ask the user for permission and they say "yes" or "y"
|
||||||
|
2. **Step-by-step execution**:
|
||||||
|
- Plan the approach in bullet points
|
||||||
|
- Wait for approval
|
||||||
|
- Implement the specific sub-task
|
||||||
|
- Test the implementation
|
||||||
|
- Update documentation if needed
|
||||||
|
3. **Quality validation**: Run through the code review checklist before marking complete
|
||||||
|
|
||||||
|
### Completion Protocol
|
||||||
|
When you finish a **sub‑task**:
|
||||||
|
1. **Immediate marking**: Change `[ ]` to `[x]`
|
||||||
|
2. **Quality check**: Verify the implementation meets quality standards
|
||||||
|
3. **Integration test**: Ensure new code works with existing functionality
|
||||||
|
4. **Documentation update**: Update relevant files if needed
|
||||||
|
5. **Parent task check**: If **all** subtasks underneath a parent task are now `[x]`, also mark the **parent task** as completed
|
||||||
|
6. **Stop and wait**: Get user approval before proceeding to next sub-task
|
||||||
|
|
||||||
|
## Enhanced Task List Structure
|
||||||
|
|
||||||
|
### Task File Header
|
||||||
|
```markdown
|
||||||
|
# Task List: [Feature Name]
|
||||||
|
|
||||||
|
**Source PRD**: `prd-[feature-name].md`
|
||||||
|
**Status**: In Progress / Complete / Blocked
|
||||||
|
**Context Last Updated**: [Date]
|
||||||
|
**Architecture Review**: Required / Complete / N/A
|
||||||
|
|
||||||
|
## Quick Links
|
||||||
|
- [Context Documentation](./CONTEXT.md)
|
||||||
|
- [Architecture Guidelines](./docs/architecture.md)
|
||||||
|
- [Related Files](#relevant-files)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Task Format with Quality Gates
|
||||||
|
```markdown
|
||||||
|
- [ ] 1.0 Parent Task Title
|
||||||
|
- **Quality Gate**: Architecture review required
|
||||||
|
- **Dependencies**: List any dependencies
|
||||||
|
- [ ] 1.1 [Sub-task description 1.1]
|
||||||
|
- **Size estimate**: [Small/Medium/Large]
|
||||||
|
- **Pattern reference**: [Reference to existing pattern]
|
||||||
|
- **Test requirements**: [Unit/Integration/Both]
|
||||||
|
- [ ] 1.2 [Sub-task description 1.2]
|
||||||
|
- **Integration points**: [List affected components]
|
||||||
|
- **Risk level**: [Low/Medium/High]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Relevant Files Management
|
||||||
|
|
||||||
|
### Enhanced File Tracking
|
||||||
|
```markdown
|
||||||
|
## Relevant Files
|
||||||
|
|
||||||
|
### Implementation Files
|
||||||
|
- `path/to/file1.ts` - Brief description of purpose and role
|
||||||
|
- **Status**: Created / Modified / Needs Review
|
||||||
|
- **Last Modified**: [Date]
|
||||||
|
- **Review Status**: Pending / Approved / Needs Changes
|
||||||
|
|
||||||
|
### Test Files
|
||||||
|
- `path/to/file1.test.ts` - Unit tests for file1.ts
|
||||||
|
- **Coverage**: [Percentage or status]
|
||||||
|
- **Last Run**: [Date and result]
|
||||||
|
|
||||||
|
### Documentation Files
|
||||||
|
- `docs/module-name.md` - Module documentation
|
||||||
|
- **Status**: Up to date / Needs update / Missing
|
||||||
|
- **Last Updated**: [Date]
|
||||||
|
|
||||||
|
### Configuration Files
|
||||||
|
- `config/setting.json` - Configuration changes
|
||||||
|
- **Environment**: [Dev/Staging/Prod affected]
|
||||||
|
- **Backup**: [Location of backup]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Task List Maintenance
|
||||||
|
|
||||||
|
### During Development
|
||||||
|
1. **Regular updates**: Update task status after each significant change
|
||||||
|
2. **File tracking**: Add new files as they are created or modified
|
||||||
|
3. **Dependency tracking**: Note when new dependencies between tasks emerge
|
||||||
|
4. **Risk assessment**: Flag tasks that become more complex than anticipated
|
||||||
|
|
||||||
|
### Quality Checkpoints
|
||||||
|
At 25%, 50%, 75%, and 100% completion:
|
||||||
|
- [ ] **Architecture alignment**: Code follows established patterns
|
||||||
|
- [ ] **Performance impact**: No significant performance degradation
|
||||||
|
- [ ] **Security review**: No security vulnerabilities introduced
|
||||||
|
- [ ] **Documentation current**: All changes are documented
|
||||||
|
|
||||||
|
### Weekly Review Process
|
||||||
|
1. **Completion assessment**: What percentage of tasks are actually complete?
|
||||||
|
2. **Quality assessment**: Are completed tasks meeting quality standards?
|
||||||
|
3. **Process assessment**: Is the iterative workflow being followed?
|
||||||
|
4. **Risk assessment**: Are there emerging risks or blockers?
|
||||||
|
|
||||||
|
## Task Status Indicators
|
||||||
|
|
||||||
|
### Status Levels
|
||||||
|
- `[ ]` **Not Started**: Task not yet begun
|
||||||
|
- `[~]` **In Progress**: Currently being worked on
|
||||||
|
- `[?]` **Blocked**: Waiting for dependencies or decisions
|
||||||
|
- `[!]` **Needs Review**: Implementation complete but needs quality review
|
||||||
|
- `[x]` **Complete**: Finished and quality approved
|
||||||
|
|
||||||
|
### Quality Indicators
|
||||||
|
- ✅ **Quality Approved**: Passed all quality gates
|
||||||
|
- ⚠️ **Quality Concerns**: Has issues but functional
|
||||||
|
- ❌ **Quality Failed**: Needs rework before approval
|
||||||
|
- 🔄 **Under Review**: Currently being reviewed
|
||||||
|
|
||||||
|
### Integration Status
|
||||||
|
- 🔗 **Integrated**: Successfully integrated with existing code
|
||||||
|
- 🔧 **Integration Issues**: Problems with existing code integration
|
||||||
|
- ⏳ **Integration Pending**: Ready for integration testing
|
||||||
|
|
||||||
|
## Emergency Procedures
|
||||||
|
|
||||||
|
### When Tasks Become Too Complex
|
||||||
|
If a sub-task grows beyond expected scope:
|
||||||
|
1. **Stop implementation** immediately
|
||||||
|
2. **Document current state** and what was discovered
|
||||||
|
3. **Break down** the task into smaller pieces
|
||||||
|
4. **Update task list** with new sub-tasks
|
||||||
|
5. **Get approval** for the new breakdown before proceeding
|
||||||
|
|
||||||
|
### When Context is Lost
|
||||||
|
If AI seems to lose track of project patterns:
|
||||||
|
1. **Pause development**
|
||||||
|
2. **Review CONTEXT.md** and recent changes
|
||||||
|
3. **Update context documentation** with current state
|
||||||
|
4. **Restart** with explicit pattern references
|
||||||
|
5. **Reduce task size** until context is re-established
|
||||||
|
|
||||||
|
### When Quality Gates Fail
|
||||||
|
If implementation doesn't meet quality standards:
|
||||||
|
1. **Mark task** with `[!]` status
|
||||||
|
2. **Document specific issues** found
|
||||||
|
3. **Create remediation tasks** if needed
|
||||||
|
4. **Don't proceed** until quality issues are resolved
|
||||||
|
|
||||||
|
## AI Instructions Integration
|
||||||
|
|
||||||
|
### Context Awareness Commands
|
||||||
|
```markdown
|
||||||
|
**Before starting any task, run these checks:**
|
||||||
|
1. @CONTEXT.md - Review current project state
|
||||||
|
2. @architecture.md - Understand design principles
|
||||||
|
3. @code-review.md - Know quality standards
|
||||||
|
4. Look at existing similar code for patterns
|
||||||
|
```
|
||||||
|
|
||||||
|
### Quality Validation Commands
|
||||||
|
```markdown
|
||||||
|
**After completing any sub-task:**
|
||||||
|
1. Run code review checklist
|
||||||
|
2. Test integration with existing code
|
||||||
|
3. Update documentation if needed
|
||||||
|
4. Mark task complete only after quality approval
|
||||||
|
```
|
||||||
|
|
||||||
|
### Workflow Commands
|
||||||
|
```markdown
|
||||||
|
**For each development session:**
|
||||||
|
1. Review incomplete tasks and their status
|
||||||
|
2. Identify next logical sub-task to work on
|
||||||
|
3. Check dependencies and blockers
|
||||||
|
4. Follow iterative workflow process
|
||||||
|
5. Update task list with progress and findings
|
||||||
|
```
|
||||||
|
|
||||||
|
## Success Metrics
|
||||||
|
|
||||||
|
### Daily Success Indicators
|
||||||
|
- Tasks are completed according to quality standards
|
||||||
|
- No sub-tasks are started without completing previous ones
|
||||||
|
- File tracking remains accurate and current
|
||||||
|
- Integration issues are caught early
|
||||||
|
|
||||||
|
### Weekly Success Indicators
|
||||||
|
- Overall task completion rate is sustainable
|
||||||
|
- Quality issues are decreasing over time
|
||||||
|
- Context loss incidents are rare
|
||||||
|
- Team confidence in codebase remains high
|
||||||
70
.cursor/rules/generate-tasks.mdc
Normal file
70
.cursor/rules/generate-tasks.mdc
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
---
|
||||||
|
description: Generate a task list or TODO for a user requirement or implementation.
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
---
|
||||||
|
description:
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
# Rule: Generating a Task List from a PRD
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
To guide an AI assistant in creating a detailed, step-by-step task list in Markdown format based on an existing Product Requirements Document (PRD). The task list should guide a developer through implementation.
|
||||||
|
|
||||||
|
## Output
|
||||||
|
|
||||||
|
- **Format:** Markdown (`.md`)
|
||||||
|
- **Location:** `/tasks/`
|
||||||
|
- **Filename:** `tasks-[prd-file-name].md` (e.g., `tasks-prd-user-profile-editing.md`)
|
||||||
|
|
||||||
|
## Process
|
||||||
|
|
||||||
|
1. **Receive PRD Reference:** The user points the AI to a specific PRD file
|
||||||
|
2. **Analyze PRD:** The AI reads and analyzes the functional requirements, user stories, and other sections of the specified PRD.
|
||||||
|
3. **Phase 1: Generate Parent Tasks:** Based on the PRD analysis, create the file and generate the main, high-level tasks required to implement the feature. Use your judgement on how many high-level tasks to use. It's likely to be about 5. Present these tasks to the user in the specified format (without sub-tasks yet). Inform the user: "I have generated the high-level tasks based on the PRD. Ready to generate the sub-tasks? Respond with 'Go' to proceed."
|
||||||
|
4. **Wait for Confirmation:** Pause and wait for the user to respond with "Go".
|
||||||
|
5. **Phase 2: Generate Sub-Tasks:** Once the user confirms, break down each parent task into smaller, actionable sub-tasks necessary to complete the parent task. Ensure sub-tasks logically follow from the parent task and cover the implementation details implied by the PRD.
|
||||||
|
6. **Identify Relevant Files:** Based on the tasks and PRD, identify potential files that will need to be created or modified. List these under the `Relevant Files` section, including corresponding test files if applicable.
|
||||||
|
7. **Generate Final Output:** Combine the parent tasks, sub-tasks, relevant files, and notes into the final Markdown structure.
|
||||||
|
8. **Save Task List:** Save the generated document in the `/tasks/` directory with the filename `tasks-[prd-file-name].md`, where `[prd-file-name]` matches the base name of the input PRD file (e.g., if the input was `prd-user-profile-editing.md`, the output is `tasks-prd-user-profile-editing.md`).
|
||||||
|
|
||||||
|
## Output Format
|
||||||
|
|
||||||
|
The generated task list _must_ follow this structure:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Relevant Files
|
||||||
|
|
||||||
|
- `path/to/potential/file1.ts` - Brief description of why this file is relevant (e.g., Contains the main component for this feature).
|
||||||
|
- `path/to/file1.test.ts` - Unit tests for `file1.ts`.
|
||||||
|
- `path/to/another/file.tsx` - Brief description (e.g., API route handler for data submission).
|
||||||
|
- `path/to/another/file.test.tsx` - Unit tests for `another/file.tsx`.
|
||||||
|
- `lib/utils/helpers.ts` - Brief description (e.g., Utility functions needed for calculations).
|
||||||
|
- `lib/utils/helpers.test.ts` - Unit tests for `helpers.ts`.
|
||||||
|
|
||||||
|
### Notes
|
||||||
|
|
||||||
|
- Unit tests should typically be placed alongside the code files they are testing (e.g., `MyComponent.tsx` and `MyComponent.test.tsx` in the same directory).
|
||||||
|
- Use `npx jest [optional/path/to/test/file]` to run tests. Running without a path executes all tests found by the Jest configuration.
|
||||||
|
|
||||||
|
## Tasks
|
||||||
|
|
||||||
|
- [ ] 1.0 Parent Task Title
|
||||||
|
- [ ] 1.1 [Sub-task description 1.1]
|
||||||
|
- [ ] 1.2 [Sub-task description 1.2]
|
||||||
|
- [ ] 2.0 Parent Task Title
|
||||||
|
- [ ] 2.1 [Sub-task description 2.1]
|
||||||
|
- [ ] 3.0 Parent Task Title (may not require sub-tasks if purely structural or configuration)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Interaction Model
|
||||||
|
|
||||||
|
The process explicitly requires a pause after generating parent tasks to get user confirmation ("Go") before proceeding to generate the detailed sub-tasks. This ensures the high-level plan aligns with user expectations before diving into details.
|
||||||
|
|
||||||
|
## Target Audience
|
||||||
|
|
||||||
|
|
||||||
|
Assume the primary reader of the task list is a **junior developer** who will implement the feature.
|
||||||
236
.cursor/rules/iterative-workflow.mdc
Normal file
236
.cursor/rules/iterative-workflow.mdc
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
---
|
||||||
|
description: Iterative development workflow for AI-assisted coding
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Iterative Development Workflow
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Establish a structured, iterative development process that prevents the chaos and complexity that can arise from uncontrolled AI-assisted development.
|
||||||
|
|
||||||
|
## Development Phases
|
||||||
|
|
||||||
|
### Phase 1: Planning and Design
|
||||||
|
**Before writing any code:**
|
||||||
|
|
||||||
|
1. **Understand the Requirement**
|
||||||
|
- Break down the task into specific, measurable objectives
|
||||||
|
- Identify existing code patterns that should be followed
|
||||||
|
- List dependencies and integration points
|
||||||
|
- Define acceptance criteria
|
||||||
|
|
||||||
|
2. **Design Review**
|
||||||
|
- Propose approach in bullet points
|
||||||
|
- Wait for explicit approval before proceeding
|
||||||
|
- Consider how the solution fits existing architecture
|
||||||
|
- Identify potential risks and mitigation strategies
|
||||||
|
|
||||||
|
### Phase 2: Incremental Implementation
|
||||||
|
**One small piece at a time:**
|
||||||
|
|
||||||
|
1. **Micro-Tasks** (≤ 50 lines each)
|
||||||
|
- Implement one function or small class at a time
|
||||||
|
- Test immediately after implementation
|
||||||
|
- Ensure integration with existing code
|
||||||
|
- Document decisions and patterns used
|
||||||
|
|
||||||
|
2. **Validation Checkpoints**
|
||||||
|
- After each micro-task, verify it works correctly
|
||||||
|
- Check that it follows established patterns
|
||||||
|
- Confirm it integrates cleanly with existing code
|
||||||
|
- Get approval before moving to next micro-task
|
||||||
|
|
||||||
|
### Phase 3: Integration and Testing
|
||||||
|
**Ensuring system coherence:**
|
||||||
|
|
||||||
|
1. **Integration Testing**
|
||||||
|
- Test new code with existing functionality
|
||||||
|
- Verify no regressions in existing features
|
||||||
|
- Check performance impact
|
||||||
|
- Validate error handling
|
||||||
|
|
||||||
|
2. **Documentation Update**
|
||||||
|
- Update relevant documentation
|
||||||
|
- Record any new patterns or decisions
|
||||||
|
- Update context files if architecture changed
|
||||||
|
|
||||||
|
## Iterative Prompting Strategy
|
||||||
|
|
||||||
|
### Step 1: Context Setting
|
||||||
|
```
|
||||||
|
Before implementing [feature], help me understand:
|
||||||
|
1. What existing patterns should I follow?
|
||||||
|
2. What existing functions/classes are relevant?
|
||||||
|
3. How should this integrate with [specific existing component]?
|
||||||
|
4. What are the potential architectural impacts?
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Plan Creation
|
||||||
|
```
|
||||||
|
Based on the context, create a detailed plan for implementing [feature]:
|
||||||
|
1. Break it into micro-tasks (≤50 lines each)
|
||||||
|
2. Identify dependencies and order of implementation
|
||||||
|
3. Specify integration points with existing code
|
||||||
|
4. List potential risks and mitigation strategies
|
||||||
|
|
||||||
|
Wait for my approval before implementing.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Incremental Implementation
|
||||||
|
```
|
||||||
|
Implement only the first micro-task: [specific task]
|
||||||
|
- Use existing patterns from [reference file/function]
|
||||||
|
- Keep it under 50 lines
|
||||||
|
- Include error handling
|
||||||
|
- Add appropriate tests
|
||||||
|
- Explain your implementation choices
|
||||||
|
|
||||||
|
Stop after this task and wait for approval.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quality Gates
|
||||||
|
|
||||||
|
### Before Each Implementation
|
||||||
|
- [ ] **Purpose is clear**: Can explain what this piece does and why
|
||||||
|
- [ ] **Pattern is established**: Following existing code patterns
|
||||||
|
- [ ] **Size is manageable**: Implementation is small enough to understand completely
|
||||||
|
- [ ] **Integration is planned**: Know how it connects to existing code
|
||||||
|
|
||||||
|
### After Each Implementation
|
||||||
|
- [ ] **Code is understood**: Can explain every line of implemented code
|
||||||
|
- [ ] **Tests pass**: All existing and new tests are passing
|
||||||
|
- [ ] **Integration works**: New code works with existing functionality
|
||||||
|
- [ ] **Documentation updated**: Changes are reflected in relevant documentation
|
||||||
|
|
||||||
|
### Before Moving to Next Task
|
||||||
|
- [ ] **Current task complete**: All acceptance criteria met
|
||||||
|
- [ ] **No regressions**: Existing functionality still works
|
||||||
|
- [ ] **Clean state**: No temporary code or debugging artifacts
|
||||||
|
- [ ] **Approval received**: Explicit go-ahead for next task
|
||||||
|
- [ ] **Documentaion updated**: If relevant changes to module was made.
|
||||||
|
|
||||||
|
## Anti-Patterns to Avoid
|
||||||
|
|
||||||
|
### Large Block Implementation
|
||||||
|
**Don't:**
|
||||||
|
```
|
||||||
|
Implement the entire user management system with authentication,
|
||||||
|
CRUD operations, and email notifications.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Do:**
|
||||||
|
```
|
||||||
|
First, implement just the User model with basic fields.
|
||||||
|
Stop there and let me review before continuing.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Context Loss
|
||||||
|
**Don't:**
|
||||||
|
```
|
||||||
|
Create a new authentication system.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Do:**
|
||||||
|
```
|
||||||
|
Looking at the existing auth patterns in auth.py, implement
|
||||||
|
password validation following the same structure as the
|
||||||
|
existing email validation function.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Over-Engineering
|
||||||
|
**Don't:**
|
||||||
|
```
|
||||||
|
Build a flexible, extensible user management framework that
|
||||||
|
can handle any future requirements.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Do:**
|
||||||
|
```
|
||||||
|
Implement user creation functionality that matches the existing
|
||||||
|
pattern in customer.py, focusing only on the current requirements.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Progress Tracking
|
||||||
|
|
||||||
|
### Task Status Indicators
|
||||||
|
- 🔄 **In Planning**: Requirements gathering and design
|
||||||
|
- ⏳ **In Progress**: Currently implementing
|
||||||
|
- ✅ **Complete**: Implemented, tested, and integrated
|
||||||
|
- 🚫 **Blocked**: Waiting for decisions or dependencies
|
||||||
|
- 🔧 **Needs Refactor**: Working but needs improvement
|
||||||
|
|
||||||
|
### Weekly Review Process
|
||||||
|
1. **Progress Assessment**
|
||||||
|
- What was completed this week?
|
||||||
|
- What challenges were encountered?
|
||||||
|
- How well did the iterative process work?
|
||||||
|
|
||||||
|
2. **Process Adjustment**
|
||||||
|
- Were task sizes appropriate?
|
||||||
|
- Did context management work effectively?
|
||||||
|
- What improvements can be made?
|
||||||
|
|
||||||
|
3. **Architecture Review**
|
||||||
|
- Is the code remaining maintainable?
|
||||||
|
- Are patterns staying consistent?
|
||||||
|
- Is technical debt accumulating?
|
||||||
|
|
||||||
|
## Emergency Procedures
|
||||||
|
|
||||||
|
### When Things Go Wrong
|
||||||
|
If development becomes chaotic or problematic:
|
||||||
|
|
||||||
|
1. **Stop Development**
|
||||||
|
- Don't continue adding to the problem
|
||||||
|
- Take time to assess the situation
|
||||||
|
- Don't rush to "fix" with more AI-generated code
|
||||||
|
|
||||||
|
2. **Assess the Situation**
|
||||||
|
- What specific problems exist?
|
||||||
|
- How far has the code diverged from established patterns?
|
||||||
|
- What parts are still working correctly?
|
||||||
|
|
||||||
|
3. **Recovery Process**
|
||||||
|
- Roll back to last known good state
|
||||||
|
- Update context documentation with lessons learned
|
||||||
|
- Restart with smaller, more focused tasks
|
||||||
|
- Get explicit approval for each step of recovery
|
||||||
|
|
||||||
|
### Context Recovery
|
||||||
|
When AI seems to lose track of project patterns:
|
||||||
|
|
||||||
|
1. **Context Refresh**
|
||||||
|
- Review and update CONTEXT.md
|
||||||
|
- Include examples of current code patterns
|
||||||
|
- Clarify architectural decisions
|
||||||
|
|
||||||
|
2. **Pattern Re-establishment**
|
||||||
|
- Show AI examples of existing, working code
|
||||||
|
- Explicitly state patterns to follow
|
||||||
|
- Start with very small, pattern-matching tasks
|
||||||
|
|
||||||
|
3. **Gradual Re-engagement**
|
||||||
|
- Begin with simple, low-risk tasks
|
||||||
|
- Verify pattern adherence at each step
|
||||||
|
- Gradually increase task complexity as consistency returns
|
||||||
|
|
||||||
|
## Success Metrics
|
||||||
|
|
||||||
|
### Short-term (Daily)
|
||||||
|
- Code is understandable and well-integrated
|
||||||
|
- No major regressions introduced
|
||||||
|
- Development velocity feels sustainable
|
||||||
|
- Team confidence in codebase remains high
|
||||||
|
|
||||||
|
### Medium-term (Weekly)
|
||||||
|
- Technical debt is not accumulating
|
||||||
|
- New features integrate cleanly
|
||||||
|
- Development patterns remain consistent
|
||||||
|
- Documentation stays current
|
||||||
|
|
||||||
|
### Long-term (Monthly)
|
||||||
|
- Codebase remains maintainable as it grows
|
||||||
|
- New team members can understand and contribute
|
||||||
|
- AI assistance enhances rather than hinders development
|
||||||
|
- Architecture remains clean and purposeful
|
||||||
24
.cursor/rules/project.mdc
Normal file
24
.cursor/rules/project.mdc
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
---
|
||||||
|
description:
|
||||||
|
globs:
|
||||||
|
alwaysApply: true
|
||||||
|
---
|
||||||
|
# Rule: Project specific rules
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Unify the project structure and interraction with tools and console
|
||||||
|
|
||||||
|
### System tools
|
||||||
|
- **ALWAYS** use UV for package management
|
||||||
|
- **ALWAYS** use windows PowerShell command for terminal
|
||||||
|
|
||||||
|
### Coding patterns
|
||||||
|
- **ALWYAS** check the arguments and methods before use to avoid errors with whron parameters or names
|
||||||
|
- If in doubt, check [CONTEXT.md](mdc:CONTEXT.md) file and [architecture.md](mdc:docs/architecture.md)
|
||||||
|
- **PREFER** ORM pattern for databases with SQLAclhemy.
|
||||||
|
- **DO NOT USE** emoji in code and comments
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
- Use UV for test in format *uv run pytest [filename]*
|
||||||
|
|
||||||
|
|
||||||
237
.cursor/rules/refactoring.mdc
Normal file
237
.cursor/rules/refactoring.mdc
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
---
|
||||||
|
description: Code refactoring and technical debt management for AI-assisted development
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Code Refactoring and Technical Debt Management
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Guide AI in systematic code refactoring to improve maintainability, reduce complexity, and prevent technical debt accumulation in AI-assisted development projects.
|
||||||
|
|
||||||
|
## When to Apply This Rule
|
||||||
|
- Code complexity has increased beyond manageable levels
|
||||||
|
- Duplicate code patterns are detected
|
||||||
|
- Performance issues are identified
|
||||||
|
- New features are difficult to integrate
|
||||||
|
- Code review reveals maintainability concerns
|
||||||
|
- Weekly technical debt assessment indicates refactoring needs
|
||||||
|
|
||||||
|
## Pre-Refactoring Assessment
|
||||||
|
|
||||||
|
Before starting any refactoring, the AI MUST:
|
||||||
|
|
||||||
|
1. **Context Analysis:**
|
||||||
|
- Review existing `CONTEXT.md` for architectural decisions
|
||||||
|
- Analyze current code patterns and conventions
|
||||||
|
- Identify all files that will be affected (search the codebase for use)
|
||||||
|
- Check for existing tests that verify current behavior
|
||||||
|
|
||||||
|
2. **Scope Definition:**
|
||||||
|
- Clearly define what will and will not be changed
|
||||||
|
- Identify the specific refactoring pattern to apply
|
||||||
|
- Estimate the blast radius of changes
|
||||||
|
- Plan rollback strategy if needed
|
||||||
|
|
||||||
|
3. **Documentation Review:**
|
||||||
|
- Check `./docs/` for relevant module documentation
|
||||||
|
- Review any existing architectural diagrams
|
||||||
|
- Identify dependencies and integration points
|
||||||
|
- Note any known constraints or limitations
|
||||||
|
|
||||||
|
## Refactoring Process
|
||||||
|
|
||||||
|
### Phase 1: Planning and Safety
|
||||||
|
1. **Create Refactoring Plan:**
|
||||||
|
- Document the current state and desired end state
|
||||||
|
- Break refactoring into small, atomic steps
|
||||||
|
- Identify tests that must pass throughout the process
|
||||||
|
- Plan verification steps for each change
|
||||||
|
|
||||||
|
2. **Establish Safety Net:**
|
||||||
|
- Ensure comprehensive test coverage exists
|
||||||
|
- If tests are missing, create them BEFORE refactoring
|
||||||
|
- Document current behavior that must be preserved
|
||||||
|
- Create backup of current implementation approach
|
||||||
|
|
||||||
|
3. **Get Approval:**
|
||||||
|
- Present the refactoring plan to the user
|
||||||
|
- Wait for explicit "Go" or "Proceed" confirmation
|
||||||
|
- Do NOT start refactoring without approval
|
||||||
|
|
||||||
|
### Phase 2: Incremental Implementation
|
||||||
|
4. **One Change at a Time:**
|
||||||
|
- Implement ONE refactoring step per iteration
|
||||||
|
- Run tests after each step to ensure nothing breaks
|
||||||
|
- Update documentation if interfaces change
|
||||||
|
- Mark progress in the refactoring plan
|
||||||
|
|
||||||
|
5. **Verification Protocol:**
|
||||||
|
- Run all relevant tests after each change
|
||||||
|
- Verify functionality works as expected
|
||||||
|
- Check performance hasn't degraded
|
||||||
|
- Ensure no new linting or type errors
|
||||||
|
|
||||||
|
6. **User Checkpoint:**
|
||||||
|
- After each significant step, pause for user review
|
||||||
|
- Present what was changed and current status
|
||||||
|
- Wait for approval before continuing
|
||||||
|
- Address any concerns before proceeding
|
||||||
|
|
||||||
|
### Phase 3: Completion and Documentation
|
||||||
|
7. **Final Verification:**
|
||||||
|
- Run full test suite to ensure nothing is broken
|
||||||
|
- Verify all original functionality is preserved
|
||||||
|
- Check that new code follows project conventions
|
||||||
|
- Confirm performance is maintained or improved
|
||||||
|
|
||||||
|
8. **Documentation Update:**
|
||||||
|
- Update `CONTEXT.md` with new patterns/decisions
|
||||||
|
- Update module documentation in `./docs/`
|
||||||
|
- Document any new conventions established
|
||||||
|
- Note lessons learned for future refactoring
|
||||||
|
|
||||||
|
## Common Refactoring Patterns
|
||||||
|
|
||||||
|
### Extract Method/Function
|
||||||
|
```
|
||||||
|
WHEN: Functions/methods exceed 50 lines or have multiple responsibilities
|
||||||
|
HOW:
|
||||||
|
1. Identify logical groupings within the function
|
||||||
|
2. Extract each group into a well-named helper function
|
||||||
|
3. Ensure each function has a single responsibility
|
||||||
|
4. Verify tests still pass
|
||||||
|
```
|
||||||
|
|
||||||
|
### Extract Module/Class
|
||||||
|
```
|
||||||
|
WHEN: Files exceed 250 lines or handle multiple concerns
|
||||||
|
HOW:
|
||||||
|
1. Identify cohesive functionality groups
|
||||||
|
2. Create new files for each group
|
||||||
|
3. Move related functions/classes together
|
||||||
|
4. Update imports and dependencies
|
||||||
|
5. Verify module boundaries are clean
|
||||||
|
```
|
||||||
|
|
||||||
|
### Eliminate Duplication
|
||||||
|
```
|
||||||
|
WHEN: Similar code appears in multiple places
|
||||||
|
HOW:
|
||||||
|
1. Identify the common pattern or functionality
|
||||||
|
2. Extract to a shared utility function or module
|
||||||
|
3. Update all usage sites to use the shared code
|
||||||
|
4. Ensure the abstraction is not over-engineered
|
||||||
|
```
|
||||||
|
|
||||||
|
### Improve Data Structures
|
||||||
|
```
|
||||||
|
WHEN: Complex nested objects or unclear data flow
|
||||||
|
HOW:
|
||||||
|
1. Define clear interfaces/types for data structures
|
||||||
|
2. Create transformation functions between different representations
|
||||||
|
3. Ensure data flow is unidirectional where possible
|
||||||
|
4. Add validation at boundaries
|
||||||
|
```
|
||||||
|
|
||||||
|
### Reduce Coupling
|
||||||
|
```
|
||||||
|
WHEN: Modules are tightly interconnected
|
||||||
|
HOW:
|
||||||
|
1. Identify dependencies between modules
|
||||||
|
2. Extract interfaces for external dependencies
|
||||||
|
3. Use dependency injection where appropriate
|
||||||
|
4. Ensure modules can be tested in isolation
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quality Gates
|
||||||
|
|
||||||
|
Every refactoring must pass these gates:
|
||||||
|
|
||||||
|
### Technical Quality
|
||||||
|
- [ ] All existing tests pass
|
||||||
|
- [ ] No new linting errors introduced
|
||||||
|
- [ ] Code follows established project conventions
|
||||||
|
- [ ] No performance regression detected
|
||||||
|
- [ ] File sizes remain under 250 lines
|
||||||
|
- [ ] Function sizes remain under 50 lines
|
||||||
|
|
||||||
|
### Maintainability
|
||||||
|
- [ ] Code is more readable than before
|
||||||
|
- [ ] Duplicated code has been reduced
|
||||||
|
- [ ] Module responsibilities are clearer
|
||||||
|
- [ ] Dependencies are explicit and minimal
|
||||||
|
- [ ] Error handling is consistent
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
- [ ] Public interfaces are documented
|
||||||
|
- [ ] Complex logic has explanatory comments
|
||||||
|
- [ ] Architectural decisions are recorded
|
||||||
|
- [ ] Examples are provided where helpful
|
||||||
|
|
||||||
|
## AI Instructions for Refactoring
|
||||||
|
|
||||||
|
1. **Always ask for permission** before starting any refactoring work
|
||||||
|
2. **Start with tests** - ensure comprehensive coverage before changing code
|
||||||
|
3. **Work incrementally** - make small changes and verify each step
|
||||||
|
4. **Preserve behavior** - functionality must remain exactly the same
|
||||||
|
5. **Update documentation** - keep all docs current with changes
|
||||||
|
6. **Follow conventions** - maintain consistency with existing codebase
|
||||||
|
7. **Stop and ask** if any step fails or produces unexpected results
|
||||||
|
8. **Explain changes** - clearly communicate what was changed and why
|
||||||
|
|
||||||
|
## Anti-Patterns to Avoid
|
||||||
|
|
||||||
|
### Over-Engineering
|
||||||
|
- Don't create abstractions for code that isn't duplicated
|
||||||
|
- Avoid complex inheritance hierarchies
|
||||||
|
- Don't optimize prematurely
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
- Never change public APIs without explicit approval
|
||||||
|
- Don't remove functionality, even if it seems unused
|
||||||
|
- Avoid changing behavior "while we're here"
|
||||||
|
|
||||||
|
### Scope Creep
|
||||||
|
- Stick to the defined refactoring scope
|
||||||
|
- Don't add new features during refactoring
|
||||||
|
- Resist the urge to "improve" unrelated code
|
||||||
|
|
||||||
|
## Success Metrics
|
||||||
|
|
||||||
|
Track these metrics to ensure refactoring effectiveness:
|
||||||
|
|
||||||
|
### Code Quality
|
||||||
|
- Reduced cyclomatic complexity
|
||||||
|
- Lower code duplication percentage
|
||||||
|
- Improved test coverage
|
||||||
|
- Fewer linting violations
|
||||||
|
|
||||||
|
### Developer Experience
|
||||||
|
- Faster time to understand code
|
||||||
|
- Easier integration of new features
|
||||||
|
- Reduced bug introduction rate
|
||||||
|
- Higher developer confidence in changes
|
||||||
|
|
||||||
|
### Maintainability
|
||||||
|
- Clearer module boundaries
|
||||||
|
- More predictable behavior
|
||||||
|
- Easier debugging and troubleshooting
|
||||||
|
- Better performance characteristics
|
||||||
|
|
||||||
|
## Output Files
|
||||||
|
|
||||||
|
When refactoring is complete, update:
|
||||||
|
- `refactoring-log-[date].md` - Document what was changed and why
|
||||||
|
- `CONTEXT.md` - Update with new patterns and decisions
|
||||||
|
- `./docs/` - Update relevant module documentation
|
||||||
|
- Task lists - Mark refactoring tasks as complete
|
||||||
|
|
||||||
|
## Final Verification
|
||||||
|
|
||||||
|
Before marking refactoring complete:
|
||||||
|
1. Run full test suite and verify all tests pass
|
||||||
|
2. Check that code follows all project conventions
|
||||||
|
3. Verify documentation is up to date
|
||||||
|
4. Confirm user is satisfied with the results
|
||||||
|
5. Record lessons learned for future refactoring efforts
|
||||||
44
.cursor/rules/task-list.mdc
Normal file
44
.cursor/rules/task-list.mdc
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
---
|
||||||
|
description: TODO list task implementation
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
---
|
||||||
|
description:
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
# Task List Management
|
||||||
|
|
||||||
|
Guidelines for managing task lists in markdown files to track progress on completing a PRD
|
||||||
|
|
||||||
|
## Task Implementation
|
||||||
|
- **One sub-task at a time:** Do **NOT** start the next sub‑task until you ask the user for permission and they say “yes” or "y"
|
||||||
|
- **Completion protocol:**
|
||||||
|
1. When you finish a **sub‑task**, immediately mark it as completed by changing `[ ]` to `[x]`.
|
||||||
|
2. If **all** subtasks underneath a parent task are now `[x]`, also mark the **parent task** as completed.
|
||||||
|
- Stop after each sub‑task and wait for the user’s go‑ahead.
|
||||||
|
|
||||||
|
## Task List Maintenance
|
||||||
|
|
||||||
|
1. **Update the task list as you work:**
|
||||||
|
- Mark tasks and subtasks as completed (`[x]`) per the protocol above.
|
||||||
|
- Add new tasks as they emerge.
|
||||||
|
|
||||||
|
2. **Maintain the “Relevant Files” section:**
|
||||||
|
- List every file created or modified.
|
||||||
|
- Give each file a one‑line description of its purpose.
|
||||||
|
|
||||||
|
## AI Instructions
|
||||||
|
|
||||||
|
When working with task lists, the AI must:
|
||||||
|
|
||||||
|
1. Regularly update the task list file after finishing any significant work.
|
||||||
|
2. Follow the completion protocol:
|
||||||
|
- Mark each finished **sub‑task** `[x]`.
|
||||||
|
- Mark the **parent task** `[x]` once **all** its subtasks are `[x]`.
|
||||||
|
3. Add newly discovered tasks.
|
||||||
|
4. Keep “Relevant Files” accurate and up to date.
|
||||||
|
5. Before starting work, check which sub‑task is next.
|
||||||
|
|
||||||
|
6. After implementing a sub‑task, update the file and then pause for user approval.
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,10 +1,13 @@
|
|||||||
# ---> Python
|
# ---> Python
|
||||||
|
/data/*.db
|
||||||
|
/credentials/*.json
|
||||||
*.csv
|
*.csv
|
||||||
*.png
|
*.png
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
/data/*.npy
|
||||||
|
|
||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
|||||||
513
README.md
513
README.md
@@ -1 +1,512 @@
|
|||||||
# Cycles
|
# Cycles - Cryptocurrency Trading Strategy Backtesting Framework
|
||||||
|
|
||||||
|
A comprehensive Python framework for backtesting cryptocurrency trading strategies using technical indicators, with advanced features like machine learning price prediction to eliminate lookahead bias.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Overview](#overview)
|
||||||
|
- [Features](#features)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Project Structure](#project-structure)
|
||||||
|
- [Core Modules](#core-modules)
|
||||||
|
- [Configuration](#configuration)
|
||||||
|
- [Usage Examples](#usage-examples)
|
||||||
|
- [API Documentation](#api-documentation)
|
||||||
|
- [Testing](#testing)
|
||||||
|
- [Contributing](#contributing)
|
||||||
|
- [License](#license)
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Cycles is a sophisticated backtesting framework designed specifically for cryptocurrency trading strategies. It provides robust tools for:
|
||||||
|
|
||||||
|
- **Strategy Backtesting**: Test trading strategies across multiple timeframes with comprehensive metrics
|
||||||
|
- **Technical Analysis**: Built-in indicators including SuperTrend, RSI, Bollinger Bands, and more
|
||||||
|
- **Machine Learning Integration**: Eliminate lookahead bias using XGBoost price prediction
|
||||||
|
- **Multi-timeframe Analysis**: Support for various timeframes from 1-minute to daily data
|
||||||
|
- **Performance Analytics**: Detailed reporting with profit ratios, drawdowns, win rates, and fee calculations
|
||||||
|
|
||||||
|
### Key Goals
|
||||||
|
|
||||||
|
1. **Realistic Trading Simulation**: Eliminate common backtesting pitfalls like lookahead bias
|
||||||
|
2. **Modular Architecture**: Easy to extend with new indicators and strategies
|
||||||
|
3. **Performance Optimization**: Parallel processing for efficient large-scale backtesting
|
||||||
|
4. **Comprehensive Analysis**: Rich reporting and visualization capabilities
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
### 🚀 Core Features
|
||||||
|
|
||||||
|
- **Multi-Strategy Backtesting**: Test multiple trading strategies simultaneously
|
||||||
|
- **Advanced Stop Loss Management**: Precise stop-loss execution using 1-minute data
|
||||||
|
- **Fee Integration**: Realistic trading fee calculations (OKX exchange fees)
|
||||||
|
- **Parallel Processing**: Efficient multi-core backtesting execution
|
||||||
|
- **Rich Analytics**: Comprehensive performance metrics and reporting
|
||||||
|
|
||||||
|
### 📊 Technical Indicators
|
||||||
|
|
||||||
|
- **SuperTrend**: Multi-parameter SuperTrend indicator with meta-trend analysis
|
||||||
|
- **RSI**: Relative Strength Index with customizable periods
|
||||||
|
- **Bollinger Bands**: Configurable period and standard deviation multipliers
|
||||||
|
- **Extensible Framework**: Easy to add new technical indicators
|
||||||
|
|
||||||
|
### 🤖 Machine Learning
|
||||||
|
|
||||||
|
- **Price Prediction**: XGBoost-based closing price prediction
|
||||||
|
- **Lookahead Bias Elimination**: Realistic trading simulations
|
||||||
|
- **Feature Engineering**: Advanced technical feature extraction
|
||||||
|
- **Model Persistence**: Save and load trained models
|
||||||
|
|
||||||
|
### 📈 Data Management
|
||||||
|
|
||||||
|
- **Multiple Data Sources**: Support for various cryptocurrency exchanges
|
||||||
|
- **Flexible Timeframes**: 1-minute to daily data aggregation
|
||||||
|
- **Efficient Storage**: Optimized data loading and caching
|
||||||
|
- **Google Sheets Integration**: External data source connectivity
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Python 3.10 or higher
|
||||||
|
- UV package manager (recommended)
|
||||||
|
- Git
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
1. **Clone the repository**:
|
||||||
|
```bash
|
||||||
|
git clone <repository-url>
|
||||||
|
cd Cycles
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Install dependencies**:
|
||||||
|
```bash
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Activate virtual environment**:
|
||||||
|
```bash
|
||||||
|
source .venv/bin/activate # Linux/Mac
|
||||||
|
# or
|
||||||
|
.venv\Scripts\activate # Windows
|
||||||
|
```
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
1. **Prepare your configuration file** (`config.json`):
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"start_date": "2023-01-01",
|
||||||
|
"stop_date": "2023-12-31",
|
||||||
|
"initial_usd": 10000,
|
||||||
|
"timeframes": ["5T", "15T", "1H", "4H"],
|
||||||
|
"stop_loss_pcts": [0.02, 0.05, 0.10]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Run a backtest**:
|
||||||
|
```bash
|
||||||
|
uv run python main.py --config config.json
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **View results**:
|
||||||
|
Results will be saved in timestamped CSV files with comprehensive metrics.
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
Cycles/
|
||||||
|
├── cycles/ # Core library modules
|
||||||
|
│ ├── Analysis/ # Technical analysis indicators
|
||||||
|
│ │ ├── boillinger_band.py
|
||||||
|
│ │ ├── rsi.py
|
||||||
|
│ │ └── __init__.py
|
||||||
|
│ ├── utils/ # Utility modules
|
||||||
|
│ │ ├── storage.py # Data storage and management
|
||||||
|
│ │ ├── system.py # System utilities
|
||||||
|
│ │ ├── data_utils.py # Data processing utilities
|
||||||
|
│ │ └── gsheets.py # Google Sheets integration
|
||||||
|
│ ├── backtest.py # Core backtesting engine
|
||||||
|
│ ├── supertrend.py # SuperTrend indicator implementation
|
||||||
|
│ ├── charts.py # Visualization utilities
|
||||||
|
│ ├── market_fees.py # Trading fee calculations
|
||||||
|
│ └── __init__.py
|
||||||
|
├── docs/ # Documentation
|
||||||
|
│ ├── analysis.md # Analysis module documentation
|
||||||
|
│ ├── utils_storage.md # Storage utilities documentation
|
||||||
|
│ └── utils_system.md # System utilities documentation
|
||||||
|
├── data/ # Data directory (not in repo)
|
||||||
|
├── results/ # Backtest results (not in repo)
|
||||||
|
├── xgboost/ # Machine learning components
|
||||||
|
├── OHLCVPredictor/ # Price prediction module
|
||||||
|
├── main.py # Main execution script
|
||||||
|
├── test_bbrsi.py # Example strategy test
|
||||||
|
├── pyproject.toml # Project configuration
|
||||||
|
├── requirements.txt # Dependencies
|
||||||
|
├── uv.lock # UV lock file
|
||||||
|
└── README.md # This file
|
||||||
|
```
|
||||||
|
|
||||||
|
## Core Modules
|
||||||
|
|
||||||
|
### Backtest Engine (`cycles/backtest.py`)
|
||||||
|
|
||||||
|
The heart of the framework, providing comprehensive backtesting capabilities:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from cycles.backtest import Backtest
|
||||||
|
|
||||||
|
results = Backtest.run(
|
||||||
|
min1_df=minute_data,
|
||||||
|
df=timeframe_data,
|
||||||
|
initial_usd=10000,
|
||||||
|
stop_loss_pct=0.05,
|
||||||
|
debug=False
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Features**:
|
||||||
|
- Meta-SuperTrend strategy implementation
|
||||||
|
- Precise stop-loss execution using 1-minute data
|
||||||
|
- Comprehensive trade logging and statistics
|
||||||
|
- Fee-aware profit calculations
|
||||||
|
|
||||||
|
### Technical Analysis (`cycles/Analysis/`)
|
||||||
|
|
||||||
|
Modular technical indicator implementations:
|
||||||
|
|
||||||
|
#### RSI (Relative Strength Index)
|
||||||
|
```python
|
||||||
|
from cycles.Analysis.rsi import RSI
|
||||||
|
|
||||||
|
rsi_calculator = RSI(period=14)
|
||||||
|
data_with_rsi = rsi_calculator.calculate(df, price_column='close')
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Bollinger Bands
|
||||||
|
```python
|
||||||
|
from cycles.Analysis.boillinger_band import BollingerBands
|
||||||
|
|
||||||
|
bb = BollingerBands(period=20, std_dev_multiplier=2.0)
|
||||||
|
data_with_bb = bb.calculate(df)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Data Management (`cycles/utils/storage.py`)
|
||||||
|
|
||||||
|
Efficient data loading, processing, and result storage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from cycles.utils.storage import Storage
|
||||||
|
|
||||||
|
storage = Storage(data_dir='./data', logging=logging)
|
||||||
|
data = storage.load_data('btcusd_1-min_data.csv', '2023-01-01', '2023-12-31')
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Backtest Configuration
|
||||||
|
|
||||||
|
Create a `config.json` file with the following structure:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"start_date": "2023-01-01",
|
||||||
|
"stop_date": "2023-12-31",
|
||||||
|
"initial_usd": 10000,
|
||||||
|
"timeframes": [
|
||||||
|
"1T", // 1 minute
|
||||||
|
"5T", // 5 minutes
|
||||||
|
"15T", // 15 minutes
|
||||||
|
"1H", // 1 hour
|
||||||
|
"4H", // 4 hours
|
||||||
|
"1D" // 1 day
|
||||||
|
],
|
||||||
|
"stop_loss_pcts": [0.02, 0.05, 0.10, 0.15]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
Set the following environment variables for enhanced functionality:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Google Sheets integration (optional)
|
||||||
|
export GOOGLE_SHEETS_CREDENTIALS_PATH="/path/to/credentials.json"
|
||||||
|
|
||||||
|
# Data directory (optional, defaults to ./data)
|
||||||
|
export DATA_DIR="/path/to/data"
|
||||||
|
|
||||||
|
# Results directory (optional, defaults to ./results)
|
||||||
|
export RESULTS_DIR="/path/to/results"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### Basic Backtest
|
||||||
|
|
||||||
|
```python
|
||||||
|
import json
|
||||||
|
from cycles.utils.storage import Storage
|
||||||
|
from cycles.backtest import Backtest
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
with open('config.json', 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
# Initialize storage
|
||||||
|
storage = Storage(data_dir='./data')
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
data_1min = storage.load_data(
|
||||||
|
'btcusd_1-min_data.csv',
|
||||||
|
config['start_date'],
|
||||||
|
config['stop_date']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run backtest
|
||||||
|
results = Backtest.run(
|
||||||
|
min1_df=data_1min,
|
||||||
|
df=data_1min, # Same data for 1-minute strategy
|
||||||
|
initial_usd=config['initial_usd'],
|
||||||
|
stop_loss_pct=0.05,
|
||||||
|
debug=True
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Final USD: {results['final_usd']:.2f}")
|
||||||
|
print(f"Number of trades: {results['n_trades']}")
|
||||||
|
print(f"Win rate: {results['win_rate']:.2%}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-Timeframe Analysis
|
||||||
|
|
||||||
|
```python
|
||||||
|
from main import process
|
||||||
|
|
||||||
|
# Define timeframes to test
|
||||||
|
timeframes = ['5T', '15T', '1H', '4H']
|
||||||
|
stop_loss_pcts = [0.02, 0.05, 0.10]
|
||||||
|
|
||||||
|
# Create tasks for parallel processing
|
||||||
|
tasks = [
|
||||||
|
(timeframe, data_1min, stop_loss_pct, 10000)
|
||||||
|
for timeframe in timeframes
|
||||||
|
for stop_loss_pct in stop_loss_pcts
|
||||||
|
]
|
||||||
|
|
||||||
|
# Process each task
|
||||||
|
for task in tasks:
|
||||||
|
results, trades = process(task, debug=False)
|
||||||
|
print(f"Timeframe: {task[0]}, Stop Loss: {task[2]:.1%}")
|
||||||
|
for result in results:
|
||||||
|
print(f" Final USD: {result['final_usd']:.2f}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Strategy Development
|
||||||
|
|
||||||
|
```python
|
||||||
|
from cycles.Analysis.rsi import RSI
|
||||||
|
from cycles.Analysis.boillinger_band import BollingerBands
|
||||||
|
|
||||||
|
def custom_strategy(df):
|
||||||
|
"""Example custom trading strategy using RSI and Bollinger Bands"""
|
||||||
|
|
||||||
|
# Calculate indicators
|
||||||
|
rsi = RSI(period=14)
|
||||||
|
bb = BollingerBands(period=20, std_dev_multiplier=2.0)
|
||||||
|
|
||||||
|
df_with_rsi = rsi.calculate(df.copy())
|
||||||
|
df_with_bb = bb.calculate(df_with_rsi)
|
||||||
|
|
||||||
|
# Define signals
|
||||||
|
buy_signals = (
|
||||||
|
(df_with_bb['close'] < df_with_bb['LowerBand']) &
|
||||||
|
(df_with_bb['RSI'] < 30)
|
||||||
|
)
|
||||||
|
|
||||||
|
sell_signals = (
|
||||||
|
(df_with_bb['close'] > df_with_bb['UpperBand']) &
|
||||||
|
(df_with_bb['RSI'] > 70)
|
||||||
|
)
|
||||||
|
|
||||||
|
return buy_signals, sell_signals
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Documentation
|
||||||
|
|
||||||
|
### Core Classes
|
||||||
|
|
||||||
|
#### `Backtest`
|
||||||
|
Main backtesting engine with static methods for strategy execution.
|
||||||
|
|
||||||
|
**Methods**:
|
||||||
|
- `run(min1_df, df, initial_usd, stop_loss_pct, debug=False)`: Execute backtest
|
||||||
|
- `check_stop_loss(...)`: Check stop-loss conditions using 1-minute data
|
||||||
|
- `handle_entry(...)`: Process trade entry logic
|
||||||
|
- `handle_exit(...)`: Process trade exit logic
|
||||||
|
|
||||||
|
#### `Storage`
|
||||||
|
Data management and persistence utilities.
|
||||||
|
|
||||||
|
**Methods**:
|
||||||
|
- `load_data(filename, start_date, stop_date)`: Load and filter historical data
|
||||||
|
- `save_data(df, filename)`: Save processed data
|
||||||
|
- `write_backtest_results(...)`: Save backtest results to CSV
|
||||||
|
|
||||||
|
#### `SystemUtils`
|
||||||
|
System optimization and resource management.
|
||||||
|
|
||||||
|
**Methods**:
|
||||||
|
- `get_optimal_workers()`: Determine optimal number of parallel workers
|
||||||
|
- `get_memory_usage()`: Monitor memory consumption
|
||||||
|
|
||||||
|
### Configuration Parameters
|
||||||
|
|
||||||
|
| Parameter | Type | Description | Default |
|
||||||
|
|-----------|------|-------------|---------|
|
||||||
|
| `start_date` | string | Backtest start date (YYYY-MM-DD) | Required |
|
||||||
|
| `stop_date` | string | Backtest end date (YYYY-MM-DD) | Required |
|
||||||
|
| `initial_usd` | float | Starting capital in USD | Required |
|
||||||
|
| `timeframes` | array | List of timeframes to test | Required |
|
||||||
|
| `stop_loss_pcts` | array | Stop-loss percentages to test | Required |
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all tests
|
||||||
|
uv run pytest
|
||||||
|
|
||||||
|
# Run specific test file
|
||||||
|
uv run pytest test_bbrsi.py
|
||||||
|
|
||||||
|
# Run with verbose output
|
||||||
|
uv run pytest -v
|
||||||
|
|
||||||
|
# Run with coverage
|
||||||
|
uv run pytest --cov=cycles
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test Structure
|
||||||
|
|
||||||
|
- `test_bbrsi.py`: Example strategy testing with RSI and Bollinger Bands
|
||||||
|
- Unit tests for individual modules (add as needed)
|
||||||
|
- Integration tests for complete workflows
|
||||||
|
|
||||||
|
### Example Test
|
||||||
|
|
||||||
|
```python
|
||||||
|
# test_bbrsi.py demonstrates strategy testing
|
||||||
|
from cycles.Analysis.rsi import RSI
|
||||||
|
from cycles.Analysis.boillinger_band import BollingerBands
|
||||||
|
|
||||||
|
def test_strategy_signals():
|
||||||
|
# Load test data
|
||||||
|
storage = Storage()
|
||||||
|
data = storage.load_data('test_data.csv', '2023-01-01', '2023-02-01')
|
||||||
|
|
||||||
|
# Calculate indicators
|
||||||
|
rsi = RSI(period=14)
|
||||||
|
bb = BollingerBands(period=20)
|
||||||
|
|
||||||
|
data_with_indicators = bb.calculate(rsi.calculate(data))
|
||||||
|
|
||||||
|
# Test signal generation
|
||||||
|
assert 'RSI' in data_with_indicators.columns
|
||||||
|
assert 'UpperBand' in data_with_indicators.columns
|
||||||
|
assert 'LowerBand' in data_with_indicators.columns
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
### Development Setup
|
||||||
|
|
||||||
|
1. Fork the repository
|
||||||
|
2. Create a feature branch: `git checkout -b feature/new-indicator`
|
||||||
|
3. Install development dependencies: `uv sync --dev`
|
||||||
|
4. Make your changes following the coding standards
|
||||||
|
5. Add tests for new functionality
|
||||||
|
6. Run tests: `uv run pytest`
|
||||||
|
7. Submit a pull request
|
||||||
|
|
||||||
|
### Coding Standards
|
||||||
|
|
||||||
|
- **Maximum file size**: 250 lines
|
||||||
|
- **Maximum function size**: 50 lines
|
||||||
|
- **Documentation**: All public functions must have docstrings
|
||||||
|
- **Type hints**: Use type hints for all function parameters and returns
|
||||||
|
- **Error handling**: Include proper error handling and meaningful error messages
|
||||||
|
- **No emoji**: Avoid emoji in code and comments
|
||||||
|
|
||||||
|
### Adding New Indicators
|
||||||
|
|
||||||
|
1. Create a new file in `cycles/Analysis/`
|
||||||
|
2. Follow the existing pattern (see `rsi.py` or `boillinger_band.py`)
|
||||||
|
3. Include comprehensive docstrings and type hints
|
||||||
|
4. Add tests for the new indicator
|
||||||
|
5. Update documentation
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
### Optimization Tips
|
||||||
|
|
||||||
|
1. **Parallel Processing**: Use the built-in parallel processing for multiple timeframes
|
||||||
|
2. **Data Caching**: Cache frequently used calculations
|
||||||
|
3. **Memory Management**: Monitor memory usage for large datasets
|
||||||
|
4. **Efficient Data Types**: Use appropriate pandas data types
|
||||||
|
|
||||||
|
### Benchmarks
|
||||||
|
|
||||||
|
Typical performance on modern hardware:
|
||||||
|
- **1-minute data**: ~1M candles processed in 2-3 minutes
|
||||||
|
- **Multiple timeframes**: 4 timeframes × 4 stop-loss values in 5-10 minutes
|
||||||
|
- **Memory usage**: ~2-4GB for 1 year of 1-minute BTC data
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **Memory errors with large datasets**:
|
||||||
|
- Reduce date range or use data chunking
|
||||||
|
- Increase system RAM or use swap space
|
||||||
|
|
||||||
|
2. **Slow performance**:
|
||||||
|
- Enable parallel processing
|
||||||
|
- Reduce number of timeframes/stop-loss values
|
||||||
|
- Use SSD storage for data files
|
||||||
|
|
||||||
|
3. **Missing data errors**:
|
||||||
|
- Verify data file format and column names
|
||||||
|
- Check date range availability in data
|
||||||
|
- Ensure proper data cleaning
|
||||||
|
|
||||||
|
### Debug Mode
|
||||||
|
|
||||||
|
Enable debug mode for detailed logging:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Set debug=True for detailed output
|
||||||
|
results = Backtest.run(..., debug=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This project is licensed under the MIT License. See the LICENSE file for details.
|
||||||
|
|
||||||
|
## Changelog
|
||||||
|
|
||||||
|
### Version 0.1.0 (Current)
|
||||||
|
- Initial release
|
||||||
|
- Core backtesting framework
|
||||||
|
- SuperTrend strategy implementation
|
||||||
|
- Technical indicators (RSI, Bollinger Bands)
|
||||||
|
- Multi-timeframe analysis
|
||||||
|
- Machine learning price prediction
|
||||||
|
- Parallel processing support
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
For more detailed documentation, see the `docs/` directory or visit our [documentation website](link-to-docs).
|
||||||
|
|
||||||
|
**Support**: For questions or issues, please create an issue on GitHub or contact the development team.
|
||||||
462
backtest_runner.py
Normal file
462
backtest_runner.py
Normal file
@@ -0,0 +1,462 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import concurrent.futures
|
||||||
|
import logging
|
||||||
|
from typing import List, Tuple, Dict, Any, Optional
|
||||||
|
|
||||||
|
from cycles.utils.storage import Storage
|
||||||
|
from cycles.utils.system import SystemUtils
|
||||||
|
from cycles.utils.progress_manager import ProgressManager
|
||||||
|
from result_processor import ResultProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def _process_single_task_static(task: Tuple[str, str, pd.DataFrame, float, float], progress_callback=None) -> Tuple[List[Dict], List[Dict]]:
|
||||||
|
"""
|
||||||
|
Static version of _process_single_task for use with ProcessPoolExecutor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Tuple of (task_id, timeframe, data_1min, stop_loss_pct, initial_usd)
|
||||||
|
progress_callback: Optional progress callback function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (results, trades)
|
||||||
|
"""
|
||||||
|
task_id, timeframe, data_1min, stop_loss_pct, initial_usd = task
|
||||||
|
|
||||||
|
try:
|
||||||
|
if timeframe == "1T" or timeframe == "1min":
|
||||||
|
df = data_1min.copy()
|
||||||
|
else:
|
||||||
|
df = _resample_data_static(data_1min, timeframe)
|
||||||
|
|
||||||
|
# Create required components for processing
|
||||||
|
from cycles.utils.storage import Storage
|
||||||
|
from result_processor import ResultProcessor
|
||||||
|
|
||||||
|
# Create storage with default paths (for subprocess)
|
||||||
|
storage = Storage()
|
||||||
|
result_processor = ResultProcessor(storage)
|
||||||
|
|
||||||
|
results, trades = result_processor.process_timeframe_results(
|
||||||
|
data_1min,
|
||||||
|
df,
|
||||||
|
[stop_loss_pct],
|
||||||
|
timeframe,
|
||||||
|
initial_usd,
|
||||||
|
progress_callback=progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
return results, trades
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to process {timeframe} with stop loss {stop_loss_pct}: {e}"
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
|
||||||
|
def _resample_data_static(data_1min: pd.DataFrame, timeframe: str) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Static function to resample 1-minute data to specified timeframe
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_1min: 1-minute data DataFrame
|
||||||
|
timeframe: Target timeframe string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resampled DataFrame
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agg_dict = {
|
||||||
|
'open': 'first',
|
||||||
|
'high': 'max',
|
||||||
|
'low': 'min',
|
||||||
|
'close': 'last',
|
||||||
|
'volume': 'sum'
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'predicted_close_price' in data_1min.columns:
|
||||||
|
agg_dict['predicted_close_price'] = 'last'
|
||||||
|
|
||||||
|
resampled = data_1min.resample(timeframe).agg(agg_dict).dropna()
|
||||||
|
|
||||||
|
return resampled.reset_index()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to resample data to {timeframe}: {e}"
|
||||||
|
raise ValueError(error_msg) from e
|
||||||
|
|
||||||
|
|
||||||
|
class BacktestRunner:
|
||||||
|
"""Handles the execution of backtests across multiple timeframes and parameters"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage: Storage,
|
||||||
|
system_utils: SystemUtils,
|
||||||
|
result_processor: ResultProcessor,
|
||||||
|
logging_instance: Optional[logging.Logger] = None,
|
||||||
|
show_progress: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize backtest runner
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage: Storage instance for data operations
|
||||||
|
system_utils: System utilities for resource management
|
||||||
|
result_processor: Result processor for handling outputs
|
||||||
|
logging_instance: Optional logging instance
|
||||||
|
show_progress: Whether to show visual progress bars
|
||||||
|
"""
|
||||||
|
self.storage = storage
|
||||||
|
self.system_utils = system_utils
|
||||||
|
self.result_processor = result_processor
|
||||||
|
self.logging = logging_instance
|
||||||
|
self.show_progress = show_progress
|
||||||
|
self.progress_manager = ProgressManager() if show_progress else None
|
||||||
|
|
||||||
|
def run_backtests(
|
||||||
|
self,
|
||||||
|
data_1min: pd.DataFrame,
|
||||||
|
timeframes: List[str],
|
||||||
|
stop_loss_pcts: List[float],
|
||||||
|
initial_usd: float,
|
||||||
|
debug: bool = False
|
||||||
|
) -> Tuple[List[Dict], List[Dict]]:
|
||||||
|
"""
|
||||||
|
Run backtests across all timeframe and stop loss combinations
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_1min: 1-minute data DataFrame
|
||||||
|
timeframes: List of timeframe strings (e.g., ['1D', '6h'])
|
||||||
|
stop_loss_pcts: List of stop loss percentages
|
||||||
|
initial_usd: Initial USD amount
|
||||||
|
debug: Whether to enable debug mode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (all_results, all_trades)
|
||||||
|
"""
|
||||||
|
# Create tasks for all combinations
|
||||||
|
tasks = self._create_tasks(timeframes, stop_loss_pcts, data_1min, initial_usd)
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Starting {len(tasks)} backtest tasks")
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
return self._run_sequential(tasks)
|
||||||
|
else:
|
||||||
|
return self._run_parallel(tasks)
|
||||||
|
|
||||||
|
def _create_tasks(
|
||||||
|
self,
|
||||||
|
timeframes: List[str],
|
||||||
|
stop_loss_pcts: List[float],
|
||||||
|
data_1min: pd.DataFrame,
|
||||||
|
initial_usd: float
|
||||||
|
) -> List[Tuple]:
|
||||||
|
"""Create task tuples for processing"""
|
||||||
|
tasks = []
|
||||||
|
for timeframe in timeframes:
|
||||||
|
for stop_loss_pct in stop_loss_pcts:
|
||||||
|
task_id = f"{timeframe}_{stop_loss_pct}"
|
||||||
|
task = (task_id, timeframe, data_1min, stop_loss_pct, initial_usd)
|
||||||
|
tasks.append(task)
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _run_sequential(self, tasks: List[Tuple]) -> Tuple[List[Dict], List[Dict]]:
|
||||||
|
"""Run tasks sequentially (for debug mode)"""
|
||||||
|
# Initialize progress tracking if enabled
|
||||||
|
if self.progress_manager:
|
||||||
|
for task in tasks:
|
||||||
|
task_id, timeframe, data_1min, stop_loss_pct, initial_usd = task
|
||||||
|
|
||||||
|
# Calculate actual DataFrame size that will be processed
|
||||||
|
if timeframe == "1T" or timeframe == "1min":
|
||||||
|
actual_df_size = len(data_1min)
|
||||||
|
else:
|
||||||
|
# Get the actual resampled DataFrame size
|
||||||
|
temp_df = self._resample_data(data_1min, timeframe)
|
||||||
|
actual_df_size = len(temp_df)
|
||||||
|
|
||||||
|
task_name = f"{timeframe} SL:{stop_loss_pct:.0%}"
|
||||||
|
self.progress_manager.start_task(task_id, task_name, actual_df_size)
|
||||||
|
|
||||||
|
self.progress_manager.start_display()
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
all_trades = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
for task in tasks:
|
||||||
|
try:
|
||||||
|
# Get progress callback for this task if available
|
||||||
|
progress_callback = None
|
||||||
|
if self.progress_manager:
|
||||||
|
progress_callback = self.progress_manager.get_task_progress_callback(task[0])
|
||||||
|
|
||||||
|
results, trades = self._process_single_task(task, progress_callback)
|
||||||
|
|
||||||
|
if results:
|
||||||
|
all_results.extend(results)
|
||||||
|
if trades:
|
||||||
|
all_trades.extend(trades)
|
||||||
|
|
||||||
|
# Mark task as completed
|
||||||
|
if self.progress_manager:
|
||||||
|
self.progress_manager.complete_task(task[0])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error processing task {task[1]} with stop loss {task[3]}: {e}"
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
finally:
|
||||||
|
# Stop progress display
|
||||||
|
if self.progress_manager:
|
||||||
|
self.progress_manager.stop_display()
|
||||||
|
|
||||||
|
return all_results, all_trades
|
||||||
|
|
||||||
|
def _run_parallel(self, tasks: List[Tuple]) -> Tuple[List[Dict], List[Dict]]:
|
||||||
|
"""Run tasks in parallel using ProcessPoolExecutor"""
|
||||||
|
workers = self.system_utils.get_optimal_workers()
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Running {len(tasks)} tasks with {workers} workers")
|
||||||
|
|
||||||
|
# OPTIMIZATION: Disable progress manager for parallel execution to reduce overhead
|
||||||
|
# Progress tracking adds significant overhead in multiprocessing
|
||||||
|
if self.progress_manager and self.logging:
|
||||||
|
self.logging.info("Progress tracking disabled for parallel execution (performance optimization)")
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
all_trades = []
|
||||||
|
completed_tasks = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
with concurrent.futures.ProcessPoolExecutor(max_workers=workers) as executor:
|
||||||
|
future_to_task = {
|
||||||
|
executor.submit(_process_single_task_static, task): task
|
||||||
|
for task in tasks
|
||||||
|
}
|
||||||
|
|
||||||
|
for future in concurrent.futures.as_completed(future_to_task):
|
||||||
|
task = future_to_task[future]
|
||||||
|
try:
|
||||||
|
results, trades = future.result()
|
||||||
|
if results:
|
||||||
|
all_results.extend(results)
|
||||||
|
if trades:
|
||||||
|
all_trades.extend(trades)
|
||||||
|
|
||||||
|
completed_tasks += 1
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Completed task {task[0]} ({completed_tasks}/{len(tasks)})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Task {task[1]} with stop loss {task[3]} failed: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Parallel execution failed: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Stop progress display
|
||||||
|
if self.progress_manager:
|
||||||
|
self.progress_manager.stop_display()
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"All {len(tasks)} tasks completed successfully")
|
||||||
|
|
||||||
|
return all_results, all_trades
|
||||||
|
|
||||||
|
def _process_single_task(
|
||||||
|
self,
|
||||||
|
task: Tuple[str, str, pd.DataFrame, float, float],
|
||||||
|
progress_callback=None
|
||||||
|
) -> Tuple[List[Dict], List[Dict]]:
|
||||||
|
"""
|
||||||
|
Process a single backtest task
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Tuple of (task_id, timeframe, data_1min, stop_loss_pct, initial_usd)
|
||||||
|
progress_callback: Optional progress callback function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (results, trades)
|
||||||
|
"""
|
||||||
|
task_id, timeframe, data_1min, stop_loss_pct, initial_usd = task
|
||||||
|
|
||||||
|
try:
|
||||||
|
if timeframe == "1T" or timeframe == "1min":
|
||||||
|
df = data_1min.copy()
|
||||||
|
else:
|
||||||
|
df = self._resample_data(data_1min, timeframe)
|
||||||
|
|
||||||
|
results, trades = self.result_processor.process_timeframe_results(
|
||||||
|
data_1min,
|
||||||
|
df,
|
||||||
|
[stop_loss_pct],
|
||||||
|
timeframe,
|
||||||
|
initial_usd,
|
||||||
|
progress_callback=progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# OPTIMIZATION: Skip individual trade file saving during parallel execution
|
||||||
|
# Trade files will be saved in batch at the end
|
||||||
|
# if trades:
|
||||||
|
# self.result_processor.save_trade_file(trades, timeframe, stop_loss_pct)
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Completed task {task_id}: {len(results)} results, {len(trades)} trades")
|
||||||
|
|
||||||
|
return results, trades
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to process {timeframe} with stop loss {stop_loss_pct}: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
def _resample_data(self, data_1min: pd.DataFrame, timeframe: str) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Resample 1-minute data to specified timeframe
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_1min: 1-minute data DataFrame
|
||||||
|
timeframe: Target timeframe string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resampled DataFrame
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agg_dict = {
|
||||||
|
'open': 'first',
|
||||||
|
'high': 'max',
|
||||||
|
'low': 'min',
|
||||||
|
'close': 'last',
|
||||||
|
'volume': 'sum'
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'predicted_close_price' in data_1min.columns:
|
||||||
|
agg_dict['predicted_close_price'] = 'last'
|
||||||
|
|
||||||
|
resampled = data_1min.resample(timeframe).agg(agg_dict).dropna()
|
||||||
|
|
||||||
|
return resampled.reset_index()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to resample data to {timeframe}: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise ValueError(error_msg) from e
|
||||||
|
|
||||||
|
def _get_timeframe_factor(self, timeframe: str) -> int:
|
||||||
|
"""
|
||||||
|
Get the factor by which data is reduced when resampling to timeframe
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeframe: Target timeframe string (e.g., '1h', '4h', '1D')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Factor for estimating data size after resampling
|
||||||
|
"""
|
||||||
|
timeframe_factors = {
|
||||||
|
'1T': 1, '1min': 1,
|
||||||
|
'5T': 5, '5min': 5,
|
||||||
|
'15T': 15, '15min': 15,
|
||||||
|
'30T': 30, '30min': 30,
|
||||||
|
'1h': 60, '1H': 60,
|
||||||
|
'2h': 120, '2H': 120,
|
||||||
|
'4h': 240, '4H': 240,
|
||||||
|
'6h': 360, '6H': 360,
|
||||||
|
'8h': 480, '8H': 480,
|
||||||
|
'12h': 720, '12H': 720,
|
||||||
|
'1D': 1440, '1d': 1440,
|
||||||
|
'2D': 2880, '2d': 2880,
|
||||||
|
'3D': 4320, '3d': 4320,
|
||||||
|
'1W': 10080, '1w': 10080
|
||||||
|
}
|
||||||
|
return timeframe_factors.get(timeframe, 60) # Default to 1 hour if unknown
|
||||||
|
|
||||||
|
def load_data(self, filename: str, start_date: str, stop_date: str) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Load and validate data for backtesting
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Name of data file
|
||||||
|
start_date: Start date string
|
||||||
|
stop_date: Stop date string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded and validated DataFrame
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If data is empty or invalid
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = self.storage.load_data(filename, start_date, stop_date)
|
||||||
|
|
||||||
|
if data.empty:
|
||||||
|
raise ValueError(f"No data loaded for period {start_date} to {stop_date}")
|
||||||
|
|
||||||
|
required_columns = ['open', 'high', 'low', 'close', 'volume']
|
||||||
|
|
||||||
|
if 'predicted_close_price' in data.columns:
|
||||||
|
required_columns.append('predicted_close_price')
|
||||||
|
|
||||||
|
missing_columns = [col for col in required_columns if col not in data.columns]
|
||||||
|
|
||||||
|
if missing_columns:
|
||||||
|
raise ValueError(f"Missing required columns: {missing_columns}")
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Loaded {len(data)} rows of data from {filename}")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to load data from {filename}: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
def validate_inputs(
|
||||||
|
self,
|
||||||
|
timeframes: List[str],
|
||||||
|
stop_loss_pcts: List[float],
|
||||||
|
initial_usd: float
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Validate backtest input parameters
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeframes: List of timeframe strings
|
||||||
|
stop_loss_pcts: List of stop loss percentages
|
||||||
|
initial_usd: Initial USD amount
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any input is invalid
|
||||||
|
"""
|
||||||
|
if not timeframes:
|
||||||
|
raise ValueError("At least one timeframe must be specified")
|
||||||
|
|
||||||
|
if not stop_loss_pcts:
|
||||||
|
raise ValueError("At least one stop loss percentage must be specified")
|
||||||
|
|
||||||
|
for pct in stop_loss_pcts:
|
||||||
|
if not 0 < pct < 1:
|
||||||
|
raise ValueError(f"Stop loss percentage must be between 0 and 1, got: {pct}")
|
||||||
|
|
||||||
|
if initial_usd <= 0:
|
||||||
|
raise ValueError("Initial USD must be positive")
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info("Input validation completed successfully")
|
||||||
86
charts.py
86
charts.py
@@ -1,86 +0,0 @@
|
|||||||
import os
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
class BacktestCharts:
|
|
||||||
def __init__(self, charts_dir="charts"):
|
|
||||||
self.charts_dir = charts_dir
|
|
||||||
os.makedirs(self.charts_dir, exist_ok=True)
|
|
||||||
|
|
||||||
def plot_profit_ratio_vs_stop_loss(self, results, filename="profit_ratio_vs_stop_loss.png"):
|
|
||||||
"""
|
|
||||||
Plots profit ratio vs stop loss percentage for each timeframe.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- results: list of dicts, each with keys: 'timeframe', 'stop_loss_pct', 'profit_ratio'
|
|
||||||
- filename: output filename (will be saved in charts_dir)
|
|
||||||
"""
|
|
||||||
# Organize data by timeframe
|
|
||||||
from collections import defaultdict
|
|
||||||
data = defaultdict(lambda: {"stop_loss_pct": [], "profit_ratio": []})
|
|
||||||
for row in results:
|
|
||||||
tf = row["timeframe"]
|
|
||||||
data[tf]["stop_loss_pct"].append(row["stop_loss_pct"])
|
|
||||||
data[tf]["profit_ratio"].append(row["profit_ratio"])
|
|
||||||
|
|
||||||
plt.figure(figsize=(10, 6))
|
|
||||||
for tf, vals in data.items():
|
|
||||||
# Sort by stop_loss_pct for smooth lines
|
|
||||||
sorted_pairs = sorted(zip(vals["stop_loss_pct"], vals["profit_ratio"]))
|
|
||||||
stop_loss, profit_ratio = zip(*sorted_pairs)
|
|
||||||
plt.plot(
|
|
||||||
[s * 100 for s in stop_loss], # Convert to percent
|
|
||||||
profit_ratio,
|
|
||||||
marker="o",
|
|
||||||
label=tf
|
|
||||||
)
|
|
||||||
|
|
||||||
plt.xlabel("Stop Loss (%)")
|
|
||||||
plt.ylabel("Profit Ratio")
|
|
||||||
plt.title("Profit Ratio vs Stop Loss (%) per Timeframe")
|
|
||||||
plt.legend(title="Timeframe")
|
|
||||||
plt.grid(True, linestyle="--", alpha=0.5)
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
output_path = os.path.join(self.charts_dir, filename)
|
|
||||||
plt.savefig(output_path)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
def plot_average_trade_vs_stop_loss(self, results, filename="average_trade_vs_stop_loss.png"):
|
|
||||||
"""
|
|
||||||
Plots average trade vs stop loss percentage for each timeframe.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- results: list of dicts, each with keys: 'timeframe', 'stop_loss_pct', 'average_trade'
|
|
||||||
- filename: output filename (will be saved in charts_dir)
|
|
||||||
"""
|
|
||||||
from collections import defaultdict
|
|
||||||
data = defaultdict(lambda: {"stop_loss_pct": [], "average_trade": []})
|
|
||||||
for row in results:
|
|
||||||
tf = row["timeframe"]
|
|
||||||
if "average_trade" not in row:
|
|
||||||
continue # Skip rows without average_trade
|
|
||||||
data[tf]["stop_loss_pct"].append(row["stop_loss_pct"])
|
|
||||||
data[tf]["average_trade"].append(row["average_trade"])
|
|
||||||
|
|
||||||
plt.figure(figsize=(10, 6))
|
|
||||||
for tf, vals in data.items():
|
|
||||||
# Sort by stop_loss_pct for smooth lines
|
|
||||||
sorted_pairs = sorted(zip(vals["stop_loss_pct"], vals["average_trade"]))
|
|
||||||
stop_loss, average_trade = zip(*sorted_pairs)
|
|
||||||
plt.plot(
|
|
||||||
[s * 100 for s in stop_loss], # Convert to percent
|
|
||||||
average_trade,
|
|
||||||
marker="o",
|
|
||||||
label=tf
|
|
||||||
)
|
|
||||||
|
|
||||||
plt.xlabel("Stop Loss (%)")
|
|
||||||
plt.ylabel("Average Trade")
|
|
||||||
plt.title("Average Trade vs Stop Loss (%) per Timeframe")
|
|
||||||
plt.legend(title="Timeframe")
|
|
||||||
plt.grid(True, linestyle="--", alpha=0.5)
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
output_path = os.path.join(self.charts_dir, filename)
|
|
||||||
plt.savefig(output_path)
|
|
||||||
plt.close()
|
|
||||||
175
config_manager.py
Normal file
175
config_manager.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
import json
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigManager:
|
||||||
|
"""Manages configuration loading, validation, and default values for backtest operations"""
|
||||||
|
|
||||||
|
DEFAULT_CONFIG = {
|
||||||
|
"start_date": "2025-05-01",
|
||||||
|
"stop_date": datetime.datetime.today().strftime('%Y-%m-%d'),
|
||||||
|
"initial_usd": 10000,
|
||||||
|
"timeframes": ["1D", "6h", "3h", "1h", "30m", "15m", "5m", "1m"],
|
||||||
|
"stop_loss_pcts": [0.01, 0.02, 0.03, 0.05],
|
||||||
|
"data_dir": "../data",
|
||||||
|
"results_dir": "results"
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, logging_instance: Optional[logging.Logger] = None):
|
||||||
|
"""
|
||||||
|
Initialize configuration manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logging_instance: Optional logging instance for output
|
||||||
|
"""
|
||||||
|
self.logging = logging_instance
|
||||||
|
self.config = {}
|
||||||
|
|
||||||
|
def load_config(self, config_path: Optional[str] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Load configuration from file or interactive input
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to JSON config file, if None prompts for interactive input
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing validated configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If config file doesn't exist
|
||||||
|
json.JSONDecodeError: If config file has invalid JSON
|
||||||
|
ValueError: If configuration values are invalid
|
||||||
|
"""
|
||||||
|
if config_path:
|
||||||
|
self.config = self._load_from_file(config_path)
|
||||||
|
else:
|
||||||
|
self.config = self._load_interactive()
|
||||||
|
|
||||||
|
self._validate_config()
|
||||||
|
return self.config
|
||||||
|
|
||||||
|
def _load_from_file(self, config_path: str) -> Dict[str, Any]:
|
||||||
|
"""Load configuration from JSON file"""
|
||||||
|
try:
|
||||||
|
config_file = Path(config_path)
|
||||||
|
if not config_file.exists():
|
||||||
|
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
||||||
|
|
||||||
|
with open(config_file, 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Configuration loaded from {config_path}")
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
error_msg = f"Invalid JSON in configuration file {config_path}: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise json.JSONDecodeError(error_msg, e.doc, e.pos)
|
||||||
|
|
||||||
|
def _load_interactive(self) -> Dict[str, Any]:
|
||||||
|
"""Load configuration through interactive prompts"""
|
||||||
|
print("No config file provided. Please enter the following values (press Enter to use default):")
|
||||||
|
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
# Start date
|
||||||
|
start_date = input(f"Start date [{self.DEFAULT_CONFIG['start_date']}]: ") or self.DEFAULT_CONFIG['start_date']
|
||||||
|
config['start_date'] = start_date
|
||||||
|
|
||||||
|
# Stop date
|
||||||
|
stop_date = input(f"Stop date [{self.DEFAULT_CONFIG['stop_date']}]: ") or self.DEFAULT_CONFIG['stop_date']
|
||||||
|
config['stop_date'] = stop_date
|
||||||
|
|
||||||
|
# Initial USD
|
||||||
|
initial_usd_str = input(f"Initial USD [{self.DEFAULT_CONFIG['initial_usd']}]: ") or str(self.DEFAULT_CONFIG['initial_usd'])
|
||||||
|
try:
|
||||||
|
config['initial_usd'] = float(initial_usd_str)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Invalid initial USD value: {initial_usd_str}")
|
||||||
|
|
||||||
|
# Timeframes
|
||||||
|
timeframes_str = input(f"Timeframes (comma separated) [{', '.join(self.DEFAULT_CONFIG['timeframes'])}]: ") or ','.join(self.DEFAULT_CONFIG['timeframes'])
|
||||||
|
config['timeframes'] = [tf.strip() for tf in timeframes_str.split(',') if tf.strip()]
|
||||||
|
|
||||||
|
# Stop loss percentages
|
||||||
|
stop_loss_pcts_str = input(f"Stop loss pcts (comma separated) [{', '.join(str(x) for x in self.DEFAULT_CONFIG['stop_loss_pcts'])}]: ") or ','.join(str(x) for x in self.DEFAULT_CONFIG['stop_loss_pcts'])
|
||||||
|
try:
|
||||||
|
config['stop_loss_pcts'] = [float(x.strip()) for x in stop_loss_pcts_str.split(',') if x.strip()]
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Invalid stop loss percentages: {stop_loss_pcts_str}")
|
||||||
|
|
||||||
|
# Add default directories
|
||||||
|
config['data_dir'] = self.DEFAULT_CONFIG['data_dir']
|
||||||
|
config['results_dir'] = self.DEFAULT_CONFIG['results_dir']
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def _validate_config(self) -> None:
|
||||||
|
"""
|
||||||
|
Validate configuration values
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any configuration value is invalid
|
||||||
|
"""
|
||||||
|
# Validate initial USD
|
||||||
|
if self.config.get('initial_usd', 0) <= 0:
|
||||||
|
raise ValueError("Initial USD must be positive")
|
||||||
|
|
||||||
|
# Validate stop loss percentages
|
||||||
|
stop_loss_pcts = self.config.get('stop_loss_pcts', [])
|
||||||
|
for pct in stop_loss_pcts:
|
||||||
|
if not 0 < pct < 1:
|
||||||
|
raise ValueError(f"Stop loss percentage must be between 0 and 1, got: {pct}")
|
||||||
|
|
||||||
|
# Validate dates
|
||||||
|
try:
|
||||||
|
datetime.datetime.strptime(self.config['start_date'], '%Y-%m-%d')
|
||||||
|
datetime.datetime.strptime(self.config['stop_date'], '%Y-%m-%d')
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"Invalid date format (should be YYYY-MM-DD): {e}")
|
||||||
|
|
||||||
|
# Validate timeframes
|
||||||
|
timeframes = self.config.get('timeframes', [])
|
||||||
|
if not timeframes:
|
||||||
|
raise ValueError("At least one timeframe must be specified")
|
||||||
|
|
||||||
|
# Validate directories exist or can be created
|
||||||
|
for dir_key in ['data_dir', 'results_dir']:
|
||||||
|
dir_path = Path(self.config.get(dir_key, ''))
|
||||||
|
try:
|
||||||
|
dir_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Cannot create directory {dir_path}: {e}")
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info("Configuration validation completed successfully")
|
||||||
|
|
||||||
|
def get_config(self) -> Dict[str, Any]:
|
||||||
|
"""Return the current configuration"""
|
||||||
|
return self.config.copy()
|
||||||
|
|
||||||
|
def save_config(self, output_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Save current configuration to file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path: Path where to save the configuration
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
json.dump(self.config, f, indent=2)
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Configuration saved to {output_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to save configuration to {output_path}: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise
|
||||||
29
configs/config_bbrs.json
Normal file
29
configs/config_bbrs.json
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
{
|
||||||
|
"start_date": "2025-01-01",
|
||||||
|
"stop_date": null,
|
||||||
|
"initial_usd": 10000,
|
||||||
|
"timeframes": ["1min"],
|
||||||
|
"strategies": [
|
||||||
|
{
|
||||||
|
"name": "bbrs",
|
||||||
|
"weight": 1.0,
|
||||||
|
"params": {
|
||||||
|
"bb_width": 0.05,
|
||||||
|
"bb_period": 20,
|
||||||
|
"rsi_period": 14,
|
||||||
|
"trending_rsi_threshold": [30, 70],
|
||||||
|
"trending_bb_multiplier": 2.5,
|
||||||
|
"sideways_rsi_threshold": [40, 60],
|
||||||
|
"sideways_bb_multiplier": 1.8,
|
||||||
|
"strategy_name": "MarketRegimeStrategy",
|
||||||
|
"SqueezeStrategy": true,
|
||||||
|
"stop_loss_pct": 0.05
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"combination_rules": {
|
||||||
|
"entry": "any",
|
||||||
|
"exit": "any",
|
||||||
|
"min_confidence": 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
29
configs/config_bbrs_multi_timeframe.json
Normal file
29
configs/config_bbrs_multi_timeframe.json
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
{
|
||||||
|
"start_date": "2024-01-01",
|
||||||
|
"stop_date": "2024-01-31",
|
||||||
|
"initial_usd": 10000,
|
||||||
|
"timeframes": ["1min"],
|
||||||
|
"stop_loss_pcts": [0.05],
|
||||||
|
"strategies": [
|
||||||
|
{
|
||||||
|
"name": "bbrs",
|
||||||
|
"weight": 1.0,
|
||||||
|
"params": {
|
||||||
|
"bb_width": 0.05,
|
||||||
|
"bb_period": 20,
|
||||||
|
"rsi_period": 14,
|
||||||
|
"trending_rsi_threshold": [30, 70],
|
||||||
|
"trending_bb_multiplier": 2.5,
|
||||||
|
"sideways_rsi_threshold": [40, 60],
|
||||||
|
"sideways_bb_multiplier": 1.8,
|
||||||
|
"strategy_name": "MarketRegimeStrategy",
|
||||||
|
"SqueezeStrategy": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"combination_rules": {
|
||||||
|
"entry": "any",
|
||||||
|
"exit": "any",
|
||||||
|
"min_confidence": 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
37
configs/config_combined.json
Normal file
37
configs/config_combined.json
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
{
|
||||||
|
"start_date": "2025-03-01",
|
||||||
|
"stop_date": "2025-03-15",
|
||||||
|
"initial_usd": 10000,
|
||||||
|
"timeframes": ["15min"],
|
||||||
|
"strategies": [
|
||||||
|
{
|
||||||
|
"name": "default",
|
||||||
|
"weight": 0.6,
|
||||||
|
"params": {
|
||||||
|
"timeframe": "15min",
|
||||||
|
"stop_loss_pct": 0.03
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "bbrs",
|
||||||
|
"weight": 0.4,
|
||||||
|
"params": {
|
||||||
|
"bb_width": 0.05,
|
||||||
|
"bb_period": 20,
|
||||||
|
"rsi_period": 14,
|
||||||
|
"trending_rsi_threshold": [30, 70],
|
||||||
|
"trending_bb_multiplier": 2.5,
|
||||||
|
"sideways_rsi_threshold": [40, 60],
|
||||||
|
"sideways_bb_multiplier": 1.8,
|
||||||
|
"strategy_name": "MarketRegimeStrategy",
|
||||||
|
"SqueezeStrategy": true,
|
||||||
|
"stop_loss_pct": 0.05
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"combination_rules": {
|
||||||
|
"entry": "weighted_consensus",
|
||||||
|
"exit": "any",
|
||||||
|
"min_confidence": 0.6
|
||||||
|
}
|
||||||
|
}
|
||||||
21
configs/config_default.json
Normal file
21
configs/config_default.json
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"start_date": "2024-01-01",
|
||||||
|
"stop_date": null,
|
||||||
|
"initial_usd": 10000,
|
||||||
|
"timeframes": ["15min"],
|
||||||
|
"strategies": [
|
||||||
|
{
|
||||||
|
"name": "default",
|
||||||
|
"weight": 1.0,
|
||||||
|
"params": {
|
||||||
|
"timeframe": "15min",
|
||||||
|
"stop_loss_pct": 0.03
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"combination_rules": {
|
||||||
|
"entry": "any",
|
||||||
|
"exit": "any",
|
||||||
|
"min_confidence": 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
21
configs/config_default_5min.json
Normal file
21
configs/config_default_5min.json
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"start_date": "2024-01-01",
|
||||||
|
"stop_date": "2024-01-31",
|
||||||
|
"initial_usd": 10000,
|
||||||
|
"timeframes": ["5min"],
|
||||||
|
"strategies": [
|
||||||
|
{
|
||||||
|
"name": "default",
|
||||||
|
"weight": 1.0,
|
||||||
|
"params": {
|
||||||
|
"timeframe": "5min",
|
||||||
|
"stop_loss_pct": 0.03
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"combination_rules": {
|
||||||
|
"entry": "any",
|
||||||
|
"exit": "any",
|
||||||
|
"min_confidence": 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
10
configs/flat_2021_2024_config.json
Normal file
10
configs/flat_2021_2024_config.json
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"start_date": "2021-11-01",
|
||||||
|
"stop_date": "2024-04-01",
|
||||||
|
"initial_usd": 10000,
|
||||||
|
"timeframes": ["1min", "2min", "3min", "4min", "5min", "10min", "15min", "30min", "1h", "2h", "4h", "6h", "8h", "12h", "1d"],
|
||||||
|
"stop_loss_pcts": [0.01, 0.02, 0.03, 0.04, 0.05, 0.1],
|
||||||
|
"data_dir": "../data",
|
||||||
|
"results_dir": "../results",
|
||||||
|
"debug": 0
|
||||||
|
}
|
||||||
10
configs/full_config.json
Normal file
10
configs/full_config.json
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"start_date": "2020-01-01",
|
||||||
|
"stop_date": "2025-07-08",
|
||||||
|
"initial_usd": 10000,
|
||||||
|
"timeframes": ["1h", "4h", "15ME", "5ME", "1ME"],
|
||||||
|
"stop_loss_pcts": [0.01, 0.02, 0.03, 0.05],
|
||||||
|
"data_dir": "../data",
|
||||||
|
"results_dir": "../results",
|
||||||
|
"debug": 1
|
||||||
|
}
|
||||||
10
configs/sample_config.json
Normal file
10
configs/sample_config.json
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"start_date": "2023-01-01",
|
||||||
|
"stop_date": "2025-01-15",
|
||||||
|
"initial_usd": 10000,
|
||||||
|
"timeframes": ["4h"],
|
||||||
|
"stop_loss_pcts": [0.05],
|
||||||
|
"data_dir": "../data",
|
||||||
|
"results_dir": "../results",
|
||||||
|
"debug": 0
|
||||||
|
}
|
||||||
0
cycles/Analysis/__init__.py
Normal file
0
cycles/Analysis/__init__.py
Normal file
415
cycles/Analysis/bb_rsi.py
Normal file
415
cycles/Analysis/bb_rsi.py
Normal file
@@ -0,0 +1,415 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from cycles.Analysis.boillinger_band import BollingerBands
|
||||||
|
from cycles.Analysis.rsi import RSI
|
||||||
|
from cycles.utils.data_utils import aggregate_to_daily, aggregate_to_hourly, aggregate_to_minutes
|
||||||
|
|
||||||
|
|
||||||
|
class BollingerBandsStrategy:
|
||||||
|
|
||||||
|
def __init__(self, config = None, logging = None):
|
||||||
|
if config is None:
|
||||||
|
raise ValueError("Config must be provided.")
|
||||||
|
self.config = config
|
||||||
|
self.logging = logging
|
||||||
|
|
||||||
|
def _ensure_datetime_index(self, data):
|
||||||
|
"""
|
||||||
|
Ensure the DataFrame has a DatetimeIndex for proper time-series operations.
|
||||||
|
If the DataFrame has a 'timestamp' column but not a DatetimeIndex, convert it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (DataFrame): Input DataFrame
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame: DataFrame with proper DatetimeIndex
|
||||||
|
"""
|
||||||
|
if data.empty:
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Check if we have a DatetimeIndex already
|
||||||
|
if isinstance(data.index, pd.DatetimeIndex):
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Check if we have a 'timestamp' column that we can use as index
|
||||||
|
if 'timestamp' in data.columns:
|
||||||
|
data_copy = data.copy()
|
||||||
|
# Convert timestamp column to datetime if it's not already
|
||||||
|
if not pd.api.types.is_datetime64_any_dtype(data_copy['timestamp']):
|
||||||
|
data_copy['timestamp'] = pd.to_datetime(data_copy['timestamp'])
|
||||||
|
# Set timestamp as index and drop the column
|
||||||
|
data_copy = data_copy.set_index('timestamp')
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info("Converted 'timestamp' column to DatetimeIndex for strategy processing.")
|
||||||
|
return data_copy
|
||||||
|
|
||||||
|
# If we have a regular index but it might be datetime strings, try to convert
|
||||||
|
try:
|
||||||
|
if data.index.dtype == 'object':
|
||||||
|
data_copy = data.copy()
|
||||||
|
data_copy.index = pd.to_datetime(data_copy.index)
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info("Converted index to DatetimeIndex for strategy processing.")
|
||||||
|
return data_copy
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If we can't create a proper DatetimeIndex, warn and return as-is
|
||||||
|
if self.logging:
|
||||||
|
self.logging.warning("Could not create DatetimeIndex for strategy processing. Time-based operations may fail.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
def run(self, data, strategy_name):
|
||||||
|
# Ensure proper DatetimeIndex before processing
|
||||||
|
data = self._ensure_datetime_index(data)
|
||||||
|
|
||||||
|
if strategy_name == "MarketRegimeStrategy":
|
||||||
|
result = self.MarketRegimeStrategy(data)
|
||||||
|
return self.standardize_output(result, strategy_name)
|
||||||
|
elif strategy_name == "CryptoTradingStrategy":
|
||||||
|
result = self.CryptoTradingStrategy(data)
|
||||||
|
return self.standardize_output(result, strategy_name)
|
||||||
|
else:
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.warning(f"Strategy {strategy_name} not found. Using no_strategy instead.")
|
||||||
|
return self.no_strategy(data)
|
||||||
|
|
||||||
|
def standardize_output(self, data, strategy_name):
|
||||||
|
"""
|
||||||
|
Standardize column names across different strategies to ensure consistent plotting and analysis
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (DataFrame): Strategy output DataFrame
|
||||||
|
strategy_name (str): Name of the strategy that generated this data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame: Data with standardized column names
|
||||||
|
"""
|
||||||
|
if data.empty:
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Create a copy to avoid modifying the original
|
||||||
|
standardized = data.copy()
|
||||||
|
|
||||||
|
# Standardize column names based on strategy
|
||||||
|
if strategy_name == "MarketRegimeStrategy":
|
||||||
|
# MarketRegimeStrategy already has standard column names for most fields
|
||||||
|
# Just ensure all standard columns exist
|
||||||
|
pass
|
||||||
|
elif strategy_name == "CryptoTradingStrategy":
|
||||||
|
# Map strategy-specific column names to standard names
|
||||||
|
column_mapping = {
|
||||||
|
'UpperBand_15m': 'UpperBand',
|
||||||
|
'LowerBand_15m': 'LowerBand',
|
||||||
|
'SMA_15m': 'SMA',
|
||||||
|
'RSI_15m': 'RSI',
|
||||||
|
'VolumeMA_15m': 'VolumeMA',
|
||||||
|
# Keep StopLoss and TakeProfit as they are
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add standard columns from mapped columns
|
||||||
|
for old_col, new_col in column_mapping.items():
|
||||||
|
if old_col in standardized.columns and new_col not in standardized.columns:
|
||||||
|
standardized[new_col] = standardized[old_col]
|
||||||
|
|
||||||
|
# Add additional strategy-specific data as metadata columns
|
||||||
|
if 'UpperBand_1h' in standardized.columns:
|
||||||
|
standardized['UpperBand_1h_meta'] = standardized['UpperBand_1h']
|
||||||
|
if 'LowerBand_1h' in standardized.columns:
|
||||||
|
standardized['LowerBand_1h_meta'] = standardized['LowerBand_1h']
|
||||||
|
|
||||||
|
# Ensure all strategies have BBWidth if possible
|
||||||
|
if 'BBWidth' not in standardized.columns and 'UpperBand' in standardized.columns and 'LowerBand' in standardized.columns:
|
||||||
|
standardized['BBWidth'] = (standardized['UpperBand'] - standardized['LowerBand']) / standardized['SMA'] if 'SMA' in standardized.columns else np.nan
|
||||||
|
|
||||||
|
return standardized
|
||||||
|
|
||||||
|
def no_strategy(self, data):
|
||||||
|
"""No strategy: returns False for both buy and sell conditions"""
|
||||||
|
buy_condition = pd.Series([False] * len(data), index=data.index)
|
||||||
|
sell_condition = pd.Series([False] * len(data), index=data.index)
|
||||||
|
return buy_condition, sell_condition
|
||||||
|
|
||||||
|
def rsi_bollinger_confirmation(self, rsi, window=14, std_mult=1.5):
|
||||||
|
"""Calculate RSI Bollinger Bands for confirmation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rsi (Series): RSI values
|
||||||
|
window (int): Rolling window for SMA
|
||||||
|
std_mult (float): Standard deviation multiplier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (oversold condition, overbought condition)
|
||||||
|
"""
|
||||||
|
valid_rsi = ~rsi.isna()
|
||||||
|
if not valid_rsi.any():
|
||||||
|
# Return empty Series if no valid RSI data
|
||||||
|
return pd.Series(False, index=rsi.index), pd.Series(False, index=rsi.index)
|
||||||
|
|
||||||
|
rsi_sma = rsi.rolling(window).mean()
|
||||||
|
rsi_std = rsi.rolling(window).std()
|
||||||
|
upper_rsi_band = rsi_sma + std_mult * rsi_std
|
||||||
|
lower_rsi_band = rsi_sma - std_mult * rsi_std
|
||||||
|
|
||||||
|
return (rsi < lower_rsi_band), (rsi > upper_rsi_band)
|
||||||
|
|
||||||
|
def MarketRegimeStrategy(self, data):
|
||||||
|
"""Optimized Bollinger Bands + RSI Strategy for Crypto Trading (Including Sideways Markets)
|
||||||
|
with adaptive Bollinger Bands
|
||||||
|
|
||||||
|
This advanced strategy combines volatility analysis, momentum confirmation, and regime detection
|
||||||
|
to adapt to Bitcoin's unique market conditions.
|
||||||
|
|
||||||
|
Entry Conditions:
|
||||||
|
- Trending Market (Breakout Mode):
|
||||||
|
Buy: Price < Lower Band ∧ RSI < 50 ∧ Volume Spike (≥1.5× 20D Avg)
|
||||||
|
Sell: Price > Upper Band ∧ RSI > 50 ∧ Volume Spike
|
||||||
|
- Sideways Market (Mean Reversion):
|
||||||
|
Buy: Price ≤ Lower Band ∧ RSI ≤ 40
|
||||||
|
Sell: Price ≥ Upper Band ∧ RSI ≥ 60
|
||||||
|
|
||||||
|
Enhanced with RSI Bollinger Squeeze for signal confirmation when enabled.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame: A unified DataFrame containing original data, BB, RSI, and signals.
|
||||||
|
"""
|
||||||
|
|
||||||
|
data = aggregate_to_hourly(data, 1)
|
||||||
|
# data = aggregate_to_daily(data)
|
||||||
|
|
||||||
|
# Calculate Bollinger Bands
|
||||||
|
bb_calculator = BollingerBands(config=self.config)
|
||||||
|
# Ensure we are working with a copy to avoid modifying the original DataFrame upstream
|
||||||
|
data_bb = bb_calculator.calculate(data.copy())
|
||||||
|
|
||||||
|
# Calculate RSI
|
||||||
|
rsi_calculator = RSI(config=self.config)
|
||||||
|
# Use the original data's copy for RSI calculation as well, to maintain index integrity
|
||||||
|
data_with_rsi = rsi_calculator.calculate(data.copy(), price_column='close')
|
||||||
|
|
||||||
|
# Combine BB and RSI data into a single DataFrame for signal generation
|
||||||
|
# Ensure indices are aligned; they should be as both are from data.copy()
|
||||||
|
if 'RSI' in data_with_rsi.columns:
|
||||||
|
data_bb['RSI'] = data_with_rsi['RSI']
|
||||||
|
else:
|
||||||
|
# If RSI wasn't calculated (e.g., not enough data), create a dummy column with NaNs
|
||||||
|
# to prevent errors later, though signals won't be generated.
|
||||||
|
data_bb['RSI'] = pd.Series(index=data_bb.index, dtype=float)
|
||||||
|
if self.logging:
|
||||||
|
self.logging.warning("RSI column not found or not calculated. Signals relying on RSI may not be generated.")
|
||||||
|
|
||||||
|
# Initialize conditions as all False
|
||||||
|
buy_condition = pd.Series(False, index=data_bb.index)
|
||||||
|
sell_condition = pd.Series(False, index=data_bb.index)
|
||||||
|
|
||||||
|
# Create masks for different market regimes
|
||||||
|
# MarketRegime is expected to be in data_bb from BollingerBands calculation
|
||||||
|
sideways_mask = data_bb['MarketRegime'] > 0
|
||||||
|
trending_mask = data_bb['MarketRegime'] <= 0
|
||||||
|
valid_data_mask = ~data_bb['MarketRegime'].isna() # Handle potential NaN values
|
||||||
|
|
||||||
|
# Calculate volume spike (≥1.5× 20D Avg)
|
||||||
|
# 'volume' column should be present in the input 'data', and thus in 'data_bb'
|
||||||
|
if 'volume' in data_bb.columns:
|
||||||
|
volume_20d_avg = data_bb['volume'].rolling(window=20).mean()
|
||||||
|
volume_spike = data_bb['volume'] >= 1.5 * volume_20d_avg
|
||||||
|
|
||||||
|
# Additional volume contraction filter for sideways markets
|
||||||
|
volume_30d_avg = data_bb['volume'].rolling(window=30).mean()
|
||||||
|
volume_contraction = data_bb['volume'] < 0.7 * volume_30d_avg
|
||||||
|
else:
|
||||||
|
# If volume data is not available, assume no volume spike
|
||||||
|
volume_spike = pd.Series(False, index=data_bb.index)
|
||||||
|
volume_contraction = pd.Series(False, index=data_bb.index)
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.warning("Volume data not available. Volume conditions will not be triggered.")
|
||||||
|
|
||||||
|
# Calculate RSI Bollinger Squeeze confirmation
|
||||||
|
# RSI column is now part of data_bb
|
||||||
|
if 'RSI' in data_bb.columns and not data_bb['RSI'].isna().all():
|
||||||
|
oversold_rsi, overbought_rsi = self.rsi_bollinger_confirmation(data_bb['RSI'])
|
||||||
|
else:
|
||||||
|
oversold_rsi = pd.Series(False, index=data_bb.index)
|
||||||
|
overbought_rsi = pd.Series(False, index=data_bb.index)
|
||||||
|
if self.logging is not None and ('RSI' not in data_bb.columns or data_bb['RSI'].isna().all()):
|
||||||
|
self.logging.warning("RSI data not available or all NaN. RSI Bollinger Squeeze will not be triggered.")
|
||||||
|
|
||||||
|
# Calculate conditions for sideways market (Mean Reversion)
|
||||||
|
if sideways_mask.any():
|
||||||
|
sideways_buy = (data_bb['close'] <= data_bb['LowerBand']) & (data_bb['RSI'] <= 40)
|
||||||
|
sideways_sell = (data_bb['close'] >= data_bb['UpperBand']) & (data_bb['RSI'] >= 60)
|
||||||
|
|
||||||
|
# Add enhanced confirmation for sideways markets
|
||||||
|
if self.config.get("SqueezeStrategy", False):
|
||||||
|
sideways_buy = sideways_buy & oversold_rsi & volume_contraction
|
||||||
|
sideways_sell = sideways_sell & overbought_rsi & volume_contraction
|
||||||
|
|
||||||
|
# Apply only where market is sideways and data is valid
|
||||||
|
buy_condition = buy_condition | (sideways_buy & sideways_mask & valid_data_mask)
|
||||||
|
sell_condition = sell_condition | (sideways_sell & sideways_mask & valid_data_mask)
|
||||||
|
|
||||||
|
# Calculate conditions for trending market (Breakout Mode)
|
||||||
|
if trending_mask.any():
|
||||||
|
trending_buy = (data_bb['close'] < data_bb['LowerBand']) & (data_bb['RSI'] < 50) & volume_spike
|
||||||
|
trending_sell = (data_bb['close'] > data_bb['UpperBand']) & (data_bb['RSI'] > 50) & volume_spike
|
||||||
|
|
||||||
|
# Add enhanced confirmation for trending markets
|
||||||
|
if self.config.get("SqueezeStrategy", False):
|
||||||
|
trending_buy = trending_buy & oversold_rsi
|
||||||
|
trending_sell = trending_sell & overbought_rsi
|
||||||
|
|
||||||
|
# Apply only where market is trending and data is valid
|
||||||
|
buy_condition = buy_condition | (trending_buy & trending_mask & valid_data_mask)
|
||||||
|
sell_condition = sell_condition | (trending_sell & trending_mask & valid_data_mask)
|
||||||
|
|
||||||
|
# Add buy/sell conditions as columns to the DataFrame
|
||||||
|
data_bb['BuySignal'] = buy_condition
|
||||||
|
data_bb['SellSignal'] = sell_condition
|
||||||
|
|
||||||
|
return data_bb
|
||||||
|
|
||||||
|
# Helper functions for CryptoTradingStrategy
|
||||||
|
def _volume_confirmation_crypto(self, current_volume, volume_ma):
|
||||||
|
"""Check volume surge against moving average for crypto strategy"""
|
||||||
|
if pd.isna(current_volume) or pd.isna(volume_ma) or volume_ma == 0:
|
||||||
|
return False
|
||||||
|
return current_volume > 1.5 * volume_ma
|
||||||
|
|
||||||
|
def _multi_timeframe_signal_crypto(self, current_price, rsi_value,
|
||||||
|
lower_band_15m, lower_band_1h,
|
||||||
|
upper_band_15m, upper_band_1h):
|
||||||
|
"""Generate signals with multi-timeframe confirmation for crypto strategy"""
|
||||||
|
# Ensure all inputs are not NaN before making comparisons
|
||||||
|
if any(pd.isna(val) for val in [current_price, rsi_value, lower_band_15m, lower_band_1h, upper_band_15m, upper_band_1h]):
|
||||||
|
return False, False
|
||||||
|
|
||||||
|
buy_signal = (current_price <= lower_band_15m and
|
||||||
|
current_price <= lower_band_1h and
|
||||||
|
rsi_value < 35)
|
||||||
|
|
||||||
|
sell_signal = (current_price >= upper_band_15m and
|
||||||
|
current_price >= upper_band_1h and
|
||||||
|
rsi_value > 65)
|
||||||
|
|
||||||
|
return buy_signal, sell_signal
|
||||||
|
|
||||||
|
def CryptoTradingStrategy(self, data):
|
||||||
|
"""Core trading algorithm with risk management
|
||||||
|
- Multi-Timeframe Confirmation: Combines 15-minute and 1-hour Bollinger Bands
|
||||||
|
- Adaptive Volatility Filtering: Uses ATR for dynamic stop-loss/take-profit
|
||||||
|
- Volume Spike Detection: Requires 1.5× average volume for confirmation
|
||||||
|
- EMA-Smoothed RSI: Reduces false signals in choppy markets
|
||||||
|
- Regime-Adaptive Parameters:
|
||||||
|
- Trending: 2σ bands, RSI 35/65 thresholds
|
||||||
|
- Sideways: 1.8σ bands, RSI 40/60 thresholds
|
||||||
|
- Strategy Logic:
|
||||||
|
- Long Entry: Price ≤ both 15m & 1h lower bands + RSI < 35 + Volume surge
|
||||||
|
- Short Entry: Price ≥ both 15m & 1h upper bands + RSI > 65 + Volume surge
|
||||||
|
- Exit: 2:1 risk-reward ratio with ATR-based stops
|
||||||
|
"""
|
||||||
|
if data.empty or 'close' not in data.columns or 'volume' not in data.columns:
|
||||||
|
if self.logging:
|
||||||
|
self.logging.warning("CryptoTradingStrategy: Input data is empty or missing 'close'/'volume' columns.")
|
||||||
|
return pd.DataFrame() # Return empty DataFrame if essential data is missing
|
||||||
|
|
||||||
|
print(f"data: {data.head()}")
|
||||||
|
|
||||||
|
# Aggregate data
|
||||||
|
data_15m = aggregate_to_minutes(data.copy(), 15)
|
||||||
|
data_1h = aggregate_to_hourly(data.copy(), 1)
|
||||||
|
|
||||||
|
if data_15m.empty or data_1h.empty:
|
||||||
|
if self.logging:
|
||||||
|
self.logging.warning("CryptoTradingStrategy: Not enough data for 15m or 1h aggregation.")
|
||||||
|
return pd.DataFrame() # Return original data if aggregation fails
|
||||||
|
|
||||||
|
# --- Calculate indicators for 15m timeframe ---
|
||||||
|
# Ensure 'close' and 'volume' exist before trying to access them
|
||||||
|
if 'close' not in data_15m.columns or 'volume' not in data_15m.columns:
|
||||||
|
if self.logging: self.logging.warning("CryptoTradingStrategy: 15m data missing close or volume.")
|
||||||
|
return data # Or an empty DF
|
||||||
|
|
||||||
|
price_data_15m = data_15m['close']
|
||||||
|
volume_data_15m = data_15m['volume']
|
||||||
|
|
||||||
|
upper_15m, sma_15m, lower_15m = BollingerBands.calculate_custom_bands(price_data_15m, window=20, num_std=2, min_periods=1)
|
||||||
|
# Use the static method from RSI class
|
||||||
|
rsi_15m = RSI.calculate_custom_rsi(price_data_15m, window=14, smoothing='EMA')
|
||||||
|
volume_ma_15m = volume_data_15m.rolling(window=20, min_periods=1).mean()
|
||||||
|
|
||||||
|
# Add 15m indicators to data_15m DataFrame
|
||||||
|
data_15m['UpperBand_15m'] = upper_15m
|
||||||
|
data_15m['SMA_15m'] = sma_15m
|
||||||
|
data_15m['LowerBand_15m'] = lower_15m
|
||||||
|
data_15m['RSI_15m'] = rsi_15m
|
||||||
|
data_15m['VolumeMA_15m'] = volume_ma_15m
|
||||||
|
|
||||||
|
# --- Calculate indicators for 1h timeframe ---
|
||||||
|
if 'close' not in data_1h.columns:
|
||||||
|
if self.logging: self.logging.warning("CryptoTradingStrategy: 1h data missing close.")
|
||||||
|
return data_15m # Return 15m data as 1h failed
|
||||||
|
|
||||||
|
price_data_1h = data_1h['close']
|
||||||
|
# Use the static method from BollingerBands class, setting min_periods to 1 explicitly
|
||||||
|
upper_1h, _, lower_1h = BollingerBands.calculate_custom_bands(price_data_1h, window=50, num_std=1.8, min_periods=1)
|
||||||
|
|
||||||
|
# Add 1h indicators to a temporary DataFrame to be merged
|
||||||
|
df_1h_indicators = pd.DataFrame(index=data_1h.index)
|
||||||
|
df_1h_indicators['UpperBand_1h'] = upper_1h
|
||||||
|
df_1h_indicators['LowerBand_1h'] = lower_1h
|
||||||
|
|
||||||
|
# Merge 1h indicators into 15m DataFrame
|
||||||
|
# Use reindex and ffill to propagate 1h values to 15m intervals
|
||||||
|
data_15m = pd.merge(data_15m, df_1h_indicators, left_index=True, right_index=True, how='left')
|
||||||
|
data_15m['UpperBand_1h'] = data_15m['UpperBand_1h'].ffill()
|
||||||
|
data_15m['LowerBand_1h'] = data_15m['LowerBand_1h'].ffill()
|
||||||
|
|
||||||
|
# --- Generate Signals ---
|
||||||
|
buy_signals = pd.Series(False, index=data_15m.index)
|
||||||
|
sell_signals = pd.Series(False, index=data_15m.index)
|
||||||
|
stop_loss_levels = pd.Series(np.nan, index=data_15m.index)
|
||||||
|
take_profit_levels = pd.Series(np.nan, index=data_15m.index)
|
||||||
|
|
||||||
|
# ATR calculation needs a rolling window, apply to 'high', 'low', 'close' if available
|
||||||
|
# Using a simplified ATR for now: std of close prices over the last 4 15-min periods (1 hour)
|
||||||
|
if 'close' in data_15m.columns:
|
||||||
|
atr_series = price_data_15m.rolling(window=4, min_periods=1).std()
|
||||||
|
else:
|
||||||
|
atr_series = pd.Series(0, index=data_15m.index) # No ATR if close is missing
|
||||||
|
|
||||||
|
for i in range(len(data_15m)):
|
||||||
|
if i == 0: continue # Skip first row for volume_ma_15m[i-1]
|
||||||
|
|
||||||
|
current_price = data_15m['close'].iloc[i]
|
||||||
|
current_volume = data_15m['volume'].iloc[i]
|
||||||
|
rsi_val = data_15m['RSI_15m'].iloc[i]
|
||||||
|
lb_15m = data_15m['LowerBand_15m'].iloc[i]
|
||||||
|
ub_15m = data_15m['UpperBand_15m'].iloc[i]
|
||||||
|
lb_1h = data_15m['LowerBand_1h'].iloc[i]
|
||||||
|
ub_1h = data_15m['UpperBand_1h'].iloc[i]
|
||||||
|
vol_ma = data_15m['VolumeMA_15m'].iloc[i-1] # Use previous period's MA
|
||||||
|
atr = atr_series.iloc[i]
|
||||||
|
|
||||||
|
vol_confirm = self._volume_confirmation_crypto(current_volume, vol_ma)
|
||||||
|
buy_signal, sell_signal = self._multi_timeframe_signal_crypto(
|
||||||
|
current_price, rsi_val, lb_15m, lb_1h, ub_15m, ub_1h
|
||||||
|
)
|
||||||
|
|
||||||
|
if buy_signal and vol_confirm:
|
||||||
|
buy_signals.iloc[i] = True
|
||||||
|
if not pd.isna(atr) and atr > 0:
|
||||||
|
stop_loss_levels.iloc[i] = current_price - 2 * atr
|
||||||
|
take_profit_levels.iloc[i] = current_price + 4 * atr
|
||||||
|
elif sell_signal and vol_confirm:
|
||||||
|
sell_signals.iloc[i] = True
|
||||||
|
if not pd.isna(atr) and atr > 0:
|
||||||
|
stop_loss_levels.iloc[i] = current_price + 2 * atr
|
||||||
|
take_profit_levels.iloc[i] = current_price - 4 * atr
|
||||||
|
|
||||||
|
data_15m['BuySignal'] = buy_signals
|
||||||
|
data_15m['SellSignal'] = sell_signals
|
||||||
|
data_15m['StopLoss'] = stop_loss_levels
|
||||||
|
data_15m['TakeProfit'] = take_profit_levels
|
||||||
|
|
||||||
|
return data_15m
|
||||||
145
cycles/Analysis/boillinger_band.py
Normal file
145
cycles/Analysis/boillinger_band.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class BollingerBands:
|
||||||
|
"""
|
||||||
|
Calculates Bollinger Bands for given financial data.
|
||||||
|
"""
|
||||||
|
def __init__(self, config):
|
||||||
|
"""
|
||||||
|
Initializes the BollingerBands calculator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
period (int): The period for the moving average and standard deviation.
|
||||||
|
std_dev_multiplier (float): The number of standard deviations for the upper and lower bands.
|
||||||
|
bb_width (float): The width of the Bollinger Bands.
|
||||||
|
"""
|
||||||
|
if config['bb_period'] <= 0:
|
||||||
|
raise ValueError("Period must be a positive integer.")
|
||||||
|
if config['trending']['bb_std_dev_multiplier'] <= 0 or config['sideways']['bb_std_dev_multiplier'] <= 0:
|
||||||
|
raise ValueError("Standard deviation multiplier must be positive.")
|
||||||
|
if config['bb_width'] <= 0:
|
||||||
|
raise ValueError("BB width must be positive.")
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def calculate(self, data_df: pd.DataFrame, price_column: str = 'close', squeeze = False) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Calculates Bollinger Bands and adds them to the DataFrame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_df (pd.DataFrame): DataFrame with price data. Must include the price_column.
|
||||||
|
price_column (str): The name of the column containing the price data (e.g., 'close').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: The original DataFrame with added columns:
|
||||||
|
'SMA' (Simple Moving Average),
|
||||||
|
'UpperBand',
|
||||||
|
'LowerBand'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Work on a copy to avoid modifying the original DataFrame passed to the function
|
||||||
|
data_df = data_df.copy()
|
||||||
|
|
||||||
|
if price_column not in data_df.columns:
|
||||||
|
raise ValueError(f"Price column '{price_column}' not found in DataFrame.")
|
||||||
|
|
||||||
|
if not squeeze:
|
||||||
|
period = self.config['bb_period']
|
||||||
|
bb_width_threshold = self.config['bb_width']
|
||||||
|
trending_std_multiplier = self.config['trending']['bb_std_dev_multiplier']
|
||||||
|
sideways_std_multiplier = self.config['sideways']['bb_std_dev_multiplier']
|
||||||
|
|
||||||
|
# Calculate SMA
|
||||||
|
data_df['SMA'] = data_df[price_column].rolling(window=period).mean()
|
||||||
|
|
||||||
|
# Calculate Standard Deviation
|
||||||
|
std_dev = data_df[price_column].rolling(window=period).std()
|
||||||
|
|
||||||
|
# Calculate reference Upper and Lower Bands for BBWidth calculation (e.g., using 2.0 std dev)
|
||||||
|
# This ensures BBWidth is calculated based on a consistent band definition before applying adaptive multipliers.
|
||||||
|
ref_upper_band = data_df['SMA'] + (2.0 * std_dev)
|
||||||
|
ref_lower_band = data_df['SMA'] - (2.0 * std_dev)
|
||||||
|
|
||||||
|
# Calculate the width of the Bollinger Bands
|
||||||
|
# Avoid division by zero or NaN if SMA is zero or NaN by replacing with np.nan
|
||||||
|
data_df['BBWidth'] = np.where(data_df['SMA'] != 0, (ref_upper_band - ref_lower_band) / data_df['SMA'], np.nan)
|
||||||
|
|
||||||
|
# Calculate the market regime (1 = sideways, 0 = trending)
|
||||||
|
# Handle NaN in BBWidth: if BBWidth is NaN, MarketRegime should also be NaN or a default (e.g. trending)
|
||||||
|
data_df['MarketRegime'] = np.where(data_df['BBWidth'].isna(), np.nan,
|
||||||
|
(data_df['BBWidth'] < bb_width_threshold).astype(float)) # Use float for NaN compatibility
|
||||||
|
|
||||||
|
# Determine the std dev multiplier for each row based on its market regime
|
||||||
|
conditions = [
|
||||||
|
data_df['MarketRegime'] == 1, # Sideways market
|
||||||
|
data_df['MarketRegime'] == 0 # Trending market
|
||||||
|
]
|
||||||
|
choices = [
|
||||||
|
sideways_std_multiplier,
|
||||||
|
trending_std_multiplier
|
||||||
|
]
|
||||||
|
# Default multiplier if MarketRegime is NaN (e.g., use trending or a neutral default like 2.0)
|
||||||
|
# For now, let's use trending_std_multiplier as default if MarketRegime is NaN.
|
||||||
|
# This can be adjusted based on desired behavior for periods where regime is undetermined.
|
||||||
|
row_specific_std_multiplier = np.select(conditions, choices, default=trending_std_multiplier)
|
||||||
|
|
||||||
|
# Calculate final Upper and Lower Bands using the row-specific multiplier
|
||||||
|
data_df['UpperBand'] = data_df['SMA'] + (row_specific_std_multiplier * std_dev)
|
||||||
|
data_df['LowerBand'] = data_df['SMA'] - (row_specific_std_multiplier * std_dev)
|
||||||
|
|
||||||
|
else: # squeeze is True
|
||||||
|
price_series = data_df[price_column]
|
||||||
|
# Use the static method for the squeeze case with fixed parameters
|
||||||
|
upper_band, sma, lower_band = self.calculate_custom_bands(
|
||||||
|
price_series,
|
||||||
|
window=14,
|
||||||
|
num_std=1.5,
|
||||||
|
min_periods=14 # Match typical squeeze behavior where bands appear after full period
|
||||||
|
)
|
||||||
|
data_df['SMA'] = sma
|
||||||
|
data_df['UpperBand'] = upper_band
|
||||||
|
data_df['LowerBand'] = lower_band
|
||||||
|
# BBWidth and MarketRegime are not typically calculated/used in a simple squeeze context by this method
|
||||||
|
# If needed, they could be added, but the current structure implies they are part of the non-squeeze path.
|
||||||
|
data_df['BBWidth'] = np.nan
|
||||||
|
data_df['MarketRegime'] = np.nan
|
||||||
|
|
||||||
|
return data_df
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_custom_bands(price_series: pd.Series, window: int = 20, num_std: float = 2.0, min_periods: int = None) -> tuple[pd.Series, pd.Series, pd.Series]:
|
||||||
|
"""
|
||||||
|
Calculates Bollinger Bands with specified window and standard deviation multiplier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
price_series (pd.Series): Series of prices.
|
||||||
|
window (int): The period for the moving average and standard deviation.
|
||||||
|
num_std (float): The number of standard deviations for the upper and lower bands.
|
||||||
|
min_periods (int, optional): Minimum number of observations in window required to have a value.
|
||||||
|
Defaults to `window` if None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[pd.Series, pd.Series, pd.Series]: Upper band, SMA, Lower band.
|
||||||
|
"""
|
||||||
|
if not isinstance(price_series, pd.Series):
|
||||||
|
raise TypeError("price_series must be a pandas Series.")
|
||||||
|
if not isinstance(window, int) or window <= 0:
|
||||||
|
raise ValueError("window must be a positive integer.")
|
||||||
|
if not isinstance(num_std, (int, float)) or num_std <= 0:
|
||||||
|
raise ValueError("num_std must be a positive number.")
|
||||||
|
if min_periods is not None and (not isinstance(min_periods, int) or min_periods <= 0):
|
||||||
|
raise ValueError("min_periods must be a positive integer if provided.")
|
||||||
|
|
||||||
|
actual_min_periods = window if min_periods is None else min_periods
|
||||||
|
|
||||||
|
sma = price_series.rolling(window=window, min_periods=actual_min_periods).mean()
|
||||||
|
std = price_series.rolling(window=window, min_periods=actual_min_periods).std()
|
||||||
|
|
||||||
|
# Replace NaN std with 0 to avoid issues if sma is present but std is not (e.g. constant price in window)
|
||||||
|
std = std.fillna(0)
|
||||||
|
|
||||||
|
upper_band = sma + (std * num_std)
|
||||||
|
lower_band = sma - (std * num_std)
|
||||||
|
|
||||||
|
return upper_band, sma, lower_band
|
||||||
113
cycles/Analysis/rsi.py
Normal file
113
cycles/Analysis/rsi.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class RSI:
|
||||||
|
"""
|
||||||
|
A class to calculate the Relative Strength Index (RSI).
|
||||||
|
"""
|
||||||
|
def __init__(self, config):
|
||||||
|
"""
|
||||||
|
Initializes the RSI calculator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
period (int): The period for RSI calculation. Default is 14.
|
||||||
|
Must be a positive integer.
|
||||||
|
"""
|
||||||
|
if not isinstance(config['rsi_period'], int) or config['rsi_period'] <= 0:
|
||||||
|
raise ValueError("Period must be a positive integer.")
|
||||||
|
self.period = config['rsi_period']
|
||||||
|
|
||||||
|
def calculate(self, data_df: pd.DataFrame, price_column: str = 'close') -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Calculates the RSI (using Wilder's smoothing) and adds it as a column to the input DataFrame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_df (pd.DataFrame): DataFrame with historical price data.
|
||||||
|
Must contain the 'price_column'.
|
||||||
|
price_column (str): The name of the column containing price data.
|
||||||
|
Default is 'close'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: The input DataFrame with an added 'RSI' column.
|
||||||
|
Returns the original DataFrame with no 'RSI' column
|
||||||
|
if the period is larger than the number of data points.
|
||||||
|
"""
|
||||||
|
if price_column not in data_df.columns:
|
||||||
|
raise ValueError(f"Price column '{price_column}' not found in DataFrame.")
|
||||||
|
|
||||||
|
# Check if data is sufficient for calculation (need period + 1 for one diff calculation)
|
||||||
|
if len(data_df) < self.period + 1:
|
||||||
|
print(f"Warning: Data length ({len(data_df)}) is less than RSI period ({self.period}) + 1. RSI will not be calculated meaningfully.")
|
||||||
|
df_copy = data_df.copy()
|
||||||
|
df_copy['RSI'] = np.nan # Add an RSI column with NaNs
|
||||||
|
return df_copy
|
||||||
|
|
||||||
|
df = data_df.copy() # Work on a copy
|
||||||
|
|
||||||
|
price_series = df[price_column]
|
||||||
|
|
||||||
|
# Call the static custom RSI calculator, defaulting to EMA for Wilder's smoothing
|
||||||
|
rsi_series = self.calculate_custom_rsi(price_series, window=self.period, smoothing='EMA')
|
||||||
|
|
||||||
|
df['RSI'] = rsi_series
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_custom_rsi(price_series: pd.Series, window: int = 14, smoothing: str = 'SMA') -> pd.Series:
|
||||||
|
"""
|
||||||
|
Calculates RSI with specified window and smoothing (SMA or EMA).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
price_series (pd.Series): Series of prices.
|
||||||
|
window (int): The period for RSI calculation. Must be a positive integer.
|
||||||
|
smoothing (str): Smoothing method, 'SMA' or 'EMA'. Defaults to 'SMA'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.Series: Series containing the RSI values.
|
||||||
|
"""
|
||||||
|
if not isinstance(price_series, pd.Series):
|
||||||
|
raise TypeError("price_series must be a pandas Series.")
|
||||||
|
if not isinstance(window, int) or window <= 0:
|
||||||
|
raise ValueError("window must be a positive integer.")
|
||||||
|
if smoothing not in ['SMA', 'EMA']:
|
||||||
|
raise ValueError("smoothing must be either 'SMA' or 'EMA'.")
|
||||||
|
if len(price_series) < window + 1: # Need at least window + 1 prices for one diff
|
||||||
|
# print(f"Warning: Data length ({len(price_series)}) is less than RSI window ({window}) + 1. RSI will be all NaN.")
|
||||||
|
return pd.Series(np.nan, index=price_series.index)
|
||||||
|
|
||||||
|
delta = price_series.diff()
|
||||||
|
# The first delta is NaN. For gain/loss calculations, it can be treated as 0.
|
||||||
|
# However, subsequent rolling/ewm will handle NaNs appropriately if min_periods is set.
|
||||||
|
|
||||||
|
gain = delta.where(delta > 0, 0.0)
|
||||||
|
loss = -delta.where(delta < 0, 0.0) # Ensure loss is positive
|
||||||
|
|
||||||
|
# Ensure gain and loss Series have the same index as price_series for rolling/ewm
|
||||||
|
# This is important if price_series has missing dates/times
|
||||||
|
gain = gain.reindex(price_series.index, fill_value=0.0)
|
||||||
|
loss = loss.reindex(price_series.index, fill_value=0.0)
|
||||||
|
|
||||||
|
if smoothing == 'EMA':
|
||||||
|
# adjust=False for Wilder's smoothing used in RSI
|
||||||
|
avg_gain = gain.ewm(alpha=1/window, adjust=False, min_periods=window).mean()
|
||||||
|
avg_loss = loss.ewm(alpha=1/window, adjust=False, min_periods=window).mean()
|
||||||
|
else: # SMA
|
||||||
|
avg_gain = gain.rolling(window=window, min_periods=window).mean()
|
||||||
|
avg_loss = loss.rolling(window=window, min_periods=window).mean()
|
||||||
|
|
||||||
|
# Handle division by zero for RS calculation
|
||||||
|
# If avg_loss is 0, RS can be considered infinite (if avg_gain > 0) or undefined (if avg_gain also 0)
|
||||||
|
rs = avg_gain / avg_loss.replace(0, 1e-9) # Replace 0 with a tiny number to avoid direct division by zero warning
|
||||||
|
|
||||||
|
rsi = 100 - (100 / (1 + rs))
|
||||||
|
|
||||||
|
# Correct RSI values for edge cases where avg_loss was 0
|
||||||
|
# If avg_loss is 0 and avg_gain is > 0, RSI is 100.
|
||||||
|
# If avg_loss is 0 and avg_gain is 0, RSI is 50 (neutral).
|
||||||
|
rsi[avg_loss == 0] = np.where(avg_gain[avg_loss == 0] > 0, 100, 50)
|
||||||
|
|
||||||
|
# Ensure RSI is NaN where avg_gain or avg_loss is NaN (due to min_periods)
|
||||||
|
rsi[avg_gain.isna() | avg_loss.isna()] = np.nan
|
||||||
|
|
||||||
|
return rsi
|
||||||
215
cycles/Analysis/supertrend.py
Normal file
215
cycles/Analysis/supertrend.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
@lru_cache(maxsize=32)
|
||||||
|
def cached_supertrend_calculation(period, multiplier, data_tuple):
|
||||||
|
high = np.array(data_tuple[0])
|
||||||
|
low = np.array(data_tuple[1])
|
||||||
|
close = np.array(data_tuple[2])
|
||||||
|
tr = np.zeros_like(close)
|
||||||
|
tr[0] = high[0] - low[0]
|
||||||
|
hc_range = np.abs(high[1:] - close[:-1])
|
||||||
|
lc_range = np.abs(low[1:] - close[:-1])
|
||||||
|
hl_range = high[1:] - low[1:]
|
||||||
|
tr[1:] = np.maximum.reduce([hl_range, hc_range, lc_range])
|
||||||
|
atr = np.zeros_like(tr)
|
||||||
|
atr[0] = tr[0]
|
||||||
|
multiplier_ema = 2.0 / (period + 1)
|
||||||
|
for i in range(1, len(tr)):
|
||||||
|
atr[i] = (tr[i] * multiplier_ema) + (atr[i-1] * (1 - multiplier_ema))
|
||||||
|
upper_band = np.zeros_like(close)
|
||||||
|
lower_band = np.zeros_like(close)
|
||||||
|
for i in range(len(close)):
|
||||||
|
hl_avg = (high[i] + low[i]) / 2
|
||||||
|
upper_band[i] = hl_avg + (multiplier * atr[i])
|
||||||
|
lower_band[i] = hl_avg - (multiplier * atr[i])
|
||||||
|
final_upper = np.zeros_like(close)
|
||||||
|
final_lower = np.zeros_like(close)
|
||||||
|
supertrend = np.zeros_like(close)
|
||||||
|
trend = np.zeros_like(close)
|
||||||
|
final_upper[0] = upper_band[0]
|
||||||
|
final_lower[0] = lower_band[0]
|
||||||
|
if close[0] <= upper_band[0]:
|
||||||
|
supertrend[0] = upper_band[0]
|
||||||
|
trend[0] = -1
|
||||||
|
else:
|
||||||
|
supertrend[0] = lower_band[0]
|
||||||
|
trend[0] = 1
|
||||||
|
for i in range(1, len(close)):
|
||||||
|
if (upper_band[i] < final_upper[i-1]) or (close[i-1] > final_upper[i-1]):
|
||||||
|
final_upper[i] = upper_band[i]
|
||||||
|
else:
|
||||||
|
final_upper[i] = final_upper[i-1]
|
||||||
|
if (lower_band[i] > final_lower[i-1]) or (close[i-1] < final_lower[i-1]):
|
||||||
|
final_lower[i] = lower_band[i]
|
||||||
|
else:
|
||||||
|
final_lower[i] = final_lower[i-1]
|
||||||
|
if supertrend[i-1] == final_upper[i-1] and close[i] <= final_upper[i]:
|
||||||
|
supertrend[i] = final_upper[i]
|
||||||
|
trend[i] = -1
|
||||||
|
elif supertrend[i-1] == final_upper[i-1] and close[i] > final_upper[i]:
|
||||||
|
supertrend[i] = final_lower[i]
|
||||||
|
trend[i] = 1
|
||||||
|
elif supertrend[i-1] == final_lower[i-1] and close[i] >= final_lower[i]:
|
||||||
|
supertrend[i] = final_lower[i]
|
||||||
|
trend[i] = 1
|
||||||
|
elif supertrend[i-1] == final_lower[i-1] and close[i] < final_lower[i]:
|
||||||
|
supertrend[i] = final_upper[i]
|
||||||
|
trend[i] = -1
|
||||||
|
return {
|
||||||
|
'supertrend': supertrend,
|
||||||
|
'trend': trend,
|
||||||
|
'upper_band': final_upper,
|
||||||
|
'lower_band': final_lower
|
||||||
|
}
|
||||||
|
|
||||||
|
def calculate_supertrend_external(data, period, multiplier, close_column='close'):
|
||||||
|
"""
|
||||||
|
External function to calculate SuperTrend with configurable close column
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- data: DataFrame with OHLC data
|
||||||
|
- period: int, period for ATR calculation
|
||||||
|
- multiplier: float, multiplier for ATR
|
||||||
|
- close_column: str, name of the column to use as close price (default: 'close')
|
||||||
|
"""
|
||||||
|
high_tuple = tuple(data['high'])
|
||||||
|
low_tuple = tuple(data['low'])
|
||||||
|
close_tuple = tuple(data[close_column])
|
||||||
|
return cached_supertrend_calculation(period, multiplier, (high_tuple, low_tuple, close_tuple))
|
||||||
|
|
||||||
|
class Supertrends:
|
||||||
|
def __init__(self, data, close_column='close', verbose=False, display=False):
|
||||||
|
"""
|
||||||
|
Initialize Supertrends calculator
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- data: pandas DataFrame with OHLC data or list of prices
|
||||||
|
- close_column: str, name of the column to use as close price (default: 'close')
|
||||||
|
- verbose: bool, enable verbose logging
|
||||||
|
- display: bool, display mode (currently unused)
|
||||||
|
"""
|
||||||
|
self.close_column = close_column
|
||||||
|
self.data = data
|
||||||
|
self.verbose = verbose
|
||||||
|
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
self.logger = logging.getLogger('TrendDetectorSimple')
|
||||||
|
|
||||||
|
if not isinstance(self.data, pd.DataFrame):
|
||||||
|
if isinstance(self.data, list):
|
||||||
|
self.data = pd.DataFrame({self.close_column: self.data})
|
||||||
|
else:
|
||||||
|
raise ValueError("Data must be a pandas DataFrame or a list")
|
||||||
|
|
||||||
|
# Validate that required columns exist
|
||||||
|
required_columns = ['high', 'low', self.close_column]
|
||||||
|
missing_columns = [col for col in required_columns if col not in self.data.columns]
|
||||||
|
if missing_columns:
|
||||||
|
raise ValueError(f"Missing required columns: {missing_columns}")
|
||||||
|
|
||||||
|
def calculate_tr(self):
|
||||||
|
"""Calculate True Range using the configured close column"""
|
||||||
|
df = self.data.copy()
|
||||||
|
high = df['high'].values
|
||||||
|
low = df['low'].values
|
||||||
|
close = df[self.close_column].values
|
||||||
|
tr = np.zeros_like(close)
|
||||||
|
tr[0] = high[0] - low[0]
|
||||||
|
for i in range(1, len(close)):
|
||||||
|
hl_range = high[i] - low[i]
|
||||||
|
hc_range = abs(high[i] - close[i-1])
|
||||||
|
lc_range = abs(low[i] - close[i-1])
|
||||||
|
tr[i] = max(hl_range, hc_range, lc_range)
|
||||||
|
return tr
|
||||||
|
|
||||||
|
def calculate_atr(self, period=14):
|
||||||
|
"""Calculate Average True Range"""
|
||||||
|
tr = self.calculate_tr()
|
||||||
|
atr = np.zeros_like(tr)
|
||||||
|
atr[0] = tr[0]
|
||||||
|
multiplier = 2.0 / (period + 1)
|
||||||
|
for i in range(1, len(tr)):
|
||||||
|
atr[i] = (tr[i] * multiplier) + (atr[i-1] * (1 - multiplier))
|
||||||
|
return atr
|
||||||
|
|
||||||
|
def calculate_supertrend(self, period=10, multiplier=3.0):
|
||||||
|
"""
|
||||||
|
Calculate SuperTrend indicator for the price data using the configured close column.
|
||||||
|
SuperTrend is a trend-following indicator that uses ATR to determine the trend direction.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- period: int, the period for the ATR calculation (default: 10)
|
||||||
|
- multiplier: float, the multiplier for the ATR (default: 3.0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Dictionary containing SuperTrend values, trend direction, and upper/lower bands
|
||||||
|
"""
|
||||||
|
df = self.data.copy()
|
||||||
|
high = df['high'].values
|
||||||
|
low = df['low'].values
|
||||||
|
close = df[self.close_column].values
|
||||||
|
atr = self.calculate_atr(period)
|
||||||
|
upper_band = np.zeros_like(close)
|
||||||
|
lower_band = np.zeros_like(close)
|
||||||
|
for i in range(len(close)):
|
||||||
|
hl_avg = (high[i] + low[i]) / 2
|
||||||
|
upper_band[i] = hl_avg + (multiplier * atr[i])
|
||||||
|
lower_band[i] = hl_avg - (multiplier * atr[i])
|
||||||
|
final_upper = np.zeros_like(close)
|
||||||
|
final_lower = np.zeros_like(close)
|
||||||
|
supertrend = np.zeros_like(close)
|
||||||
|
trend = np.zeros_like(close)
|
||||||
|
final_upper[0] = upper_band[0]
|
||||||
|
final_lower[0] = lower_band[0]
|
||||||
|
if close[0] <= upper_band[0]:
|
||||||
|
supertrend[0] = upper_band[0]
|
||||||
|
trend[0] = -1
|
||||||
|
else:
|
||||||
|
supertrend[0] = lower_band[0]
|
||||||
|
trend[0] = 1
|
||||||
|
for i in range(1, len(close)):
|
||||||
|
if (upper_band[i] < final_upper[i-1]) or (close[i-1] > final_upper[i-1]):
|
||||||
|
final_upper[i] = upper_band[i]
|
||||||
|
else:
|
||||||
|
final_upper[i] = final_upper[i-1]
|
||||||
|
if (lower_band[i] > final_lower[i-1]) or (close[i-1] < final_lower[i-1]):
|
||||||
|
final_lower[i] = lower_band[i]
|
||||||
|
else:
|
||||||
|
final_lower[i] = final_lower[i-1]
|
||||||
|
if supertrend[i-1] == final_upper[i-1] and close[i] <= final_upper[i]:
|
||||||
|
supertrend[i] = final_upper[i]
|
||||||
|
trend[i] = -1
|
||||||
|
elif supertrend[i-1] == final_upper[i-1] and close[i] > final_upper[i]:
|
||||||
|
supertrend[i] = final_lower[i]
|
||||||
|
trend[i] = 1
|
||||||
|
elif supertrend[i-1] == final_lower[i-1] and close[i] >= final_lower[i]:
|
||||||
|
supertrend[i] = final_lower[i]
|
||||||
|
trend[i] = 1
|
||||||
|
elif supertrend[i-1] == final_lower[i-1] and close[i] < final_lower[i]:
|
||||||
|
supertrend[i] = final_upper[i]
|
||||||
|
trend[i] = -1
|
||||||
|
supertrend_results = {
|
||||||
|
'supertrend': supertrend,
|
||||||
|
'trend': trend,
|
||||||
|
'upper_band': final_upper,
|
||||||
|
'lower_band': final_lower
|
||||||
|
}
|
||||||
|
return supertrend_results
|
||||||
|
|
||||||
|
def calculate_supertrend_indicators(self):
|
||||||
|
supertrend_params = [
|
||||||
|
{"period": 12, "multiplier": 3.0},
|
||||||
|
{"period": 10, "multiplier": 1.0},
|
||||||
|
{"period": 11, "multiplier": 2.0}
|
||||||
|
]
|
||||||
|
results = []
|
||||||
|
for p in supertrend_params:
|
||||||
|
result = self.calculate_supertrend(period=p["period"], multiplier=p["multiplier"])
|
||||||
|
results.append({
|
||||||
|
"results": result,
|
||||||
|
"params": p
|
||||||
|
})
|
||||||
|
return results
|
||||||
0
cycles/__init__.py
Normal file
0
cycles/__init__.py
Normal file
332
cycles/backtest.py
Normal file
332
cycles/backtest.py
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
|
||||||
|
from cycles.supertrend import Supertrends
|
||||||
|
from cycles.market_fees import MarketFees
|
||||||
|
|
||||||
|
class Backtest:
|
||||||
|
@staticmethod
|
||||||
|
def run(min1_df, df, initial_usd, stop_loss_pct, progress_callback=None, verbose=False):
|
||||||
|
"""
|
||||||
|
Backtest a simple strategy using the meta supertrend (all three supertrends agree).
|
||||||
|
Buys when meta supertrend is positive, sells when negative, applies a percentage stop loss.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- min1_df: pandas DataFrame, 1-minute timeframe data for more accurate stop loss checking (optional)
|
||||||
|
- df: pandas DataFrame, main timeframe data for signals
|
||||||
|
- initial_usd: float, starting USD amount
|
||||||
|
- stop_loss_pct: float, stop loss as a fraction (e.g. 0.05 for 5%)
|
||||||
|
- progress_callback: callable, optional callback function to report progress (current_step)
|
||||||
|
- verbose: bool, enable debug logging for stop loss checks
|
||||||
|
"""
|
||||||
|
_df = df.copy().reset_index()
|
||||||
|
|
||||||
|
# Ensure we have a timestamp column regardless of original index name
|
||||||
|
if 'timestamp' not in _df.columns:
|
||||||
|
# If reset_index() created a column with the original index name, rename it
|
||||||
|
if len(_df.columns) > 0 and _df.columns[0] not in ['open', 'high', 'low', 'close', 'volume', 'predicted_close_price']:
|
||||||
|
_df = _df.rename(columns={_df.columns[0]: 'timestamp'})
|
||||||
|
else:
|
||||||
|
raise ValueError("Unable to identify timestamp column in DataFrame")
|
||||||
|
|
||||||
|
_df['timestamp'] = pd.to_datetime(_df['timestamp'])
|
||||||
|
|
||||||
|
supertrends = Supertrends(_df, verbose=False, close_column='predicted_close_price')
|
||||||
|
|
||||||
|
supertrend_results_list = supertrends.calculate_supertrend_indicators()
|
||||||
|
trends = [st['results']['trend'] for st in supertrend_results_list]
|
||||||
|
trends_arr = np.stack(trends, axis=1)
|
||||||
|
meta_trend = np.where((trends_arr[:,0] == trends_arr[:,1]) & (trends_arr[:,1] == trends_arr[:,2]),
|
||||||
|
trends_arr[:,0], 0)
|
||||||
|
# Shift meta_trend by one to avoid lookahead bias
|
||||||
|
meta_trend_signal = np.roll(meta_trend, 1)
|
||||||
|
meta_trend_signal[0] = 0 # or np.nan, but 0 means 'no signal' for first bar
|
||||||
|
|
||||||
|
position = 0 # 0 = no position, 1 = long
|
||||||
|
entry_price = 0
|
||||||
|
usd = initial_usd
|
||||||
|
coin = 0
|
||||||
|
trade_log = []
|
||||||
|
max_balance = initial_usd
|
||||||
|
drawdowns = []
|
||||||
|
trades = []
|
||||||
|
entry_time = None
|
||||||
|
stop_loss_count = 0 # Track number of stop losses
|
||||||
|
|
||||||
|
# Ensure min1_df has proper DatetimeIndex
|
||||||
|
if min1_df is not None and not min1_df.empty:
|
||||||
|
min1_df.index = pd.to_datetime(min1_df.index)
|
||||||
|
|
||||||
|
for i in range(1, len(_df)):
|
||||||
|
# Report progress if callback is provided
|
||||||
|
if progress_callback:
|
||||||
|
# Update more frequently for better responsiveness
|
||||||
|
update_frequency = max(1, len(_df) // 50) # Update every 2% of dataset (50 updates total)
|
||||||
|
if i % update_frequency == 0 or i == len(_df) - 1: # Always update on last iteration
|
||||||
|
if verbose: # Only print in verbose mode to avoid spam
|
||||||
|
print(f"DEBUG: Progress callback called with i={i}, total={len(_df)-1}")
|
||||||
|
progress_callback(i)
|
||||||
|
|
||||||
|
price_open = _df['open'].iloc[i]
|
||||||
|
price_close = _df['close'].iloc[i]
|
||||||
|
date = _df['timestamp'].iloc[i]
|
||||||
|
prev_mt = meta_trend_signal[i-1]
|
||||||
|
curr_mt = meta_trend_signal[i]
|
||||||
|
|
||||||
|
# Check stop loss if in position
|
||||||
|
if position == 1:
|
||||||
|
stop_loss_result = Backtest.check_stop_loss(
|
||||||
|
min1_df,
|
||||||
|
entry_time,
|
||||||
|
date,
|
||||||
|
entry_price,
|
||||||
|
stop_loss_pct,
|
||||||
|
coin,
|
||||||
|
verbose=verbose
|
||||||
|
)
|
||||||
|
if stop_loss_result is not None:
|
||||||
|
trade_log_entry, position, coin, entry_price, usd = stop_loss_result
|
||||||
|
trade_log.append(trade_log_entry)
|
||||||
|
stop_loss_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Entry: only if not in position and signal changes to 1
|
||||||
|
if position == 0 and prev_mt != 1 and curr_mt == 1:
|
||||||
|
entry_result = Backtest.handle_entry(usd, price_open, date)
|
||||||
|
coin, entry_price, entry_time, usd, position, trade_log_entry = entry_result
|
||||||
|
trade_log.append(trade_log_entry)
|
||||||
|
|
||||||
|
# Exit: only if in position and signal changes from 1 to -1
|
||||||
|
elif position == 1 and prev_mt == 1 and curr_mt == -1:
|
||||||
|
exit_result = Backtest.handle_exit(coin, price_open, entry_price, entry_time, date)
|
||||||
|
usd, coin, position, entry_price, trade_log_entry = exit_result
|
||||||
|
trade_log.append(trade_log_entry)
|
||||||
|
|
||||||
|
# Track drawdown
|
||||||
|
balance = usd if position == 0 else coin * price_close
|
||||||
|
if balance > max_balance:
|
||||||
|
max_balance = balance
|
||||||
|
drawdown = (max_balance - balance) / max_balance
|
||||||
|
drawdowns.append(drawdown)
|
||||||
|
|
||||||
|
# Report completion if callback is provided
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(len(_df) - 1)
|
||||||
|
|
||||||
|
# If still in position at end, sell at last close
|
||||||
|
if position == 1:
|
||||||
|
exit_result = Backtest.handle_exit(coin, _df['close'].iloc[-1], entry_price, entry_time, _df['timestamp'].iloc[-1])
|
||||||
|
usd, coin, position, entry_price, trade_log_entry = exit_result
|
||||||
|
trade_log.append(trade_log_entry)
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
final_balance = usd
|
||||||
|
n_trades = len(trade_log)
|
||||||
|
wins = [1 for t in trade_log if t['exit'] is not None and t['exit'] > t['entry']]
|
||||||
|
win_rate = len(wins) / n_trades if n_trades > 0 else 0
|
||||||
|
max_drawdown = max(drawdowns) if drawdowns else 0
|
||||||
|
avg_trade = np.mean([t['exit']/t['entry']-1 for t in trade_log if t['exit'] is not None]) if trade_log else 0
|
||||||
|
|
||||||
|
trades = []
|
||||||
|
total_fees_usd = 0.0
|
||||||
|
for trade in trade_log:
|
||||||
|
if trade['exit'] is not None:
|
||||||
|
profit_pct = (trade['exit'] - trade['entry']) / trade['entry']
|
||||||
|
else:
|
||||||
|
profit_pct = 0.0
|
||||||
|
|
||||||
|
# Validate fee_usd field
|
||||||
|
if 'fee_usd' not in trade:
|
||||||
|
raise ValueError(f"Trade missing required field 'fee_usd': {trade}")
|
||||||
|
fee_usd = trade['fee_usd']
|
||||||
|
if fee_usd is None:
|
||||||
|
raise ValueError(f"Trade fee_usd is None: {trade}")
|
||||||
|
|
||||||
|
# Validate trade type field
|
||||||
|
if 'type' not in trade:
|
||||||
|
raise ValueError(f"Trade missing required field 'type': {trade}")
|
||||||
|
trade_type = trade['type']
|
||||||
|
if trade_type is None:
|
||||||
|
raise ValueError(f"Trade type is None: {trade}")
|
||||||
|
|
||||||
|
trades.append({
|
||||||
|
'entry_time': trade['entry_time'],
|
||||||
|
'exit_time': trade['exit_time'],
|
||||||
|
'entry': trade['entry'],
|
||||||
|
'exit': trade['exit'],
|
||||||
|
'profit_pct': profit_pct,
|
||||||
|
'type': trade_type,
|
||||||
|
'fee_usd': fee_usd
|
||||||
|
})
|
||||||
|
total_fees_usd += fee_usd
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"initial_usd": initial_usd,
|
||||||
|
"final_usd": final_balance,
|
||||||
|
"n_trades": n_trades,
|
||||||
|
"n_stop_loss": stop_loss_count, # Add stop loss count
|
||||||
|
"win_rate": win_rate,
|
||||||
|
"max_drawdown": max_drawdown,
|
||||||
|
"avg_trade": avg_trade,
|
||||||
|
"trade_log": trade_log,
|
||||||
|
"trades": trades,
|
||||||
|
"total_fees_usd": total_fees_usd,
|
||||||
|
}
|
||||||
|
if n_trades > 0:
|
||||||
|
results["first_trade"] = {
|
||||||
|
"entry_time": trade_log[0]['entry_time'],
|
||||||
|
"entry": trade_log[0]['entry']
|
||||||
|
}
|
||||||
|
results["last_trade"] = {
|
||||||
|
"exit_time": trade_log[-1]['exit_time'],
|
||||||
|
"exit": trade_log[-1]['exit']
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_stop_loss(min1_df, entry_time, current_time, entry_price, stop_loss_pct, coin, verbose=False):
|
||||||
|
"""
|
||||||
|
Check if stop loss should be triggered based on 1-minute data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min1_df: 1-minute DataFrame with DatetimeIndex
|
||||||
|
entry_time: Entry timestamp
|
||||||
|
current_time: Current timestamp
|
||||||
|
entry_price: Entry price
|
||||||
|
stop_loss_pct: Stop loss percentage (e.g. 0.05 for 5%)
|
||||||
|
coin: Current coin position
|
||||||
|
verbose: Enable debug logging
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (trade_log_entry, position, coin, entry_price, usd) if stop loss triggered, None otherwise
|
||||||
|
"""
|
||||||
|
if min1_df is None or min1_df.empty:
|
||||||
|
if verbose:
|
||||||
|
print("Warning: No 1-minute data available for stop loss checking")
|
||||||
|
return None
|
||||||
|
|
||||||
|
stop_price = entry_price * (1 - stop_loss_pct)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Ensure min1_df has a DatetimeIndex
|
||||||
|
if not isinstance(min1_df.index, pd.DatetimeIndex):
|
||||||
|
if verbose:
|
||||||
|
print("Warning: min1_df does not have DatetimeIndex")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert entry_time and current_time to pandas Timestamps for comparison
|
||||||
|
entry_ts = pd.to_datetime(entry_time)
|
||||||
|
current_ts = pd.to_datetime(current_time)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Checking stop loss from {entry_ts} to {current_ts}, stop_price: {stop_price:.2f}")
|
||||||
|
|
||||||
|
# Handle edge case where entry and current time are the same (1-minute timeframe)
|
||||||
|
if entry_ts == current_ts:
|
||||||
|
if verbose:
|
||||||
|
print("Entry and current time are the same, no range to check")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find the range of 1-minute data to check (exclusive of entry time, inclusive of current time)
|
||||||
|
# We start from the candle AFTER entry to avoid checking the entry candle itself
|
||||||
|
start_check_time = entry_ts + pd.Timedelta(minutes=1)
|
||||||
|
|
||||||
|
# Get the slice of data to check for stop loss
|
||||||
|
mask = (min1_df.index > entry_ts) & (min1_df.index <= current_ts)
|
||||||
|
min1_slice = min1_df.loc[mask]
|
||||||
|
|
||||||
|
if len(min1_slice) == 0:
|
||||||
|
if verbose:
|
||||||
|
print(f"No 1-minute data found between {start_check_time} and {current_ts}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Checking {len(min1_slice)} candles for stop loss")
|
||||||
|
|
||||||
|
# Check if any low price in the slice hits the stop loss
|
||||||
|
stop_triggered = (min1_slice['low'] <= stop_price).any()
|
||||||
|
|
||||||
|
if stop_triggered:
|
||||||
|
# Find the exact candle where stop loss was triggered
|
||||||
|
stop_candle = min1_slice[min1_slice['low'] <= stop_price].iloc[0]
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Stop loss triggered at {stop_candle.name}, low: {stop_candle['low']:.2f}")
|
||||||
|
|
||||||
|
# More realistic fill: if open < stop, fill at open, else at stop
|
||||||
|
if stop_candle['open'] < stop_price:
|
||||||
|
sell_price = stop_candle['open']
|
||||||
|
if verbose:
|
||||||
|
print(f"Filled at open price: {sell_price:.2f}")
|
||||||
|
else:
|
||||||
|
sell_price = stop_price
|
||||||
|
if verbose:
|
||||||
|
print(f"Filled at stop price: {sell_price:.2f}")
|
||||||
|
|
||||||
|
btc_to_sell = coin
|
||||||
|
usd_gross = btc_to_sell * sell_price
|
||||||
|
exit_fee = MarketFees.calculate_okx_taker_maker_fee(usd_gross, is_maker=False)
|
||||||
|
usd_after_stop = usd_gross - exit_fee
|
||||||
|
|
||||||
|
trade_log_entry = {
|
||||||
|
'type': 'STOP',
|
||||||
|
'entry': entry_price,
|
||||||
|
'exit': sell_price,
|
||||||
|
'entry_time': entry_time,
|
||||||
|
'exit_time': stop_candle.name,
|
||||||
|
'fee_usd': exit_fee
|
||||||
|
}
|
||||||
|
# After stop loss, reset position and entry, return USD balance
|
||||||
|
return trade_log_entry, 0, 0, 0, usd_after_stop
|
||||||
|
elif verbose:
|
||||||
|
print(f"No stop loss triggered, min low in range: {min1_slice['low'].min():.2f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# In case of any error, don't trigger stop loss but log the issue
|
||||||
|
error_msg = f"Warning: Stop loss check failed: {e}"
|
||||||
|
print(error_msg)
|
||||||
|
if verbose:
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def handle_entry(usd, price_open, date):
|
||||||
|
entry_fee = MarketFees.calculate_okx_taker_maker_fee(usd, is_maker=False)
|
||||||
|
usd_after_fee = usd - entry_fee
|
||||||
|
coin = usd_after_fee / price_open
|
||||||
|
entry_price = price_open
|
||||||
|
entry_time = date
|
||||||
|
usd = 0
|
||||||
|
position = 1
|
||||||
|
trade_log_entry = {
|
||||||
|
'type': 'BUY',
|
||||||
|
'entry': entry_price,
|
||||||
|
'exit': None,
|
||||||
|
'entry_time': entry_time,
|
||||||
|
'exit_time': None,
|
||||||
|
'fee_usd': entry_fee
|
||||||
|
}
|
||||||
|
return coin, entry_price, entry_time, usd, position, trade_log_entry
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def handle_exit(coin, price_open, entry_price, entry_time, date):
|
||||||
|
btc_to_sell = coin
|
||||||
|
usd_gross = btc_to_sell * price_open
|
||||||
|
exit_fee = MarketFees.calculate_okx_taker_maker_fee(usd_gross, is_maker=False)
|
||||||
|
usd = usd_gross - exit_fee
|
||||||
|
trade_log_entry = {
|
||||||
|
'type': 'SELL',
|
||||||
|
'entry': entry_price,
|
||||||
|
'exit': price_open,
|
||||||
|
'entry_time': entry_time,
|
||||||
|
'exit_time': date,
|
||||||
|
'fee_usd': exit_fee
|
||||||
|
}
|
||||||
|
coin = 0
|
||||||
|
position = 0
|
||||||
|
entry_price = 0
|
||||||
|
return usd, coin, position, entry_price, trade_log_entry
|
||||||
453
cycles/charts.py
Normal file
453
cycles/charts.py
Normal file
@@ -0,0 +1,453 @@
|
|||||||
|
import os
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class BacktestCharts:
|
||||||
|
@staticmethod
|
||||||
|
def plot(df, meta_trend):
|
||||||
|
"""
|
||||||
|
Plot close price line chart with a bar at the bottom: green when trend is 1, red when trend is 0.
|
||||||
|
The bar stays at the bottom even when zooming/panning.
|
||||||
|
- df: DataFrame with columns ['close', ...] and a datetime index or 'timestamp' column.
|
||||||
|
- meta_trend: array-like, same length as df, values 1 (green) or 0 (red).
|
||||||
|
"""
|
||||||
|
fig, (ax_price, ax_bar) = plt.subplots(
|
||||||
|
nrows=2, ncols=1, figsize=(16, 8), sharex=True,
|
||||||
|
gridspec_kw={'height_ratios': [12, 1]}
|
||||||
|
)
|
||||||
|
|
||||||
|
sns.lineplot(x=df.index, y=df['close'], label='Close Price', color='blue', ax=ax_price)
|
||||||
|
ax_price.set_title('Close Price with Trend Bar (Green=1, Red=0)')
|
||||||
|
ax_price.set_ylabel('Price')
|
||||||
|
ax_price.grid(True, alpha=0.3)
|
||||||
|
ax_price.legend()
|
||||||
|
|
||||||
|
# Clean meta_trend: ensure only 0/1, handle NaNs by forward-fill then fill remaining with 0
|
||||||
|
meta_trend_arr = np.asarray(meta_trend)
|
||||||
|
if not np.issubdtype(meta_trend_arr.dtype, np.number):
|
||||||
|
meta_trend_arr = pd.Series(meta_trend_arr).astype(float).to_numpy()
|
||||||
|
if np.isnan(meta_trend_arr).any():
|
||||||
|
meta_trend_arr = pd.Series(meta_trend_arr).fillna(method='ffill').fillna(0).astype(int).to_numpy()
|
||||||
|
else:
|
||||||
|
meta_trend_arr = meta_trend_arr.astype(int)
|
||||||
|
meta_trend_arr = np.where(meta_trend_arr != 1, 0, 1) # force only 0 or 1
|
||||||
|
if hasattr(df.index, 'to_numpy'):
|
||||||
|
x_vals = df.index.to_numpy()
|
||||||
|
else:
|
||||||
|
x_vals = np.array(df.index)
|
||||||
|
|
||||||
|
# Find contiguous regions
|
||||||
|
regions = []
|
||||||
|
start = 0
|
||||||
|
for i in range(1, len(meta_trend_arr)):
|
||||||
|
if meta_trend_arr[i] != meta_trend_arr[i-1]:
|
||||||
|
regions.append((start, i-1, meta_trend_arr[i-1]))
|
||||||
|
start = i
|
||||||
|
regions.append((start, len(meta_trend_arr)-1, meta_trend_arr[-1]))
|
||||||
|
|
||||||
|
# Draw red vertical lines at the start of each new region (except the first)
|
||||||
|
for region_idx in range(1, len(regions)):
|
||||||
|
region_start = regions[region_idx][0]
|
||||||
|
ax_price.axvline(x=x_vals[region_start], color='black', linestyle='--', alpha=0.7, linewidth=1)
|
||||||
|
|
||||||
|
for start, end, trend in regions:
|
||||||
|
color = '#089981' if trend == 1 else '#F23645'
|
||||||
|
# Offset by 1 on x: span from x_vals[start] to x_vals[end+1] if possible
|
||||||
|
x_start = x_vals[start]
|
||||||
|
x_end = x_vals[end+1] if end+1 < len(x_vals) else x_vals[end]
|
||||||
|
ax_bar.axvspan(x_start, x_end, color=color, alpha=1, ymin=0, ymax=1)
|
||||||
|
|
||||||
|
ax_bar.set_ylim(0, 1)
|
||||||
|
ax_bar.set_yticks([])
|
||||||
|
ax_bar.set_ylabel('Trend')
|
||||||
|
ax_bar.set_xlabel('Time')
|
||||||
|
ax_bar.grid(False)
|
||||||
|
ax_bar.set_title('Meta Trend')
|
||||||
|
|
||||||
|
plt.tight_layout(h_pad=0.1)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_strategy_data_with_trades(strategy_data, backtest_results):
|
||||||
|
"""
|
||||||
|
Format strategy data for universal plotting with actual executed trades.
|
||||||
|
Converts strategy output into the expected column format: "x_type_name"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy_data (DataFrame): Output from strategy with columns like 'close', 'UpperBand', 'LowerBand', 'RSI'
|
||||||
|
backtest_results (dict): Results from backtest.run() containing actual executed trades
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame: Formatted data ready for plot_data function
|
||||||
|
"""
|
||||||
|
formatted_df = pd.DataFrame(index=strategy_data.index)
|
||||||
|
|
||||||
|
# Plot 1: Price data with Bollinger Bands and actual trade signals
|
||||||
|
if 'close' in strategy_data.columns:
|
||||||
|
formatted_df['1_line_close'] = strategy_data['close']
|
||||||
|
|
||||||
|
# Bollinger Bands area (prefer standard names, fallback to timeframe-specific)
|
||||||
|
upper_band_col = None
|
||||||
|
lower_band_col = None
|
||||||
|
sma_col = None
|
||||||
|
|
||||||
|
# Check for standard BB columns first
|
||||||
|
if 'UpperBand' in strategy_data.columns and 'LowerBand' in strategy_data.columns:
|
||||||
|
upper_band_col = 'UpperBand'
|
||||||
|
lower_band_col = 'LowerBand'
|
||||||
|
# Check for 15m BB columns
|
||||||
|
elif 'UpperBand_15m' in strategy_data.columns and 'LowerBand_15m' in strategy_data.columns:
|
||||||
|
upper_band_col = 'UpperBand_15m'
|
||||||
|
lower_band_col = 'LowerBand_15m'
|
||||||
|
|
||||||
|
if upper_band_col and lower_band_col:
|
||||||
|
formatted_df['1_area_bb_upper'] = strategy_data[upper_band_col]
|
||||||
|
formatted_df['1_area_bb_lower'] = strategy_data[lower_band_col]
|
||||||
|
|
||||||
|
# SMA/Moving Average line
|
||||||
|
if 'SMA' in strategy_data.columns:
|
||||||
|
sma_col = 'SMA'
|
||||||
|
elif 'SMA_15m' in strategy_data.columns:
|
||||||
|
sma_col = 'SMA_15m'
|
||||||
|
|
||||||
|
if sma_col:
|
||||||
|
formatted_df['1_line_sma'] = strategy_data[sma_col]
|
||||||
|
|
||||||
|
# Strategy buy/sell signals (all signals from strategy) as smaller scatter points
|
||||||
|
if 'BuySignal' in strategy_data.columns and 'close' in strategy_data.columns:
|
||||||
|
strategy_buy_points = strategy_data['close'].where(strategy_data['BuySignal'], np.nan)
|
||||||
|
formatted_df['1_scatter_strategy_buy'] = strategy_buy_points
|
||||||
|
|
||||||
|
if 'SellSignal' in strategy_data.columns and 'close' in strategy_data.columns:
|
||||||
|
strategy_sell_points = strategy_data['close'].where(strategy_data['SellSignal'], np.nan)
|
||||||
|
formatted_df['1_scatter_strategy_sell'] = strategy_sell_points
|
||||||
|
|
||||||
|
# Actual executed trades from backtest results (larger, more prominent)
|
||||||
|
if 'trades' in backtest_results and backtest_results['trades']:
|
||||||
|
# Create series for buy and sell points
|
||||||
|
buy_points = pd.Series(np.nan, index=strategy_data.index)
|
||||||
|
sell_points = pd.Series(np.nan, index=strategy_data.index)
|
||||||
|
|
||||||
|
for trade in backtest_results['trades']:
|
||||||
|
entry_time = trade.get('entry_time')
|
||||||
|
exit_time = trade.get('exit_time')
|
||||||
|
entry_price = trade.get('entry')
|
||||||
|
exit_price = trade.get('exit')
|
||||||
|
|
||||||
|
# Find closest index for entry time
|
||||||
|
if entry_time is not None and entry_price is not None:
|
||||||
|
try:
|
||||||
|
if isinstance(entry_time, str):
|
||||||
|
entry_time = pd.to_datetime(entry_time)
|
||||||
|
# Find the closest index to entry_time
|
||||||
|
closest_entry_idx = strategy_data.index.get_indexer([entry_time], method='nearest')[0]
|
||||||
|
if closest_entry_idx >= 0:
|
||||||
|
buy_points.iloc[closest_entry_idx] = entry_price
|
||||||
|
except (ValueError, IndexError, TypeError):
|
||||||
|
pass # Skip if can't find matching time
|
||||||
|
|
||||||
|
# Find closest index for exit time
|
||||||
|
if exit_time is not None and exit_price is not None:
|
||||||
|
try:
|
||||||
|
if isinstance(exit_time, str):
|
||||||
|
exit_time = pd.to_datetime(exit_time)
|
||||||
|
# Find the closest index to exit_time
|
||||||
|
closest_exit_idx = strategy_data.index.get_indexer([exit_time], method='nearest')[0]
|
||||||
|
if closest_exit_idx >= 0:
|
||||||
|
sell_points.iloc[closest_exit_idx] = exit_price
|
||||||
|
except (ValueError, IndexError, TypeError):
|
||||||
|
pass # Skip if can't find matching time
|
||||||
|
|
||||||
|
formatted_df['1_scatter_actual_buy'] = buy_points
|
||||||
|
formatted_df['1_scatter_actual_sell'] = sell_points
|
||||||
|
|
||||||
|
# Stop Loss and Take Profit levels
|
||||||
|
if 'StopLoss' in strategy_data.columns:
|
||||||
|
formatted_df['1_line_stop_loss'] = strategy_data['StopLoss']
|
||||||
|
if 'TakeProfit' in strategy_data.columns:
|
||||||
|
formatted_df['1_line_take_profit'] = strategy_data['TakeProfit']
|
||||||
|
|
||||||
|
# Plot 2: RSI
|
||||||
|
rsi_col = None
|
||||||
|
if 'RSI' in strategy_data.columns:
|
||||||
|
rsi_col = 'RSI'
|
||||||
|
elif 'RSI_15m' in strategy_data.columns:
|
||||||
|
rsi_col = 'RSI_15m'
|
||||||
|
|
||||||
|
if rsi_col:
|
||||||
|
formatted_df['2_line_rsi'] = strategy_data[rsi_col]
|
||||||
|
# Add RSI overbought/oversold levels
|
||||||
|
formatted_df['2_line_rsi_overbought'] = 70
|
||||||
|
formatted_df['2_line_rsi_oversold'] = 30
|
||||||
|
|
||||||
|
# Plot 3: Volume (if available)
|
||||||
|
if 'volume' in strategy_data.columns:
|
||||||
|
formatted_df['3_bar_volume'] = strategy_data['volume']
|
||||||
|
|
||||||
|
# Add volume moving average if available
|
||||||
|
if 'VolumeMA_15m' in strategy_data.columns:
|
||||||
|
formatted_df['3_line_volume_ma'] = strategy_data['VolumeMA_15m']
|
||||||
|
|
||||||
|
return formatted_df
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_strategy_data(strategy_data):
|
||||||
|
"""
|
||||||
|
Format strategy data for universal plotting (without trade signals).
|
||||||
|
Converts strategy output into the expected column format: "x_type_name"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy_data (DataFrame): Output from strategy with columns like 'close', 'UpperBand', 'LowerBand', 'RSI'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame: Formatted data ready for plot_data function
|
||||||
|
"""
|
||||||
|
formatted_df = pd.DataFrame(index=strategy_data.index)
|
||||||
|
|
||||||
|
# Plot 1: Price data with Bollinger Bands
|
||||||
|
if 'close' in strategy_data.columns:
|
||||||
|
formatted_df['1_line_close'] = strategy_data['close']
|
||||||
|
|
||||||
|
# Bollinger Bands area (prefer standard names, fallback to timeframe-specific)
|
||||||
|
upper_band_col = None
|
||||||
|
lower_band_col = None
|
||||||
|
sma_col = None
|
||||||
|
|
||||||
|
# Check for standard BB columns first
|
||||||
|
if 'UpperBand' in strategy_data.columns and 'LowerBand' in strategy_data.columns:
|
||||||
|
upper_band_col = 'UpperBand'
|
||||||
|
lower_band_col = 'LowerBand'
|
||||||
|
# Check for 15m BB columns
|
||||||
|
elif 'UpperBand_15m' in strategy_data.columns and 'LowerBand_15m' in strategy_data.columns:
|
||||||
|
upper_band_col = 'UpperBand_15m'
|
||||||
|
lower_band_col = 'LowerBand_15m'
|
||||||
|
|
||||||
|
if upper_band_col and lower_band_col:
|
||||||
|
formatted_df['1_area_bb_upper'] = strategy_data[upper_band_col]
|
||||||
|
formatted_df['1_area_bb_lower'] = strategy_data[lower_band_col]
|
||||||
|
|
||||||
|
# SMA/Moving Average line
|
||||||
|
if 'SMA' in strategy_data.columns:
|
||||||
|
sma_col = 'SMA'
|
||||||
|
elif 'SMA_15m' in strategy_data.columns:
|
||||||
|
sma_col = 'SMA_15m'
|
||||||
|
|
||||||
|
if sma_col:
|
||||||
|
formatted_df['1_line_sma'] = strategy_data[sma_col]
|
||||||
|
|
||||||
|
# Stop Loss and Take Profit levels
|
||||||
|
if 'StopLoss' in strategy_data.columns:
|
||||||
|
formatted_df['1_line_stop_loss'] = strategy_data['StopLoss']
|
||||||
|
if 'TakeProfit' in strategy_data.columns:
|
||||||
|
formatted_df['1_line_take_profit'] = strategy_data['TakeProfit']
|
||||||
|
|
||||||
|
# Plot 2: RSI
|
||||||
|
rsi_col = None
|
||||||
|
if 'RSI' in strategy_data.columns:
|
||||||
|
rsi_col = 'RSI'
|
||||||
|
elif 'RSI_15m' in strategy_data.columns:
|
||||||
|
rsi_col = 'RSI_15m'
|
||||||
|
|
||||||
|
if rsi_col:
|
||||||
|
formatted_df['2_line_rsi'] = strategy_data[rsi_col]
|
||||||
|
# Add RSI overbought/oversold levels
|
||||||
|
formatted_df['2_line_rsi_overbought'] = 70
|
||||||
|
formatted_df['2_line_rsi_oversold'] = 30
|
||||||
|
|
||||||
|
# Plot 3: Volume (if available)
|
||||||
|
if 'volume' in strategy_data.columns:
|
||||||
|
formatted_df['3_bar_volume'] = strategy_data['volume']
|
||||||
|
|
||||||
|
# Add volume moving average if available
|
||||||
|
if 'VolumeMA_15m' in strategy_data.columns:
|
||||||
|
formatted_df['3_line_volume_ma'] = strategy_data['VolumeMA_15m']
|
||||||
|
|
||||||
|
return formatted_df
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def plot_data(df):
|
||||||
|
"""
|
||||||
|
Universal plot function for any formatted data.
|
||||||
|
- df: DataFrame with column names in format "x_type_name" where:
|
||||||
|
x = plot number (subplot)
|
||||||
|
type = plot type (line, area, scatter, bar, etc.)
|
||||||
|
name = descriptive name for the data series
|
||||||
|
"""
|
||||||
|
if df.empty:
|
||||||
|
print("No data to plot")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Parse all columns
|
||||||
|
plot_info = []
|
||||||
|
for column in df.columns:
|
||||||
|
parts = column.split('_', 2) # Split into max 3 parts
|
||||||
|
if len(parts) < 3:
|
||||||
|
print(f"Warning: Skipping column '{column}' - invalid format. Expected 'x_type_name'")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
plot_number = int(parts[0])
|
||||||
|
plot_type = parts[1].lower()
|
||||||
|
plot_name = parts[2]
|
||||||
|
plot_info.append((plot_number, plot_type, plot_name, column))
|
||||||
|
except ValueError:
|
||||||
|
print(f"Warning: Skipping column '{column}' - invalid plot number")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not plot_info:
|
||||||
|
print("No valid columns found for plotting")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Group by plot number
|
||||||
|
plots = {}
|
||||||
|
for plot_num, plot_type, plot_name, column in plot_info:
|
||||||
|
if plot_num not in plots:
|
||||||
|
plots[plot_num] = []
|
||||||
|
plots[plot_num].append((plot_type, plot_name, column))
|
||||||
|
|
||||||
|
# Sort plot numbers
|
||||||
|
plot_numbers = sorted(plots.keys())
|
||||||
|
n_plots = len(plot_numbers)
|
||||||
|
|
||||||
|
# Create subplots
|
||||||
|
fig, axs = plt.subplots(n_plots, 1, figsize=(16, 6 * n_plots), sharex=True)
|
||||||
|
if n_plots == 1:
|
||||||
|
axs = [axs] # Ensure axs is always a list
|
||||||
|
|
||||||
|
# Plot each subplot
|
||||||
|
for i, plot_num in enumerate(plot_numbers):
|
||||||
|
ax = axs[i]
|
||||||
|
plot_items = plots[plot_num]
|
||||||
|
|
||||||
|
# Handle Bollinger Bands area first (needs special handling)
|
||||||
|
bb_upper = None
|
||||||
|
bb_lower = None
|
||||||
|
|
||||||
|
for plot_type, plot_name, column in plot_items:
|
||||||
|
if plot_type == 'area' and 'bb_upper' in plot_name:
|
||||||
|
bb_upper = df[column]
|
||||||
|
elif plot_type == 'area' and 'bb_lower' in plot_name:
|
||||||
|
bb_lower = df[column]
|
||||||
|
|
||||||
|
# Plot Bollinger Bands area if both bounds exist
|
||||||
|
if bb_upper is not None and bb_lower is not None:
|
||||||
|
ax.fill_between(df.index, bb_upper, bb_lower, alpha=0.2, color='gray', label='Bollinger Bands')
|
||||||
|
|
||||||
|
# Plot other items
|
||||||
|
for plot_type, plot_name, column in plot_items:
|
||||||
|
if plot_type == 'area' and ('bb_upper' in plot_name or 'bb_lower' in plot_name):
|
||||||
|
continue # Already handled above
|
||||||
|
|
||||||
|
data = df[column].dropna() # Remove NaN values for cleaner plots
|
||||||
|
|
||||||
|
if plot_type == 'line':
|
||||||
|
color = None
|
||||||
|
linestyle = '-'
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
# Special styling for different line types
|
||||||
|
if 'overbought' in plot_name:
|
||||||
|
color = 'red'
|
||||||
|
linestyle = '--'
|
||||||
|
alpha = 0.7
|
||||||
|
elif 'oversold' in plot_name:
|
||||||
|
color = 'green'
|
||||||
|
linestyle = '--'
|
||||||
|
alpha = 0.7
|
||||||
|
elif 'stop_loss' in plot_name:
|
||||||
|
color = 'red'
|
||||||
|
linestyle = ':'
|
||||||
|
alpha = 0.8
|
||||||
|
elif 'take_profit' in plot_name:
|
||||||
|
color = 'green'
|
||||||
|
linestyle = ':'
|
||||||
|
alpha = 0.8
|
||||||
|
elif 'sma' in plot_name:
|
||||||
|
color = 'orange'
|
||||||
|
alpha = 0.8
|
||||||
|
elif 'volume_ma' in plot_name:
|
||||||
|
color = 'purple'
|
||||||
|
alpha = 0.7
|
||||||
|
|
||||||
|
ax.plot(data.index, data, label=plot_name.replace('_', ' ').title(),
|
||||||
|
color=color, linestyle=linestyle, alpha=alpha)
|
||||||
|
|
||||||
|
elif plot_type == 'scatter':
|
||||||
|
color = 'green' if 'buy' in plot_name else 'red' if 'sell' in plot_name else 'blue'
|
||||||
|
marker = '^' if 'buy' in plot_name else 'v' if 'sell' in plot_name else 'o'
|
||||||
|
size = 100 if 'buy' in plot_name or 'sell' in plot_name else 50
|
||||||
|
alpha = 0.8
|
||||||
|
zorder = 5
|
||||||
|
label_name = plot_name.replace('_', ' ').title()
|
||||||
|
|
||||||
|
# Special styling for different signal types
|
||||||
|
if 'actual_buy' in plot_name:
|
||||||
|
color = 'darkgreen'
|
||||||
|
marker = '^'
|
||||||
|
size = 120
|
||||||
|
alpha = 1.0
|
||||||
|
zorder = 10 # Higher z-order to appear on top
|
||||||
|
label_name = 'Actual Buy Trades'
|
||||||
|
elif 'actual_sell' in plot_name:
|
||||||
|
color = 'darkred'
|
||||||
|
marker = 'v'
|
||||||
|
size = 120
|
||||||
|
alpha = 1.0
|
||||||
|
zorder = 10 # Higher z-order to appear on top
|
||||||
|
label_name = 'Actual Sell Trades'
|
||||||
|
elif 'strategy_buy' in plot_name:
|
||||||
|
color = 'lightgreen'
|
||||||
|
marker = '^'
|
||||||
|
size = 60
|
||||||
|
alpha = 0.6
|
||||||
|
zorder = 3 # Lower z-order to appear behind actual trades
|
||||||
|
label_name = 'Strategy Buy Signals'
|
||||||
|
elif 'strategy_sell' in plot_name:
|
||||||
|
color = 'lightcoral'
|
||||||
|
marker = 'v'
|
||||||
|
size = 60
|
||||||
|
alpha = 0.6
|
||||||
|
zorder = 3 # Lower z-order to appear behind actual trades
|
||||||
|
label_name = 'Strategy Sell Signals'
|
||||||
|
|
||||||
|
ax.scatter(data.index, data, label=label_name,
|
||||||
|
color=color, marker=marker, s=size, alpha=alpha, zorder=zorder)
|
||||||
|
|
||||||
|
elif plot_type == 'area':
|
||||||
|
ax.fill_between(data.index, data, alpha=0.5, label=plot_name.replace('_', ' ').title())
|
||||||
|
|
||||||
|
elif plot_type == 'bar':
|
||||||
|
ax.bar(data.index, data, alpha=0.7, label=plot_name.replace('_', ' ').title())
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"Warning: Plot type '{plot_type}' not supported for column '{column}'")
|
||||||
|
|
||||||
|
# Customize subplot
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
# Set titles and labels
|
||||||
|
if plot_num == 1:
|
||||||
|
ax.set_title('Price Chart with Bollinger Bands and Signals')
|
||||||
|
ax.set_ylabel('Price')
|
||||||
|
elif plot_num == 2:
|
||||||
|
ax.set_title('RSI Indicator')
|
||||||
|
ax.set_ylabel('RSI')
|
||||||
|
ax.set_ylim(0, 100)
|
||||||
|
elif plot_num == 3:
|
||||||
|
ax.set_title('Volume')
|
||||||
|
ax.set_ylabel('Volume')
|
||||||
|
else:
|
||||||
|
ax.set_title(f'Plot {plot_num}')
|
||||||
|
|
||||||
|
# Set x-axis label only on the bottom subplot
|
||||||
|
axs[-1].set_xlabel('Time')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
7
cycles/market_fees.py
Normal file
7
cycles/market_fees.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
class MarketFees:
|
||||||
|
@staticmethod
|
||||||
|
def calculate_okx_taker_maker_fee(amount, is_maker=True) -> float:
|
||||||
|
fee_rate = 0.0008 if is_maker else 0.0010
|
||||||
|
return amount * fee_rate
|
||||||
42
cycles/strategies/__init__.py
Normal file
42
cycles/strategies/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""
|
||||||
|
Strategies Module
|
||||||
|
|
||||||
|
This module contains the strategy management system for trading strategies.
|
||||||
|
It provides a flexible framework for implementing, combining, and managing multiple trading strategies.
|
||||||
|
|
||||||
|
Components:
|
||||||
|
- StrategyBase: Abstract base class for all strategies
|
||||||
|
- DefaultStrategy: Meta-trend based strategy
|
||||||
|
- BBRSStrategy: Bollinger Bands + RSI strategy
|
||||||
|
- StrategyManager: Orchestrates multiple strategies
|
||||||
|
- StrategySignal: Represents trading signals with confidence levels
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from cycles.strategies import StrategyManager, create_strategy_manager
|
||||||
|
|
||||||
|
# Create strategy manager from config
|
||||||
|
strategy_manager = create_strategy_manager(config)
|
||||||
|
|
||||||
|
# Or create individual strategies
|
||||||
|
from cycles.strategies import DefaultStrategy, BBRSStrategy
|
||||||
|
default_strategy = DefaultStrategy(weight=1.0, params={})
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import StrategyBase, StrategySignal
|
||||||
|
from .default_strategy import DefaultStrategy
|
||||||
|
from .bbrs_strategy import BBRSStrategy
|
||||||
|
from .random_strategy import RandomStrategy
|
||||||
|
from .manager import StrategyManager, create_strategy_manager
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'StrategyBase',
|
||||||
|
'StrategySignal',
|
||||||
|
'DefaultStrategy',
|
||||||
|
'BBRSStrategy',
|
||||||
|
'RandomStrategy',
|
||||||
|
'StrategyManager',
|
||||||
|
'create_strategy_manager'
|
||||||
|
]
|
||||||
|
|
||||||
|
__version__ = '1.0.0'
|
||||||
|
__author__ = 'TCP Cycles Team'
|
||||||
250
cycles/strategies/base.py
Normal file
250
cycles/strategies/base.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
"""
|
||||||
|
Base classes for the strategy management system.
|
||||||
|
|
||||||
|
This module contains the fundamental building blocks for all trading strategies:
|
||||||
|
- StrategySignal: Represents trading signals with confidence and metadata
|
||||||
|
- StrategyBase: Abstract base class that all strategies must inherit from
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Optional, List, Union
|
||||||
|
|
||||||
|
|
||||||
|
class StrategySignal:
|
||||||
|
"""
|
||||||
|
Represents a trading signal from a strategy.
|
||||||
|
|
||||||
|
A signal encapsulates the strategy's recommendation along with confidence
|
||||||
|
level, optional price target, and additional metadata.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
signal_type (str): Type of signal - "ENTRY", "EXIT", or "HOLD"
|
||||||
|
confidence (float): Confidence level from 0.0 to 1.0
|
||||||
|
price (Optional[float]): Optional specific price for the signal
|
||||||
|
metadata (Dict): Additional signal data and context
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Entry signal with high confidence
|
||||||
|
signal = StrategySignal("ENTRY", confidence=0.8)
|
||||||
|
|
||||||
|
# Exit signal with stop loss price
|
||||||
|
signal = StrategySignal("EXIT", confidence=1.0, price=50000,
|
||||||
|
metadata={"type": "STOP_LOSS"})
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, signal_type: str, confidence: float = 1.0,
|
||||||
|
price: Optional[float] = None, metadata: Optional[Dict] = None):
|
||||||
|
"""
|
||||||
|
Initialize a strategy signal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
signal_type: Type of signal ("ENTRY", "EXIT", "HOLD")
|
||||||
|
confidence: Confidence level (0.0 to 1.0)
|
||||||
|
price: Optional specific price for the signal
|
||||||
|
metadata: Additional signal data and context
|
||||||
|
"""
|
||||||
|
self.signal_type = signal_type
|
||||||
|
self.confidence = max(0.0, min(1.0, confidence)) # Clamp to [0,1]
|
||||||
|
self.price = price
|
||||||
|
self.metadata = metadata or {}
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""String representation of the signal."""
|
||||||
|
return (f"StrategySignal(type={self.signal_type}, "
|
||||||
|
f"confidence={self.confidence:.2f}, "
|
||||||
|
f"price={self.price}, metadata={self.metadata})")
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyBase(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for all trading strategies.
|
||||||
|
|
||||||
|
This class defines the interface that all strategies must implement:
|
||||||
|
- get_timeframes(): Specify required timeframes for the strategy
|
||||||
|
- initialize(): Setup strategy with backtester data
|
||||||
|
- get_entry_signal(): Generate entry signals
|
||||||
|
- get_exit_signal(): Generate exit signals
|
||||||
|
- get_confidence(): Optional confidence calculation
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name (str): Strategy name
|
||||||
|
weight (float): Strategy weight for combination
|
||||||
|
params (Dict): Strategy parameters
|
||||||
|
initialized (bool): Whether strategy has been initialized
|
||||||
|
timeframes_data (Dict): Resampled data for different timeframes
|
||||||
|
|
||||||
|
Example:
|
||||||
|
class MyStrategy(StrategyBase):
|
||||||
|
def get_timeframes(self):
|
||||||
|
return ["15min"] # This strategy works on 15-minute data
|
||||||
|
|
||||||
|
def initialize(self, backtester):
|
||||||
|
# Setup strategy indicators using self.timeframes_data["15min"]
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def get_entry_signal(self, backtester, df_index):
|
||||||
|
# Return StrategySignal based on analysis
|
||||||
|
if should_enter:
|
||||||
|
return StrategySignal("ENTRY", confidence=0.7)
|
||||||
|
return StrategySignal("HOLD", confidence=0.0)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: str, weight: float = 1.0, params: Optional[Dict] = None):
|
||||||
|
"""
|
||||||
|
Initialize the strategy base.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Strategy name/identifier
|
||||||
|
weight: Strategy weight for combination (default: 1.0)
|
||||||
|
params: Strategy-specific parameters
|
||||||
|
"""
|
||||||
|
self.name = name
|
||||||
|
self.weight = weight
|
||||||
|
self.params = params or {}
|
||||||
|
self.initialized = False
|
||||||
|
self.timeframes_data = {} # Will store resampled data for each timeframe
|
||||||
|
|
||||||
|
def get_timeframes(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get the list of timeframes required by this strategy.
|
||||||
|
|
||||||
|
Override this method to specify which timeframes your strategy needs.
|
||||||
|
The base class will automatically resample the 1-minute data to these timeframes
|
||||||
|
and make them available in self.timeframes_data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of timeframe strings (e.g., ["1min", "15min", "1h"])
|
||||||
|
|
||||||
|
Example:
|
||||||
|
def get_timeframes(self):
|
||||||
|
return ["15min"] # Strategy needs 15-minute data
|
||||||
|
|
||||||
|
def get_timeframes(self):
|
||||||
|
return ["5min", "15min", "1h"] # Multi-timeframe strategy
|
||||||
|
"""
|
||||||
|
return ["1min"] # Default to 1-minute data
|
||||||
|
|
||||||
|
def _resample_data(self, original_data: pd.DataFrame) -> None:
|
||||||
|
"""
|
||||||
|
Resample the original 1-minute data to all required timeframes.
|
||||||
|
|
||||||
|
This method is called automatically during initialization to create
|
||||||
|
resampled versions of the data for each timeframe the strategy needs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_data: Original 1-minute OHLCV data with DatetimeIndex
|
||||||
|
"""
|
||||||
|
self.timeframes_data = {}
|
||||||
|
|
||||||
|
for timeframe in self.get_timeframes():
|
||||||
|
if timeframe == "1min":
|
||||||
|
# For 1-minute data, just use the original
|
||||||
|
self.timeframes_data[timeframe] = original_data.copy()
|
||||||
|
else:
|
||||||
|
# Resample to the specified timeframe
|
||||||
|
resampled = original_data.resample(timeframe).agg({
|
||||||
|
'open': 'first',
|
||||||
|
'high': 'max',
|
||||||
|
'low': 'min',
|
||||||
|
'close': 'last',
|
||||||
|
'volume': 'sum'
|
||||||
|
}).dropna()
|
||||||
|
|
||||||
|
self.timeframes_data[timeframe] = resampled
|
||||||
|
|
||||||
|
def get_data_for_timeframe(self, timeframe: str) -> Optional[pd.DataFrame]:
|
||||||
|
"""
|
||||||
|
Get resampled data for a specific timeframe.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeframe: Timeframe string (e.g., "15min", "1h")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: Resampled OHLCV data or None if timeframe not available
|
||||||
|
"""
|
||||||
|
return self.timeframes_data.get(timeframe)
|
||||||
|
|
||||||
|
def get_primary_timeframe_data(self) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Get data for the primary (first) timeframe.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: Data for the first timeframe in get_timeframes() list
|
||||||
|
"""
|
||||||
|
primary_timeframe = self.get_timeframes()[0]
|
||||||
|
return self.timeframes_data[primary_timeframe]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def initialize(self, backtester) -> None:
|
||||||
|
"""
|
||||||
|
Initialize strategy with backtester data.
|
||||||
|
|
||||||
|
This method is called once before backtesting begins.
|
||||||
|
The original 1-minute data will already be resampled to all required timeframes
|
||||||
|
and available in self.timeframes_data.
|
||||||
|
|
||||||
|
Strategies should setup indicators, validate data, and
|
||||||
|
set self.initialized = True when complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with data and configuration
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_entry_signal(self, backtester, df_index: int) -> StrategySignal:
|
||||||
|
"""
|
||||||
|
Generate entry signal for the given data index.
|
||||||
|
|
||||||
|
The df_index refers to the index in the backtester's working dataframe,
|
||||||
|
which corresponds to the primary timeframe data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with current state
|
||||||
|
df_index: Current index in the primary timeframe dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StrategySignal: Entry signal with confidence level
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_exit_signal(self, backtester, df_index: int) -> StrategySignal:
|
||||||
|
"""
|
||||||
|
Generate exit signal for the given data index.
|
||||||
|
|
||||||
|
The df_index refers to the index in the backtester's working dataframe,
|
||||||
|
which corresponds to the primary timeframe data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with current state
|
||||||
|
df_index: Current index in the primary timeframe dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StrategySignal: Exit signal with confidence level
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_confidence(self, backtester, df_index: int) -> float:
|
||||||
|
"""
|
||||||
|
Get strategy confidence for the current market state.
|
||||||
|
|
||||||
|
Default implementation returns 1.0. Strategies can override
|
||||||
|
this to provide dynamic confidence based on market conditions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with current state
|
||||||
|
df_index: Current index in the primary timeframe dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Confidence level (0.0 to 1.0)
|
||||||
|
"""
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""String representation of the strategy."""
|
||||||
|
timeframes = self.get_timeframes()
|
||||||
|
return (f"{self.__class__.__name__}(name={self.name}, "
|
||||||
|
f"weight={self.weight}, timeframes={timeframes}, "
|
||||||
|
f"initialized={self.initialized})")
|
||||||
344
cycles/strategies/bbrs_strategy.py
Normal file
344
cycles/strategies/bbrs_strategy.py
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
"""
|
||||||
|
Bollinger Bands + RSI Strategy (BBRS)
|
||||||
|
|
||||||
|
This module implements a sophisticated trading strategy that combines Bollinger Bands
|
||||||
|
and RSI indicators with market regime detection. The strategy adapts its parameters
|
||||||
|
based on whether the market is trending or moving sideways.
|
||||||
|
|
||||||
|
Key Features:
|
||||||
|
- Dynamic parameter adjustment based on market regime
|
||||||
|
- Bollinger Band squeeze detection
|
||||||
|
- RSI overbought/oversold conditions
|
||||||
|
- Market regime-specific thresholds
|
||||||
|
- Multi-timeframe analysis support
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
from typing import Tuple, Optional, List
|
||||||
|
|
||||||
|
from .base import StrategyBase, StrategySignal
|
||||||
|
|
||||||
|
|
||||||
|
class BBRSStrategy(StrategyBase):
|
||||||
|
"""
|
||||||
|
Bollinger Bands + RSI Strategy implementation.
|
||||||
|
|
||||||
|
This strategy uses Bollinger Bands and RSI indicators with market regime detection
|
||||||
|
to generate trading signals. It adapts its parameters based on whether the market
|
||||||
|
is in a trending or sideways regime.
|
||||||
|
|
||||||
|
The strategy works with 1-minute data as input and lets the underlying Strategy class
|
||||||
|
handle internal resampling to the timeframes it needs (typically 15min and 1h).
|
||||||
|
Stop-loss execution uses 1-minute precision.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
bb_width (float): Bollinger Band width threshold (default: 0.05)
|
||||||
|
bb_period (int): Bollinger Band period (default: 20)
|
||||||
|
rsi_period (int): RSI calculation period (default: 14)
|
||||||
|
trending_rsi_threshold (list): RSI thresholds for trending market [low, high]
|
||||||
|
trending_bb_multiplier (float): BB multiplier for trending market
|
||||||
|
sideways_rsi_threshold (list): RSI thresholds for sideways market [low, high]
|
||||||
|
sideways_bb_multiplier (float): BB multiplier for sideways market
|
||||||
|
strategy_name (str): Strategy implementation name ("MarketRegimeStrategy" or "CryptoTradingStrategy")
|
||||||
|
SqueezeStrategy (bool): Enable squeeze strategy
|
||||||
|
stop_loss_pct (float): Stop loss percentage (default: 0.05)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
params = {
|
||||||
|
"bb_width": 0.05,
|
||||||
|
"bb_period": 20,
|
||||||
|
"rsi_period": 14,
|
||||||
|
"strategy_name": "MarketRegimeStrategy",
|
||||||
|
"SqueezeStrategy": true
|
||||||
|
}
|
||||||
|
strategy = BBRSStrategy(weight=1.0, params=params)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, weight: float = 1.0, params: Optional[dict] = None):
|
||||||
|
"""
|
||||||
|
Initialize the BBRS strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight: Strategy weight for combination (default: 1.0)
|
||||||
|
params: Strategy parameters for Bollinger Bands and RSI
|
||||||
|
"""
|
||||||
|
super().__init__("bbrs", weight, params)
|
||||||
|
|
||||||
|
def get_timeframes(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get the timeframes required by the BBRS strategy.
|
||||||
|
|
||||||
|
BBRS strategy uses 1-minute data as input and lets the Strategy class
|
||||||
|
handle internal resampling to the timeframes it needs (15min, 1h, etc.).
|
||||||
|
We still include 1min for stop-loss precision.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of timeframes needed for the strategy
|
||||||
|
"""
|
||||||
|
# BBRS strategy works with 1-minute data and lets Strategy class handle resampling
|
||||||
|
return ["1min"]
|
||||||
|
|
||||||
|
def initialize(self, backtester) -> None:
|
||||||
|
"""
|
||||||
|
Initialize BBRS strategy with signal processing.
|
||||||
|
|
||||||
|
Sets up the strategy by:
|
||||||
|
1. Using 1-minute data directly (Strategy class handles internal resampling)
|
||||||
|
2. Running the BBRS strategy processing on 1-minute data
|
||||||
|
3. Creating signals aligned with backtester expectations
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with OHLCV data
|
||||||
|
"""
|
||||||
|
# Resample to get 1-minute data (which should be the original data)
|
||||||
|
self._resample_data(backtester.original_df)
|
||||||
|
|
||||||
|
# Get 1-minute data for strategy processing - Strategy class will handle internal resampling
|
||||||
|
min1_data = self.get_data_for_timeframe("1min")
|
||||||
|
|
||||||
|
# Initialize empty signal series for backtester compatibility
|
||||||
|
# Note: These will be populated after strategy processing
|
||||||
|
backtester.strategies["buy_signals"] = pd.Series(False, index=range(len(min1_data)))
|
||||||
|
backtester.strategies["sell_signals"] = pd.Series(False, index=range(len(min1_data)))
|
||||||
|
backtester.strategies["stop_loss_pct"] = self.params.get("stop_loss_pct", 0.05)
|
||||||
|
backtester.strategies["primary_timeframe"] = "1min"
|
||||||
|
|
||||||
|
# Run strategy processing on 1-minute data
|
||||||
|
self._run_strategy_processing(backtester)
|
||||||
|
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def _run_strategy_processing(self, backtester) -> None:
|
||||||
|
"""
|
||||||
|
Run the actual BBRS strategy processing.
|
||||||
|
|
||||||
|
Uses the Strategy class from cycles.Analysis.strategies to process
|
||||||
|
the 1-minute data. The Strategy class will handle internal resampling
|
||||||
|
to the timeframes it needs (15min, 1h, etc.) and generate buy/sell signals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with timeframes_data available
|
||||||
|
"""
|
||||||
|
from cycles.Analysis.bb_rsi import BollingerBandsStrategy
|
||||||
|
|
||||||
|
# Get 1-minute data for strategy processing - let Strategy class handle resampling
|
||||||
|
strategy_data = self.get_data_for_timeframe("1min")
|
||||||
|
|
||||||
|
# Configure strategy parameters with defaults
|
||||||
|
config_strategy = {
|
||||||
|
"bb_width": self.params.get("bb_width", 0.05),
|
||||||
|
"bb_period": self.params.get("bb_period", 20),
|
||||||
|
"rsi_period": self.params.get("rsi_period", 14),
|
||||||
|
"trending": {
|
||||||
|
"rsi_threshold": self.params.get("trending_rsi_threshold", [30, 70]),
|
||||||
|
"bb_std_dev_multiplier": self.params.get("trending_bb_multiplier", 2.5),
|
||||||
|
},
|
||||||
|
"sideways": {
|
||||||
|
"rsi_threshold": self.params.get("sideways_rsi_threshold", [40, 60]),
|
||||||
|
"bb_std_dev_multiplier": self.params.get("sideways_bb_multiplier", 1.8),
|
||||||
|
},
|
||||||
|
"strategy_name": self.params.get("strategy_name", "MarketRegimeStrategy"),
|
||||||
|
"SqueezeStrategy": self.params.get("SqueezeStrategy", True)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run strategy processing on 1-minute data - Strategy class handles internal resampling
|
||||||
|
strategy = BollingerBandsStrategy(config=config_strategy, logging=logging)
|
||||||
|
processed_data = strategy.run(strategy_data, config_strategy["strategy_name"])
|
||||||
|
|
||||||
|
# Store processed data for plotting and analysis
|
||||||
|
backtester.processed_data = processed_data
|
||||||
|
|
||||||
|
if processed_data.empty:
|
||||||
|
# If strategy processing failed, keep empty signals
|
||||||
|
return
|
||||||
|
|
||||||
|
# Extract signals from processed data
|
||||||
|
buy_signals_raw = processed_data.get('BuySignal', pd.Series(False, index=processed_data.index)).astype(bool)
|
||||||
|
sell_signals_raw = processed_data.get('SellSignal', pd.Series(False, index=processed_data.index)).astype(bool)
|
||||||
|
|
||||||
|
# The processed_data will be on whatever timeframe the Strategy class outputs
|
||||||
|
# We need to map these signals back to 1-minute resolution for backtesting
|
||||||
|
original_1min_data = self.get_data_for_timeframe("1min")
|
||||||
|
|
||||||
|
# Reindex signals to 1-minute resolution using forward-fill
|
||||||
|
buy_signals_1min = buy_signals_raw.reindex(original_1min_data.index, method='ffill').fillna(False)
|
||||||
|
sell_signals_1min = sell_signals_raw.reindex(original_1min_data.index, method='ffill').fillna(False)
|
||||||
|
|
||||||
|
# Convert to integer index to match backtester expectations
|
||||||
|
backtester.strategies["buy_signals"] = pd.Series(buy_signals_1min.values, index=range(len(buy_signals_1min)))
|
||||||
|
backtester.strategies["sell_signals"] = pd.Series(sell_signals_1min.values, index=range(len(sell_signals_1min)))
|
||||||
|
|
||||||
|
def get_entry_signal(self, backtester, df_index: int) -> StrategySignal:
|
||||||
|
"""
|
||||||
|
Generate entry signal based on BBRS buy signals.
|
||||||
|
|
||||||
|
Entry occurs when the BBRS strategy processing has generated
|
||||||
|
a buy signal based on Bollinger Bands and RSI conditions on
|
||||||
|
the primary timeframe.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with current state
|
||||||
|
df_index: Current index in the primary timeframe dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StrategySignal: Entry signal if buy condition met, hold otherwise
|
||||||
|
"""
|
||||||
|
if not self.initialized:
|
||||||
|
return StrategySignal("HOLD", confidence=0.0)
|
||||||
|
|
||||||
|
if df_index >= len(backtester.strategies["buy_signals"]):
|
||||||
|
return StrategySignal("HOLD", confidence=0.0)
|
||||||
|
|
||||||
|
if backtester.strategies["buy_signals"].iloc[df_index]:
|
||||||
|
# High confidence for BBRS buy signals
|
||||||
|
confidence = self._calculate_signal_confidence(backtester, df_index, "entry")
|
||||||
|
return StrategySignal("ENTRY", confidence=confidence)
|
||||||
|
|
||||||
|
return StrategySignal("HOLD", confidence=0.0)
|
||||||
|
|
||||||
|
def get_exit_signal(self, backtester, df_index: int) -> StrategySignal:
|
||||||
|
"""
|
||||||
|
Generate exit signal based on BBRS sell signals or stop loss.
|
||||||
|
|
||||||
|
Exit occurs when:
|
||||||
|
1. BBRS strategy generates a sell signal
|
||||||
|
2. Stop loss is triggered based on price movement
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with current state
|
||||||
|
df_index: Current index in the primary timeframe dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StrategySignal: Exit signal with type and price, or hold signal
|
||||||
|
"""
|
||||||
|
if not self.initialized:
|
||||||
|
return StrategySignal("HOLD", confidence=0.0)
|
||||||
|
|
||||||
|
if df_index >= len(backtester.strategies["sell_signals"]):
|
||||||
|
return StrategySignal("HOLD", confidence=0.0)
|
||||||
|
|
||||||
|
# Check for sell signal
|
||||||
|
if backtester.strategies["sell_signals"].iloc[df_index]:
|
||||||
|
confidence = self._calculate_signal_confidence(backtester, df_index, "exit")
|
||||||
|
return StrategySignal("EXIT", confidence=confidence,
|
||||||
|
metadata={"type": "SELL_SIGNAL"})
|
||||||
|
|
||||||
|
# Check for stop loss using 1-minute data for precision
|
||||||
|
stop_loss_result, sell_price = self._check_stop_loss(backtester)
|
||||||
|
if stop_loss_result:
|
||||||
|
return StrategySignal("EXIT", confidence=1.0, price=sell_price,
|
||||||
|
metadata={"type": "STOP_LOSS"})
|
||||||
|
|
||||||
|
return StrategySignal("HOLD", confidence=0.0)
|
||||||
|
|
||||||
|
def get_confidence(self, backtester, df_index: int) -> float:
|
||||||
|
"""
|
||||||
|
Get strategy confidence based on signal strength and market conditions.
|
||||||
|
|
||||||
|
Confidence can be enhanced by analyzing multiple timeframes and
|
||||||
|
market regime consistency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with current state
|
||||||
|
df_index: Current index in the primary timeframe dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Confidence level (0.0 to 1.0)
|
||||||
|
"""
|
||||||
|
if not self.initialized:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Check for active signals
|
||||||
|
has_buy_signal = (df_index < len(backtester.strategies["buy_signals"]) and
|
||||||
|
backtester.strategies["buy_signals"].iloc[df_index])
|
||||||
|
has_sell_signal = (df_index < len(backtester.strategies["sell_signals"]) and
|
||||||
|
backtester.strategies["sell_signals"].iloc[df_index])
|
||||||
|
|
||||||
|
if has_buy_signal or has_sell_signal:
|
||||||
|
signal_type = "entry" if has_buy_signal else "exit"
|
||||||
|
return self._calculate_signal_confidence(backtester, df_index, signal_type)
|
||||||
|
|
||||||
|
# Moderate confidence during neutral periods
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
def _calculate_signal_confidence(self, backtester, df_index: int, signal_type: str) -> float:
|
||||||
|
"""
|
||||||
|
Calculate confidence level for a signal based on multiple factors.
|
||||||
|
|
||||||
|
Can consider multiple timeframes, market regime, volatility, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance
|
||||||
|
df_index: Current index
|
||||||
|
signal_type: "entry" or "exit"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Confidence level (0.0 to 1.0)
|
||||||
|
"""
|
||||||
|
base_confidence = 1.0
|
||||||
|
|
||||||
|
# TODO: Implement multi-timeframe confirmation
|
||||||
|
# For now, return high confidence for primary signals
|
||||||
|
# Future enhancements could include:
|
||||||
|
# - Checking confirmation from additional timeframes
|
||||||
|
# - Analyzing market regime consistency
|
||||||
|
# - Considering volatility levels
|
||||||
|
# - RSI and BB position analysis
|
||||||
|
|
||||||
|
return base_confidence
|
||||||
|
|
||||||
|
def _check_stop_loss(self, backtester) -> Tuple[bool, Optional[float]]:
|
||||||
|
"""
|
||||||
|
Check if stop loss is triggered using 1-minute data for precision.
|
||||||
|
|
||||||
|
Uses 1-minute data regardless of primary timeframe to ensure
|
||||||
|
accurate stop loss execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with current trade state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, Optional[float]]: (stop_loss_triggered, sell_price)
|
||||||
|
"""
|
||||||
|
# Calculate stop loss price
|
||||||
|
stop_price = backtester.entry_price * (1 - backtester.strategies["stop_loss_pct"])
|
||||||
|
|
||||||
|
# Use 1-minute data for precise stop loss checking
|
||||||
|
min1_data = self.get_data_for_timeframe("1min")
|
||||||
|
if min1_data is None:
|
||||||
|
# Fallback to original_df if 1min timeframe not available
|
||||||
|
min1_data = backtester.original_df if hasattr(backtester, 'original_df') else backtester.min1_df
|
||||||
|
|
||||||
|
min1_index = min1_data.index
|
||||||
|
|
||||||
|
# Find data range from entry to current time
|
||||||
|
start_candidates = min1_index[min1_index >= backtester.entry_time]
|
||||||
|
if len(start_candidates) == 0:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
backtester.current_trade_min1_start_idx = start_candidates[0]
|
||||||
|
end_candidates = min1_index[min1_index <= backtester.current_date]
|
||||||
|
|
||||||
|
if len(end_candidates) == 0:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
backtester.current_min1_end_idx = end_candidates[-1]
|
||||||
|
|
||||||
|
# Check if any candle in the range triggered stop loss
|
||||||
|
min1_slice = min1_data.loc[backtester.current_trade_min1_start_idx:backtester.current_min1_end_idx]
|
||||||
|
|
||||||
|
if (min1_slice['low'] <= stop_price).any():
|
||||||
|
# Find the first candle that triggered stop loss
|
||||||
|
stop_candle = min1_slice[min1_slice['low'] <= stop_price].iloc[0]
|
||||||
|
|
||||||
|
# Use open price if it gapped below stop, otherwise use stop price
|
||||||
|
if stop_candle['open'] < stop_price:
|
||||||
|
sell_price = stop_candle['open']
|
||||||
|
else:
|
||||||
|
sell_price = stop_price
|
||||||
|
|
||||||
|
return True, sell_price
|
||||||
|
|
||||||
|
return False, None
|
||||||
254
cycles/strategies/default_strategy.py
Normal file
254
cycles/strategies/default_strategy.py
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
"""
|
||||||
|
Default Meta-Trend Strategy
|
||||||
|
|
||||||
|
This module implements the default trading strategy based on meta-trend analysis
|
||||||
|
using multiple Supertrend indicators. The strategy enters when trends align
|
||||||
|
and exits on trend reversal or stop loss.
|
||||||
|
|
||||||
|
The meta-trend is calculated by comparing three Supertrend indicators:
|
||||||
|
- Entry: When meta-trend changes from != 1 to == 1
|
||||||
|
- Exit: When meta-trend changes to -1 or stop loss is triggered
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from typing import Tuple, Optional, List
|
||||||
|
|
||||||
|
from .base import StrategyBase, StrategySignal
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultStrategy(StrategyBase):
|
||||||
|
"""
|
||||||
|
Default meta-trend strategy implementation.
|
||||||
|
|
||||||
|
This strategy uses multiple Supertrend indicators to determine market direction.
|
||||||
|
It generates entry signals when all three Supertrend indicators align in an
|
||||||
|
upward direction, and exit signals when they reverse or stop loss is triggered.
|
||||||
|
|
||||||
|
The strategy works best on 15-minute timeframes but can be configured for other timeframes.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
stop_loss_pct (float): Stop loss percentage (default: 0.03)
|
||||||
|
timeframe (str): Preferred timeframe for analysis (default: "15min")
|
||||||
|
|
||||||
|
Example:
|
||||||
|
strategy = DefaultStrategy(weight=1.0, params={"stop_loss_pct": 0.05, "timeframe": "15min"})
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, weight: float = 1.0, params: Optional[dict] = None):
|
||||||
|
"""
|
||||||
|
Initialize the default strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight: Strategy weight for combination (default: 1.0)
|
||||||
|
params: Strategy parameters including stop_loss_pct and timeframe
|
||||||
|
"""
|
||||||
|
super().__init__("default", weight, params)
|
||||||
|
|
||||||
|
def get_timeframes(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get the timeframes required by the default strategy.
|
||||||
|
|
||||||
|
The default strategy works on a single timeframe (typically 15min)
|
||||||
|
but also needs 1min data for precise stop-loss execution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List containing primary timeframe and 1min for stop-loss
|
||||||
|
"""
|
||||||
|
primary_timeframe = self.params.get("timeframe", "15min")
|
||||||
|
|
||||||
|
# Always include 1min for stop-loss precision, avoid duplicates
|
||||||
|
timeframes = [primary_timeframe]
|
||||||
|
if primary_timeframe != "1min":
|
||||||
|
timeframes.append("1min")
|
||||||
|
|
||||||
|
return timeframes
|
||||||
|
|
||||||
|
def initialize(self, backtester) -> None:
|
||||||
|
"""
|
||||||
|
Initialize meta trend calculation using Supertrend indicators.
|
||||||
|
|
||||||
|
Calculates the meta-trend by comparing three Supertrend indicators.
|
||||||
|
When all three agree on direction, meta-trend follows that direction.
|
||||||
|
Otherwise, meta-trend is neutral (0).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with OHLCV data
|
||||||
|
"""
|
||||||
|
from cycles.Analysis.supertrend import Supertrends
|
||||||
|
|
||||||
|
# First, resample the original 1-minute data to required timeframes
|
||||||
|
self._resample_data(backtester.original_df)
|
||||||
|
|
||||||
|
# Get the primary timeframe data for strategy calculations
|
||||||
|
primary_timeframe = self.get_timeframes()[0]
|
||||||
|
strategy_data = self.get_data_for_timeframe(primary_timeframe)
|
||||||
|
|
||||||
|
# Calculate Supertrend indicators on the primary timeframe
|
||||||
|
supertrends = Supertrends(strategy_data, verbose=False)
|
||||||
|
supertrend_results_list = supertrends.calculate_supertrend_indicators()
|
||||||
|
|
||||||
|
# Extract trend arrays from each Supertrend
|
||||||
|
trends = [st['results']['trend'] for st in supertrend_results_list]
|
||||||
|
trends_arr = np.stack(trends, axis=1)
|
||||||
|
|
||||||
|
# Calculate meta-trend: all three must agree for direction signal
|
||||||
|
meta_trend = np.where(
|
||||||
|
(trends_arr[:,0] == trends_arr[:,1]) & (trends_arr[:,1] == trends_arr[:,2]),
|
||||||
|
trends_arr[:,0],
|
||||||
|
0 # Neutral when trends don't agree
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store in backtester for access during trading
|
||||||
|
# Note: backtester.df should now be using our primary timeframe
|
||||||
|
backtester.strategies["meta_trend"] = meta_trend
|
||||||
|
backtester.strategies["stop_loss_pct"] = self.params.get("stop_loss_pct", 0.03)
|
||||||
|
backtester.strategies["primary_timeframe"] = primary_timeframe
|
||||||
|
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def get_entry_signal(self, backtester, df_index: int) -> StrategySignal:
|
||||||
|
"""
|
||||||
|
Generate entry signal based on meta-trend direction change.
|
||||||
|
|
||||||
|
Entry occurs when meta-trend changes from != 1 to == 1, indicating
|
||||||
|
all Supertrend indicators now agree on upward direction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with current state
|
||||||
|
df_index: Current index in the primary timeframe dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StrategySignal: Entry signal if trend aligns, hold signal otherwise
|
||||||
|
"""
|
||||||
|
if not self.initialized:
|
||||||
|
return StrategySignal("HOLD", 0.0)
|
||||||
|
|
||||||
|
if df_index < 2: # shifting one index to prevent lookahead bias
|
||||||
|
return StrategySignal("HOLD", 0.0)
|
||||||
|
|
||||||
|
# Check for meta-trend entry condition
|
||||||
|
prev_trend = backtester.strategies["meta_trend"][df_index - 2]
|
||||||
|
curr_trend = backtester.strategies["meta_trend"][df_index - 1]
|
||||||
|
|
||||||
|
if prev_trend != 1 and curr_trend == 1:
|
||||||
|
# Strong confidence when all indicators align for entry
|
||||||
|
return StrategySignal("ENTRY", confidence=1.0)
|
||||||
|
|
||||||
|
return StrategySignal("HOLD", confidence=0.0)
|
||||||
|
|
||||||
|
def get_exit_signal(self, backtester, df_index: int) -> StrategySignal:
|
||||||
|
"""
|
||||||
|
Generate exit signal based on meta-trend reversal or stop loss.
|
||||||
|
|
||||||
|
Exit occurs when:
|
||||||
|
1. Meta-trend changes to -1 (trend reversal)
|
||||||
|
2. Stop loss is triggered based on price movement
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with current state
|
||||||
|
df_index: Current index in the primary timeframe dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StrategySignal: Exit signal with type and price, or hold signal
|
||||||
|
"""
|
||||||
|
if not self.initialized:
|
||||||
|
return StrategySignal("HOLD", 0.0)
|
||||||
|
|
||||||
|
if df_index < 1:
|
||||||
|
return StrategySignal("HOLD", 0.0)
|
||||||
|
|
||||||
|
# Check for meta-trend exit signal
|
||||||
|
prev_trend = backtester.strategies["meta_trend"][df_index - 1]
|
||||||
|
curr_trend = backtester.strategies["meta_trend"][df_index]
|
||||||
|
|
||||||
|
if prev_trend != 1 and curr_trend == -1:
|
||||||
|
return StrategySignal("EXIT", confidence=1.0,
|
||||||
|
metadata={"type": "META_TREND_EXIT_SIGNAL"})
|
||||||
|
|
||||||
|
# Check for stop loss using 1-minute data for precision
|
||||||
|
stop_loss_result, sell_price = self._check_stop_loss(backtester)
|
||||||
|
if stop_loss_result:
|
||||||
|
return StrategySignal("EXIT", confidence=1.0, price=sell_price,
|
||||||
|
metadata={"type": "STOP_LOSS"})
|
||||||
|
|
||||||
|
return StrategySignal("HOLD", confidence=0.0)
|
||||||
|
|
||||||
|
def get_confidence(self, backtester, df_index: int) -> float:
|
||||||
|
"""
|
||||||
|
Get strategy confidence based on meta-trend strength.
|
||||||
|
|
||||||
|
Higher confidence when meta-trend is strongly directional,
|
||||||
|
lower confidence during neutral periods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with current state
|
||||||
|
df_index: Current index in the primary timeframe dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Confidence level (0.0 to 1.0)
|
||||||
|
"""
|
||||||
|
if not self.initialized or df_index >= len(backtester.strategies["meta_trend"]):
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
curr_trend = backtester.strategies["meta_trend"][df_index]
|
||||||
|
|
||||||
|
# High confidence for strong directional signals
|
||||||
|
if curr_trend == 1 or curr_trend == -1:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
# Low confidence for neutral trend
|
||||||
|
return 0.3
|
||||||
|
|
||||||
|
def _check_stop_loss(self, backtester) -> Tuple[bool, Optional[float]]:
|
||||||
|
"""
|
||||||
|
Check if stop loss is triggered based on price movement.
|
||||||
|
|
||||||
|
Uses 1-minute data for precise stop loss checking regardless of
|
||||||
|
the primary timeframe used for strategy signals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with current trade state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, Optional[float]]: (stop_loss_triggered, sell_price)
|
||||||
|
"""
|
||||||
|
# Calculate stop loss price
|
||||||
|
stop_price = backtester.entry_price * (1 - backtester.strategies["stop_loss_pct"])
|
||||||
|
|
||||||
|
# Use 1-minute data for precise stop loss checking
|
||||||
|
min1_data = self.get_data_for_timeframe("1min")
|
||||||
|
if min1_data is None:
|
||||||
|
# Fallback to original_df if 1min timeframe not available
|
||||||
|
min1_data = backtester.original_df if hasattr(backtester, 'original_df') else backtester.min1_df
|
||||||
|
|
||||||
|
min1_index = min1_data.index
|
||||||
|
|
||||||
|
# Find data range from entry to current time
|
||||||
|
start_candidates = min1_index[min1_index >= backtester.entry_time]
|
||||||
|
if len(start_candidates) == 0:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
backtester.current_trade_min1_start_idx = start_candidates[0]
|
||||||
|
end_candidates = min1_index[min1_index <= backtester.current_date]
|
||||||
|
|
||||||
|
if len(end_candidates) == 0:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
backtester.current_min1_end_idx = end_candidates[-1]
|
||||||
|
|
||||||
|
# Check if any candle in the range triggered stop loss
|
||||||
|
min1_slice = min1_data.loc[backtester.current_trade_min1_start_idx:backtester.current_min1_end_idx]
|
||||||
|
|
||||||
|
if (min1_slice['low'] <= stop_price).any():
|
||||||
|
# Find the first candle that triggered stop loss
|
||||||
|
stop_candle = min1_slice[min1_slice['low'] <= stop_price].iloc[0]
|
||||||
|
|
||||||
|
# Use open price if it gapped below stop, otherwise use stop price
|
||||||
|
if stop_candle['open'] < stop_price:
|
||||||
|
sell_price = stop_candle['open']
|
||||||
|
else:
|
||||||
|
sell_price = stop_price
|
||||||
|
|
||||||
|
return True, sell_price
|
||||||
|
|
||||||
|
return False, None
|
||||||
397
cycles/strategies/manager.py
Normal file
397
cycles/strategies/manager.py
Normal file
@@ -0,0 +1,397 @@
|
|||||||
|
"""
|
||||||
|
Strategy Manager
|
||||||
|
|
||||||
|
This module contains the StrategyManager class that orchestrates multiple trading strategies
|
||||||
|
and combines their signals using configurable aggregation rules.
|
||||||
|
|
||||||
|
The StrategyManager supports various combination methods for entry and exit signals:
|
||||||
|
- Entry: any, all, majority, weighted_consensus
|
||||||
|
- Exit: any, all, priority (with stop loss prioritization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List, Tuple, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .base import StrategyBase, StrategySignal
|
||||||
|
from .default_strategy import DefaultStrategy
|
||||||
|
from .bbrs_strategy import BBRSStrategy
|
||||||
|
from .random_strategy import RandomStrategy
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyManager:
|
||||||
|
"""
|
||||||
|
Manages multiple strategies and combines their signals.
|
||||||
|
|
||||||
|
The StrategyManager loads multiple strategies from configuration,
|
||||||
|
initializes them with backtester data, and combines their signals
|
||||||
|
using configurable aggregation rules.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
strategies (List[StrategyBase]): List of loaded strategies
|
||||||
|
combination_rules (Dict): Rules for combining signals
|
||||||
|
initialized (bool): Whether manager has been initialized
|
||||||
|
|
||||||
|
Example:
|
||||||
|
config = {
|
||||||
|
"strategies": [
|
||||||
|
{"name": "default", "weight": 0.6, "params": {}},
|
||||||
|
{"name": "bbrs", "weight": 0.4, "params": {"bb_width": 0.05}}
|
||||||
|
],
|
||||||
|
"combination_rules": {
|
||||||
|
"entry": "weighted_consensus",
|
||||||
|
"exit": "any",
|
||||||
|
"min_confidence": 0.6
|
||||||
|
}
|
||||||
|
}
|
||||||
|
manager = StrategyManager(config["strategies"], config["combination_rules"])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, strategies_config: List[Dict], combination_rules: Optional[Dict] = None):
|
||||||
|
"""
|
||||||
|
Initialize the strategy manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategies_config: List of strategy configurations
|
||||||
|
combination_rules: Rules for combining signals
|
||||||
|
"""
|
||||||
|
self.strategies = self._load_strategies(strategies_config)
|
||||||
|
self.combination_rules = combination_rules or {
|
||||||
|
"entry": "weighted_consensus",
|
||||||
|
"exit": "any",
|
||||||
|
"min_confidence": 0.5
|
||||||
|
}
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
def _load_strategies(self, strategies_config: List[Dict]) -> List[StrategyBase]:
|
||||||
|
"""
|
||||||
|
Load strategies from configuration.
|
||||||
|
|
||||||
|
Creates strategy instances based on configuration and registers
|
||||||
|
them with the manager. Supports extensible strategy registration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategies_config: List of strategy configurations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[StrategyBase]: List of instantiated strategies
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If unknown strategy name is specified
|
||||||
|
"""
|
||||||
|
strategies = []
|
||||||
|
|
||||||
|
for config in strategies_config:
|
||||||
|
name = config.get("name", "").lower()
|
||||||
|
weight = config.get("weight", 1.0)
|
||||||
|
params = config.get("params", {})
|
||||||
|
|
||||||
|
if name == "default":
|
||||||
|
strategies.append(DefaultStrategy(weight, params))
|
||||||
|
elif name == "bbrs":
|
||||||
|
strategies.append(BBRSStrategy(weight, params))
|
||||||
|
elif name == "random":
|
||||||
|
strategies.append(RandomStrategy(weight, params))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown strategy: {name}. "
|
||||||
|
f"Available strategies: default, bbrs, random")
|
||||||
|
|
||||||
|
return strategies
|
||||||
|
|
||||||
|
def initialize(self, backtester) -> None:
|
||||||
|
"""
|
||||||
|
Initialize all strategies with backtester data.
|
||||||
|
|
||||||
|
Calls the initialize method on each strategy, allowing them
|
||||||
|
to set up indicators, validate data, and prepare for trading.
|
||||||
|
Each strategy will handle its own timeframe resampling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with OHLCV data
|
||||||
|
"""
|
||||||
|
for strategy in self.strategies:
|
||||||
|
try:
|
||||||
|
strategy.initialize(backtester)
|
||||||
|
|
||||||
|
# Log strategy timeframe information
|
||||||
|
timeframes = strategy.get_timeframes()
|
||||||
|
logging.info(f"Initialized strategy: {strategy.name} with timeframes: {timeframes}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to initialize strategy {strategy.name}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
self.initialized = True
|
||||||
|
logging.info(f"Strategy manager initialized with {len(self.strategies)} strategies")
|
||||||
|
|
||||||
|
# Log summary of all timeframes being used
|
||||||
|
all_timeframes = set()
|
||||||
|
for strategy in self.strategies:
|
||||||
|
all_timeframes.update(strategy.get_timeframes())
|
||||||
|
logging.info(f"Total unique timeframes in use: {sorted(all_timeframes)}")
|
||||||
|
|
||||||
|
def get_entry_signal(self, backtester, df_index: int) -> bool:
|
||||||
|
"""
|
||||||
|
Get combined entry signal from all strategies.
|
||||||
|
|
||||||
|
Collects entry signals from all strategies and combines them
|
||||||
|
according to the configured combination rules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with current state
|
||||||
|
df_index: Current index in the dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if combined signal suggests entry, False otherwise
|
||||||
|
"""
|
||||||
|
if not self.initialized:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Collect signals from all strategies
|
||||||
|
signals = {}
|
||||||
|
for strategy in self.strategies:
|
||||||
|
try:
|
||||||
|
signal = strategy.get_entry_signal(backtester, df_index)
|
||||||
|
signals[strategy.name] = {
|
||||||
|
"signal": signal,
|
||||||
|
"weight": strategy.weight,
|
||||||
|
"confidence": signal.confidence
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Strategy {strategy.name} entry signal failed: {e}")
|
||||||
|
signals[strategy.name] = {
|
||||||
|
"signal": StrategySignal("HOLD", 0.0),
|
||||||
|
"weight": strategy.weight,
|
||||||
|
"confidence": 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
return self._combine_entry_signals(signals)
|
||||||
|
|
||||||
|
def get_exit_signal(self, backtester, df_index: int) -> Tuple[Optional[str], Optional[float]]:
|
||||||
|
"""
|
||||||
|
Get combined exit signal from all strategies.
|
||||||
|
|
||||||
|
Collects exit signals from all strategies and combines them
|
||||||
|
according to the configured combination rules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backtester: Backtest instance with current state
|
||||||
|
df_index: Current index in the dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Optional[str], Optional[float]]: (exit_type, exit_price) or (None, None)
|
||||||
|
"""
|
||||||
|
if not self.initialized:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Collect signals from all strategies
|
||||||
|
signals = {}
|
||||||
|
for strategy in self.strategies:
|
||||||
|
try:
|
||||||
|
signal = strategy.get_exit_signal(backtester, df_index)
|
||||||
|
signals[strategy.name] = {
|
||||||
|
"signal": signal,
|
||||||
|
"weight": strategy.weight,
|
||||||
|
"confidence": signal.confidence
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Strategy {strategy.name} exit signal failed: {e}")
|
||||||
|
signals[strategy.name] = {
|
||||||
|
"signal": StrategySignal("HOLD", 0.0),
|
||||||
|
"weight": strategy.weight,
|
||||||
|
"confidence": 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
return self._combine_exit_signals(signals)
|
||||||
|
|
||||||
|
def _combine_entry_signals(self, signals: Dict) -> bool:
|
||||||
|
"""
|
||||||
|
Combine entry signals based on combination rules.
|
||||||
|
|
||||||
|
Supports multiple combination methods:
|
||||||
|
- any: Enter if ANY strategy signals entry
|
||||||
|
- all: Enter only if ALL strategies signal entry
|
||||||
|
- majority: Enter if majority of strategies signal entry
|
||||||
|
- weighted_consensus: Enter based on weighted average confidence
|
||||||
|
|
||||||
|
Args:
|
||||||
|
signals: Dictionary of strategy signals with weights and confidence
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Combined entry decision
|
||||||
|
"""
|
||||||
|
method = self.combination_rules.get("entry", "weighted_consensus")
|
||||||
|
min_confidence = self.combination_rules.get("min_confidence", 0.5)
|
||||||
|
|
||||||
|
# Filter for entry signals above minimum confidence
|
||||||
|
entry_signals = [
|
||||||
|
s for s in signals.values()
|
||||||
|
if s["signal"].signal_type == "ENTRY" and s["signal"].confidence >= min_confidence
|
||||||
|
]
|
||||||
|
|
||||||
|
if not entry_signals:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if method == "any":
|
||||||
|
# Enter if any strategy signals entry
|
||||||
|
return len(entry_signals) > 0
|
||||||
|
|
||||||
|
elif method == "all":
|
||||||
|
# Enter only if all strategies signal entry
|
||||||
|
return len(entry_signals) == len(self.strategies)
|
||||||
|
|
||||||
|
elif method == "majority":
|
||||||
|
# Enter if majority of strategies signal entry
|
||||||
|
return len(entry_signals) > len(self.strategies) / 2
|
||||||
|
|
||||||
|
elif method == "weighted_consensus":
|
||||||
|
# Enter based on weighted average confidence
|
||||||
|
total_weight = sum(s["weight"] for s in entry_signals)
|
||||||
|
if total_weight == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
weighted_confidence = sum(
|
||||||
|
s["signal"].confidence * s["weight"]
|
||||||
|
for s in entry_signals
|
||||||
|
) / total_weight
|
||||||
|
|
||||||
|
return weighted_confidence >= min_confidence
|
||||||
|
|
||||||
|
else:
|
||||||
|
logging.warning(f"Unknown entry combination method: {method}, using 'any'")
|
||||||
|
return len(entry_signals) > 0
|
||||||
|
|
||||||
|
def _combine_exit_signals(self, signals: Dict) -> Tuple[Optional[str], Optional[float]]:
|
||||||
|
"""
|
||||||
|
Combine exit signals based on combination rules.
|
||||||
|
|
||||||
|
Supports multiple combination methods:
|
||||||
|
- any: Exit if ANY strategy signals exit (recommended for risk management)
|
||||||
|
- all: Exit only if ALL strategies agree on exit
|
||||||
|
- priority: Exit based on priority order (STOP_LOSS > SELL_SIGNAL > others)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
signals: Dictionary of strategy signals with weights and confidence
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Optional[str], Optional[float]]: (exit_type, exit_price) or (None, None)
|
||||||
|
"""
|
||||||
|
method = self.combination_rules.get("exit", "any")
|
||||||
|
|
||||||
|
# Filter for exit signals
|
||||||
|
exit_signals = [
|
||||||
|
s for s in signals.values()
|
||||||
|
if s["signal"].signal_type == "EXIT"
|
||||||
|
]
|
||||||
|
|
||||||
|
if not exit_signals:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if method == "any":
|
||||||
|
# Exit if any strategy signals exit (first one found)
|
||||||
|
for signal_data in exit_signals:
|
||||||
|
signal = signal_data["signal"]
|
||||||
|
exit_type = signal.metadata.get("type", "EXIT")
|
||||||
|
return exit_type, signal.price
|
||||||
|
|
||||||
|
elif method == "all":
|
||||||
|
# Exit only if all strategies agree on exit
|
||||||
|
if len(exit_signals) == len(self.strategies):
|
||||||
|
signal = exit_signals[0]["signal"]
|
||||||
|
exit_type = signal.metadata.get("type", "EXIT")
|
||||||
|
return exit_type, signal.price
|
||||||
|
|
||||||
|
elif method == "priority":
|
||||||
|
# Priority order: STOP_LOSS > SELL_SIGNAL > others
|
||||||
|
stop_loss_signals = [
|
||||||
|
s for s in exit_signals
|
||||||
|
if s["signal"].metadata.get("type") == "STOP_LOSS"
|
||||||
|
]
|
||||||
|
if stop_loss_signals:
|
||||||
|
signal = stop_loss_signals[0]["signal"]
|
||||||
|
return "STOP_LOSS", signal.price
|
||||||
|
|
||||||
|
sell_signals = [
|
||||||
|
s for s in exit_signals
|
||||||
|
if s["signal"].metadata.get("type") == "SELL_SIGNAL"
|
||||||
|
]
|
||||||
|
if sell_signals:
|
||||||
|
signal = sell_signals[0]["signal"]
|
||||||
|
return "SELL_SIGNAL", signal.price
|
||||||
|
|
||||||
|
# Return first available exit signal
|
||||||
|
signal = exit_signals[0]["signal"]
|
||||||
|
exit_type = signal.metadata.get("type", "EXIT")
|
||||||
|
return exit_type, signal.price
|
||||||
|
|
||||||
|
else:
|
||||||
|
logging.warning(f"Unknown exit combination method: {method}, using 'any'")
|
||||||
|
# Fallback to 'any' method
|
||||||
|
signal = exit_signals[0]["signal"]
|
||||||
|
exit_type = signal.metadata.get("type", "EXIT")
|
||||||
|
return exit_type, signal.price
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def get_strategy_summary(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Get summary of loaded strategies and their configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Summary of strategies, weights, combination rules, and timeframes
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"strategies": [
|
||||||
|
{
|
||||||
|
"name": strategy.name,
|
||||||
|
"weight": strategy.weight,
|
||||||
|
"params": strategy.params,
|
||||||
|
"timeframes": strategy.get_timeframes(),
|
||||||
|
"initialized": strategy.initialized
|
||||||
|
}
|
||||||
|
for strategy in self.strategies
|
||||||
|
],
|
||||||
|
"combination_rules": self.combination_rules,
|
||||||
|
"total_strategies": len(self.strategies),
|
||||||
|
"initialized": self.initialized,
|
||||||
|
"all_timeframes": list(set().union(*[strategy.get_timeframes() for strategy in self.strategies]))
|
||||||
|
}
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""String representation of the strategy manager."""
|
||||||
|
strategy_names = [s.name for s in self.strategies]
|
||||||
|
return (f"StrategyManager(strategies={strategy_names}, "
|
||||||
|
f"initialized={self.initialized})")
|
||||||
|
|
||||||
|
|
||||||
|
def create_strategy_manager(config: Dict) -> StrategyManager:
|
||||||
|
"""
|
||||||
|
Factory function to create StrategyManager from configuration.
|
||||||
|
|
||||||
|
Provides a convenient way to create a StrategyManager instance
|
||||||
|
from a configuration dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary with strategies and combination_rules
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StrategyManager: Configured strategy manager instance
|
||||||
|
|
||||||
|
Example:
|
||||||
|
config = {
|
||||||
|
"strategies": [
|
||||||
|
{"name": "default", "weight": 1.0, "params": {}}
|
||||||
|
],
|
||||||
|
"combination_rules": {
|
||||||
|
"entry": "any",
|
||||||
|
"exit": "any"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
manager = create_strategy_manager(config)
|
||||||
|
"""
|
||||||
|
strategies_config = config.get("strategies", [])
|
||||||
|
combination_rules = config.get("combination_rules", {})
|
||||||
|
|
||||||
|
if not strategies_config:
|
||||||
|
raise ValueError("No strategies specified in configuration")
|
||||||
|
|
||||||
|
return StrategyManager(strategies_config, combination_rules)
|
||||||
218
cycles/strategies/random_strategy.py
Normal file
218
cycles/strategies/random_strategy.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
"""
|
||||||
|
Random Strategy for Testing
|
||||||
|
|
||||||
|
This strategy generates random entry and exit signals for testing the strategy system.
|
||||||
|
It's useful for verifying that the strategy framework is working correctly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from .base import StrategyBase, StrategySignal
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomStrategy(StrategyBase):
|
||||||
|
"""
|
||||||
|
Random signal generator strategy for testing.
|
||||||
|
|
||||||
|
This strategy generates random entry and exit signals with configurable
|
||||||
|
probability and confidence levels. It's designed to test the strategy
|
||||||
|
framework and signal processing system.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
entry_probability: Probability of generating an entry signal (0.0-1.0)
|
||||||
|
exit_probability: Probability of generating an exit signal (0.0-1.0)
|
||||||
|
min_confidence: Minimum confidence level for signals
|
||||||
|
max_confidence: Maximum confidence level for signals
|
||||||
|
timeframe: Timeframe to operate on (default: "1min")
|
||||||
|
signal_frequency: How often to generate signals (every N bars)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, weight: float = 1.0, params: Optional[Dict] = None):
|
||||||
|
"""Initialize the random strategy."""
|
||||||
|
super().__init__("random", weight, params)
|
||||||
|
|
||||||
|
# Strategy parameters with defaults
|
||||||
|
self.entry_probability = self.params.get("entry_probability", 0.05) # 5% chance per bar
|
||||||
|
self.exit_probability = self.params.get("exit_probability", 0.1) # 10% chance per bar
|
||||||
|
self.min_confidence = self.params.get("min_confidence", 0.6)
|
||||||
|
self.max_confidence = self.params.get("max_confidence", 0.9)
|
||||||
|
self.timeframe = self.params.get("timeframe", "1min")
|
||||||
|
self.signal_frequency = self.params.get("signal_frequency", 1) # Every bar
|
||||||
|
|
||||||
|
# Internal state
|
||||||
|
self.bar_count = 0
|
||||||
|
self.last_signal_bar = -1
|
||||||
|
self.last_processed_timestamp = None # Track last processed timestamp to avoid duplicates
|
||||||
|
|
||||||
|
logger.info(f"RandomStrategy initialized with entry_prob={self.entry_probability}, "
|
||||||
|
f"exit_prob={self.exit_probability}, timeframe={self.timeframe}")
|
||||||
|
|
||||||
|
def get_timeframes(self) -> List[str]:
|
||||||
|
"""Return required timeframes for this strategy."""
|
||||||
|
return [self.timeframe, "1min"] # Always include 1min for precision
|
||||||
|
|
||||||
|
def initialize(self, backtester) -> None:
|
||||||
|
"""Initialize strategy with backtester data."""
|
||||||
|
try:
|
||||||
|
logger.info(f"RandomStrategy: Starting initialization...")
|
||||||
|
|
||||||
|
# Resample data to required timeframes
|
||||||
|
self._resample_data(backtester.original_df)
|
||||||
|
|
||||||
|
# Get primary timeframe data
|
||||||
|
self.df = self.get_primary_timeframe_data()
|
||||||
|
|
||||||
|
if self.df is None or self.df.empty:
|
||||||
|
raise ValueError(f"No data available for timeframe {self.timeframe}")
|
||||||
|
|
||||||
|
# Reset internal state
|
||||||
|
self.bar_count = 0
|
||||||
|
self.last_signal_bar = -1
|
||||||
|
|
||||||
|
self.initialized = True
|
||||||
|
logger.info(f"RandomStrategy initialized with {len(self.df)} bars on {self.timeframe}")
|
||||||
|
logger.info(f"RandomStrategy: Data range from {self.df.index[0]} to {self.df.index[-1]}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize RandomStrategy: {e}")
|
||||||
|
logger.error(f"RandomStrategy: backtester.original_df shape: {backtester.original_df.shape if hasattr(backtester, 'original_df') else 'No original_df'}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_entry_signal(self, backtester, df_index: int) -> StrategySignal:
|
||||||
|
"""Generate random entry signals."""
|
||||||
|
if not self.initialized:
|
||||||
|
logger.warning(f"RandomStrategy: get_entry_signal called but not initialized")
|
||||||
|
return StrategySignal("HOLD", 0.0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get current timestamp to avoid duplicate signals
|
||||||
|
current_timestamp = None
|
||||||
|
if hasattr(backtester, 'original_df') and not backtester.original_df.empty:
|
||||||
|
current_timestamp = backtester.original_df.index[-1]
|
||||||
|
|
||||||
|
# Skip if we already processed this timestamp
|
||||||
|
if current_timestamp and self.last_processed_timestamp == current_timestamp:
|
||||||
|
return StrategySignal("HOLD", 0.0)
|
||||||
|
|
||||||
|
self.bar_count += 1
|
||||||
|
|
||||||
|
# Debug logging every 10 bars
|
||||||
|
if self.bar_count % 10 == 0:
|
||||||
|
logger.info(f"RandomStrategy: Processing bar {self.bar_count}, df_index={df_index}, timestamp={current_timestamp}")
|
||||||
|
|
||||||
|
# Check if we should generate a signal based on frequency
|
||||||
|
if (self.bar_count - self.last_signal_bar) < self.signal_frequency:
|
||||||
|
return StrategySignal("HOLD", 0.0)
|
||||||
|
|
||||||
|
# Generate random entry signal
|
||||||
|
random_value = random.random()
|
||||||
|
if random_value < self.entry_probability:
|
||||||
|
confidence = random.uniform(self.min_confidence, self.max_confidence)
|
||||||
|
self.last_signal_bar = self.bar_count
|
||||||
|
self.last_processed_timestamp = current_timestamp # Update last processed timestamp
|
||||||
|
|
||||||
|
# Get current price from backtester's original data (more reliable)
|
||||||
|
try:
|
||||||
|
if hasattr(backtester, 'original_df') and not backtester.original_df.empty:
|
||||||
|
# Use the last available price from the original data
|
||||||
|
current_price = backtester.original_df['close'].iloc[-1]
|
||||||
|
elif hasattr(backtester, 'df') and not backtester.df.empty:
|
||||||
|
# Fallback to backtester's main dataframe
|
||||||
|
current_price = backtester.df['close'].iloc[min(df_index, len(backtester.df)-1)]
|
||||||
|
else:
|
||||||
|
# Last resort: use our internal dataframe
|
||||||
|
current_price = self.df.iloc[min(df_index, len(self.df)-1)]['close']
|
||||||
|
except (IndexError, KeyError) as e:
|
||||||
|
logger.warning(f"RandomStrategy: Error getting current price: {e}, using fallback")
|
||||||
|
current_price = self.df.iloc[-1]['close'] if not self.df.empty else 50000.0
|
||||||
|
|
||||||
|
logger.info(f"RandomStrategy: Generated ENTRY signal at bar {self.bar_count}, "
|
||||||
|
f"price=${current_price:.2f}, confidence={confidence:.2f}, random_value={random_value:.3f}")
|
||||||
|
|
||||||
|
return StrategySignal(
|
||||||
|
"ENTRY",
|
||||||
|
confidence=confidence,
|
||||||
|
price=current_price,
|
||||||
|
metadata={
|
||||||
|
"strategy": "random",
|
||||||
|
"bar_count": self.bar_count,
|
||||||
|
"timeframe": self.timeframe
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update timestamp even if no signal generated
|
||||||
|
if current_timestamp:
|
||||||
|
self.last_processed_timestamp = current_timestamp
|
||||||
|
|
||||||
|
return StrategySignal("HOLD", 0.0)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"RandomStrategy entry signal error: {e}")
|
||||||
|
return StrategySignal("HOLD", 0.0)
|
||||||
|
|
||||||
|
def get_exit_signal(self, backtester, df_index: int) -> StrategySignal:
|
||||||
|
"""Generate random exit signals."""
|
||||||
|
if not self.initialized:
|
||||||
|
return StrategySignal("HOLD", 0.0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Only generate exit signals if we have an open position
|
||||||
|
# This is handled by the strategy trader, but we can add logic here
|
||||||
|
|
||||||
|
# Generate random exit signal
|
||||||
|
if random.random() < self.exit_probability:
|
||||||
|
confidence = random.uniform(self.min_confidence, self.max_confidence)
|
||||||
|
|
||||||
|
# Get current price from backtester's original data (more reliable)
|
||||||
|
try:
|
||||||
|
if hasattr(backtester, 'original_df') and not backtester.original_df.empty:
|
||||||
|
# Use the last available price from the original data
|
||||||
|
current_price = backtester.original_df['close'].iloc[-1]
|
||||||
|
elif hasattr(backtester, 'df') and not backtester.df.empty:
|
||||||
|
# Fallback to backtester's main dataframe
|
||||||
|
current_price = backtester.df['close'].iloc[min(df_index, len(backtester.df)-1)]
|
||||||
|
else:
|
||||||
|
# Last resort: use our internal dataframe
|
||||||
|
current_price = self.df.iloc[min(df_index, len(self.df)-1)]['close']
|
||||||
|
except (IndexError, KeyError) as e:
|
||||||
|
logger.warning(f"RandomStrategy: Error getting current price for exit: {e}, using fallback")
|
||||||
|
current_price = self.df.iloc[-1]['close'] if not self.df.empty else 50000.0
|
||||||
|
|
||||||
|
# Randomly choose exit type
|
||||||
|
exit_types = ["SELL_SIGNAL", "TAKE_PROFIT", "STOP_LOSS"]
|
||||||
|
exit_type = random.choice(exit_types)
|
||||||
|
|
||||||
|
logger.info(f"RandomStrategy: Generated EXIT signal at bar {self.bar_count}, "
|
||||||
|
f"price=${current_price:.2f}, confidence={confidence:.2f}, type={exit_type}")
|
||||||
|
|
||||||
|
return StrategySignal(
|
||||||
|
"EXIT",
|
||||||
|
confidence=confidence,
|
||||||
|
price=current_price,
|
||||||
|
metadata={
|
||||||
|
"type": exit_type,
|
||||||
|
"strategy": "random",
|
||||||
|
"bar_count": self.bar_count,
|
||||||
|
"timeframe": self.timeframe
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return StrategySignal("HOLD", 0.0)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"RandomStrategy exit signal error: {e}")
|
||||||
|
return StrategySignal("HOLD", 0.0)
|
||||||
|
|
||||||
|
def get_confidence(self, backtester, df_index: int) -> float:
|
||||||
|
"""Return random confidence level."""
|
||||||
|
return random.uniform(self.min_confidence, self.max_confidence)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""String representation of the strategy."""
|
||||||
|
return (f"RandomStrategy(entry_prob={self.entry_probability}, "
|
||||||
|
f"exit_prob={self.exit_probability}, timeframe={self.timeframe})")
|
||||||
0
cycles/utils/__init__.py
Normal file
0
cycles/utils/__init__.py
Normal file
152
cycles/utils/data_loader.py
Normal file
152
cycles/utils/data_loader.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import pandas as pd
|
||||||
|
from typing import Union, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .storage_utils import (
|
||||||
|
_parse_timestamp_column,
|
||||||
|
_filter_by_date_range,
|
||||||
|
_normalize_column_names,
|
||||||
|
TimestampParsingError,
|
||||||
|
DataLoadingError
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DataLoader:
|
||||||
|
"""Handles loading and preprocessing of data from various file formats"""
|
||||||
|
|
||||||
|
def __init__(self, data_dir: str, logging_instance: Optional[logging.Logger] = None):
|
||||||
|
"""Initialize data loader
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: Directory containing data files
|
||||||
|
logging_instance: Optional logging instance
|
||||||
|
"""
|
||||||
|
self.data_dir = data_dir
|
||||||
|
self.logging = logging_instance
|
||||||
|
|
||||||
|
def load_data(self, file_path: str, start_date: Union[str, pd.Timestamp],
|
||||||
|
stop_date: Union[str, pd.Timestamp]) -> pd.DataFrame:
|
||||||
|
"""Load data with optimized dtypes and filtering, supporting CSV and JSON input
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: path to the data file
|
||||||
|
start_date: start date (string or datetime-like)
|
||||||
|
stop_date: stop date (string or datetime-like)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pandas DataFrame with timestamp index
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DataLoadingError: If data loading fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Convert string dates to pandas datetime objects for proper comparison
|
||||||
|
start_date = pd.to_datetime(start_date)
|
||||||
|
stop_date = pd.to_datetime(stop_date)
|
||||||
|
|
||||||
|
# Determine file type
|
||||||
|
_, ext = os.path.splitext(file_path)
|
||||||
|
ext = ext.lower()
|
||||||
|
|
||||||
|
if ext == ".json":
|
||||||
|
return self._load_json_data(file_path, start_date, stop_date)
|
||||||
|
else:
|
||||||
|
return self._load_csv_data(file_path, start_date, stop_date)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error loading data from {file_path}: {e}"
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
# Return an empty DataFrame with a DatetimeIndex
|
||||||
|
return pd.DataFrame(index=pd.to_datetime([]))
|
||||||
|
|
||||||
|
def _load_json_data(self, file_path: str, start_date: pd.Timestamp,
|
||||||
|
stop_date: pd.Timestamp) -> pd.DataFrame:
|
||||||
|
"""Load and process JSON data file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to JSON file
|
||||||
|
start_date: Start date for filtering
|
||||||
|
stop_date: Stop date for filtering
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed DataFrame with timestamp index
|
||||||
|
"""
|
||||||
|
with open(os.path.join(self.data_dir, file_path), 'r') as f:
|
||||||
|
raw = json.load(f)
|
||||||
|
|
||||||
|
data = pd.DataFrame(raw["Data"])
|
||||||
|
data = _normalize_column_names(data)
|
||||||
|
|
||||||
|
# Convert timestamp to datetime
|
||||||
|
data["timestamp"] = pd.to_datetime(data["timestamp"], unit="s")
|
||||||
|
|
||||||
|
# Filter by date range
|
||||||
|
data = _filter_by_date_range(data, "timestamp", start_date, stop_date)
|
||||||
|
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.info(f"Data loaded from {file_path} for date range {start_date} to {stop_date}")
|
||||||
|
|
||||||
|
return data.set_index("timestamp")
|
||||||
|
|
||||||
|
def _load_csv_data(self, file_path: str, start_date: pd.Timestamp,
|
||||||
|
stop_date: pd.Timestamp) -> pd.DataFrame:
|
||||||
|
"""Load and process CSV data file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to CSV file
|
||||||
|
start_date: Start date for filtering
|
||||||
|
stop_date: Stop date for filtering
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed DataFrame with timestamp index
|
||||||
|
"""
|
||||||
|
# Define optimized dtypes
|
||||||
|
dtypes = {
|
||||||
|
'Open': 'float32',
|
||||||
|
'High': 'float32',
|
||||||
|
'Low': 'float32',
|
||||||
|
'Close': 'float32',
|
||||||
|
'Volume': 'float32'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Read data with original capitalized column names
|
||||||
|
data = pd.read_csv(os.path.join(self.data_dir, file_path), dtype=dtypes)
|
||||||
|
|
||||||
|
return self._process_csv_timestamps(data, start_date, stop_date, file_path)
|
||||||
|
|
||||||
|
def _process_csv_timestamps(self, data: pd.DataFrame, start_date: pd.Timestamp,
|
||||||
|
stop_date: pd.Timestamp, file_path: str) -> pd.DataFrame:
|
||||||
|
"""Process timestamps in CSV data and filter by date range
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame with CSV data
|
||||||
|
start_date: Start date for filtering
|
||||||
|
stop_date: Stop date for filtering
|
||||||
|
file_path: Original file path for logging
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed DataFrame with timestamp index
|
||||||
|
"""
|
||||||
|
if 'Timestamp' in data.columns:
|
||||||
|
data = _parse_timestamp_column(data, 'Timestamp')
|
||||||
|
data = _filter_by_date_range(data, 'Timestamp', start_date, stop_date)
|
||||||
|
data = _normalize_column_names(data)
|
||||||
|
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.info(f"Data loaded from {file_path} for date range {start_date} to {stop_date}")
|
||||||
|
|
||||||
|
return data.set_index('timestamp')
|
||||||
|
else:
|
||||||
|
# Attempt to use the first column if 'Timestamp' is not present
|
||||||
|
data.rename(columns={data.columns[0]: 'timestamp'}, inplace=True)
|
||||||
|
data = _parse_timestamp_column(data, 'timestamp')
|
||||||
|
data = _filter_by_date_range(data, 'timestamp', start_date, stop_date)
|
||||||
|
data = _normalize_column_names(data)
|
||||||
|
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.info(f"Data loaded from {file_path} (using first column as timestamp) for date range {start_date} to {stop_date}")
|
||||||
|
|
||||||
|
return data.set_index('timestamp')
|
||||||
106
cycles/utils/data_saver.py
Normal file
106
cycles/utils/data_saver.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
from typing import Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .storage_utils import DataSavingError
|
||||||
|
|
||||||
|
|
||||||
|
class DataSaver:
|
||||||
|
"""Handles saving data to various file formats"""
|
||||||
|
|
||||||
|
def __init__(self, data_dir: str, logging_instance: Optional[logging.Logger] = None):
|
||||||
|
"""Initialize data saver
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: Directory for saving data files
|
||||||
|
logging_instance: Optional logging instance
|
||||||
|
"""
|
||||||
|
self.data_dir = data_dir
|
||||||
|
self.logging = logging_instance
|
||||||
|
|
||||||
|
def save_data(self, data: pd.DataFrame, file_path: str) -> None:
|
||||||
|
"""Save processed data to a CSV file.
|
||||||
|
If the DataFrame has a DatetimeIndex, it's converted to float Unix timestamps
|
||||||
|
(seconds since epoch) before saving. The index is saved as a column named 'timestamp'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame to save
|
||||||
|
file_path: path to the data file relative to the data_dir
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DataSavingError: If saving fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data_to_save = data.copy()
|
||||||
|
data_to_save = self._prepare_data_for_saving(data_to_save)
|
||||||
|
|
||||||
|
# Save to CSV, ensuring the 'timestamp' column (if created) is written
|
||||||
|
full_path = os.path.join(self.data_dir, file_path)
|
||||||
|
data_to_save.to_csv(full_path, index=False)
|
||||||
|
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.info(f"Data saved to {full_path} with Unix timestamp column.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to save data to {file_path}: {e}"
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise DataSavingError(error_msg) from e
|
||||||
|
|
||||||
|
def _prepare_data_for_saving(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Prepare DataFrame for saving by handling different index types
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame to prepare
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame ready for saving
|
||||||
|
"""
|
||||||
|
if isinstance(data.index, pd.DatetimeIndex):
|
||||||
|
return self._convert_datetime_index_to_timestamp(data)
|
||||||
|
elif pd.api.types.is_numeric_dtype(data.index.dtype):
|
||||||
|
return self._convert_numeric_index_to_timestamp(data)
|
||||||
|
else:
|
||||||
|
# For other index types, save with the current index
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _convert_datetime_index_to_timestamp(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Convert DatetimeIndex to Unix timestamp column
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame with DatetimeIndex
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with timestamp column
|
||||||
|
"""
|
||||||
|
# Convert DatetimeIndex to Unix timestamp (float seconds since epoch)
|
||||||
|
data['timestamp'] = data.index.astype('int64') / 1e9
|
||||||
|
data.reset_index(drop=True, inplace=True)
|
||||||
|
|
||||||
|
# Ensure 'timestamp' is the first column if other columns exist
|
||||||
|
if 'timestamp' in data.columns and len(data.columns) > 1:
|
||||||
|
cols = ['timestamp'] + [col for col in data.columns if col != 'timestamp']
|
||||||
|
data = data[cols]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _convert_numeric_index_to_timestamp(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Convert numeric index to timestamp column
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame with numeric index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with timestamp column
|
||||||
|
"""
|
||||||
|
# If index is already numeric (e.g. float Unix timestamps from a previous save/load cycle)
|
||||||
|
data['timestamp'] = data.index
|
||||||
|
data.reset_index(drop=True, inplace=True)
|
||||||
|
|
||||||
|
# Ensure 'timestamp' is the first column if other columns exist
|
||||||
|
if 'timestamp' in data.columns and len(data.columns) > 1:
|
||||||
|
cols = ['timestamp'] + [col for col in data.columns if col != 'timestamp']
|
||||||
|
data = data[cols]
|
||||||
|
|
||||||
|
return data
|
||||||
199
cycles/utils/data_utils.py
Normal file
199
cycles/utils/data_utils.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
def check_data(data_df: pd.DataFrame) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if the input DataFrame has a DatetimeIndex.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_df (pd.DataFrame): DataFrame to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the DataFrame has a DatetimeIndex, False otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(data_df.index, pd.DatetimeIndex):
|
||||||
|
print("Warning: Input DataFrame must have a DatetimeIndex.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
agg_rules = {}
|
||||||
|
|
||||||
|
# Define aggregation rules based on available columns
|
||||||
|
if 'open' in data_df.columns:
|
||||||
|
agg_rules['open'] = 'first'
|
||||||
|
if 'high' in data_df.columns:
|
||||||
|
agg_rules['high'] = 'max'
|
||||||
|
if 'low' in data_df.columns:
|
||||||
|
agg_rules['low'] = 'min'
|
||||||
|
if 'close' in data_df.columns:
|
||||||
|
agg_rules['close'] = 'last'
|
||||||
|
if 'volume' in data_df.columns:
|
||||||
|
agg_rules['volume'] = 'sum'
|
||||||
|
|
||||||
|
if not agg_rules:
|
||||||
|
print("Warning: No standard OHLCV columns (open, high, low, close, volume) found for daily aggregation.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return agg_rules
|
||||||
|
|
||||||
|
def aggregate_to_weekly(data_df: pd.DataFrame, weeks: int = 1) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Aggregates time-series financial data to weekly OHLCV format.
|
||||||
|
|
||||||
|
The input DataFrame is expected to have a DatetimeIndex.
|
||||||
|
'open' will be the first 'open' price of the week.
|
||||||
|
'close' will be the last 'close' price of the week.
|
||||||
|
'high' will be the maximum 'high' price of the week.
|
||||||
|
'low' will be the minimum 'low' price of the week.
|
||||||
|
'volume' (if present) will be the sum of volumes for the week.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_df (pd.DataFrame): DataFrame with a DatetimeIndex and columns
|
||||||
|
like 'open', 'high', 'low', 'close', and optionally 'volume'.
|
||||||
|
weeks (int): The number of weeks to aggregate to. Default is 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: DataFrame aggregated to weekly OHLCV data.
|
||||||
|
The index will be a DatetimeIndex with the time set to the start of the week.
|
||||||
|
Returns an empty DataFrame if no relevant OHLCV columns are found.
|
||||||
|
"""
|
||||||
|
|
||||||
|
agg_rules = check_data(data_df)
|
||||||
|
|
||||||
|
if not agg_rules:
|
||||||
|
print("Warning: No standard OHLCV columns (open, high, low, close, volume) found for weekly aggregation.")
|
||||||
|
return pd.DataFrame(index=pd.to_datetime([]))
|
||||||
|
|
||||||
|
# Resample to weekly frequency and apply aggregation rules
|
||||||
|
weekly_data = data_df.resample(f'{weeks}W').agg(agg_rules)
|
||||||
|
|
||||||
|
weekly_data.dropna(how='all', inplace=True)
|
||||||
|
|
||||||
|
# Adjust timestamps to the start of the week
|
||||||
|
if not weekly_data.empty and isinstance(weekly_data.index, pd.DatetimeIndex):
|
||||||
|
weekly_data.index = weekly_data.index.floor('W')
|
||||||
|
|
||||||
|
return weekly_data
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_to_daily(data_df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Aggregates time-series financial data to daily OHLCV format.
|
||||||
|
|
||||||
|
The input DataFrame is expected to have a DatetimeIndex.
|
||||||
|
'open' will be the first 'open' price of the day.
|
||||||
|
'close' will be the last 'close' price of the day.
|
||||||
|
'high' will be the maximum 'high' price of the day.
|
||||||
|
'low' will be the minimum 'low' price of the day.
|
||||||
|
'volume' (if present) will be the sum of volumes for the day.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_df (pd.DataFrame): DataFrame with a DatetimeIndex and columns
|
||||||
|
like 'open', 'high', 'low', 'close', and optionally 'volume'.
|
||||||
|
Column names are expected to be lowercase.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: DataFrame aggregated to daily OHLCV data.
|
||||||
|
The index will be a DatetimeIndex with the time set to noon (12:00:00) for each day.
|
||||||
|
Returns an empty DataFrame if no relevant OHLCV columns are found.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the input DataFrame does not have a DatetimeIndex.
|
||||||
|
"""
|
||||||
|
|
||||||
|
agg_rules = check_data(data_df)
|
||||||
|
|
||||||
|
if not agg_rules:
|
||||||
|
# Log a warning or raise an error if no relevant columns are found
|
||||||
|
# For now, returning an empty DataFrame with a message might be suitable for some cases
|
||||||
|
print("Warning: No standard OHLCV columns (open, high, low, close, volume) found for daily aggregation.")
|
||||||
|
return pd.DataFrame(index=pd.to_datetime([])) # Return empty DF with datetime index
|
||||||
|
|
||||||
|
# Resample to daily frequency and apply aggregation rules
|
||||||
|
daily_data = data_df.resample('D').agg(agg_rules)
|
||||||
|
|
||||||
|
# Adjust timestamps to noon if data exists
|
||||||
|
if not daily_data.empty and isinstance(daily_data.index, pd.DatetimeIndex):
|
||||||
|
daily_data.index = daily_data.index + pd.Timedelta(hours=12)
|
||||||
|
|
||||||
|
# Remove rows where all values are NaN (these are days with no trades in the original data)
|
||||||
|
daily_data.dropna(how='all', inplace=True)
|
||||||
|
|
||||||
|
return daily_data
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_to_hourly(data_df: pd.DataFrame, hours: int = 1) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Aggregates time-series financial data to hourly OHLCV format.
|
||||||
|
|
||||||
|
The input DataFrame is expected to have a DatetimeIndex.
|
||||||
|
'open' will be the first 'open' price of the hour.
|
||||||
|
'close' will be the last 'close' price of the hour.
|
||||||
|
'high' will be the maximum 'high' price of the hour.
|
||||||
|
'low' will be the minimum 'low' price of the hour.
|
||||||
|
'volume' (if present) will be the sum of volumes for the hour.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_df (pd.DataFrame): DataFrame with a DatetimeIndex and columns
|
||||||
|
like 'open', 'high', 'low', 'close', and optionally 'volume'.
|
||||||
|
hours (int): The number of hours to aggregate to. Default is 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: DataFrame aggregated to hourly OHLCV data.
|
||||||
|
The index will be a DatetimeIndex with the time set to the start of the hour.
|
||||||
|
Returns an empty DataFrame if no relevant OHLCV columns are found.
|
||||||
|
"""
|
||||||
|
|
||||||
|
agg_rules = check_data(data_df)
|
||||||
|
|
||||||
|
if not agg_rules:
|
||||||
|
print("Warning: No standard OHLCV columns (open, high, low, close, volume) found for hourly aggregation.")
|
||||||
|
return pd.DataFrame(index=pd.to_datetime([]))
|
||||||
|
|
||||||
|
# Resample to hourly frequency and apply aggregation rules
|
||||||
|
hourly_data = data_df.resample(f'{hours}h').agg(agg_rules)
|
||||||
|
|
||||||
|
hourly_data.dropna(how='all', inplace=True)
|
||||||
|
|
||||||
|
# Adjust timestamps to the start of the hour
|
||||||
|
if not hourly_data.empty and isinstance(hourly_data.index, pd.DatetimeIndex):
|
||||||
|
hourly_data.index = hourly_data.index.floor('h')
|
||||||
|
|
||||||
|
return hourly_data
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_to_minutes(data_df: pd.DataFrame, minutes: int) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Aggregates time-series financial data to N-minute OHLCV format.
|
||||||
|
|
||||||
|
The input DataFrame is expected to have a DatetimeIndex.
|
||||||
|
'open' will be the first 'open' price of the N-minute interval.
|
||||||
|
'close' will be the last 'close' price of the N-minute interval.
|
||||||
|
'high' will be the maximum 'high' price of the N-minute interval.
|
||||||
|
'low' will be the minimum 'low' price of the N-minute interval.
|
||||||
|
'volume' (if present) will be the sum of volumes for the N-minute interval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_df (pd.DataFrame): DataFrame with a DatetimeIndex and columns
|
||||||
|
like 'open', 'high', 'low', 'close', and optionally 'volume'.
|
||||||
|
minutes (int): The number of minutes to aggregate to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: DataFrame aggregated to N-minute OHLCV data.
|
||||||
|
The index will be a DatetimeIndex.
|
||||||
|
Returns an empty DataFrame if no relevant OHLCV columns are found or
|
||||||
|
if the input DataFrame does not have a DatetimeIndex.
|
||||||
|
"""
|
||||||
|
agg_rules_obj = check_data(data_df) # check_data returns rules or False
|
||||||
|
|
||||||
|
if not agg_rules_obj:
|
||||||
|
# check_data already prints a warning if index is not DatetimeIndex or no OHLCV columns
|
||||||
|
# Ensure an empty DataFrame with a DatetimeIndex is returned for consistency
|
||||||
|
return pd.DataFrame(index=pd.to_datetime([]))
|
||||||
|
|
||||||
|
# Resample to N-minute frequency and apply aggregation rules
|
||||||
|
# Using .agg(agg_rules_obj) where agg_rules_obj is the dict from check_data
|
||||||
|
resampled_data = data_df.resample(f'{minutes}min').agg(agg_rules_obj)
|
||||||
|
|
||||||
|
resampled_data.dropna(how='all', inplace=True)
|
||||||
|
|
||||||
|
return resampled_data
|
||||||
233
cycles/utils/progress_manager.py
Normal file
233
cycles/utils/progress_manager.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Progress Manager for tracking multiple parallel backtest tasks
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import sys
|
||||||
|
from typing import Dict, Optional, Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskProgress:
|
||||||
|
"""Represents progress information for a single task"""
|
||||||
|
task_id: str
|
||||||
|
name: str
|
||||||
|
current: int
|
||||||
|
total: int
|
||||||
|
start_time: float
|
||||||
|
last_update: float
|
||||||
|
|
||||||
|
@property
|
||||||
|
def percentage(self) -> float:
|
||||||
|
"""Calculate completion percentage"""
|
||||||
|
if self.total == 0:
|
||||||
|
return 0.0
|
||||||
|
return (self.current / self.total) * 100
|
||||||
|
|
||||||
|
@property
|
||||||
|
def elapsed_time(self) -> float:
|
||||||
|
"""Calculate elapsed time in seconds"""
|
||||||
|
return time.time() - self.start_time
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eta(self) -> Optional[float]:
|
||||||
|
"""Estimate time to completion in seconds"""
|
||||||
|
if self.current == 0 or self.percentage >= 100:
|
||||||
|
return None
|
||||||
|
|
||||||
|
elapsed = self.elapsed_time
|
||||||
|
rate = self.current / elapsed
|
||||||
|
remaining = self.total - self.current
|
||||||
|
return remaining / rate if rate > 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressManager:
|
||||||
|
"""Manages progress tracking for multiple parallel tasks"""
|
||||||
|
|
||||||
|
def __init__(self, update_interval: float = 1.0, display_width: int = 50):
|
||||||
|
"""
|
||||||
|
Initialize progress manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
update_interval: How often to update display (seconds)
|
||||||
|
display_width: Width of progress bar in characters
|
||||||
|
"""
|
||||||
|
self.tasks: Dict[str, TaskProgress] = {}
|
||||||
|
self.update_interval = update_interval
|
||||||
|
self.display_width = display_width
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.display_thread: Optional[threading.Thread] = None
|
||||||
|
self.running = False
|
||||||
|
self.last_display_height = 0
|
||||||
|
|
||||||
|
def start_task(self, task_id: str, name: str, total: int) -> None:
|
||||||
|
"""
|
||||||
|
Start tracking a new task
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Unique identifier for the task
|
||||||
|
name: Human-readable name for the task
|
||||||
|
total: Total number of steps in the task
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
self.tasks[task_id] = TaskProgress(
|
||||||
|
task_id=task_id,
|
||||||
|
name=name,
|
||||||
|
current=0,
|
||||||
|
total=total,
|
||||||
|
start_time=time.time(),
|
||||||
|
last_update=time.time()
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_progress(self, task_id: str, current: int) -> None:
|
||||||
|
"""
|
||||||
|
Update progress for a specific task
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task identifier
|
||||||
|
current: Current progress value
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
if task_id in self.tasks:
|
||||||
|
self.tasks[task_id].current = current
|
||||||
|
self.tasks[task_id].last_update = time.time()
|
||||||
|
|
||||||
|
def complete_task(self, task_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Mark a task as completed
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task identifier
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
if task_id in self.tasks:
|
||||||
|
task = self.tasks[task_id]
|
||||||
|
task.current = task.total
|
||||||
|
task.last_update = time.time()
|
||||||
|
|
||||||
|
def start_display(self) -> None:
|
||||||
|
"""Start the progress display thread"""
|
||||||
|
if not self.running:
|
||||||
|
self.running = True
|
||||||
|
self.display_thread = threading.Thread(target=self._display_loop, daemon=True)
|
||||||
|
self.display_thread.start()
|
||||||
|
|
||||||
|
def stop_display(self) -> None:
|
||||||
|
"""Stop the progress display thread"""
|
||||||
|
self.running = False
|
||||||
|
if self.display_thread:
|
||||||
|
self.display_thread.join(timeout=1.0)
|
||||||
|
self._clear_display()
|
||||||
|
|
||||||
|
def _display_loop(self) -> None:
|
||||||
|
"""Main loop for updating the progress display"""
|
||||||
|
while self.running:
|
||||||
|
self._update_display()
|
||||||
|
time.sleep(self.update_interval)
|
||||||
|
|
||||||
|
def _update_display(self) -> None:
|
||||||
|
"""Update the console display with current progress"""
|
||||||
|
with self.lock:
|
||||||
|
if not self.tasks:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Clear previous display
|
||||||
|
self._clear_display()
|
||||||
|
|
||||||
|
# Build display lines
|
||||||
|
lines = []
|
||||||
|
for task in sorted(self.tasks.values(), key=lambda t: t.task_id):
|
||||||
|
line = self._format_progress_line(task)
|
||||||
|
lines.append(line)
|
||||||
|
|
||||||
|
# Print all lines
|
||||||
|
for line in lines:
|
||||||
|
print(line, flush=True)
|
||||||
|
|
||||||
|
self.last_display_height = len(lines)
|
||||||
|
|
||||||
|
def _clear_display(self) -> None:
|
||||||
|
"""Clear the previous progress display"""
|
||||||
|
if self.last_display_height > 0:
|
||||||
|
# Move cursor up and clear lines
|
||||||
|
for _ in range(self.last_display_height):
|
||||||
|
sys.stdout.write('\033[F') # Move cursor up one line
|
||||||
|
sys.stdout.write('\033[K') # Clear line
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
def _format_progress_line(self, task: TaskProgress) -> str:
|
||||||
|
"""
|
||||||
|
Format a single progress line for display
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: TaskProgress instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted progress string
|
||||||
|
"""
|
||||||
|
# Progress bar
|
||||||
|
filled_width = int(task.percentage / 100 * self.display_width)
|
||||||
|
bar = '█' * filled_width + '░' * (self.display_width - filled_width)
|
||||||
|
|
||||||
|
# Time information
|
||||||
|
elapsed_str = self._format_time(task.elapsed_time)
|
||||||
|
eta_str = self._format_time(task.eta) if task.eta else "N/A"
|
||||||
|
|
||||||
|
# Format line
|
||||||
|
line = (f"{task.name:<25} │{bar}│ "
|
||||||
|
f"{task.percentage:5.1f}% "
|
||||||
|
f"({task.current:,}/{task.total:,}) "
|
||||||
|
f"⏱ {elapsed_str} ETA: {eta_str}")
|
||||||
|
|
||||||
|
return line
|
||||||
|
|
||||||
|
def _format_time(self, seconds: float) -> str:
|
||||||
|
"""
|
||||||
|
Format time duration for display
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seconds: Time in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted time string
|
||||||
|
"""
|
||||||
|
if seconds < 60:
|
||||||
|
return f"{seconds:.0f}s"
|
||||||
|
elif seconds < 3600:
|
||||||
|
minutes = seconds / 60
|
||||||
|
return f"{minutes:.1f}m"
|
||||||
|
else:
|
||||||
|
hours = seconds / 3600
|
||||||
|
return f"{hours:.1f}h"
|
||||||
|
|
||||||
|
def get_task_progress_callback(self, task_id: str) -> Callable[[int], None]:
|
||||||
|
"""
|
||||||
|
Get a progress callback function for a specific task
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callback function that updates progress for this task
|
||||||
|
"""
|
||||||
|
def callback(current: int) -> None:
|
||||||
|
self.update_progress(task_id, current)
|
||||||
|
|
||||||
|
return callback
|
||||||
|
|
||||||
|
def all_tasks_completed(self) -> bool:
|
||||||
|
"""Check if all tasks are completed"""
|
||||||
|
with self.lock:
|
||||||
|
return all(task.current >= task.total for task in self.tasks.values())
|
||||||
|
|
||||||
|
def get_summary(self) -> str:
|
||||||
|
"""Get a summary of all tasks"""
|
||||||
|
with self.lock:
|
||||||
|
total_tasks = len(self.tasks)
|
||||||
|
completed_tasks = sum(1 for task in self.tasks.values()
|
||||||
|
if task.current >= task.total)
|
||||||
|
|
||||||
|
return f"Tasks: {completed_tasks}/{total_tasks} completed"
|
||||||
179
cycles/utils/result_formatter.py
Normal file
179
cycles/utils/result_formatter.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
import os
|
||||||
|
import csv
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from collections import defaultdict
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .storage_utils import DataSavingError
|
||||||
|
|
||||||
|
|
||||||
|
class ResultFormatter:
|
||||||
|
"""Handles formatting and writing of backtest results to CSV files"""
|
||||||
|
|
||||||
|
def __init__(self, results_dir: str, logging_instance: Optional[logging.Logger] = None):
|
||||||
|
"""Initialize result formatter
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results_dir: Directory for saving result files
|
||||||
|
logging_instance: Optional logging instance
|
||||||
|
"""
|
||||||
|
self.results_dir = results_dir
|
||||||
|
self.logging = logging_instance
|
||||||
|
|
||||||
|
def format_row(self, row: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
"""Format a row for a combined results CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
row: Dictionary containing row data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with formatted values
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"timeframe": row["timeframe"],
|
||||||
|
"stop_loss_pct": f"{row['stop_loss_pct']*100:.2f}%",
|
||||||
|
"n_trades": row["n_trades"],
|
||||||
|
"n_stop_loss": row["n_stop_loss"],
|
||||||
|
"win_rate": f"{row['win_rate']*100:.2f}%",
|
||||||
|
"max_drawdown": f"{row['max_drawdown']*100:.2f}%",
|
||||||
|
"avg_trade": f"{row['avg_trade']*100:.2f}%",
|
||||||
|
"profit_ratio": f"{row['profit_ratio']*100:.2f}%",
|
||||||
|
"final_usd": f"{row['final_usd']:.2f}",
|
||||||
|
"total_fees_usd": f"{row['total_fees_usd']:.2f}",
|
||||||
|
}
|
||||||
|
|
||||||
|
def write_results_chunk(self, filename: str, fieldnames: List[str],
|
||||||
|
rows: List[Dict], write_header: bool = False,
|
||||||
|
initial_usd: Optional[float] = None) -> None:
|
||||||
|
"""Write a chunk of results to a CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: filename to write to
|
||||||
|
fieldnames: list of fieldnames
|
||||||
|
rows: list of rows
|
||||||
|
write_header: whether to write the header
|
||||||
|
initial_usd: initial USD value for header comment
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DataSavingError: If writing fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
mode = 'w' if write_header else 'a'
|
||||||
|
|
||||||
|
with open(filename, mode, newline="") as csvfile:
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
|
if write_header:
|
||||||
|
if initial_usd is not None:
|
||||||
|
csvfile.write(f"# initial_usd: {initial_usd}\n")
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
# Only keep keys that are in fieldnames
|
||||||
|
filtered_row = {k: v for k, v in row.items() if k in fieldnames}
|
||||||
|
writer.writerow(filtered_row)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to write results chunk to {filename}: {e}"
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise DataSavingError(error_msg) from e
|
||||||
|
|
||||||
|
def write_backtest_results(self, filename: str, fieldnames: List[str],
|
||||||
|
rows: List[Dict], metadata_lines: Optional[List[str]] = None) -> str:
|
||||||
|
"""Write combined backtest results to a CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: filename to write to
|
||||||
|
fieldnames: list of fieldnames
|
||||||
|
rows: list of result dictionaries
|
||||||
|
metadata_lines: optional list of strings to write as header comments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Full path to the written file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DataSavingError: If writing fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
fname = os.path.join(self.results_dir, filename)
|
||||||
|
with open(fname, "w", newline="") as csvfile:
|
||||||
|
if metadata_lines:
|
||||||
|
for line in metadata_lines:
|
||||||
|
csvfile.write(f"{line}\n")
|
||||||
|
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter='\t')
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
writer.writerow(self.format_row(row))
|
||||||
|
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.info(f"Combined results written to {fname}")
|
||||||
|
|
||||||
|
return fname
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to write backtest results to {filename}: {e}"
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise DataSavingError(error_msg) from e
|
||||||
|
|
||||||
|
def write_trades(self, all_trade_rows: List[Dict], trades_fieldnames: List[str]) -> None:
|
||||||
|
"""Write trades to separate CSV files grouped by timeframe and stop loss
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_trade_rows: list of trade dictionaries
|
||||||
|
trades_fieldnames: list of trade fieldnames
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DataSavingError: If writing fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
trades_by_combo = self._group_trades_by_combination(all_trade_rows)
|
||||||
|
|
||||||
|
for (tf, sl), trades in trades_by_combo.items():
|
||||||
|
self._write_single_trade_file(tf, sl, trades, trades_fieldnames)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to write trades: {e}"
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise DataSavingError(error_msg) from e
|
||||||
|
|
||||||
|
def _group_trades_by_combination(self, all_trade_rows: List[Dict]) -> Dict:
|
||||||
|
"""Group trades by timeframe and stop loss combination
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_trade_rows: List of trade dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary grouped by (timeframe, stop_loss_pct) tuples
|
||||||
|
"""
|
||||||
|
trades_by_combo = defaultdict(list)
|
||||||
|
for trade in all_trade_rows:
|
||||||
|
tf = trade.get("timeframe")
|
||||||
|
sl = trade.get("stop_loss_pct")
|
||||||
|
trades_by_combo[(tf, sl)].append(trade)
|
||||||
|
return trades_by_combo
|
||||||
|
|
||||||
|
def _write_single_trade_file(self, timeframe: str, stop_loss_pct: float,
|
||||||
|
trades: List[Dict], trades_fieldnames: List[str]) -> None:
|
||||||
|
"""Write trades for a single timeframe/stop-loss combination
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeframe: Timeframe identifier
|
||||||
|
stop_loss_pct: Stop loss percentage
|
||||||
|
trades: List of trades for this combination
|
||||||
|
trades_fieldnames: List of field names for trades
|
||||||
|
"""
|
||||||
|
sl_percent = int(round(stop_loss_pct * 100))
|
||||||
|
trades_filename = os.path.join(self.results_dir, f"trades_{timeframe}_ST{sl_percent}pct.csv")
|
||||||
|
|
||||||
|
with open(trades_filename, "w", newline="") as csvfile:
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=trades_fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
for trade in trades:
|
||||||
|
writer.writerow({k: trade.get(k, "") for k in trades_fieldnames})
|
||||||
|
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.info(f"Trades written to {trades_filename}")
|
||||||
123
cycles/utils/storage.py
Normal file
123
cycles/utils/storage.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
from typing import Optional, Union, Dict, Any, List
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .data_loader import DataLoader
|
||||||
|
from .data_saver import DataSaver
|
||||||
|
from .result_formatter import ResultFormatter
|
||||||
|
from .storage_utils import DataLoadingError, DataSavingError
|
||||||
|
|
||||||
|
RESULTS_DIR = "../results"
|
||||||
|
DATA_DIR = "../data"
|
||||||
|
|
||||||
|
|
||||||
|
class Storage:
|
||||||
|
"""Unified storage interface for data and results operations
|
||||||
|
|
||||||
|
Acts as a coordinator for DataLoader, DataSaver, and ResultFormatter components,
|
||||||
|
maintaining backward compatibility while providing a clean separation of concerns.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, logging=None, results_dir=RESULTS_DIR, data_dir=DATA_DIR):
|
||||||
|
"""Initialize storage with component instances
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logging: Optional logging instance
|
||||||
|
results_dir: Directory for results files
|
||||||
|
data_dir: Directory for data files
|
||||||
|
"""
|
||||||
|
self.results_dir = results_dir
|
||||||
|
self.data_dir = data_dir
|
||||||
|
self.logging = logging
|
||||||
|
|
||||||
|
# Create directories if they don't exist
|
||||||
|
os.makedirs(self.results_dir, exist_ok=True)
|
||||||
|
os.makedirs(self.data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Initialize component instances
|
||||||
|
self.data_loader = DataLoader(data_dir, logging)
|
||||||
|
self.data_saver = DataSaver(data_dir, logging)
|
||||||
|
self.result_formatter = ResultFormatter(results_dir, logging)
|
||||||
|
|
||||||
|
def load_data(self, file_path: str, start_date: Union[str, pd.Timestamp],
|
||||||
|
stop_date: Union[str, pd.Timestamp]) -> pd.DataFrame:
|
||||||
|
"""Load data with optimized dtypes and filtering, supporting CSV and JSON input
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: path to the data file
|
||||||
|
start_date: start date (string or datetime-like)
|
||||||
|
stop_date: stop date (string or datetime-like)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pandas DataFrame with timestamp index
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DataLoadingError: If data loading fails
|
||||||
|
"""
|
||||||
|
return self.data_loader.load_data(file_path, start_date, stop_date)
|
||||||
|
|
||||||
|
def save_data(self, data: pd.DataFrame, file_path: str) -> None:
|
||||||
|
"""Save processed data to a CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame to save
|
||||||
|
file_path: path to the data file relative to the data_dir
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DataSavingError: If saving fails
|
||||||
|
"""
|
||||||
|
self.data_saver.save_data(data, file_path)
|
||||||
|
|
||||||
|
def format_row(self, row: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
"""Format a row for a combined results CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
row: Dictionary containing row data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with formatted values
|
||||||
|
"""
|
||||||
|
return self.result_formatter.format_row(row)
|
||||||
|
|
||||||
|
def write_results_chunk(self, filename: str, fieldnames: List[str],
|
||||||
|
rows: List[Dict], write_header: bool = False,
|
||||||
|
initial_usd: Optional[float] = None) -> None:
|
||||||
|
"""Write a chunk of results to a CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: filename to write to
|
||||||
|
fieldnames: list of fieldnames
|
||||||
|
rows: list of rows
|
||||||
|
write_header: whether to write the header
|
||||||
|
initial_usd: initial USD value for header comment
|
||||||
|
"""
|
||||||
|
self.result_formatter.write_results_chunk(
|
||||||
|
filename, fieldnames, rows, write_header, initial_usd
|
||||||
|
)
|
||||||
|
|
||||||
|
def write_backtest_results(self, filename: str, fieldnames: List[str],
|
||||||
|
rows: List[Dict], metadata_lines: Optional[List[str]] = None) -> str:
|
||||||
|
"""Write combined backtest results to a CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: filename to write to
|
||||||
|
fieldnames: list of fieldnames
|
||||||
|
rows: list of result dictionaries
|
||||||
|
metadata_lines: optional list of strings to write as header comments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Full path to the written file
|
||||||
|
"""
|
||||||
|
return self.result_formatter.write_backtest_results(
|
||||||
|
filename, fieldnames, rows, metadata_lines
|
||||||
|
)
|
||||||
|
|
||||||
|
def write_trades(self, all_trade_rows: List[Dict], trades_fieldnames: List[str]) -> None:
|
||||||
|
"""Write trades to separate CSV files grouped by timeframe and stop loss
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_trade_rows: list of trade dictionaries
|
||||||
|
trades_fieldnames: list of trade fieldnames
|
||||||
|
"""
|
||||||
|
self.result_formatter.write_trades(all_trade_rows, trades_fieldnames)
|
||||||
73
cycles/utils/storage_utils.py
Normal file
73
cycles/utils/storage_utils.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
class TimestampParsingError(Exception):
|
||||||
|
"""Custom exception for timestamp parsing errors"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DataLoadingError(Exception):
|
||||||
|
"""Custom exception for data loading errors"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DataSavingError(Exception):
|
||||||
|
"""Custom exception for data saving errors"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_timestamp_column(data: pd.DataFrame, column_name: str) -> pd.DataFrame:
|
||||||
|
"""Parse timestamp column handling both Unix timestamps and datetime strings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame containing the timestamp column
|
||||||
|
column_name: Name of the timestamp column
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with parsed timestamp column
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimestampParsingError: If timestamp parsing fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
sample_timestamp = str(data[column_name].iloc[0])
|
||||||
|
try:
|
||||||
|
# Check if it's a Unix timestamp (numeric)
|
||||||
|
float(sample_timestamp)
|
||||||
|
# It's a Unix timestamp, convert using unit='s'
|
||||||
|
data[column_name] = pd.to_datetime(data[column_name], unit='s')
|
||||||
|
except ValueError:
|
||||||
|
# It's already in datetime string format, convert without unit
|
||||||
|
data[column_name] = pd.to_datetime(data[column_name])
|
||||||
|
return data
|
||||||
|
except Exception as e:
|
||||||
|
raise TimestampParsingError(f"Failed to parse timestamp column '{column_name}': {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_by_date_range(data: pd.DataFrame, timestamp_col: str,
|
||||||
|
start_date: pd.Timestamp, stop_date: pd.Timestamp) -> pd.DataFrame:
|
||||||
|
"""Filter DataFrame by date range
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame to filter
|
||||||
|
timestamp_col: Name of timestamp column
|
||||||
|
start_date: Start date for filtering
|
||||||
|
stop_date: Stop date for filtering
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered DataFrame
|
||||||
|
"""
|
||||||
|
return data[(data[timestamp_col] >= start_date) & (data[timestamp_col] <= stop_date)]
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_column_names(data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Convert all column names to lowercase
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame to normalize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with lowercase column names
|
||||||
|
"""
|
||||||
|
data.columns = data.columns.str.lower()
|
||||||
|
return data
|
||||||
21
cycles/utils/system.py
Normal file
21
cycles/utils/system.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import os
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
class SystemUtils:
|
||||||
|
|
||||||
|
def __init__(self, logging=None):
|
||||||
|
self.logging = logging
|
||||||
|
|
||||||
|
def get_optimal_workers(self):
|
||||||
|
"""Determine optimal number of worker processes based on system resources"""
|
||||||
|
cpu_count = os.cpu_count() or 4
|
||||||
|
memory_gb = psutil.virtual_memory().total / (1024**3)
|
||||||
|
|
||||||
|
# OPTIMIZATION: More aggressive worker allocation for better performance
|
||||||
|
workers_by_memory = max(1, int(memory_gb / 2)) # 2GB per worker
|
||||||
|
workers_by_cpu = max(1, int(cpu_count * 0.8)) # Use 80% of CPU cores
|
||||||
|
optimal_workers = min(workers_by_cpu, workers_by_memory, 8) # Cap at 8 workers
|
||||||
|
|
||||||
|
if self.logging is not None:
|
||||||
|
self.logging.info(f"Using {optimal_workers} workers for processing (CPU-based: {workers_by_cpu}, Memory-based: {workers_by_memory})")
|
||||||
|
return optimal_workers
|
||||||
106
docs/analysis.md
Normal file
106
docs/analysis.md
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
# Analysis Module
|
||||||
|
|
||||||
|
This document provides an overview of the `Analysis` module and its components, which are typically used for technical analysis of financial market data.
|
||||||
|
|
||||||
|
## Modules
|
||||||
|
|
||||||
|
The `Analysis` module includes classes for calculating common technical indicators:
|
||||||
|
|
||||||
|
- **Relative Strength Index (RSI)**: Implemented in `cycles/Analysis/rsi.py`.
|
||||||
|
- **Bollinger Bands**: Implemented in `cycles/Analysis/boillinger_band.py`.
|
||||||
|
- Note: Trading strategies are detailed in `strategies.md`.
|
||||||
|
|
||||||
|
## Class: `RSI`
|
||||||
|
|
||||||
|
Found in `cycles/Analysis/rsi.py`.
|
||||||
|
|
||||||
|
Calculates the Relative Strength Index.
|
||||||
|
### Mathematical Model
|
||||||
|
The standard RSI calculation typically involves Wilder's smoothing for average gains and losses.
|
||||||
|
1. **Price Change (Delta)**: Difference between consecutive closing prices.
|
||||||
|
2. **Gain and Loss**: Separate positive (gain) and negative (loss, expressed as positive) price changes.
|
||||||
|
3. **Average Gain (AvgU)** and **Average Loss (AvgD)**: Smoothed averages of gains and losses over the RSI period. Wilder's smoothing is a specific type of exponential moving average (EMA):
|
||||||
|
- Initial AvgU/AvgD: Simple Moving Average (SMA) over the first `period` values.
|
||||||
|
- Subsequent AvgU: `(Previous AvgU * (period - 1) + Current Gain) / period`
|
||||||
|
- Subsequent AvgD: `(Previous AvgD * (period - 1) + Current Loss) / period`
|
||||||
|
4. **Relative Strength (RS)**:
|
||||||
|
$$
|
||||||
|
RS = \\frac{\\text{AvgU}}{\\text{AvgD}}
|
||||||
|
$$
|
||||||
|
5. **RSI**:
|
||||||
|
$$
|
||||||
|
RSI = 100 - \\frac{100}{1 + RS}
|
||||||
|
$$
|
||||||
|
Special conditions:
|
||||||
|
- If AvgD is 0: RSI is 100 if AvgU > 0, or 50 if AvgU is also 0 (neutral).
|
||||||
|
|
||||||
|
### `__init__(self, config: dict)`
|
||||||
|
|
||||||
|
- **Description**: Initializes the RSI calculator.
|
||||||
|
- **Parameters**:\n - `config` (dict): Configuration dictionary. Must contain an `'rsi_period'` key with a positive integer value (e.g., `{'rsi_period': 14}`).
|
||||||
|
|
||||||
|
### `calculate(self, data_df: pd.DataFrame, price_column: str = 'close') -> pd.DataFrame`
|
||||||
|
|
||||||
|
- **Description**: Calculates the RSI (using Wilder's smoothing by default) and adds it as an 'RSI' column to the input DataFrame. This method utilizes `calculate_custom_rsi` internally with `smoothing='EMA'`.
|
||||||
|
- **Parameters**:\n - `data_df` (pd.DataFrame): DataFrame with historical price data. Must contain the `price_column`.\n - `price_column` (str, optional): The name of the column containing price data. Defaults to 'close'.
|
||||||
|
- **Returns**: `pd.DataFrame` - A copy of the input DataFrame with an added 'RSI' column. If data length is insufficient for the period, the 'RSI' column will contain `np.nan`.
|
||||||
|
|
||||||
|
### `calculate_custom_rsi(price_series: pd.Series, window: int = 14, smoothing: str = 'SMA') -> pd.Series` (Static Method)
|
||||||
|
|
||||||
|
- **Description**: Calculates RSI with a specified window and smoothing method (SMA or EMA). This is the core calculation engine.
|
||||||
|
- **Parameters**:
|
||||||
|
- `price_series` (pd.Series): Series of prices.
|
||||||
|
- `window` (int, optional): The period for RSI calculation. Defaults to 14. Must be a positive integer.
|
||||||
|
- `smoothing` (str, optional): Smoothing method, can be 'SMA' (Simple Moving Average) or 'EMA' (Exponential Moving Average, specifically Wilder's smoothing when `alpha = 1/window`). Defaults to 'SMA'.
|
||||||
|
- **Returns**: `pd.Series` - Series containing the RSI values. Returns a series of NaNs if data length is insufficient.
|
||||||
|
|
||||||
|
## Class: `BollingerBands`
|
||||||
|
|
||||||
|
Found in `cycles/Analysis/boillinger_band.py`.
|
||||||
|
|
||||||
|
Calculates Bollinger Bands.
|
||||||
|
### Mathematical Model
|
||||||
|
1. **Middle Band**: Simple Moving Average (SMA) over `period`.
|
||||||
|
$$
|
||||||
|
\\text{Middle Band} = \\text{SMA}(\\text{price}, \\text{period})
|
||||||
|
$$
|
||||||
|
2. **Standard Deviation (σ)**: Standard deviation of price over `period`.
|
||||||
|
3. **Upper Band**: Middle Band + `num_std` × σ
|
||||||
|
$$
|
||||||
|
\\text{Upper Band} = \\text{Middle Band} + \\text{num_std} \\times \\sigma_{\\text{period}}
|
||||||
|
$$
|
||||||
|
4. **Lower Band**: Middle Band − `num_std` × σ
|
||||||
|
$$
|
||||||
|
\\text{Lower Band} = \\text{Middle Band} - \\text{num_std} \\times \\sigma_{\\text{period}}
|
||||||
|
$$
|
||||||
|
For the adaptive calculation in the `calculate` method (when `squeeze=False`):
|
||||||
|
- **BBWidth**: `(Reference Upper Band - Reference Lower Band) / SMA`, where reference bands are typically calculated using a 2.0 standard deviation multiplier.
|
||||||
|
- **MarketRegime**: Determined by comparing `BBWidth` to a threshold from the configuration. `1` for sideways, `0` for trending.
|
||||||
|
- The `num_std` used for the final Upper and Lower Bands then varies based on this `MarketRegime` and the `bb_std_dev_multiplier` values for "trending" and "sideways" markets from the configuration, applied row-wise.
|
||||||
|
|
||||||
|
### `__init__(self, config: dict)`
|
||||||
|
|
||||||
|
- **Description**: Initializes the BollingerBands calculator.
|
||||||
|
- **Parameters**:\n - `config` (dict): Configuration dictionary. It must contain:
|
||||||
|
- `'bb_period'` (int): Positive integer for the moving average and standard deviation period.
|
||||||
|
- `'trending'` (dict): Containing `'bb_std_dev_multiplier'` (float, positive) for trending markets.
|
||||||
|
- `'sideways'` (dict): Containing `'bb_std_dev_multiplier'` (float, positive) for sideways markets.
|
||||||
|
- `'bb_width'` (float): Positive float threshold for determining market regime.
|
||||||
|
|
||||||
|
### `calculate(self, data_df: pd.DataFrame, price_column: str = 'close', squeeze: bool = False) -> pd.DataFrame`
|
||||||
|
|
||||||
|
- **Description**: Calculates Bollinger Bands and adds relevant columns to the DataFrame.
|
||||||
|
- If `squeeze` is `False` (default): Calculates adaptive Bollinger Bands. It determines the market regime (trending/sideways) based on `BBWidth` and applies different standard deviation multipliers (from the `config`) on a row-by-row basis. Adds 'SMA', 'UpperBand', 'LowerBand', 'BBWidth', and 'MarketRegime' columns.
|
||||||
|
- If `squeeze` is `True`: Calculates simpler Bollinger Bands with a fixed window of 14 and a standard deviation multiplier of 1.5 by calling `calculate_custom_bands`. Adds 'SMA', 'UpperBand', 'LowerBand' columns; 'BBWidth' and 'MarketRegime' will be `NaN`.
|
||||||
|
- **Parameters**:\n - `data_df` (pd.DataFrame): DataFrame with price data. Must include the `price_column`.\n - `price_column` (str, optional): The name of the column containing the price data. Defaults to 'close'.\n - `squeeze` (bool, optional): If `True`, calculates bands with fixed parameters (window 14, std 1.5). Defaults to `False`.
|
||||||
|
- **Returns**: `pd.DataFrame` - A copy of the original DataFrame with added Bollinger Band related columns.
|
||||||
|
|
||||||
|
### `calculate_custom_bands(price_series: pd.Series, window: int = 20, num_std: float = 2.0, min_periods: int = None) -> tuple[pd.Series, pd.Series, pd.Series]` (Static Method)
|
||||||
|
|
||||||
|
- **Description**: Calculates Bollinger Bands with a specified window, standard deviation multiplier, and minimum periods.
|
||||||
|
- **Parameters**:
|
||||||
|
- `price_series` (pd.Series): Series of prices.
|
||||||
|
- `window` (int, optional): The period for the moving average and standard deviation. Defaults to 20.
|
||||||
|
- `num_std` (float, optional): The number of standard deviations for the upper and lower bands. Defaults to 2.0.
|
||||||
|
- `min_periods` (int, optional): Minimum number of observations in window required to have a value. Defaults to `window` if `None`.
|
||||||
|
- **Returns**: `tuple[pd.Series, pd.Series, pd.Series]` - A tuple containing the Upper band, SMA, and Lower band series.
|
||||||
405
docs/strategies.md
Normal file
405
docs/strategies.md
Normal file
@@ -0,0 +1,405 @@
|
|||||||
|
# Strategies Documentation
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The Cycles framework implements advanced trading strategies with sophisticated timeframe management, signal processing, and multi-strategy combination capabilities. Each strategy can operate on its preferred timeframes while maintaining precise execution control.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Strategy System Components
|
||||||
|
|
||||||
|
1. **StrategyBase**: Abstract base class with timeframe management
|
||||||
|
2. **Individual Strategies**: DefaultStrategy, BBRSStrategy implementations
|
||||||
|
3. **StrategyManager**: Multi-strategy orchestration and signal combination
|
||||||
|
4. **Timeframe System**: Automatic data resampling and signal mapping
|
||||||
|
|
||||||
|
### New Timeframe Management
|
||||||
|
|
||||||
|
Each strategy now controls its own timeframe requirements:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyStrategy(StrategyBase):
|
||||||
|
def get_timeframes(self):
|
||||||
|
return ["15min", "1h"] # Strategy specifies needed timeframes
|
||||||
|
|
||||||
|
def initialize(self, backtester):
|
||||||
|
# Framework automatically resamples data
|
||||||
|
self._resample_data(backtester.original_df)
|
||||||
|
|
||||||
|
# Access resampled data
|
||||||
|
data_15m = self.get_data_for_timeframe("15min")
|
||||||
|
data_1h = self.get_data_for_timeframe("1h")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Available Strategies
|
||||||
|
|
||||||
|
### 1. Default Strategy (Meta-Trend Analysis)
|
||||||
|
|
||||||
|
**Purpose**: Meta-trend analysis using multiple Supertrend indicators
|
||||||
|
|
||||||
|
**Timeframe Behavior**:
|
||||||
|
- **Configurable Primary Timeframe**: Set via `params["timeframe"]` (default: "15min")
|
||||||
|
- **1-Minute Precision**: Always includes 1min data for precise stop-loss execution
|
||||||
|
- **Example Timeframes**: `["15min", "1min"]` or `["5min", "1min"]`
|
||||||
|
|
||||||
|
**Configuration**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "default",
|
||||||
|
"weight": 1.0,
|
||||||
|
"params": {
|
||||||
|
"timeframe": "15min", // Configurable: "5min", "15min", "1h", etc.
|
||||||
|
"stop_loss_pct": 0.03 // Stop loss percentage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Algorithm**:
|
||||||
|
1. Calculate 3 Supertrend indicators with different parameters on primary timeframe
|
||||||
|
2. Determine meta-trend: all three must agree for directional signal
|
||||||
|
3. **Entry**: Meta-trend changes from != 1 to == 1 (all trends align upward)
|
||||||
|
4. **Exit**: Meta-trend changes to -1 (trend reversal) or stop-loss triggered
|
||||||
|
5. **Stop-Loss**: 1-minute precision using percentage-based threshold
|
||||||
|
|
||||||
|
**Strengths**:
|
||||||
|
- Robust trend following with multiple confirmations
|
||||||
|
- Configurable for different market timeframes
|
||||||
|
- Precise risk management
|
||||||
|
- Low false signals in trending markets
|
||||||
|
|
||||||
|
**Best Use Cases**:
|
||||||
|
- Medium to long-term trend following
|
||||||
|
- Markets with clear directional movements
|
||||||
|
- Risk-conscious trading with defined exits
|
||||||
|
|
||||||
|
### 2. BBRS Strategy (Bollinger Bands + RSI)
|
||||||
|
|
||||||
|
**Purpose**: Market regime-adaptive strategy combining Bollinger Bands and RSI
|
||||||
|
|
||||||
|
**Timeframe Behavior**:
|
||||||
|
- **1-Minute Input**: Strategy receives 1-minute data
|
||||||
|
- **Internal Resampling**: Underlying Strategy class handles resampling to 15min/1h
|
||||||
|
- **No Double-Resampling**: Avoids conflicts with existing resampling logic
|
||||||
|
- **Signal Mapping**: Results mapped back to 1-minute resolution
|
||||||
|
|
||||||
|
**Configuration**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "bbrs",
|
||||||
|
"weight": 1.0,
|
||||||
|
"params": {
|
||||||
|
"bb_width": 0.05, // Bollinger Band width threshold
|
||||||
|
"bb_period": 20, // Bollinger Band period
|
||||||
|
"rsi_period": 14, // RSI calculation period
|
||||||
|
"trending_rsi_threshold": [30, 70], // RSI thresholds for trending market
|
||||||
|
"trending_bb_multiplier": 2.5, // BB multiplier for trending market
|
||||||
|
"sideways_rsi_threshold": [40, 60], // RSI thresholds for sideways market
|
||||||
|
"sideways_bb_multiplier": 1.8, // BB multiplier for sideways market
|
||||||
|
"strategy_name": "MarketRegimeStrategy", // Implementation variant
|
||||||
|
"SqueezeStrategy": true, // Enable squeeze detection
|
||||||
|
"stop_loss_pct": 0.05 // Stop loss percentage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Algorithm**:
|
||||||
|
|
||||||
|
**MarketRegimeStrategy** (Primary Implementation):
|
||||||
|
1. **Market Regime Detection**: Determines if market is trending or sideways
|
||||||
|
2. **Adaptive Parameters**: Adjusts BB/RSI thresholds based on market regime
|
||||||
|
3. **Trending Market Entry**: Price < Lower Band ∧ RSI < 50 ∧ Volume Spike
|
||||||
|
4. **Sideways Market Entry**: Price ≤ Lower Band ∧ RSI ≤ 40
|
||||||
|
5. **Exit Conditions**: Opposite band touch, RSI reversal, or stop-loss
|
||||||
|
6. **Volume Confirmation**: Requires 1.5× average volume for trending signals
|
||||||
|
|
||||||
|
**CryptoTradingStrategy** (Alternative Implementation):
|
||||||
|
1. **Multi-Timeframe Analysis**: Combines 15-minute and 1-hour Bollinger Bands
|
||||||
|
2. **Entry**: Price ≤ both 15m & 1h lower bands + RSI < 35 + Volume surge
|
||||||
|
3. **Exit**: 2:1 risk-reward ratio with ATR-based stops
|
||||||
|
4. **Adaptive Volatility**: Uses ATR for dynamic stop-loss/take-profit
|
||||||
|
|
||||||
|
**Strengths**:
|
||||||
|
- Adapts to different market regimes
|
||||||
|
- Multiple timeframe confirmation (internal)
|
||||||
|
- Volume analysis for signal quality
|
||||||
|
- Sophisticated entry/exit conditions
|
||||||
|
|
||||||
|
**Best Use Cases**:
|
||||||
|
- Volatile cryptocurrency markets
|
||||||
|
- Markets with alternating trending/sideways periods
|
||||||
|
- Short to medium-term trading
|
||||||
|
|
||||||
|
## Strategy Combination
|
||||||
|
|
||||||
|
### Multi-Strategy Architecture
|
||||||
|
|
||||||
|
The StrategyManager allows combining multiple strategies with configurable rules:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"strategies": [
|
||||||
|
{
|
||||||
|
"name": "default",
|
||||||
|
"weight": 0.6,
|
||||||
|
"params": {"timeframe": "15min"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "bbrs",
|
||||||
|
"weight": 0.4,
|
||||||
|
"params": {"strategy_name": "MarketRegimeStrategy"}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"combination_rules": {
|
||||||
|
"entry": "weighted_consensus",
|
||||||
|
"exit": "any",
|
||||||
|
"min_confidence": 0.6
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Signal Combination Methods
|
||||||
|
|
||||||
|
**Entry Combinations**:
|
||||||
|
- **`any`**: Enter if ANY strategy signals entry
|
||||||
|
- **`all`**: Enter only if ALL strategies signal entry
|
||||||
|
- **`majority`**: Enter if majority of strategies signal entry
|
||||||
|
- **`weighted_consensus`**: Enter based on weighted confidence average
|
||||||
|
|
||||||
|
**Exit Combinations**:
|
||||||
|
- **`any`**: Exit if ANY strategy signals exit (recommended for risk management)
|
||||||
|
- **`all`**: Exit only if ALL strategies agree
|
||||||
|
- **`priority`**: Prioritized exit (STOP_LOSS > SELL_SIGNAL > others)
|
||||||
|
|
||||||
|
## Performance Characteristics
|
||||||
|
|
||||||
|
### Default Strategy Performance
|
||||||
|
|
||||||
|
**Strengths**:
|
||||||
|
- **Trend Accuracy**: High accuracy in strong trending markets
|
||||||
|
- **Risk Management**: Defined stop-losses with 1-minute precision
|
||||||
|
- **Low Noise**: Multiple Supertrend confirmation reduces false signals
|
||||||
|
- **Adaptable**: Works across different timeframes
|
||||||
|
|
||||||
|
**Weaknesses**:
|
||||||
|
- **Sideways Markets**: May generate false signals in ranging markets
|
||||||
|
- **Lag**: Multiple confirmations can delay entry/exit signals
|
||||||
|
- **Whipsaws**: Vulnerable to rapid trend reversals
|
||||||
|
|
||||||
|
**Optimal Conditions**:
|
||||||
|
- Clear trending markets
|
||||||
|
- Medium to low volatility trending
|
||||||
|
- Sufficient data history for Supertrend calculation
|
||||||
|
|
||||||
|
### BBRS Strategy Performance
|
||||||
|
|
||||||
|
**Strengths**:
|
||||||
|
- **Market Adaptation**: Automatically adjusts to market regime
|
||||||
|
- **Volume Confirmation**: Reduces false signals with volume analysis
|
||||||
|
- **Multi-Timeframe**: Internal analysis across multiple timeframes
|
||||||
|
- **Volatility Handling**: Designed for cryptocurrency volatility
|
||||||
|
|
||||||
|
**Weaknesses**:
|
||||||
|
- **Complexity**: More parameters to optimize
|
||||||
|
- **Market Noise**: Can be sensitive to short-term noise
|
||||||
|
- **Volume Dependency**: Requires reliable volume data
|
||||||
|
|
||||||
|
**Optimal Conditions**:
|
||||||
|
- High-volume cryptocurrency markets
|
||||||
|
- Markets with clear regime shifts
|
||||||
|
- Sufficient data for regime detection
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### Single Strategy Backtests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Default strategy on 15-minute timeframe
|
||||||
|
uv run .\main.py .\configs\config_default.json
|
||||||
|
|
||||||
|
# Default strategy on 5-minute timeframe
|
||||||
|
uv run .\main.py .\configs\config_default_5min.json
|
||||||
|
|
||||||
|
# BBRS strategy with market regime detection
|
||||||
|
uv run .\main.py .\configs\config_bbrs.json
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-Strategy Backtests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Combined strategies with weighted consensus
|
||||||
|
uv run .\main.py .\configs\config_combined.json
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Configurations
|
||||||
|
|
||||||
|
**Aggressive Default Strategy**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "default",
|
||||||
|
"params": {
|
||||||
|
"timeframe": "5min", // Faster signals
|
||||||
|
"stop_loss_pct": 0.02 // Tighter stop-loss
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Conservative BBRS Strategy**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "bbrs",
|
||||||
|
"params": {
|
||||||
|
"bb_width": 0.03, // Tighter BB width
|
||||||
|
"stop_loss_pct": 0.07, // Wider stop-loss
|
||||||
|
"SqueezeStrategy": false // Disable squeeze for simplicity
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Development Guidelines
|
||||||
|
|
||||||
|
### Creating New Strategies
|
||||||
|
|
||||||
|
1. **Inherit from StrategyBase**:
|
||||||
|
```python
|
||||||
|
from cycles.strategies.base import StrategyBase, StrategySignal
|
||||||
|
|
||||||
|
class NewStrategy(StrategyBase):
|
||||||
|
def __init__(self, weight=1.0, params=None):
|
||||||
|
super().__init__("new_strategy", weight, params)
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Specify Timeframes**:
|
||||||
|
```python
|
||||||
|
def get_timeframes(self):
|
||||||
|
return ["1h"] # Specify required timeframes
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Implement Core Methods**:
|
||||||
|
```python
|
||||||
|
def initialize(self, backtester):
|
||||||
|
self._resample_data(backtester.original_df)
|
||||||
|
# Calculate indicators...
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def get_entry_signal(self, backtester, df_index):
|
||||||
|
# Entry logic...
|
||||||
|
return StrategySignal("ENTRY", confidence=0.8)
|
||||||
|
|
||||||
|
def get_exit_signal(self, backtester, df_index):
|
||||||
|
# Exit logic...
|
||||||
|
return StrategySignal("EXIT", confidence=1.0)
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Register Strategy**:
|
||||||
|
```python
|
||||||
|
# In StrategyManager._load_strategies()
|
||||||
|
elif name == "new_strategy":
|
||||||
|
strategies.append(NewStrategy(weight, params))
|
||||||
|
```
|
||||||
|
|
||||||
|
### Timeframe Best Practices
|
||||||
|
|
||||||
|
1. **Minimize Timeframe Requirements**:
|
||||||
|
```python
|
||||||
|
def get_timeframes(self):
|
||||||
|
return ["15min"] # Only what's needed
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Include 1min for Stop-Loss**:
|
||||||
|
```python
|
||||||
|
def get_timeframes(self):
|
||||||
|
primary_tf = self.params.get("timeframe", "15min")
|
||||||
|
timeframes = [primary_tf]
|
||||||
|
if "1min" not in timeframes:
|
||||||
|
timeframes.append("1min")
|
||||||
|
return timeframes
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Handle Multi-Timeframe Synchronization**:
|
||||||
|
```python
|
||||||
|
def get_entry_signal(self, backtester, df_index):
|
||||||
|
# Get current timestamp from primary timeframe
|
||||||
|
primary_data = self.get_primary_timeframe_data()
|
||||||
|
current_time = primary_data.index[df_index]
|
||||||
|
|
||||||
|
# Map to other timeframes
|
||||||
|
hourly_data = self.get_data_for_timeframe("1h")
|
||||||
|
h1_idx = hourly_data.index.get_indexer([current_time], method='ffill')[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing and Validation
|
||||||
|
|
||||||
|
### Strategy Testing Workflow
|
||||||
|
|
||||||
|
1. **Individual Strategy Testing**:
|
||||||
|
- Test each strategy independently
|
||||||
|
- Validate on different timeframes
|
||||||
|
- Check edge cases and data sufficiency
|
||||||
|
|
||||||
|
2. **Multi-Strategy Testing**:
|
||||||
|
- Test strategy combinations
|
||||||
|
- Validate combination rules
|
||||||
|
- Monitor for signal conflicts
|
||||||
|
|
||||||
|
3. **Timeframe Validation**:
|
||||||
|
- Ensure consistent behavior across timeframes
|
||||||
|
- Validate data alignment
|
||||||
|
- Check memory usage with large datasets
|
||||||
|
|
||||||
|
### Performance Monitoring
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Get strategy summary
|
||||||
|
summary = strategy_manager.get_strategy_summary()
|
||||||
|
print(f"Strategies: {[s['name'] for s in summary['strategies']]}")
|
||||||
|
print(f"Timeframes: {summary['all_timeframes']}")
|
||||||
|
|
||||||
|
# Monitor individual strategy performance
|
||||||
|
for strategy in strategy_manager.strategies:
|
||||||
|
print(f"{strategy.name}: {strategy.get_timeframes()}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced Topics
|
||||||
|
|
||||||
|
### Multi-Timeframe Strategy Development
|
||||||
|
|
||||||
|
For strategies requiring multiple timeframes:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MultiTimeframeStrategy(StrategyBase):
|
||||||
|
def get_timeframes(self):
|
||||||
|
return ["5min", "15min", "1h"]
|
||||||
|
|
||||||
|
def get_entry_signal(self, backtester, df_index):
|
||||||
|
# Analyze multiple timeframes
|
||||||
|
data_5m = self.get_data_for_timeframe("5min")
|
||||||
|
data_15m = self.get_data_for_timeframe("15min")
|
||||||
|
data_1h = self.get_data_for_timeframe("1h")
|
||||||
|
|
||||||
|
# Synchronize across timeframes
|
||||||
|
current_time = data_5m.index[df_index]
|
||||||
|
idx_15m = data_15m.index.get_indexer([current_time], method='ffill')[0]
|
||||||
|
idx_1h = data_1h.index.get_indexer([current_time], method='ffill')[0]
|
||||||
|
|
||||||
|
# Multi-timeframe logic
|
||||||
|
short_signal = self._analyze_5min(data_5m, df_index)
|
||||||
|
medium_signal = self._analyze_15min(data_15m, idx_15m)
|
||||||
|
long_signal = self._analyze_1h(data_1h, idx_1h)
|
||||||
|
|
||||||
|
# Combine signals with appropriate confidence
|
||||||
|
if short_signal and medium_signal and long_signal:
|
||||||
|
return StrategySignal("ENTRY", confidence=0.9)
|
||||||
|
elif short_signal and medium_signal:
|
||||||
|
return StrategySignal("ENTRY", confidence=0.7)
|
||||||
|
else:
|
||||||
|
return StrategySignal("HOLD", confidence=0.0)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Strategy Optimization
|
||||||
|
|
||||||
|
1. **Parameter Optimization**: Systematic testing of strategy parameters
|
||||||
|
2. **Timeframe Optimization**: Finding optimal timeframes for each strategy
|
||||||
|
3. **Combination Optimization**: Optimizing weights and combination rules
|
||||||
|
4. **Market Regime Adaptation**: Adapting strategies to different market conditions
|
||||||
|
|
||||||
|
For detailed timeframe system documentation, see [Timeframe System](./timeframe_system.md).
|
||||||
390
docs/strategy_manager.md
Normal file
390
docs/strategy_manager.md
Normal file
@@ -0,0 +1,390 @@
|
|||||||
|
# Strategy Manager Documentation
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The Strategy Manager is a sophisticated orchestration system that enables the combination of multiple trading strategies with configurable signal aggregation rules. It supports multi-timeframe analysis, weighted consensus voting, and flexible signal combination methods.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Core Components
|
||||||
|
|
||||||
|
1. **StrategyBase**: Abstract base class defining the strategy interface
|
||||||
|
2. **StrategySignal**: Encapsulates trading signals with confidence levels
|
||||||
|
3. **StrategyManager**: Orchestrates multiple strategies and combines signals
|
||||||
|
4. **Strategy Implementations**: DefaultStrategy, BBRSStrategy, etc.
|
||||||
|
|
||||||
|
### New Timeframe System
|
||||||
|
|
||||||
|
The framework now supports strategy-level timeframe management:
|
||||||
|
|
||||||
|
- **Strategy-Controlled Timeframes**: Each strategy specifies its required timeframes
|
||||||
|
- **Automatic Data Resampling**: Framework automatically resamples 1-minute data to strategy needs
|
||||||
|
- **Multi-Timeframe Support**: Strategies can use multiple timeframes simultaneously
|
||||||
|
- **Precision Stop-Loss**: All strategies maintain 1-minute data for precise execution
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyStrategy(StrategyBase):
|
||||||
|
def get_timeframes(self):
|
||||||
|
return ["15min", "1h"] # Strategy needs both timeframes
|
||||||
|
|
||||||
|
def initialize(self, backtester):
|
||||||
|
# Access resampled data
|
||||||
|
data_15m = self.get_data_for_timeframe("15min")
|
||||||
|
data_1h = self.get_data_for_timeframe("1h")
|
||||||
|
# Setup indicators...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Strategy Interface
|
||||||
|
|
||||||
|
### StrategyBase Class
|
||||||
|
|
||||||
|
All strategies must inherit from `StrategyBase` and implement:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from cycles.strategies.base import StrategyBase, StrategySignal
|
||||||
|
|
||||||
|
class MyStrategy(StrategyBase):
|
||||||
|
def get_timeframes(self) -> List[str]:
|
||||||
|
"""Specify required timeframes"""
|
||||||
|
return ["15min"]
|
||||||
|
|
||||||
|
def initialize(self, backtester) -> None:
|
||||||
|
"""Setup strategy with data"""
|
||||||
|
self._resample_data(backtester.original_df)
|
||||||
|
# Calculate indicators...
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def get_entry_signal(self, backtester, df_index: int) -> StrategySignal:
|
||||||
|
"""Generate entry signals"""
|
||||||
|
if condition_met:
|
||||||
|
return StrategySignal("ENTRY", confidence=0.8)
|
||||||
|
return StrategySignal("HOLD", confidence=0.0)
|
||||||
|
|
||||||
|
def get_exit_signal(self, backtester, df_index: int) -> StrategySignal:
|
||||||
|
"""Generate exit signals"""
|
||||||
|
if exit_condition:
|
||||||
|
return StrategySignal("EXIT", confidence=1.0,
|
||||||
|
metadata={"type": "SELL_SIGNAL"})
|
||||||
|
return StrategySignal("HOLD", confidence=0.0)
|
||||||
|
```
|
||||||
|
|
||||||
|
### StrategySignal Class
|
||||||
|
|
||||||
|
Encapsulates trading signals with metadata:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Entry signal with high confidence
|
||||||
|
entry_signal = StrategySignal("ENTRY", confidence=0.9)
|
||||||
|
|
||||||
|
# Exit signal with specific price
|
||||||
|
exit_signal = StrategySignal("EXIT", confidence=1.0, price=50000,
|
||||||
|
metadata={"type": "STOP_LOSS"})
|
||||||
|
|
||||||
|
# Hold signal
|
||||||
|
hold_signal = StrategySignal("HOLD", confidence=0.0)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Available Strategies
|
||||||
|
|
||||||
|
### 1. Default Strategy
|
||||||
|
|
||||||
|
Meta-trend analysis using multiple Supertrend indicators.
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- Uses 3 Supertrend indicators with different parameters
|
||||||
|
- Configurable timeframe (default: 15min)
|
||||||
|
- Entry when all trends align upward
|
||||||
|
- Exit on trend reversal or stop-loss
|
||||||
|
|
||||||
|
**Configuration:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "default",
|
||||||
|
"weight": 1.0,
|
||||||
|
"params": {
|
||||||
|
"timeframe": "15min",
|
||||||
|
"stop_loss_pct": 0.03
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Timeframes:**
|
||||||
|
- Primary: Configurable (default 15min)
|
||||||
|
- Stop-loss: Always includes 1min for precision
|
||||||
|
|
||||||
|
### 2. BBRS Strategy
|
||||||
|
|
||||||
|
Bollinger Bands + RSI with market regime detection.
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- Market regime detection (trending vs sideways)
|
||||||
|
- Adaptive parameters based on market conditions
|
||||||
|
- Volume analysis and confirmation
|
||||||
|
- Multi-timeframe internal analysis (1min → 15min/1h)
|
||||||
|
|
||||||
|
**Configuration:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "bbrs",
|
||||||
|
"weight": 1.0,
|
||||||
|
"params": {
|
||||||
|
"bb_width": 0.05,
|
||||||
|
"bb_period": 20,
|
||||||
|
"rsi_period": 14,
|
||||||
|
"strategy_name": "MarketRegimeStrategy",
|
||||||
|
"stop_loss_pct": 0.05
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Timeframes:**
|
||||||
|
- Input: 1min (Strategy class handles internal resampling)
|
||||||
|
- Internal: 15min, 1h (handled by underlying Strategy class)
|
||||||
|
- Output: Mapped back to 1min for backtesting
|
||||||
|
|
||||||
|
## Signal Combination
|
||||||
|
|
||||||
|
### Entry Signal Combination
|
||||||
|
|
||||||
|
```python
|
||||||
|
combination_rules = {
|
||||||
|
"entry": "weighted_consensus", # or "any", "all", "majority"
|
||||||
|
"min_confidence": 0.6
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Methods:**
|
||||||
|
- **`any`**: Enter if ANY strategy signals entry
|
||||||
|
- **`all`**: Enter only if ALL strategies signal entry
|
||||||
|
- **`majority`**: Enter if majority of strategies signal entry
|
||||||
|
- **`weighted_consensus`**: Enter based on weighted average confidence
|
||||||
|
|
||||||
|
### Exit Signal Combination
|
||||||
|
|
||||||
|
```python
|
||||||
|
combination_rules = {
|
||||||
|
"exit": "priority" # or "any", "all"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Methods:**
|
||||||
|
- **`any`**: Exit if ANY strategy signals exit (recommended for risk management)
|
||||||
|
- **`all`**: Exit only if ALL strategies agree
|
||||||
|
- **`priority`**: Prioritized exit (STOP_LOSS > SELL_SIGNAL > others)
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Basic Strategy Manager Setup
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"strategies": [
|
||||||
|
{
|
||||||
|
"name": "default",
|
||||||
|
"weight": 0.6,
|
||||||
|
"params": {
|
||||||
|
"timeframe": "15min",
|
||||||
|
"stop_loss_pct": 0.03
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "bbrs",
|
||||||
|
"weight": 0.4,
|
||||||
|
"params": {
|
||||||
|
"bb_width": 0.05,
|
||||||
|
"strategy_name": "MarketRegimeStrategy"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"combination_rules": {
|
||||||
|
"entry": "weighted_consensus",
|
||||||
|
"exit": "any",
|
||||||
|
"min_confidence": 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Timeframe Examples
|
||||||
|
|
||||||
|
**Single Timeframe Strategy:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "default",
|
||||||
|
"params": {
|
||||||
|
"timeframe": "5min" # Strategy works on 5-minute data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Multi-Timeframe Strategy (Future Enhancement):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "multi_tf_strategy",
|
||||||
|
"params": {
|
||||||
|
"timeframes": ["5min", "15min", "1h"], # Multiple timeframes
|
||||||
|
"primary_timeframe": "15min"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### Create Strategy Manager
|
||||||
|
|
||||||
|
```python
|
||||||
|
from cycles.strategies import create_strategy_manager
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"strategies": [
|
||||||
|
{"name": "default", "weight": 1.0, "params": {"timeframe": "15min"}}
|
||||||
|
],
|
||||||
|
"combination_rules": {
|
||||||
|
"entry": "any",
|
||||||
|
"exit": "any"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
strategy_manager = create_strategy_manager(config)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Initialize and Use
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Initialize with backtester
|
||||||
|
strategy_manager.initialize(backtester)
|
||||||
|
|
||||||
|
# Get signals during backtesting
|
||||||
|
entry_signal = strategy_manager.get_entry_signal(backtester, df_index)
|
||||||
|
exit_signal, exit_price = strategy_manager.get_exit_signal(backtester, df_index)
|
||||||
|
|
||||||
|
# Get strategy summary
|
||||||
|
summary = strategy_manager.get_strategy_summary()
|
||||||
|
print(f"Loaded strategies: {[s['name'] for s in summary['strategies']]}")
|
||||||
|
print(f"All timeframes: {summary['all_timeframes']}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Extending the System
|
||||||
|
|
||||||
|
### Adding New Strategies
|
||||||
|
|
||||||
|
1. **Create Strategy Class:**
|
||||||
|
```python
|
||||||
|
class NewStrategy(StrategyBase):
|
||||||
|
def get_timeframes(self):
|
||||||
|
return ["1h"] # Specify required timeframes
|
||||||
|
|
||||||
|
def initialize(self, backtester):
|
||||||
|
self._resample_data(backtester.original_df)
|
||||||
|
# Setup indicators...
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def get_entry_signal(self, backtester, df_index):
|
||||||
|
# Implement entry logic
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_exit_signal(self, backtester, df_index):
|
||||||
|
# Implement exit logic
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Register in StrategyManager:**
|
||||||
|
```python
|
||||||
|
# In StrategyManager._load_strategies()
|
||||||
|
elif name == "new_strategy":
|
||||||
|
strategies.append(NewStrategy(weight, params))
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-Timeframe Strategy Development
|
||||||
|
|
||||||
|
For strategies requiring multiple timeframes:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MultiTimeframeStrategy(StrategyBase):
|
||||||
|
def get_timeframes(self):
|
||||||
|
return ["5min", "15min", "1h"]
|
||||||
|
|
||||||
|
def initialize(self, backtester):
|
||||||
|
self._resample_data(backtester.original_df)
|
||||||
|
|
||||||
|
# Access different timeframes
|
||||||
|
data_5m = self.get_data_for_timeframe("5min")
|
||||||
|
data_15m = self.get_data_for_timeframe("15min")
|
||||||
|
data_1h = self.get_data_for_timeframe("1h")
|
||||||
|
|
||||||
|
# Calculate indicators on each timeframe
|
||||||
|
# ...
|
||||||
|
|
||||||
|
def _calculate_signal_confidence(self, backtester, df_index):
|
||||||
|
# Analyze multiple timeframes for confidence
|
||||||
|
primary_signal = self._get_primary_signal(df_index)
|
||||||
|
confirmation = self._get_timeframe_confirmation(df_index)
|
||||||
|
|
||||||
|
return primary_signal * confirmation
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
### Timeframe Management
|
||||||
|
|
||||||
|
- **Efficient Resampling**: Each strategy resamples data once during initialization
|
||||||
|
- **Memory Usage**: Only required timeframes are kept in memory
|
||||||
|
- **Signal Mapping**: Efficient mapping between timeframes using pandas reindex
|
||||||
|
|
||||||
|
### Strategy Combination
|
||||||
|
|
||||||
|
- **Lazy Evaluation**: Signals calculated only when needed
|
||||||
|
- **Error Handling**: Individual strategy failures don't crash the system
|
||||||
|
- **Logging**: Comprehensive logging for debugging and monitoring
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Strategy Design:**
|
||||||
|
- Specify minimal required timeframes
|
||||||
|
- Include 1min for stop-loss precision
|
||||||
|
- Use confidence levels effectively
|
||||||
|
|
||||||
|
2. **Signal Combination:**
|
||||||
|
- Use `any` for exits (risk management)
|
||||||
|
- Use `weighted_consensus` for entries
|
||||||
|
- Set appropriate minimum confidence levels
|
||||||
|
|
||||||
|
3. **Error Handling:**
|
||||||
|
- Implement robust initialization checks
|
||||||
|
- Handle missing data gracefully
|
||||||
|
- Log strategy-specific warnings
|
||||||
|
|
||||||
|
4. **Testing:**
|
||||||
|
- Test strategies individually before combining
|
||||||
|
- Validate timeframe requirements
|
||||||
|
- Monitor memory usage with large datasets
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **Timeframe Mismatches:**
|
||||||
|
- Ensure strategy specifies correct timeframes
|
||||||
|
- Check data availability for all timeframes
|
||||||
|
|
||||||
|
2. **Signal Conflicts:**
|
||||||
|
- Review combination rules
|
||||||
|
- Adjust confidence thresholds
|
||||||
|
- Monitor strategy weights
|
||||||
|
|
||||||
|
3. **Performance Issues:**
|
||||||
|
- Minimize timeframe requirements
|
||||||
|
- Optimize indicator calculations
|
||||||
|
- Use efficient pandas operations
|
||||||
|
|
||||||
|
### Debugging Tips
|
||||||
|
|
||||||
|
- Enable detailed logging: `logging.basicConfig(level=logging.DEBUG)`
|
||||||
|
- Use strategy summary: `manager.get_strategy_summary()`
|
||||||
|
- Test individual strategies before combining
|
||||||
|
- Monitor signal confidence levels
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Version**: 1.0.0
|
||||||
|
**Last Updated**: January 2025
|
||||||
|
**TCP Cycles Project**
|
||||||
488
docs/timeframe_system.md
Normal file
488
docs/timeframe_system.md
Normal file
@@ -0,0 +1,488 @@
|
|||||||
|
# Timeframe System Documentation
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The Cycles framework features a sophisticated timeframe management system that allows strategies to operate on their preferred timeframes while maintaining precise execution control. This system supports both single-timeframe and multi-timeframe strategies with automatic data resampling and intelligent signal mapping.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Core Concepts
|
||||||
|
|
||||||
|
1. **Strategy-Controlled Timeframes**: Each strategy specifies its required timeframes
|
||||||
|
2. **Automatic Resampling**: Framework resamples 1-minute data to strategy needs
|
||||||
|
3. **Precision Execution**: All strategies maintain 1-minute data for accurate stop-loss execution
|
||||||
|
4. **Signal Mapping**: Intelligent mapping between different timeframe resolutions
|
||||||
|
|
||||||
|
### Data Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
Original 1min Data
|
||||||
|
↓
|
||||||
|
Strategy.get_timeframes() → ["15min", "1h"]
|
||||||
|
↓
|
||||||
|
Automatic Resampling
|
||||||
|
↓
|
||||||
|
Strategy Logic (15min + 1h analysis)
|
||||||
|
↓
|
||||||
|
Signal Generation
|
||||||
|
↓
|
||||||
|
Map to Working Timeframe
|
||||||
|
↓
|
||||||
|
Backtesting Engine
|
||||||
|
```
|
||||||
|
|
||||||
|
## Strategy Timeframe Interface
|
||||||
|
|
||||||
|
### StrategyBase Methods
|
||||||
|
|
||||||
|
All strategies inherit timeframe capabilities from `StrategyBase`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyStrategy(StrategyBase):
|
||||||
|
def get_timeframes(self) -> List[str]:
|
||||||
|
"""Specify required timeframes for this strategy"""
|
||||||
|
return ["15min", "1h"] # Strategy needs both timeframes
|
||||||
|
|
||||||
|
def initialize(self, backtester) -> None:
|
||||||
|
# Automatic resampling happens here
|
||||||
|
self._resample_data(backtester.original_df)
|
||||||
|
|
||||||
|
# Access resampled data
|
||||||
|
data_15m = self.get_data_for_timeframe("15min")
|
||||||
|
data_1h = self.get_data_for_timeframe("1h")
|
||||||
|
|
||||||
|
# Calculate indicators on each timeframe
|
||||||
|
self.indicators_15m = self._calculate_indicators(data_15m)
|
||||||
|
self.indicators_1h = self._calculate_indicators(data_1h)
|
||||||
|
|
||||||
|
self.initialized = True
|
||||||
|
```
|
||||||
|
|
||||||
|
### Data Access Methods
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Get data for specific timeframe
|
||||||
|
data_15m = strategy.get_data_for_timeframe("15min")
|
||||||
|
|
||||||
|
# Get primary timeframe data (first in list)
|
||||||
|
primary_data = strategy.get_primary_timeframe_data()
|
||||||
|
|
||||||
|
# Check available timeframes
|
||||||
|
timeframes = strategy.get_timeframes()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Timeframes
|
||||||
|
|
||||||
|
### Standard Timeframes
|
||||||
|
|
||||||
|
- **`"1min"`**: 1-minute bars (original resolution)
|
||||||
|
- **`"5min"`**: 5-minute bars
|
||||||
|
- **`"15min"`**: 15-minute bars
|
||||||
|
- **`"30min"`**: 30-minute bars
|
||||||
|
- **`"1h"`**: 1-hour bars
|
||||||
|
- **`"4h"`**: 4-hour bars
|
||||||
|
- **`"1d"`**: Daily bars
|
||||||
|
|
||||||
|
### Custom Timeframes
|
||||||
|
|
||||||
|
Any pandas-compatible frequency string is supported:
|
||||||
|
- **`"2min"`**: 2-minute bars
|
||||||
|
- **`"10min"`**: 10-minute bars
|
||||||
|
- **`"2h"`**: 2-hour bars
|
||||||
|
- **`"12h"`**: 12-hour bars
|
||||||
|
|
||||||
|
## Strategy Examples
|
||||||
|
|
||||||
|
### Single Timeframe Strategy
|
||||||
|
|
||||||
|
```python
|
||||||
|
class SingleTimeframeStrategy(StrategyBase):
|
||||||
|
def get_timeframes(self):
|
||||||
|
return ["15min"] # Only needs 15-minute data
|
||||||
|
|
||||||
|
def initialize(self, backtester):
|
||||||
|
self._resample_data(backtester.original_df)
|
||||||
|
|
||||||
|
# Work with 15-minute data
|
||||||
|
data = self.get_primary_timeframe_data()
|
||||||
|
self.indicators = self._calculate_indicators(data)
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def get_entry_signal(self, backtester, df_index):
|
||||||
|
# df_index refers to 15-minute data
|
||||||
|
if self.indicators['signal'][df_index]:
|
||||||
|
return StrategySignal("ENTRY", confidence=0.8)
|
||||||
|
return StrategySignal("HOLD", confidence=0.0)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-Timeframe Strategy
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MultiTimeframeStrategy(StrategyBase):
|
||||||
|
def get_timeframes(self):
|
||||||
|
return ["15min", "1h", "4h"] # Multiple timeframes
|
||||||
|
|
||||||
|
def initialize(self, backtester):
|
||||||
|
self._resample_data(backtester.original_df)
|
||||||
|
|
||||||
|
# Access different timeframes
|
||||||
|
self.data_15m = self.get_data_for_timeframe("15min")
|
||||||
|
self.data_1h = self.get_data_for_timeframe("1h")
|
||||||
|
self.data_4h = self.get_data_for_timeframe("4h")
|
||||||
|
|
||||||
|
# Calculate indicators on each timeframe
|
||||||
|
self.trend_4h = self._calculate_trend(self.data_4h)
|
||||||
|
self.momentum_1h = self._calculate_momentum(self.data_1h)
|
||||||
|
self.entry_signals_15m = self._calculate_entries(self.data_15m)
|
||||||
|
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def get_entry_signal(self, backtester, df_index):
|
||||||
|
# Primary timeframe is 15min (first in list)
|
||||||
|
# Map df_index to other timeframes for confirmation
|
||||||
|
|
||||||
|
# Get current 15min timestamp
|
||||||
|
current_time = self.data_15m.index[df_index]
|
||||||
|
|
||||||
|
# Find corresponding indices in other timeframes
|
||||||
|
h1_idx = self.data_1h.index.get_indexer([current_time], method='ffill')[0]
|
||||||
|
h4_idx = self.data_4h.index.get_indexer([current_time], method='ffill')[0]
|
||||||
|
|
||||||
|
# Multi-timeframe confirmation
|
||||||
|
trend_ok = self.trend_4h[h4_idx] > 0
|
||||||
|
momentum_ok = self.momentum_1h[h1_idx] > 0.5
|
||||||
|
entry_signal = self.entry_signals_15m[df_index]
|
||||||
|
|
||||||
|
if trend_ok and momentum_ok and entry_signal:
|
||||||
|
confidence = 0.9 # High confidence with all timeframes aligned
|
||||||
|
return StrategySignal("ENTRY", confidence=confidence)
|
||||||
|
|
||||||
|
return StrategySignal("HOLD", confidence=0.0)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configurable Timeframe Strategy
|
||||||
|
|
||||||
|
```python
|
||||||
|
class ConfigurableStrategy(StrategyBase):
|
||||||
|
def get_timeframes(self):
|
||||||
|
# Strategy timeframe configurable via parameters
|
||||||
|
primary_tf = self.params.get("timeframe", "15min")
|
||||||
|
return [primary_tf, "1min"] # Primary + 1min for stop-loss
|
||||||
|
|
||||||
|
def initialize(self, backtester):
|
||||||
|
self._resample_data(backtester.original_df)
|
||||||
|
|
||||||
|
primary_tf = self.get_timeframes()[0]
|
||||||
|
self.data = self.get_data_for_timeframe(primary_tf)
|
||||||
|
|
||||||
|
# Indicator parameters can also be timeframe-dependent
|
||||||
|
if primary_tf == "5min":
|
||||||
|
self.ma_period = 20
|
||||||
|
elif primary_tf == "15min":
|
||||||
|
self.ma_period = 14
|
||||||
|
else:
|
||||||
|
self.ma_period = 10
|
||||||
|
|
||||||
|
self.indicators = self._calculate_indicators(self.data)
|
||||||
|
self.initialized = True
|
||||||
|
```
|
||||||
|
|
||||||
|
## Built-in Strategy Timeframe Behavior
|
||||||
|
|
||||||
|
### Default Strategy
|
||||||
|
|
||||||
|
**Timeframes**: Configurable primary + 1min for stop-loss
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Configuration
|
||||||
|
{
|
||||||
|
"name": "default",
|
||||||
|
"params": {
|
||||||
|
"timeframe": "5min" # Configurable timeframe
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Resulting timeframes: ["5min", "1min"]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Features**:
|
||||||
|
- Supertrend analysis on configured timeframe
|
||||||
|
- 1-minute precision for stop-loss execution
|
||||||
|
- Optimized for 15-minute default, but works on any timeframe
|
||||||
|
|
||||||
|
### BBRS Strategy
|
||||||
|
|
||||||
|
**Timeframes**: 1min input (internal resampling)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Configuration
|
||||||
|
{
|
||||||
|
"name": "bbrs",
|
||||||
|
"params": {
|
||||||
|
"strategy_name": "MarketRegimeStrategy"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Resulting timeframes: ["1min"]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Features**:
|
||||||
|
- Uses 1-minute data as input
|
||||||
|
- Internal resampling to 15min/1h by Strategy class
|
||||||
|
- Signals mapped back to 1-minute resolution
|
||||||
|
- No double-resampling issues
|
||||||
|
|
||||||
|
## Advanced Features
|
||||||
|
|
||||||
|
### Timeframe Synchronization
|
||||||
|
|
||||||
|
When working with multiple timeframes, synchronization is crucial:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _get_synchronized_signals(self, df_index, primary_timeframe="15min"):
|
||||||
|
"""Get signals synchronized across timeframes"""
|
||||||
|
|
||||||
|
# Get timestamp from primary timeframe
|
||||||
|
primary_data = self.get_data_for_timeframe(primary_timeframe)
|
||||||
|
current_time = primary_data.index[df_index]
|
||||||
|
|
||||||
|
signals = {}
|
||||||
|
for tf in self.get_timeframes():
|
||||||
|
if tf == primary_timeframe:
|
||||||
|
signals[tf] = df_index
|
||||||
|
else:
|
||||||
|
# Find corresponding index in other timeframe
|
||||||
|
tf_data = self.get_data_for_timeframe(tf)
|
||||||
|
tf_idx = tf_data.index.get_indexer([current_time], method='ffill')[0]
|
||||||
|
signals[tf] = tf_idx
|
||||||
|
|
||||||
|
return signals
|
||||||
|
```
|
||||||
|
|
||||||
|
### Dynamic Timeframe Selection
|
||||||
|
|
||||||
|
Strategies can adapt timeframes based on market conditions:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class AdaptiveStrategy(StrategyBase):
|
||||||
|
def get_timeframes(self):
|
||||||
|
# Fixed set of timeframes strategy might need
|
||||||
|
return ["5min", "15min", "1h"]
|
||||||
|
|
||||||
|
def _select_active_timeframe(self, market_volatility):
|
||||||
|
"""Select timeframe based on market conditions"""
|
||||||
|
if market_volatility > 0.8:
|
||||||
|
return "5min" # High volatility -> shorter timeframe
|
||||||
|
elif market_volatility > 0.4:
|
||||||
|
return "15min" # Medium volatility -> medium timeframe
|
||||||
|
else:
|
||||||
|
return "1h" # Low volatility -> longer timeframe
|
||||||
|
|
||||||
|
def get_entry_signal(self, backtester, df_index):
|
||||||
|
# Calculate market volatility
|
||||||
|
volatility = self._calculate_volatility(df_index)
|
||||||
|
|
||||||
|
# Select appropriate timeframe
|
||||||
|
active_tf = self._select_active_timeframe(volatility)
|
||||||
|
|
||||||
|
# Generate signal on selected timeframe
|
||||||
|
return self._generate_signal_for_timeframe(active_tf, df_index)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Examples
|
||||||
|
|
||||||
|
### Single Timeframe Configuration
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"strategies": [
|
||||||
|
{
|
||||||
|
"name": "default",
|
||||||
|
"weight": 1.0,
|
||||||
|
"params": {
|
||||||
|
"timeframe": "15min",
|
||||||
|
"stop_loss_pct": 0.03
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-Timeframe Configuration
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"strategies": [
|
||||||
|
{
|
||||||
|
"name": "multi_timeframe_strategy",
|
||||||
|
"weight": 1.0,
|
||||||
|
"params": {
|
||||||
|
"primary_timeframe": "15min",
|
||||||
|
"confirmation_timeframes": ["1h", "4h"],
|
||||||
|
"signal_timeframe": "5min"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mixed Strategy Configuration
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"strategies": [
|
||||||
|
{
|
||||||
|
"name": "default",
|
||||||
|
"weight": 0.6,
|
||||||
|
"params": {
|
||||||
|
"timeframe": "15min"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "bbrs",
|
||||||
|
"weight": 0.4,
|
||||||
|
"params": {
|
||||||
|
"strategy_name": "MarketRegimeStrategy"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
### Memory Usage
|
||||||
|
|
||||||
|
- Only required timeframes are resampled and stored
|
||||||
|
- Original 1-minute data shared across all strategies
|
||||||
|
- Efficient pandas resampling with minimal memory overhead
|
||||||
|
|
||||||
|
### Processing Speed
|
||||||
|
|
||||||
|
- Resampling happens once during initialization
|
||||||
|
- No repeated resampling during backtesting
|
||||||
|
- Vectorized operations on pre-computed timeframes
|
||||||
|
|
||||||
|
### Data Alignment
|
||||||
|
|
||||||
|
- All timeframes aligned to original 1-minute timestamps
|
||||||
|
- Forward-fill resampling ensures data availability
|
||||||
|
- Intelligent handling of missing data points
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
### 1. Minimize Timeframe Requirements
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Good - minimal timeframes
|
||||||
|
def get_timeframes(self):
|
||||||
|
return ["15min"]
|
||||||
|
|
||||||
|
# Less optimal - unnecessary timeframes
|
||||||
|
def get_timeframes(self):
|
||||||
|
return ["1min", "5min", "15min", "1h", "4h", "1d"]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Use Appropriate Timeframes for Strategy Logic
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Good - timeframe matches strategy logic
|
||||||
|
class TrendStrategy(StrategyBase):
|
||||||
|
def get_timeframes(self):
|
||||||
|
return ["1h"] # Trend analysis works well on hourly data
|
||||||
|
|
||||||
|
class ScalpingStrategy(StrategyBase):
|
||||||
|
def get_timeframes(self):
|
||||||
|
return ["1min", "5min"] # Scalping needs fine-grained data
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Include 1min for Stop-Loss Precision
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_timeframes(self):
|
||||||
|
primary_tf = self.params.get("timeframe", "15min")
|
||||||
|
timeframes = [primary_tf]
|
||||||
|
|
||||||
|
# Always include 1min for precise stop-loss
|
||||||
|
if "1min" not in timeframes:
|
||||||
|
timeframes.append("1min")
|
||||||
|
|
||||||
|
return timeframes
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Handle Timeframe Edge Cases
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_entry_signal(self, backtester, df_index):
|
||||||
|
# Check bounds for all timeframes
|
||||||
|
if df_index >= len(self.get_primary_timeframe_data()):
|
||||||
|
return StrategySignal("HOLD", confidence=0.0)
|
||||||
|
|
||||||
|
# Robust timeframe indexing
|
||||||
|
try:
|
||||||
|
signal = self._calculate_signal(df_index)
|
||||||
|
return signal
|
||||||
|
except IndexError:
|
||||||
|
return StrategySignal("HOLD", confidence=0.0)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **Index Out of Bounds**
|
||||||
|
```python
|
||||||
|
# Problem: Different timeframes have different lengths
|
||||||
|
# Solution: Always check bounds
|
||||||
|
if df_index < len(self.data_1h):
|
||||||
|
signal = self.data_1h[df_index]
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Timeframe Misalignment**
|
||||||
|
```python
|
||||||
|
# Problem: Assuming same index across timeframes
|
||||||
|
# Solution: Use timestamp-based alignment
|
||||||
|
current_time = primary_data.index[df_index]
|
||||||
|
h1_idx = hourly_data.index.get_indexer([current_time], method='ffill')[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Memory Issues with Large Datasets**
|
||||||
|
```python
|
||||||
|
# Solution: Only include necessary timeframes
|
||||||
|
def get_timeframes(self):
|
||||||
|
# Return minimal set
|
||||||
|
return ["15min"] # Not ["1min", "5min", "15min", "1h"]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Debugging Tips
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Log timeframe information
|
||||||
|
def initialize(self, backtester):
|
||||||
|
self._resample_data(backtester.original_df)
|
||||||
|
|
||||||
|
for tf in self.get_timeframes():
|
||||||
|
data = self.get_data_for_timeframe(tf)
|
||||||
|
print(f"Timeframe {tf}: {len(data)} bars, "
|
||||||
|
f"from {data.index[0]} to {data.index[-1]}")
|
||||||
|
|
||||||
|
self.initialized = True
|
||||||
|
```
|
||||||
|
|
||||||
|
## Future Enhancements
|
||||||
|
|
||||||
|
### Planned Features
|
||||||
|
|
||||||
|
1. **Dynamic Timeframe Switching**: Strategies adapt timeframes based on market conditions
|
||||||
|
2. **Timeframe Confidence Weighting**: Different confidence levels per timeframe
|
||||||
|
3. **Cross-Timeframe Signal Validation**: Automatic signal confirmation across timeframes
|
||||||
|
4. **Optimized Memory Management**: Lazy loading and caching for large datasets
|
||||||
|
|
||||||
|
### Extension Points
|
||||||
|
|
||||||
|
The timeframe system is designed for easy extension:
|
||||||
|
|
||||||
|
- Custom resampling methods
|
||||||
|
- Alternative timeframe synchronization strategies
|
||||||
|
- Market-specific timeframe preferences
|
||||||
|
- Real-time timeframe adaptation
|
||||||
207
docs/utils_storage.md
Normal file
207
docs/utils_storage.md
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
# Storage Utilities
|
||||||
|
|
||||||
|
This document describes the refactored storage utilities found in `cycles/utils/` that provide modular, maintainable data and results management.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The storage utilities have been refactored into a modular architecture with clear separation of concerns:
|
||||||
|
|
||||||
|
- **`Storage`** - Main coordinator class providing unified interface (backward compatible)
|
||||||
|
- **`DataLoader`** - Handles loading data from various file formats
|
||||||
|
- **`DataSaver`** - Manages saving data with proper format handling
|
||||||
|
- **`ResultFormatter`** - Formats and writes backtest results to CSV files
|
||||||
|
- **`storage_utils`** - Shared utilities and custom exceptions
|
||||||
|
|
||||||
|
This design improves maintainability, testability, and follows the single responsibility principle.
|
||||||
|
|
||||||
|
## Constants
|
||||||
|
|
||||||
|
- `RESULTS_DIR`: Default directory for storing results (default: "../results")
|
||||||
|
- `DATA_DIR`: Default directory for storing input data (default: "../data")
|
||||||
|
|
||||||
|
## Main Classes
|
||||||
|
|
||||||
|
### `Storage` (Coordinator Class)
|
||||||
|
|
||||||
|
The main interface that coordinates all storage operations while maintaining backward compatibility.
|
||||||
|
|
||||||
|
#### `__init__(self, logging=None, results_dir=RESULTS_DIR, data_dir=DATA_DIR)`
|
||||||
|
|
||||||
|
**Description**: Initializes the Storage coordinator with component instances.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `logging` (optional): A logging instance for outputting information
|
||||||
|
- `results_dir` (str, optional): Path to the directory for storing results
|
||||||
|
- `data_dir` (str, optional): Path to the directory for storing data
|
||||||
|
|
||||||
|
**Creates**: Component instances for DataLoader, DataSaver, and ResultFormatter
|
||||||
|
|
||||||
|
#### `load_data(self, file_path: str, start_date: Union[str, pd.Timestamp], stop_date: Union[str, pd.Timestamp]) -> pd.DataFrame`
|
||||||
|
|
||||||
|
**Description**: Loads data with optimized dtypes and filtering, supporting CSV and JSON input.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `file_path` (str): Path to the data file (relative to `data_dir`)
|
||||||
|
- `start_date` (datetime-like): The start date for filtering data
|
||||||
|
- `stop_date` (datetime-like): The end date for filtering data
|
||||||
|
|
||||||
|
**Returns**: `pandas.DataFrame` with timestamp index
|
||||||
|
|
||||||
|
**Raises**: `DataLoadingError` if loading fails
|
||||||
|
|
||||||
|
#### `save_data(self, data: pd.DataFrame, file_path: str) -> None`
|
||||||
|
|
||||||
|
**Description**: Saves processed data to a CSV file with proper timestamp handling.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `data` (pd.DataFrame): The DataFrame to save
|
||||||
|
- `file_path` (str): Path to the data file (relative to `data_dir`)
|
||||||
|
|
||||||
|
**Raises**: `DataSavingError` if saving fails
|
||||||
|
|
||||||
|
#### `format_row(self, row: Dict[str, Any]) -> Dict[str, str]`
|
||||||
|
|
||||||
|
**Description**: Formats a dictionary row for output to results CSV files.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `row` (dict): The row of data to format
|
||||||
|
|
||||||
|
**Returns**: `dict` with formatted values (percentages, currency, etc.)
|
||||||
|
|
||||||
|
#### `write_results_chunk(self, filename: str, fieldnames: List[str], rows: List[Dict], write_header: bool = False, initial_usd: Optional[float] = None) -> None`
|
||||||
|
|
||||||
|
**Description**: Writes a chunk of results to a CSV file with optional header.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `filename` (str): The name of the file to write to
|
||||||
|
- `fieldnames` (list): CSV header/column names
|
||||||
|
- `rows` (list): List of dictionaries representing rows
|
||||||
|
- `write_header` (bool, optional): Whether to write the header
|
||||||
|
- `initial_usd` (float, optional): Initial USD value for header comment
|
||||||
|
|
||||||
|
#### `write_backtest_results(self, filename: str, fieldnames: List[str], rows: List[Dict], metadata_lines: Optional[List[str]] = None) -> str`
|
||||||
|
|
||||||
|
**Description**: Writes combined backtest results to a CSV file with metadata.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `filename` (str): Name of the file to write to (relative to `results_dir`)
|
||||||
|
- `fieldnames` (list): CSV header/column names
|
||||||
|
- `rows` (list): List of result dictionaries
|
||||||
|
- `metadata_lines` (list, optional): Header comment lines
|
||||||
|
|
||||||
|
**Returns**: Full path to the written file
|
||||||
|
|
||||||
|
#### `write_trades(self, all_trade_rows: List[Dict], trades_fieldnames: List[str]) -> None`
|
||||||
|
|
||||||
|
**Description**: Writes trade data to separate CSV files grouped by timeframe and stop-loss.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `all_trade_rows` (list): List of trade dictionaries
|
||||||
|
- `trades_fieldnames` (list): CSV header for trade files
|
||||||
|
|
||||||
|
**Files Created**: `trades_{timeframe}_ST{sl_percent}pct.csv` in `results_dir`
|
||||||
|
|
||||||
|
### `DataLoader`
|
||||||
|
|
||||||
|
Handles loading and preprocessing of data from various file formats.
|
||||||
|
|
||||||
|
#### Key Features:
|
||||||
|
- Supports CSV and JSON formats
|
||||||
|
- Optimized pandas dtypes for financial data
|
||||||
|
- Intelligent timestamp parsing (Unix timestamps and datetime strings)
|
||||||
|
- Date range filtering
|
||||||
|
- Column name normalization (lowercase)
|
||||||
|
- Comprehensive error handling
|
||||||
|
|
||||||
|
#### Methods:
|
||||||
|
- `load_data()` - Main loading interface
|
||||||
|
- `_load_json_data()` - JSON-specific loading logic
|
||||||
|
- `_load_csv_data()` - CSV-specific loading logic
|
||||||
|
- `_process_csv_timestamps()` - Timestamp parsing for CSV data
|
||||||
|
|
||||||
|
### `DataSaver`
|
||||||
|
|
||||||
|
Manages saving data with proper format handling and index conversion.
|
||||||
|
|
||||||
|
#### Key Features:
|
||||||
|
- Converts DatetimeIndex to Unix timestamps for CSV compatibility
|
||||||
|
- Handles numeric indexes appropriately
|
||||||
|
- Ensures 'timestamp' column is first in output
|
||||||
|
- Comprehensive error handling and logging
|
||||||
|
|
||||||
|
#### Methods:
|
||||||
|
- `save_data()` - Main saving interface
|
||||||
|
- `_prepare_data_for_saving()` - Data preparation logic
|
||||||
|
- `_convert_datetime_index_to_timestamp()` - DatetimeIndex conversion
|
||||||
|
- `_convert_numeric_index_to_timestamp()` - Numeric index conversion
|
||||||
|
|
||||||
|
### `ResultFormatter`
|
||||||
|
|
||||||
|
Handles formatting and writing of backtest results to CSV files.
|
||||||
|
|
||||||
|
#### Key Features:
|
||||||
|
- Consistent formatting for percentages and currency
|
||||||
|
- Grouped trade file writing by timeframe/stop-loss
|
||||||
|
- Metadata header support
|
||||||
|
- Tab-delimited output for results
|
||||||
|
- Error handling for all write operations
|
||||||
|
|
||||||
|
#### Methods:
|
||||||
|
- `format_row()` - Format individual result rows
|
||||||
|
- `write_results_chunk()` - Write result chunks with headers
|
||||||
|
- `write_backtest_results()` - Write combined results with metadata
|
||||||
|
- `write_trades()` - Write grouped trade files
|
||||||
|
|
||||||
|
## Utility Functions and Exceptions
|
||||||
|
|
||||||
|
### Custom Exceptions
|
||||||
|
|
||||||
|
- **`TimestampParsingError`** - Raised when timestamp parsing fails
|
||||||
|
- **`DataLoadingError`** - Raised when data loading operations fail
|
||||||
|
- **`DataSavingError`** - Raised when data saving operations fail
|
||||||
|
|
||||||
|
### Utility Functions
|
||||||
|
|
||||||
|
- **`_parse_timestamp_column()`** - Parse timestamp columns with format detection
|
||||||
|
- **`_filter_by_date_range()`** - Filter DataFrames by date range
|
||||||
|
- **`_normalize_column_names()`** - Convert column names to lowercase
|
||||||
|
|
||||||
|
## Architecture Benefits
|
||||||
|
|
||||||
|
### Separation of Concerns
|
||||||
|
- Each class has a single, well-defined responsibility
|
||||||
|
- Data loading, saving, and result formatting are cleanly separated
|
||||||
|
- Shared utilities are extracted to prevent code duplication
|
||||||
|
|
||||||
|
### Maintainability
|
||||||
|
- All files are under 250 lines (quality gate)
|
||||||
|
- All methods are under 50 lines (quality gate)
|
||||||
|
- Clear interfaces and comprehensive documentation
|
||||||
|
- Type hints for better IDE support and clarity
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
- Custom exceptions for different error types
|
||||||
|
- Consistent error logging patterns
|
||||||
|
- Graceful degradation (empty DataFrames on load failure)
|
||||||
|
|
||||||
|
### Backward Compatibility
|
||||||
|
- Storage class maintains exact same public interface
|
||||||
|
- All existing code continues to work unchanged
|
||||||
|
- Component classes are available for advanced usage
|
||||||
|
|
||||||
|
## Migration Notes
|
||||||
|
|
||||||
|
The refactoring maintains full backward compatibility. Existing code using `Storage` will continue to work unchanged. For new code, consider using the component classes directly for more focused functionality:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Existing pattern (still works)
|
||||||
|
from cycles.utils.storage import Storage
|
||||||
|
storage = Storage(logging=logger)
|
||||||
|
data = storage.load_data('file.csv', start, end)
|
||||||
|
|
||||||
|
# New pattern for focused usage
|
||||||
|
from cycles.utils.data_loader import DataLoader
|
||||||
|
loader = DataLoader(data_dir, logger)
|
||||||
|
data = loader.load_data('file.csv', start, end)
|
||||||
|
```
|
||||||
|
|
||||||
49
docs/utils_system.md
Normal file
49
docs/utils_system.md
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# System Utilities
|
||||||
|
|
||||||
|
This document describes the system utility functions found in `cycles/utils/system.py`.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `system.py` module provides utility functions related to system information and resource management. It currently includes a class `SystemUtils` for determining optimal configurations based on system resources.
|
||||||
|
|
||||||
|
## Classes and Methods
|
||||||
|
|
||||||
|
### `SystemUtils`
|
||||||
|
|
||||||
|
A class to provide system-related utility methods.
|
||||||
|
|
||||||
|
#### `__init__(self, logging=None)`
|
||||||
|
|
||||||
|
- **Description**: Initializes the `SystemUtils` class.
|
||||||
|
- **Parameters**:
|
||||||
|
- `logging` (optional): A logging instance to output information. Defaults to `None`.
|
||||||
|
|
||||||
|
#### `get_optimal_workers(self)`
|
||||||
|
|
||||||
|
- **Description**: Determines the optimal number of worker processes based on available CPU cores and memory.
|
||||||
|
The heuristic aims to use 75% of CPU cores, with a cap based on available memory (assuming each worker might need ~2GB for large datasets). It returns the minimum of the workers calculated by CPU and memory.
|
||||||
|
- **Parameters**: None.
|
||||||
|
- **Returns**: `int` - The recommended number of worker processes.
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
```python
|
||||||
|
from cycles.utils.system import SystemUtils
|
||||||
|
|
||||||
|
# Initialize (optionally with a logger)
|
||||||
|
# import logging
|
||||||
|
# logging.basicConfig(level=logging.INFO)
|
||||||
|
# logger = logging.getLogger(__name__)
|
||||||
|
# sys_utils = SystemUtils(logging=logger)
|
||||||
|
sys_utils = SystemUtils()
|
||||||
|
|
||||||
|
|
||||||
|
optimal_workers = sys_utils.get_optimal_workers()
|
||||||
|
print(f"Optimal number of workers: {optimal_workers}")
|
||||||
|
|
||||||
|
# This value can then be used, for example, when setting up a ThreadPoolExecutor
|
||||||
|
# from concurrent.futures import ThreadPoolExecutor
|
||||||
|
# with ThreadPoolExecutor(max_workers=optimal_workers) as executor:
|
||||||
|
# # ... submit tasks ...
|
||||||
|
# pass
|
||||||
|
```
|
||||||
440
main.py
440
main.py
@@ -1,18 +1,27 @@
|
|||||||
import pandas as pd
|
#!/usr/bin/env python3
|
||||||
import numpy as np
|
"""
|
||||||
from trend_detector_macd import TrendDetectorMACD
|
Backtest execution script for cryptocurrency trading strategies
|
||||||
from trend_detector_simple import TrendDetectorSimple
|
Refactored for improved maintainability and error handling
|
||||||
from cycle_detector import CycleDetector
|
"""
|
||||||
import csv
|
|
||||||
import logging
|
import logging
|
||||||
import concurrent.futures
|
import datetime
|
||||||
import os
|
import argparse
|
||||||
import psutil
|
import sys
|
||||||
import datetime
|
from pathlib import Path
|
||||||
from charts import BacktestCharts
|
|
||||||
from collections import Counter
|
# Import custom modules
|
||||||
|
from config_manager import ConfigManager
|
||||||
|
from backtest_runner import BacktestRunner
|
||||||
|
from result_processor import ResultProcessor
|
||||||
|
from cycles.utils.storage import Storage
|
||||||
|
from cycles.utils.system import SystemUtils
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging() -> logging.Logger:
|
||||||
|
"""Configure and return logging instance"""
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Set up logging
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
@@ -22,280 +31,145 @@ logging.basicConfig(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_optimal_workers():
|
return logger
|
||||||
"""Determine optimal number of worker processes based on system resources"""
|
|
||||||
cpu_count = os.cpu_count() or 4
|
|
||||||
memory_gb = psutil.virtual_memory().total / (1024**3)
|
|
||||||
# Heuristic: Use 75% of cores, but cap based on available memory
|
|
||||||
# Assume each worker needs ~2GB for large datasets
|
|
||||||
workers_by_memory = max(1, int(memory_gb / 2))
|
|
||||||
workers_by_cpu = max(1, int(cpu_count * 0.75))
|
|
||||||
return min(workers_by_cpu, workers_by_memory)
|
|
||||||
|
|
||||||
def load_data(file_path, start_date, stop_date):
|
|
||||||
"""Load data with optimized dtypes and filtering"""
|
|
||||||
# Define optimized dtypes
|
|
||||||
dtypes = {
|
|
||||||
'Open': 'float32',
|
|
||||||
'High': 'float32',
|
|
||||||
'Low': 'float32',
|
|
||||||
'Close': 'float32',
|
|
||||||
'Volume': 'float32'
|
|
||||||
}
|
|
||||||
|
|
||||||
# Read data with original capitalized column names
|
def create_metadata_lines(config: dict, data_df, result_processor: ResultProcessor) -> list:
|
||||||
data = pd.read_csv(file_path, dtype=dtypes)
|
"""Create metadata lines for results file"""
|
||||||
|
start_date = config['start_date']
|
||||||
|
stop_date = config['stop_date']
|
||||||
|
initial_usd = config['initial_usd']
|
||||||
|
|
||||||
# Convert timestamp to datetime
|
# Get price information
|
||||||
data['Timestamp'] = pd.to_datetime(data['Timestamp'], unit='s')
|
start_time, start_price = result_processor.get_price_info(data_df, start_date)
|
||||||
|
stop_time, stop_price = result_processor.get_price_info(data_df, stop_date)
|
||||||
|
|
||||||
# Filter by date range
|
metadata_lines = [
|
||||||
data = data[(data['Timestamp'] >= start_date) & (data['Timestamp'] <= stop_date)]
|
f"Start date\t{start_date}\tPrice\t{start_price or 'N/A'}",
|
||||||
|
f"Stop date\t{stop_date}\tPrice\t{stop_price or 'N/A'}",
|
||||||
# Now convert column names to lowercase
|
f"Initial USD\t{initial_usd}"
|
||||||
data.columns = data.columns.str.lower()
|
|
||||||
|
|
||||||
return data.set_index('timestamp')
|
|
||||||
|
|
||||||
def process_timeframe_data(min1_df, df, stop_loss_pcts, rule_name, initial_usd, debug=False):
|
|
||||||
"""Process the entire timeframe with all stop loss values (no monthly split)"""
|
|
||||||
df = df.copy().reset_index(drop=True)
|
|
||||||
trend_detector = TrendDetectorSimple(df, verbose=False)
|
|
||||||
|
|
||||||
results_rows = []
|
|
||||||
trade_rows = []
|
|
||||||
for stop_loss_pct in stop_loss_pcts:
|
|
||||||
results = trend_detector.backtest_meta_supertrend(
|
|
||||||
min1_df,
|
|
||||||
initial_usd=initial_usd,
|
|
||||||
stop_loss_pct=stop_loss_pct,
|
|
||||||
debug=debug
|
|
||||||
)
|
|
||||||
n_trades = results["n_trades"]
|
|
||||||
trades = results.get('trades', [])
|
|
||||||
n_winning_trades = sum(1 for trade in trades if trade['profit_pct'] > 0)
|
|
||||||
total_profit = sum(trade['profit_pct'] for trade in trades)
|
|
||||||
total_loss = sum(-trade['profit_pct'] for trade in trades if trade['profit_pct'] < 0)
|
|
||||||
win_rate = n_winning_trades / n_trades if n_trades > 0 else 0
|
|
||||||
avg_trade = total_profit / n_trades if n_trades > 0 else 0
|
|
||||||
profit_ratio = total_profit / total_loss if total_loss > 0 else float('inf')
|
|
||||||
cumulative_profit = 0
|
|
||||||
max_drawdown = 0
|
|
||||||
peak = 0
|
|
||||||
for trade in trades:
|
|
||||||
cumulative_profit += trade['profit_pct']
|
|
||||||
if cumulative_profit > peak:
|
|
||||||
peak = cumulative_profit
|
|
||||||
drawdown = peak - cumulative_profit
|
|
||||||
if drawdown > max_drawdown:
|
|
||||||
max_drawdown = drawdown
|
|
||||||
final_usd = initial_usd
|
|
||||||
for trade in trades:
|
|
||||||
final_usd *= (1 + trade['profit_pct'])
|
|
||||||
row = {
|
|
||||||
"timeframe": rule_name,
|
|
||||||
"stop_loss_pct": stop_loss_pct,
|
|
||||||
"n_trades": n_trades,
|
|
||||||
"n_stop_loss": sum(1 for trade in trades if 'type' in trade and trade['type'] == 'STOP'),
|
|
||||||
"win_rate": win_rate,
|
|
||||||
"max_drawdown": max_drawdown,
|
|
||||||
"avg_trade": avg_trade,
|
|
||||||
"profit_ratio": profit_ratio,
|
|
||||||
"initial_usd": initial_usd,
|
|
||||||
"final_usd": final_usd,
|
|
||||||
}
|
|
||||||
results_rows.append(row)
|
|
||||||
for trade in trades:
|
|
||||||
trade_rows.append({
|
|
||||||
"timeframe": rule_name,
|
|
||||||
"stop_loss_pct": stop_loss_pct,
|
|
||||||
"entry_time": trade.get("entry_time"),
|
|
||||||
"exit_time": trade.get("exit_time"),
|
|
||||||
"entry_price": trade.get("entry"),
|
|
||||||
"exit_price": trade.get("exit"),
|
|
||||||
"profit_pct": trade.get("profit_pct"),
|
|
||||||
"type": trade.get("type", ""),
|
|
||||||
})
|
|
||||||
logging.info(f"Timeframe: {rule_name}, Stop Loss: {stop_loss_pct}, Trades: {n_trades}")
|
|
||||||
if debug:
|
|
||||||
for trade in trades:
|
|
||||||
if trade['type'] == 'STOP':
|
|
||||||
print(trade)
|
|
||||||
for trade in trades:
|
|
||||||
if trade['profit_pct'] < -0.09: # or whatever is close to -0.10
|
|
||||||
print("Large loss trade:", trade)
|
|
||||||
return results_rows, trade_rows
|
|
||||||
|
|
||||||
def process_timeframe(timeframe_info, debug=False):
|
|
||||||
"""Process an entire timeframe (no monthly split)"""
|
|
||||||
rule, data_1min, stop_loss_pcts, initial_usd = timeframe_info
|
|
||||||
if rule == "1T":
|
|
||||||
df = data_1min.copy()
|
|
||||||
else:
|
|
||||||
df = data_1min.resample(rule).agg({
|
|
||||||
'open': 'first',
|
|
||||||
'high': 'max',
|
|
||||||
'low': 'min',
|
|
||||||
'close': 'last',
|
|
||||||
'volume': 'sum'
|
|
||||||
}).dropna()
|
|
||||||
df = df.reset_index()
|
|
||||||
|
|
||||||
# --- Add this block to check alignment ---
|
|
||||||
print("1-min data range:", data_1min.index.min(), "to", data_1min.index.max())
|
|
||||||
print(f"{rule} data range:", df['timestamp'].min(), "to", df['timestamp'].max())
|
|
||||||
# -----------------------------------------
|
|
||||||
|
|
||||||
results_rows, all_trade_rows = process_timeframe_data(data_1min, df, stop_loss_pcts, rule, initial_usd, debug=debug)
|
|
||||||
return results_rows, all_trade_rows
|
|
||||||
|
|
||||||
def write_results_chunk(filename, fieldnames, rows, write_header=False):
|
|
||||||
"""Write a chunk of results to a CSV file"""
|
|
||||||
mode = 'w' if write_header else 'a'
|
|
||||||
|
|
||||||
with open(filename, mode, newline="") as csvfile:
|
|
||||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
|
||||||
if write_header:
|
|
||||||
csvfile.write(f"# initial_usd: {initial_usd}\n")
|
|
||||||
writer.writeheader()
|
|
||||||
|
|
||||||
for row in rows:
|
|
||||||
# Only keep keys that are in fieldnames
|
|
||||||
filtered_row = {k: v for k, v in row.items() if k in fieldnames}
|
|
||||||
writer.writerow(filtered_row)
|
|
||||||
|
|
||||||
def aggregate_results(all_rows):
|
|
||||||
"""Aggregate results per stop_loss_pct and per rule (timeframe)"""
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
grouped = defaultdict(list)
|
|
||||||
for row in all_rows:
|
|
||||||
key = (row['timeframe'], row['stop_loss_pct'])
|
|
||||||
grouped[key].append(row)
|
|
||||||
|
|
||||||
summary_rows = []
|
|
||||||
for (rule, stop_loss_pct), rows in grouped.items():
|
|
||||||
n_months = len(rows)
|
|
||||||
total_trades = sum(r['n_trades'] for r in rows)
|
|
||||||
total_stop_loss = sum(r['n_stop_loss'] for r in rows)
|
|
||||||
avg_win_rate = np.mean([r['win_rate'] for r in rows])
|
|
||||||
avg_max_drawdown = np.mean([r['max_drawdown'] for r in rows])
|
|
||||||
avg_avg_trade = np.mean([r['avg_trade'] for r in rows])
|
|
||||||
avg_profit_ratio = np.mean([r['profit_ratio'] for r in rows])
|
|
||||||
|
|
||||||
# Calculate final USD
|
|
||||||
final_usd = np.mean([r.get('final_usd', initial_usd) for r in rows])
|
|
||||||
|
|
||||||
summary_rows.append({
|
|
||||||
"timeframe": rule,
|
|
||||||
"stop_loss_pct": stop_loss_pct,
|
|
||||||
"n_trades": total_trades,
|
|
||||||
"n_stop_loss": total_stop_loss,
|
|
||||||
"win_rate": avg_win_rate,
|
|
||||||
"max_drawdown": avg_max_drawdown,
|
|
||||||
"avg_trade": avg_avg_trade,
|
|
||||||
"profit_ratio": avg_profit_ratio,
|
|
||||||
"initial_usd": initial_usd,
|
|
||||||
"final_usd": final_usd,
|
|
||||||
})
|
|
||||||
return summary_rows
|
|
||||||
|
|
||||||
def write_results_per_combination(results_rows, trade_rows, timestamp):
|
|
||||||
results_dir = "results"
|
|
||||||
os.makedirs(results_dir, exist_ok=True)
|
|
||||||
for row in results_rows:
|
|
||||||
timeframe = row["timeframe"]
|
|
||||||
stop_loss_pct = row["stop_loss_pct"]
|
|
||||||
filename = os.path.join(
|
|
||||||
results_dir,
|
|
||||||
f"{timestamp}_backtest_{timeframe}_{stop_loss_pct}.csv"
|
|
||||||
)
|
|
||||||
fieldnames = ["timeframe", "stop_loss_pct", "n_trades", "n_stop_loss", "win_rate", "max_drawdown", "avg_trade", "profit_ratio", "initial_usd", "final_usd"]
|
|
||||||
write_results_chunk(filename, fieldnames, [row], write_header=not os.path.exists(filename))
|
|
||||||
for trade in trade_rows:
|
|
||||||
timeframe = trade["timeframe"]
|
|
||||||
stop_loss_pct = trade["stop_loss_pct"]
|
|
||||||
trades_filename = os.path.join(
|
|
||||||
results_dir,
|
|
||||||
f"{timestamp}_trades_{timeframe}_{stop_loss_pct}.csv"
|
|
||||||
)
|
|
||||||
trades_fieldnames = [
|
|
||||||
"timeframe", "stop_loss_pct", "entry_time", "exit_time",
|
|
||||||
"entry_price", "exit_price", "profit_pct", "type"
|
|
||||||
]
|
]
|
||||||
write_results_chunk(trades_filename, trades_fieldnames, [trade], write_header=not os.path.exists(trades_filename))
|
|
||||||
|
return metadata_lines
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main execution function"""
|
||||||
|
logger = setup_logging()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(description="Run backtest with config file.")
|
||||||
|
parser.add_argument("config", type=str, nargs="?", help="Path to config JSON file.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Initialize configuration manager
|
||||||
|
config_manager = ConfigManager(logging_instance=logger)
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
logger.info("Loading configuration...")
|
||||||
|
config = config_manager.load_config(args.config)
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
logger.info("Initializing components...")
|
||||||
|
storage = Storage(
|
||||||
|
data_dir=config['data_dir'],
|
||||||
|
results_dir=config['results_dir'],
|
||||||
|
logging=logger
|
||||||
|
)
|
||||||
|
system_utils = SystemUtils(logging=logger)
|
||||||
|
result_processor = ResultProcessor(storage, logging_instance=logger)
|
||||||
|
|
||||||
|
# OPTIMIZATION: Disable progress for parallel execution to improve performance
|
||||||
|
show_progress = config.get('show_progress', True)
|
||||||
|
debug_mode = config.get('debug', 0) == 1
|
||||||
|
|
||||||
|
# Only show progress in debug (sequential) mode
|
||||||
|
if not debug_mode:
|
||||||
|
show_progress = False
|
||||||
|
logger.info("Progress tracking disabled for parallel execution (performance optimization)")
|
||||||
|
|
||||||
|
runner = BacktestRunner(
|
||||||
|
storage,
|
||||||
|
system_utils,
|
||||||
|
result_processor,
|
||||||
|
logging_instance=logger,
|
||||||
|
show_progress=show_progress
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate inputs
|
||||||
|
logger.info("Validating inputs...")
|
||||||
|
runner.validate_inputs(
|
||||||
|
config['timeframes'],
|
||||||
|
config['stop_loss_pcts'],
|
||||||
|
config['initial_usd']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
logger.info("Loading market data...")
|
||||||
|
# data_filename = 'btcusd_1-min_data.csv'
|
||||||
|
data_filename = 'btcusd_1-min_data_with_price_predictions.csv'
|
||||||
|
data_1min = runner.load_data(
|
||||||
|
data_filename,
|
||||||
|
config['start_date'],
|
||||||
|
config['stop_date']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run backtests
|
||||||
|
logger.info("Starting backtest execution...")
|
||||||
|
|
||||||
|
all_results, all_trades = runner.run_backtests(
|
||||||
|
data_1min,
|
||||||
|
config['timeframes'],
|
||||||
|
config['stop_loss_pcts'],
|
||||||
|
config['initial_usd'],
|
||||||
|
debug=debug_mode
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process and save results
|
||||||
|
logger.info("Processing and saving results...")
|
||||||
|
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M")
|
||||||
|
|
||||||
|
# OPTIMIZATION: Save trade files in batch after parallel execution
|
||||||
|
if all_trades and not debug_mode:
|
||||||
|
logger.info("Saving trade files in batch...")
|
||||||
|
result_processor.save_all_trade_files(all_trades)
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
metadata_lines = create_metadata_lines(config, data_1min, result_processor)
|
||||||
|
|
||||||
|
# Save aggregated results
|
||||||
|
result_file = result_processor.save_backtest_results(
|
||||||
|
all_results,
|
||||||
|
metadata_lines,
|
||||||
|
timestamp
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Backtest completed successfully. Results saved to {result_file}")
|
||||||
|
logger.info(f"Processed {len(all_results)} result combinations")
|
||||||
|
logger.info(f"Generated {len(all_trades)} total trades")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.warning("Backtest interrupted by user")
|
||||||
|
sys.exit(130) # Standard exit code for Ctrl+C
|
||||||
|
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.error(f"File not found: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Invalid configuration or data: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Runtime error during backtest: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Configuration
|
main()
|
||||||
start_date = '2020-01-01'
|
|
||||||
stop_date = '2025-05-15'
|
|
||||||
initial_usd = 10000
|
|
||||||
debug = False # Set to True to enable debug prints
|
|
||||||
# --- NEW: Prepare results folder and timestamp ---
|
|
||||||
results_dir = "results"
|
|
||||||
os.makedirs(results_dir, exist_ok=True)
|
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M")
|
|
||||||
# --- END NEW ---
|
|
||||||
|
|
||||||
# Replace the dictionary with a list of timeframe names
|
|
||||||
timeframes = ["15min", "1h", "6h", "1D"]
|
|
||||||
# timeframes = ["6h"]
|
|
||||||
|
|
||||||
stop_loss_pcts = [0.01, 0.02, 0.03, 0.05, 0.07, 0.10]
|
|
||||||
# stop_loss_pcts = [0.01]
|
|
||||||
|
|
||||||
# Load data once
|
|
||||||
data_1min = load_data('./data/btcusd_1-min_data.csv', start_date, stop_date)
|
|
||||||
logging.info(f"1min rows: {len(data_1min)}")
|
|
||||||
|
|
||||||
# Prepare tasks
|
|
||||||
tasks = [
|
|
||||||
(name, data_1min, stop_loss_pcts, initial_usd)
|
|
||||||
for name in timeframes
|
|
||||||
]
|
|
||||||
|
|
||||||
# Determine optimal worker count
|
|
||||||
workers = get_optimal_workers()
|
|
||||||
logging.info(f"Using {workers} workers for processing")
|
|
||||||
|
|
||||||
# Process tasks with optimized concurrency
|
|
||||||
with concurrent.futures.ProcessPoolExecutor(max_workers=workers) as executor:
|
|
||||||
futures = {executor.submit(process_timeframe, task, debug): task[1] for task in tasks}
|
|
||||||
all_results_rows = []
|
|
||||||
for future in concurrent.futures.as_completed(futures):
|
|
||||||
#try:
|
|
||||||
results, trades = future.result()
|
|
||||||
if results or trades:
|
|
||||||
all_results_rows.extend(results)
|
|
||||||
write_results_per_combination(results, trades, timestamp)
|
|
||||||
#except Exception as exc:
|
|
||||||
# logging.error(f"generated an exception: {exc}")
|
|
||||||
|
|
||||||
# Write all results to a single CSV file
|
|
||||||
combined_filename = os.path.join(results_dir, f"{timestamp}_backtest_combined.csv")
|
|
||||||
combined_fieldnames = [
|
|
||||||
"timeframe", "stop_loss_pct", "n_trades", "n_stop_loss", "win_rate",
|
|
||||||
"max_drawdown", "avg_trade", "profit_ratio", "final_usd"
|
|
||||||
]
|
|
||||||
|
|
||||||
def format_row(row):
|
|
||||||
# Format percentages and floats as in your example
|
|
||||||
return {
|
|
||||||
"timeframe": row["timeframe"],
|
|
||||||
"stop_loss_pct": f"{row['stop_loss_pct']*100:.2f}%",
|
|
||||||
"n_trades": row["n_trades"],
|
|
||||||
"n_stop_loss": row["n_stop_loss"],
|
|
||||||
"win_rate": f"{row['win_rate']*100:.2f}%",
|
|
||||||
"max_drawdown": f"{row['max_drawdown']*100:.2f}%",
|
|
||||||
"avg_trade": f"{row['avg_trade']*100:.2f}%",
|
|
||||||
"profit_ratio": f"{row['profit_ratio']*100:.2f}%",
|
|
||||||
"final_usd": f"{row['final_usd']:.2f}",
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(combined_filename, "w", newline="") as csvfile:
|
|
||||||
writer = csv.DictWriter(csvfile, fieldnames=combined_fieldnames, delimiter='\t')
|
|
||||||
writer.writeheader()
|
|
||||||
for row in all_results_rows:
|
|
||||||
writer.writerow(format_row(row))
|
|
||||||
|
|
||||||
logging.info(f"Combined results written to {combined_filename}")
|
|
||||||
197
main_debug.py
197
main_debug.py
@@ -1,197 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
from trend_detector_simple import TrendDetectorSimple
|
|
||||||
import os
|
|
||||||
import datetime
|
|
||||||
import csv
|
|
||||||
|
|
||||||
def load_data(file_path, start_date, stop_date):
|
|
||||||
"""Load and filter data by date range."""
|
|
||||||
data = pd.read_csv(file_path)
|
|
||||||
data['Timestamp'] = pd.to_datetime(data['Timestamp'], unit='s')
|
|
||||||
data = data[(data['Timestamp'] >= start_date) & (data['Timestamp'] <= stop_date)]
|
|
||||||
data.columns = data.columns.str.lower()
|
|
||||||
return data.set_index('timestamp')
|
|
||||||
|
|
||||||
def process_month_timeframe(min1_df, month_df, stop_loss_pcts, rule_name, initial_usd):
|
|
||||||
"""Process a single month for a given timeframe with all stop loss values."""
|
|
||||||
month_df = month_df.copy().reset_index(drop=True)
|
|
||||||
trend_detector = TrendDetectorSimple(month_df, verbose=False)
|
|
||||||
analysis_results = trend_detector.detect_trends()
|
|
||||||
signal_df = analysis_results.get('signal_df')
|
|
||||||
|
|
||||||
results_rows = []
|
|
||||||
trade_rows = []
|
|
||||||
for stop_loss_pct in stop_loss_pcts:
|
|
||||||
results = trend_detector.backtest_meta_supertrend(
|
|
||||||
min1_df,
|
|
||||||
initial_usd=initial_usd,
|
|
||||||
stop_loss_pct=stop_loss_pct
|
|
||||||
)
|
|
||||||
trades = results.get('trades', [])
|
|
||||||
n_trades = results["n_trades"]
|
|
||||||
n_winning_trades = sum(1 for trade in trades if trade['profit_pct'] > 0)
|
|
||||||
total_profit = sum(trade['profit_pct'] for trade in trades)
|
|
||||||
total_loss = sum(-trade['profit_pct'] for trade in trades if trade['profit_pct'] < 0)
|
|
||||||
win_rate = n_winning_trades / n_trades if n_trades > 0 else 0
|
|
||||||
avg_trade = total_profit / n_trades if n_trades > 0 else 0
|
|
||||||
profit_ratio = total_profit / total_loss if total_loss > 0 else float('inf')
|
|
||||||
|
|
||||||
# Max drawdown
|
|
||||||
cumulative_profit = 0
|
|
||||||
max_drawdown = 0
|
|
||||||
peak = 0
|
|
||||||
for trade in trades:
|
|
||||||
cumulative_profit += trade['profit_pct']
|
|
||||||
if cumulative_profit > peak:
|
|
||||||
peak = cumulative_profit
|
|
||||||
drawdown = peak - cumulative_profit
|
|
||||||
if drawdown > max_drawdown:
|
|
||||||
max_drawdown = drawdown
|
|
||||||
|
|
||||||
# Final USD
|
|
||||||
final_usd = initial_usd
|
|
||||||
for trade in trades:
|
|
||||||
final_usd *= (1 + trade['profit_pct'])
|
|
||||||
|
|
||||||
row = {
|
|
||||||
"timeframe": rule_name,
|
|
||||||
"month": str(month_df['timestamp'].iloc[0].to_period('M')),
|
|
||||||
"stop_loss_pct": stop_loss_pct,
|
|
||||||
"n_trades": n_trades,
|
|
||||||
"n_stop_loss": sum(1 for trade in trades if 'type' in trade and trade['type'] == 'STOP'),
|
|
||||||
"win_rate": win_rate,
|
|
||||||
"max_drawdown": max_drawdown,
|
|
||||||
"avg_trade": avg_trade,
|
|
||||||
"profit_ratio": profit_ratio,
|
|
||||||
"initial_usd": initial_usd,
|
|
||||||
"final_usd": final_usd,
|
|
||||||
}
|
|
||||||
results_rows.append(row)
|
|
||||||
|
|
||||||
for trade in trades:
|
|
||||||
trade_rows.append({
|
|
||||||
"timeframe": rule_name,
|
|
||||||
"month": str(month_df['timestamp'].iloc[0].to_period('M')),
|
|
||||||
"stop_loss_pct": stop_loss_pct,
|
|
||||||
"entry_time": trade.get("entry_time"),
|
|
||||||
"exit_time": trade.get("exit_time"),
|
|
||||||
"entry_price": trade.get("entry_price"),
|
|
||||||
"exit_price": trade.get("exit_price"),
|
|
||||||
"profit_pct": trade.get("profit_pct"),
|
|
||||||
"type": trade.get("type", ""),
|
|
||||||
})
|
|
||||||
|
|
||||||
return results_rows, trade_rows
|
|
||||||
|
|
||||||
def process_timeframe(rule, data_1min, stop_loss_pcts, initial_usd):
|
|
||||||
"""Process an entire timeframe sequentially."""
|
|
||||||
if rule == "1T":
|
|
||||||
df = data_1min.copy()
|
|
||||||
else:
|
|
||||||
df = data_1min.resample(rule).agg({
|
|
||||||
'open': 'first',
|
|
||||||
'high': 'max',
|
|
||||||
'low': 'min',
|
|
||||||
'close': 'last',
|
|
||||||
'volume': 'sum'
|
|
||||||
}).dropna()
|
|
||||||
|
|
||||||
df = df.reset_index()
|
|
||||||
df['month'] = df['timestamp'].dt.to_period('M')
|
|
||||||
results_rows = []
|
|
||||||
all_trade_rows = []
|
|
||||||
|
|
||||||
for month, month_df in df.groupby('month'):
|
|
||||||
if len(month_df) < 10:
|
|
||||||
continue
|
|
||||||
month_results, month_trades = process_month_timeframe(data_1min, month_df, stop_loss_pcts, rule, initial_usd)
|
|
||||||
results_rows.extend(month_results)
|
|
||||||
all_trade_rows.extend(month_trades)
|
|
||||||
|
|
||||||
return results_rows, all_trade_rows
|
|
||||||
|
|
||||||
def aggregate_results(all_rows, initial_usd):
|
|
||||||
"""Aggregate results per stop_loss_pct and per rule (timeframe)."""
|
|
||||||
from collections import defaultdict
|
|
||||||
grouped = defaultdict(list)
|
|
||||||
for row in all_rows:
|
|
||||||
key = (row['timeframe'], row['stop_loss_pct'])
|
|
||||||
grouped[key].append(row)
|
|
||||||
|
|
||||||
summary_rows = []
|
|
||||||
for (rule, stop_loss_pct), rows in grouped.items():
|
|
||||||
n_months = len(rows)
|
|
||||||
total_trades = sum(r['n_trades'] for r in rows)
|
|
||||||
total_stop_loss = sum(r['n_stop_loss'] for r in rows)
|
|
||||||
avg_win_rate = np.mean([r['win_rate'] for r in rows])
|
|
||||||
avg_max_drawdown = np.mean([r['max_drawdown'] for r in rows])
|
|
||||||
avg_avg_trade = np.mean([r['avg_trade'] for r in rows])
|
|
||||||
avg_profit_ratio = np.mean([r['profit_ratio'] for r in rows])
|
|
||||||
final_usd = np.mean([r.get('final_usd', initial_usd) for r in rows])
|
|
||||||
|
|
||||||
summary_rows.append({
|
|
||||||
"timeframe": rule,
|
|
||||||
"stop_loss_pct": stop_loss_pct,
|
|
||||||
"n_trades": total_trades,
|
|
||||||
"n_stop_loss": total_stop_loss,
|
|
||||||
"win_rate": avg_win_rate,
|
|
||||||
"max_drawdown": avg_max_drawdown,
|
|
||||||
"avg_trade": avg_avg_trade,
|
|
||||||
"profit_ratio": avg_profit_ratio,
|
|
||||||
"initial_usd": initial_usd,
|
|
||||||
"final_usd": final_usd,
|
|
||||||
})
|
|
||||||
return summary_rows
|
|
||||||
|
|
||||||
def write_results(filename, fieldnames, rows):
|
|
||||||
"""Write results to a CSV file."""
|
|
||||||
with open(filename, 'w', newline="") as csvfile:
|
|
||||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
|
||||||
writer.writeheader()
|
|
||||||
for row in rows:
|
|
||||||
writer.writerow(row)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Config
|
|
||||||
start_date = '2020-01-01'
|
|
||||||
stop_date = '2025-05-15'
|
|
||||||
initial_usd = 10000
|
|
||||||
|
|
||||||
results_dir = "results"
|
|
||||||
os.makedirs(results_dir, exist_ok=True)
|
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M")
|
|
||||||
|
|
||||||
timeframes = ["6h", "1D"]
|
|
||||||
stop_loss_pcts = [0.01, 0.02, 0.03, 0.05, 0.07, 0.10]
|
|
||||||
|
|
||||||
data_1min = load_data('./data/btcusd_1-min_data.csv', start_date, stop_date)
|
|
||||||
print(f"1min rows: {len(data_1min)}")
|
|
||||||
|
|
||||||
filename = os.path.join(
|
|
||||||
results_dir,
|
|
||||||
f"{timestamp}_backtest_results_{start_date}_{stop_date}_multi_timeframe_stoploss.csv"
|
|
||||||
)
|
|
||||||
fieldnames = ["timeframe", "stop_loss_pct", "n_trades", "n_stop_loss", "win_rate", "max_drawdown", "avg_trade", "profit_ratio", "initial_usd", "final_usd"]
|
|
||||||
|
|
||||||
all_results = []
|
|
||||||
all_trades = []
|
|
||||||
|
|
||||||
for name in timeframes:
|
|
||||||
print(f"Processing timeframe: {name}")
|
|
||||||
results, trades = process_timeframe(name, data_1min, stop_loss_pcts, initial_usd)
|
|
||||||
all_results.extend(results)
|
|
||||||
all_trades.extend(trades)
|
|
||||||
|
|
||||||
summary_rows = aggregate_results(all_results, initial_usd)
|
|
||||||
# write_results(filename, fieldnames, summary_rows)
|
|
||||||
|
|
||||||
trades_filename = os.path.join(
|
|
||||||
results_dir,
|
|
||||||
f"{timestamp}_backtest_trades.csv"
|
|
||||||
)
|
|
||||||
trades_fieldnames = [
|
|
||||||
"timeframe", "month", "stop_loss_pct", "entry_time", "exit_time",
|
|
||||||
"entry_price", "exit_price", "profit_pct", "type"
|
|
||||||
]
|
|
||||||
# write_results(trades_filename, trades_fieldnames, all_trades)
|
|
||||||
19
pyproject.toml
Normal file
19
pyproject.toml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
[project]
|
||||||
|
name = "cycles"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"dash>=3.0.4",
|
||||||
|
"gspread>=6.2.1",
|
||||||
|
"matplotlib>=3.10.3",
|
||||||
|
"numba>=0.61.2",
|
||||||
|
"pandas>=2.2.3",
|
||||||
|
"psutil>=7.0.0",
|
||||||
|
"scikit-learn>=1.6.1",
|
||||||
|
"scipy>=1.15.3",
|
||||||
|
"seaborn>=0.13.2",
|
||||||
|
"ta>=0.11.0",
|
||||||
|
"xgboost>=3.0.2",
|
||||||
|
]
|
||||||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
446
result_processor.py
Normal file
446
result_processor.py
Normal file
@@ -0,0 +1,446 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import csv
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from cycles.utils.storage import Storage
|
||||||
|
|
||||||
|
|
||||||
|
class ResultProcessor:
|
||||||
|
"""Handles processing, aggregation, and saving of backtest results"""
|
||||||
|
|
||||||
|
def __init__(self, storage: Storage, logging_instance: Optional[logging.Logger] = None):
|
||||||
|
"""
|
||||||
|
Initialize result processor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage: Storage instance for file operations
|
||||||
|
logging_instance: Optional logging instance
|
||||||
|
"""
|
||||||
|
self.storage = storage
|
||||||
|
self.logging = logging_instance
|
||||||
|
|
||||||
|
def process_timeframe_results(
|
||||||
|
self,
|
||||||
|
min1_df: pd.DataFrame,
|
||||||
|
df: pd.DataFrame,
|
||||||
|
stop_loss_pcts: List[float],
|
||||||
|
timeframe_name: str,
|
||||||
|
initial_usd: float,
|
||||||
|
progress_callback=None
|
||||||
|
) -> Tuple[List[Dict], List[Dict]]:
|
||||||
|
"""
|
||||||
|
Process results for a single timeframe with multiple stop loss values
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min1_df: 1-minute data DataFrame
|
||||||
|
df: Resampled timeframe DataFrame
|
||||||
|
stop_loss_pcts: List of stop loss percentages to test
|
||||||
|
timeframe_name: Name of the timeframe (e.g., '1D', '6h')
|
||||||
|
initial_usd: Initial USD amount
|
||||||
|
progress_callback: Optional progress callback function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (results_rows, trade_rows)
|
||||||
|
"""
|
||||||
|
from cycles.backtest import Backtest
|
||||||
|
|
||||||
|
df = df.copy().reset_index(drop=True)
|
||||||
|
results_rows = []
|
||||||
|
trade_rows = []
|
||||||
|
|
||||||
|
for stop_loss_pct in stop_loss_pcts:
|
||||||
|
try:
|
||||||
|
results = Backtest.run(
|
||||||
|
min1_df,
|
||||||
|
df,
|
||||||
|
initial_usd=initial_usd,
|
||||||
|
stop_loss_pct=stop_loss_pct,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
verbose=False # Default to False for production runs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
metrics = self._calculate_metrics(results, initial_usd, stop_loss_pct, timeframe_name)
|
||||||
|
results_rows.append(metrics)
|
||||||
|
|
||||||
|
# Process trades
|
||||||
|
if 'trades' not in results:
|
||||||
|
raise ValueError(f"Backtest results missing 'trades' field for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||||
|
trades = self._process_trades(results['trades'], timeframe_name, stop_loss_pct)
|
||||||
|
trade_rows.extend(trades)
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Timeframe: {timeframe_name}, Stop Loss: {stop_loss_pct}, Trades: {results['n_trades']}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error processing {timeframe_name} with stop loss {stop_loss_pct}: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
return results_rows, trade_rows
|
||||||
|
|
||||||
|
def _calculate_metrics(
|
||||||
|
self,
|
||||||
|
results: Dict[str, Any],
|
||||||
|
initial_usd: float,
|
||||||
|
stop_loss_pct: float,
|
||||||
|
timeframe_name: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Calculate performance metrics from backtest results"""
|
||||||
|
if 'trades' not in results:
|
||||||
|
raise ValueError(f"Backtest results missing 'trades' field for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||||
|
trades = results['trades']
|
||||||
|
n_trades = results["n_trades"]
|
||||||
|
|
||||||
|
# Validate that all required fields are present
|
||||||
|
required_fields = ['final_usd', 'max_drawdown', 'total_fees_usd', 'n_trades', 'n_stop_loss', 'win_rate', 'avg_trade']
|
||||||
|
missing_fields = [field for field in required_fields if field not in results]
|
||||||
|
if missing_fields:
|
||||||
|
raise ValueError(f"Backtest results missing required fields: {missing_fields}")
|
||||||
|
|
||||||
|
# Calculate win metrics - validate trade fields
|
||||||
|
winning_trades = []
|
||||||
|
for t in trades:
|
||||||
|
if 'exit' not in t:
|
||||||
|
raise ValueError(f"Trade missing 'exit' field: {t}")
|
||||||
|
if 'entry' not in t:
|
||||||
|
raise ValueError(f"Trade missing 'entry' field: {t}")
|
||||||
|
if t['exit'] is not None and t['exit'] > t['entry']:
|
||||||
|
winning_trades.append(t)
|
||||||
|
n_winning_trades = len(winning_trades)
|
||||||
|
win_rate = n_winning_trades / n_trades if n_trades > 0 else 0
|
||||||
|
|
||||||
|
# Calculate profit metrics
|
||||||
|
total_profit = sum(trade['profit_pct'] for trade in trades if trade['profit_pct'] > 0)
|
||||||
|
total_loss = abs(sum(trade['profit_pct'] for trade in trades if trade['profit_pct'] < 0))
|
||||||
|
avg_trade = sum(trade['profit_pct'] for trade in trades) / n_trades if n_trades > 0 else 0
|
||||||
|
profit_ratio = total_profit / total_loss if total_loss > 0 else (float('inf') if total_profit > 0 else 0)
|
||||||
|
|
||||||
|
# Get values directly from backtest results (no defaults)
|
||||||
|
max_drawdown = results['max_drawdown']
|
||||||
|
final_usd = results['final_usd']
|
||||||
|
total_fees_usd = results['total_fees_usd']
|
||||||
|
n_stop_loss = results['n_stop_loss'] # Get stop loss count directly from backtest
|
||||||
|
|
||||||
|
# Validate no None values
|
||||||
|
if max_drawdown is None:
|
||||||
|
raise ValueError(f"max_drawdown is None for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||||
|
if final_usd is None:
|
||||||
|
raise ValueError(f"final_usd is None for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||||
|
if total_fees_usd is None:
|
||||||
|
raise ValueError(f"total_fees_usd is None for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||||
|
if n_stop_loss is None:
|
||||||
|
raise ValueError(f"n_stop_loss is None for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"timeframe": timeframe_name,
|
||||||
|
"stop_loss_pct": stop_loss_pct,
|
||||||
|
"n_trades": n_trades,
|
||||||
|
"n_stop_loss": n_stop_loss,
|
||||||
|
"win_rate": win_rate,
|
||||||
|
"max_drawdown": max_drawdown,
|
||||||
|
"avg_trade": avg_trade,
|
||||||
|
"total_profit": total_profit,
|
||||||
|
"total_loss": total_loss,
|
||||||
|
"profit_ratio": profit_ratio,
|
||||||
|
"initial_usd": initial_usd,
|
||||||
|
"final_usd": final_usd,
|
||||||
|
"total_fees_usd": total_fees_usd,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _calculate_max_drawdown(self, trades: List[Dict]) -> float:
|
||||||
|
"""Calculate maximum drawdown from trade sequence"""
|
||||||
|
cumulative_profit = 0
|
||||||
|
max_drawdown = 0
|
||||||
|
peak = 0
|
||||||
|
|
||||||
|
for trade in trades:
|
||||||
|
cumulative_profit += trade['profit_pct']
|
||||||
|
if cumulative_profit > peak:
|
||||||
|
peak = cumulative_profit
|
||||||
|
drawdown = peak - cumulative_profit
|
||||||
|
if drawdown > max_drawdown:
|
||||||
|
max_drawdown = drawdown
|
||||||
|
|
||||||
|
return max_drawdown
|
||||||
|
|
||||||
|
def _process_trades(
|
||||||
|
self,
|
||||||
|
trades: List[Dict],
|
||||||
|
timeframe_name: str,
|
||||||
|
stop_loss_pct: float
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""Process individual trades with metadata"""
|
||||||
|
processed_trades = []
|
||||||
|
|
||||||
|
for trade in trades:
|
||||||
|
# Validate all required trade fields
|
||||||
|
required_fields = ["entry_time", "exit_time", "entry", "exit", "profit_pct", "type", "fee_usd"]
|
||||||
|
missing_fields = [field for field in required_fields if field not in trade]
|
||||||
|
if missing_fields:
|
||||||
|
raise ValueError(f"Trade missing required fields: {missing_fields} in trade: {trade}")
|
||||||
|
|
||||||
|
processed_trade = {
|
||||||
|
"timeframe": timeframe_name,
|
||||||
|
"stop_loss_pct": stop_loss_pct,
|
||||||
|
"entry_time": trade["entry_time"],
|
||||||
|
"exit_time": trade["exit_time"],
|
||||||
|
"entry_price": trade["entry"],
|
||||||
|
"exit_price": trade["exit"],
|
||||||
|
"profit_pct": trade["profit_pct"],
|
||||||
|
"type": trade["type"],
|
||||||
|
"fee_usd": trade["fee_usd"],
|
||||||
|
}
|
||||||
|
processed_trades.append(processed_trade)
|
||||||
|
|
||||||
|
return processed_trades
|
||||||
|
|
||||||
|
def _debug_output(self, results: Dict[str, Any]) -> None:
|
||||||
|
"""Output debug information for backtest results"""
|
||||||
|
if 'trades' not in results:
|
||||||
|
raise ValueError("Backtest results missing 'trades' field for debug output")
|
||||||
|
trades = results['trades']
|
||||||
|
|
||||||
|
# Print stop loss trades
|
||||||
|
stop_loss_trades = []
|
||||||
|
for t in trades:
|
||||||
|
if 'type' not in t:
|
||||||
|
raise ValueError(f"Trade missing 'type' field: {t}")
|
||||||
|
if t['type'] == 'STOP':
|
||||||
|
stop_loss_trades.append(t)
|
||||||
|
|
||||||
|
if stop_loss_trades:
|
||||||
|
print("Stop Loss Trades:")
|
||||||
|
for trade in stop_loss_trades:
|
||||||
|
print(trade)
|
||||||
|
|
||||||
|
# Print large loss trades
|
||||||
|
large_loss_trades = []
|
||||||
|
for t in trades:
|
||||||
|
if 'profit_pct' not in t:
|
||||||
|
raise ValueError(f"Trade missing 'profit_pct' field: {t}")
|
||||||
|
if t['profit_pct'] < -0.09:
|
||||||
|
large_loss_trades.append(t)
|
||||||
|
|
||||||
|
if large_loss_trades:
|
||||||
|
print("Large Loss Trades:")
|
||||||
|
for trade in large_loss_trades:
|
||||||
|
print("Large loss trade:", trade)
|
||||||
|
|
||||||
|
def aggregate_results(self, all_results: List[Dict]) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Aggregate results per stop_loss_pct and timeframe
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_results: List of result dictionaries from all timeframes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of aggregated summary rows
|
||||||
|
"""
|
||||||
|
grouped = defaultdict(list)
|
||||||
|
for row in all_results:
|
||||||
|
key = (row['timeframe'], row['stop_loss_pct'])
|
||||||
|
grouped[key].append(row)
|
||||||
|
|
||||||
|
summary_rows = []
|
||||||
|
for (timeframe, stop_loss_pct), rows in grouped.items():
|
||||||
|
summary = self._aggregate_group(rows, timeframe, stop_loss_pct)
|
||||||
|
summary_rows.append(summary)
|
||||||
|
|
||||||
|
return summary_rows
|
||||||
|
|
||||||
|
def _aggregate_group(self, rows: List[Dict], timeframe: str, stop_loss_pct: float) -> Dict:
|
||||||
|
"""Aggregate a group of rows with the same timeframe and stop loss"""
|
||||||
|
if not rows:
|
||||||
|
raise ValueError(f"No rows to aggregate for {timeframe} with {stop_loss_pct} stop loss")
|
||||||
|
|
||||||
|
# Validate all rows have required fields
|
||||||
|
required_fields = ['n_trades', 'n_stop_loss', 'win_rate', 'max_drawdown', 'avg_trade', 'profit_ratio', 'final_usd', 'total_fees_usd', 'initial_usd']
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
missing_fields = [field for field in required_fields if field not in row]
|
||||||
|
if missing_fields:
|
||||||
|
raise ValueError(f"Row {i} missing required fields: {missing_fields}")
|
||||||
|
|
||||||
|
total_trades = sum(r['n_trades'] for r in rows)
|
||||||
|
total_stop_loss = sum(r['n_stop_loss'] for r in rows)
|
||||||
|
|
||||||
|
# Calculate averages (no defaults, expect all values to be present)
|
||||||
|
avg_win_rate = np.mean([r['win_rate'] for r in rows])
|
||||||
|
avg_max_drawdown = np.mean([r['max_drawdown'] for r in rows])
|
||||||
|
avg_avg_trade = np.mean([r['avg_trade'] for r in rows])
|
||||||
|
|
||||||
|
# Handle infinite profit ratios properly
|
||||||
|
finite_profit_ratios = [r['profit_ratio'] for r in rows if not np.isinf(r['profit_ratio'])]
|
||||||
|
avg_profit_ratio = np.mean(finite_profit_ratios) if finite_profit_ratios else 0
|
||||||
|
|
||||||
|
# Calculate final USD and fees (no defaults)
|
||||||
|
final_usd = np.mean([r['final_usd'] for r in rows])
|
||||||
|
total_fees_usd = np.mean([r['total_fees_usd'] for r in rows])
|
||||||
|
initial_usd = rows[0]['initial_usd']
|
||||||
|
|
||||||
|
return {
|
||||||
|
"timeframe": timeframe,
|
||||||
|
"stop_loss_pct": stop_loss_pct,
|
||||||
|
"n_trades": total_trades,
|
||||||
|
"n_stop_loss": total_stop_loss,
|
||||||
|
"win_rate": avg_win_rate,
|
||||||
|
"max_drawdown": avg_max_drawdown,
|
||||||
|
"avg_trade": avg_avg_trade,
|
||||||
|
"profit_ratio": avg_profit_ratio,
|
||||||
|
"initial_usd": initial_usd,
|
||||||
|
"final_usd": final_usd,
|
||||||
|
"total_fees_usd": total_fees_usd,
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_trade_file(self, trades: List[Dict], timeframe: str, stop_loss_pct: float) -> None:
|
||||||
|
"""
|
||||||
|
Save individual trade file with summary header
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trades: List of trades for this combination
|
||||||
|
timeframe: Timeframe name
|
||||||
|
stop_loss_pct: Stop loss percentage
|
||||||
|
"""
|
||||||
|
if not trades:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate filename
|
||||||
|
sl_percent = int(round(stop_loss_pct * 100))
|
||||||
|
trades_filename = os.path.join(self.storage.results_dir, f"trades_{timeframe}_ST{sl_percent}pct.csv")
|
||||||
|
|
||||||
|
# Prepare summary from first trade
|
||||||
|
sample_trade = trades[0]
|
||||||
|
summary_fields = ["timeframe", "stop_loss_pct", "n_trades", "win_rate"]
|
||||||
|
summary_values = [timeframe, stop_loss_pct, len(trades), "calculated_elsewhere"]
|
||||||
|
|
||||||
|
# Write file with header and trades
|
||||||
|
trades_fieldnames = ["entry_time", "exit_time", "entry_price", "exit_price", "profit_pct", "type", "fee_usd"]
|
||||||
|
|
||||||
|
with open(trades_filename, "w", newline="") as f:
|
||||||
|
# Write summary header
|
||||||
|
f.write("\t".join(summary_fields) + "\n")
|
||||||
|
f.write("\t".join(str(v) for v in summary_values) + "\n")
|
||||||
|
|
||||||
|
# Write trades
|
||||||
|
writer = csv.DictWriter(f, fieldnames=trades_fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
for trade in trades:
|
||||||
|
# Validate all required fields are present
|
||||||
|
missing_fields = [k for k in trades_fieldnames if k not in trade]
|
||||||
|
if missing_fields:
|
||||||
|
raise ValueError(f"Trade missing required fields for CSV: {missing_fields} in trade: {trade}")
|
||||||
|
writer.writerow({k: trade[k] for k in trades_fieldnames})
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Trades saved to {trades_filename}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to save trades file for {timeframe}_ST{int(round(stop_loss_pct * 100))}pct: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
def save_backtest_results(
|
||||||
|
self,
|
||||||
|
results: List[Dict],
|
||||||
|
metadata_lines: List[str],
|
||||||
|
timestamp: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Save aggregated backtest results to CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of aggregated result dictionaries
|
||||||
|
metadata_lines: List of metadata strings
|
||||||
|
timestamp: Timestamp for filename
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to saved file
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
filename = f"{timestamp}_backtest.csv"
|
||||||
|
fieldnames = [
|
||||||
|
"timeframe", "stop_loss_pct", "n_trades", "n_stop_loss", "win_rate",
|
||||||
|
"max_drawdown", "avg_trade", "profit_ratio", "final_usd", "total_fees_usd"
|
||||||
|
]
|
||||||
|
|
||||||
|
filepath = self.storage.write_backtest_results(filename, fieldnames, results, metadata_lines)
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Backtest results saved to {filepath}")
|
||||||
|
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to save backtest results: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
def get_price_info(self, data_df: pd.DataFrame, date: str) -> Tuple[Optional[str], Optional[float]]:
|
||||||
|
"""
|
||||||
|
Get nearest price information for a given date
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_df: DataFrame with price data
|
||||||
|
date: Target date string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (nearest_time, price) or (None, None) if no data
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if len(data_df) == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
target_ts = pd.to_datetime(date)
|
||||||
|
nearest_idx = data_df.index.get_indexer([target_ts], method='nearest')[0]
|
||||||
|
nearest_time = data_df.index[nearest_idx]
|
||||||
|
price = data_df.iloc[nearest_idx]['close']
|
||||||
|
|
||||||
|
return str(nearest_time), float(price)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.logging:
|
||||||
|
self.logging.warning(f"Could not get price info for {date}: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def save_all_trade_files(self, all_trades: List[Dict]) -> None:
|
||||||
|
"""
|
||||||
|
Save all trade files in batch after parallel execution completes
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_trades: List of all trades from all tasks
|
||||||
|
"""
|
||||||
|
if not all_trades:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Group trades by timeframe and stop loss
|
||||||
|
trade_groups = {}
|
||||||
|
for trade in all_trades:
|
||||||
|
timeframe = trade.get('timeframe')
|
||||||
|
stop_loss_pct = trade.get('stop_loss_pct')
|
||||||
|
if timeframe and stop_loss_pct is not None:
|
||||||
|
key = (timeframe, stop_loss_pct)
|
||||||
|
if key not in trade_groups:
|
||||||
|
trade_groups[key] = []
|
||||||
|
trade_groups[key].append(trade)
|
||||||
|
|
||||||
|
# Save each group
|
||||||
|
for (timeframe, stop_loss_pct), trades in trade_groups.items():
|
||||||
|
self.save_trade_file(trades, timeframe, stop_loss_pct)
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self.logging.info(f"Saved {len(trade_groups)} trade files in batch")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to save trade files in batch: {e}"
|
||||||
|
if self.logging:
|
||||||
|
self.logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
161
test_bbrsi.py
Normal file
161
test_bbrsi.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
import logging
|
||||||
|
import seaborn as sns
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from cycles.utils.storage import Storage
|
||||||
|
from cycles.Analysis.strategies import Strategy
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler("backtest.log"),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"start_date": "2025-03-01",
|
||||||
|
"stop_date": datetime.datetime.today().strftime('%Y-%m-%d'),
|
||||||
|
"data_file": "btcusd_1-min_data.csv"
|
||||||
|
}
|
||||||
|
|
||||||
|
config_strategy = {
|
||||||
|
"bb_width": 0.05,
|
||||||
|
"bb_period": 20,
|
||||||
|
"rsi_period": 14,
|
||||||
|
"trending": {
|
||||||
|
"rsi_threshold": [30, 70],
|
||||||
|
"bb_std_dev_multiplier": 2.5,
|
||||||
|
},
|
||||||
|
"sideways": {
|
||||||
|
"rsi_threshold": [40, 60],
|
||||||
|
"bb_std_dev_multiplier": 1.8,
|
||||||
|
},
|
||||||
|
"strategy_name": "MarketRegimeStrategy", # CryptoTradingStrategy
|
||||||
|
"SqueezeStrategy": True
|
||||||
|
}
|
||||||
|
|
||||||
|
IS_DAY = False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
storage = Storage(logging=logging)
|
||||||
|
data = storage.load_data(config["data_file"], config["start_date"], config["stop_date"])
|
||||||
|
|
||||||
|
# Run strategy
|
||||||
|
strategy = Strategy(config=config_strategy, logging=logging)
|
||||||
|
processed_data = strategy.run(data.copy(), config_strategy["strategy_name"])
|
||||||
|
|
||||||
|
# Get buy and sell signals
|
||||||
|
buy_condition = processed_data.get('BuySignal', pd.Series(False, index=processed_data.index)).astype(bool)
|
||||||
|
sell_condition = processed_data.get('SellSignal', pd.Series(False, index=processed_data.index)).astype(bool)
|
||||||
|
|
||||||
|
buy_signals = processed_data[buy_condition]
|
||||||
|
sell_signals = processed_data[sell_condition]
|
||||||
|
|
||||||
|
# Plot the data with seaborn library
|
||||||
|
if processed_data is not None and not processed_data.empty:
|
||||||
|
# Create a figure with two subplots, sharing the x-axis
|
||||||
|
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(16, 8), sharex=True)
|
||||||
|
|
||||||
|
strategy_name = config_strategy["strategy_name"]
|
||||||
|
|
||||||
|
# Plot 1: Close Price and Strategy-Specific Bands/Levels
|
||||||
|
sns.lineplot(x=processed_data.index, y='close', data=processed_data, label='Close Price', ax=ax1)
|
||||||
|
|
||||||
|
# Use standardized column names for bands
|
||||||
|
if 'UpperBand' in processed_data.columns and 'LowerBand' in processed_data.columns:
|
||||||
|
# Instead of lines, shade the area between upper and lower bands
|
||||||
|
ax1.fill_between(processed_data.index,
|
||||||
|
processed_data['LowerBand'],
|
||||||
|
processed_data['UpperBand'],
|
||||||
|
alpha=0.1, color='blue', label='Bollinger Bands')
|
||||||
|
else:
|
||||||
|
logging.warning(f"{strategy_name}: UpperBand or LowerBand not found for plotting.")
|
||||||
|
|
||||||
|
# Add strategy-specific extra indicators if available
|
||||||
|
if strategy_name == "CryptoTradingStrategy":
|
||||||
|
if 'StopLoss' in processed_data.columns:
|
||||||
|
sns.lineplot(x=processed_data.index, y='StopLoss', data=processed_data, label='Stop Loss', ax=ax1, linestyle='--', color='orange')
|
||||||
|
if 'TakeProfit' in processed_data.columns:
|
||||||
|
sns.lineplot(x=processed_data.index, y='TakeProfit', data=processed_data, label='Take Profit', ax=ax1, linestyle='--', color='purple')
|
||||||
|
|
||||||
|
# Plot Buy/Sell signals on Price chart
|
||||||
|
if not buy_signals.empty:
|
||||||
|
ax1.scatter(buy_signals.index, buy_signals['close'], color='green', marker='o', s=20, label='Buy Signal', zorder=5)
|
||||||
|
if not sell_signals.empty:
|
||||||
|
ax1.scatter(sell_signals.index, sell_signals['close'], color='red', marker='o', s=20, label='Sell Signal', zorder=5)
|
||||||
|
ax1.set_title(f'Price and Signals ({strategy_name})')
|
||||||
|
ax1.set_ylabel('Price')
|
||||||
|
ax1.legend()
|
||||||
|
ax1.grid(True)
|
||||||
|
|
||||||
|
# Plot 2: RSI and Strategy-Specific Thresholds
|
||||||
|
if 'RSI' in processed_data.columns:
|
||||||
|
sns.lineplot(x=processed_data.index, y='RSI', data=processed_data, label=f'RSI (' + str(config_strategy.get("rsi_period", 14)) + ')', ax=ax2, color='purple')
|
||||||
|
if strategy_name == "MarketRegimeStrategy":
|
||||||
|
# Get threshold values
|
||||||
|
upper_threshold = config_strategy.get("trending", {}).get("rsi_threshold", [30,70])[1]
|
||||||
|
lower_threshold = config_strategy.get("trending", {}).get("rsi_threshold", [30,70])[0]
|
||||||
|
|
||||||
|
# Shade overbought area (upper)
|
||||||
|
ax2.fill_between(processed_data.index, upper_threshold, 100,
|
||||||
|
alpha=0.1, color='red', label=f'Overbought (>{upper_threshold})')
|
||||||
|
|
||||||
|
# Shade oversold area (lower)
|
||||||
|
ax2.fill_between(processed_data.index, 0, lower_threshold,
|
||||||
|
alpha=0.1, color='green', label=f'Oversold (<{lower_threshold})')
|
||||||
|
|
||||||
|
elif strategy_name == "CryptoTradingStrategy":
|
||||||
|
# Shade overbought area (upper)
|
||||||
|
ax2.fill_between(processed_data.index, 65, 100,
|
||||||
|
alpha=0.1, color='red', label='Overbought (>65)')
|
||||||
|
|
||||||
|
# Shade oversold area (lower)
|
||||||
|
ax2.fill_between(processed_data.index, 0, 35,
|
||||||
|
alpha=0.1, color='green', label='Oversold (<35)')
|
||||||
|
|
||||||
|
# Plot Buy/Sell signals on RSI chart
|
||||||
|
if not buy_signals.empty and 'RSI' in buy_signals.columns:
|
||||||
|
ax2.scatter(buy_signals.index, buy_signals['RSI'], color='green', marker='o', s=20, label='Buy Signal (RSI)', zorder=5)
|
||||||
|
if not sell_signals.empty and 'RSI' in sell_signals.columns:
|
||||||
|
ax2.scatter(sell_signals.index, sell_signals['RSI'], color='red', marker='o', s=20, label='Sell Signal (RSI)', zorder=5)
|
||||||
|
ax2.set_title('Relative Strength Index (RSI) with Signals')
|
||||||
|
ax2.set_ylabel('RSI Value')
|
||||||
|
ax2.set_ylim(0, 100)
|
||||||
|
ax2.legend()
|
||||||
|
ax2.grid(True)
|
||||||
|
else:
|
||||||
|
logging.info("RSI data not available for plotting.")
|
||||||
|
|
||||||
|
# Plot 3: Strategy-Specific Indicators
|
||||||
|
ax3.clear() # Clear previous plot content if any
|
||||||
|
if 'BBWidth' in processed_data.columns:
|
||||||
|
sns.lineplot(x=processed_data.index, y='BBWidth', data=processed_data, label='BB Width', ax=ax3)
|
||||||
|
|
||||||
|
if strategy_name == "MarketRegimeStrategy":
|
||||||
|
if 'MarketRegime' in processed_data.columns:
|
||||||
|
sns.lineplot(x=processed_data.index, y='MarketRegime', data=processed_data, label='Market Regime (Sideways: 1, Trending: 0)', ax=ax3)
|
||||||
|
ax3.set_title('Bollinger Bands Width & Market Regime')
|
||||||
|
ax3.set_ylabel('Value')
|
||||||
|
elif strategy_name == "CryptoTradingStrategy":
|
||||||
|
if 'VolumeMA' in processed_data.columns:
|
||||||
|
sns.lineplot(x=processed_data.index, y='VolumeMA', data=processed_data, label='Volume MA', ax=ax3)
|
||||||
|
if 'volume' in processed_data.columns:
|
||||||
|
sns.lineplot(x=processed_data.index, y='volume', data=processed_data, label='Volume', ax=ax3, alpha=0.5)
|
||||||
|
ax3.set_title('Volume Analysis')
|
||||||
|
ax3.set_ylabel('Volume')
|
||||||
|
|
||||||
|
ax3.legend()
|
||||||
|
ax3.grid(True)
|
||||||
|
|
||||||
|
plt.xlabel('Date')
|
||||||
|
fig.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
else:
|
||||||
|
logging.info("No data to plot.")
|
||||||
|
|
||||||
229
trader/cryptocom_trader.py
Normal file
229
trader/cryptocom_trader.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import hmac
|
||||||
|
import hashlib
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import pandas as pd
|
||||||
|
import threading
|
||||||
|
from websocket import create_connection, WebSocketTimeoutException
|
||||||
|
|
||||||
|
class CryptoComTrader:
|
||||||
|
ENV_URLS = {
|
||||||
|
"production": {
|
||||||
|
"WS_URL": "wss://deriv-stream.crypto.com/v1/market",
|
||||||
|
"WS_PRIVATE_URL": "wss://deriv-stream.crypto.com/v1/user"
|
||||||
|
},
|
||||||
|
"uat": {
|
||||||
|
"WS_URL": "wss://uat-deriv-stream.3ona.co/v1/market",
|
||||||
|
"WS_PRIVATE_URL": "wss://uat-deriv-stream.3ona.co/v1/user"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.env = os.getenv("CRYPTOCOM_ENV", "UAT").lower()
|
||||||
|
urls = self.ENV_URLS.get(self.env, self.ENV_URLS["production"])
|
||||||
|
self.WS_URL = urls["WS_URL"]
|
||||||
|
self.WS_PRIVATE_URL = urls["WS_PRIVATE_URL"]
|
||||||
|
self.api_key = os.getenv("CRYPTOCOM_API_KEY")
|
||||||
|
self.api_secret = os.getenv("CRYPTOCOM_API_SECRET")
|
||||||
|
self.ws = None
|
||||||
|
self.ws_private = None
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._private_lock = threading.Lock()
|
||||||
|
self._connect_ws()
|
||||||
|
|
||||||
|
def _connect_ws(self):
|
||||||
|
if self.ws is None:
|
||||||
|
self.ws = create_connection(self.WS_URL, timeout=10)
|
||||||
|
if self.api_key and self.api_secret and self.ws_private is None:
|
||||||
|
self.ws_private = create_connection(self.WS_PRIVATE_URL, timeout=10)
|
||||||
|
|
||||||
|
def _send_ws(self, payload, private=False):
|
||||||
|
ws = self.ws_private if private else self.ws
|
||||||
|
lock = self._private_lock if private else self._lock
|
||||||
|
with lock:
|
||||||
|
ws.send(json.dumps(payload))
|
||||||
|
try:
|
||||||
|
resp = ws.recv()
|
||||||
|
return json.loads(resp)
|
||||||
|
except WebSocketTimeoutException:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _sign(self, params):
|
||||||
|
t = str(int(time.time() * 1000))
|
||||||
|
params['id'] = t
|
||||||
|
params['nonce'] = t
|
||||||
|
params['api_key'] = self.api_key
|
||||||
|
param_str = json.dumps(params, separators=(',', ':'), sort_keys=True)
|
||||||
|
sig = hmac.new(
|
||||||
|
bytes(self.api_secret, 'utf-8'),
|
||||||
|
msg=bytes(param_str, 'utf-8'),
|
||||||
|
digestmod=hashlib.sha256
|
||||||
|
).hexdigest()
|
||||||
|
params['sig'] = sig
|
||||||
|
return params
|
||||||
|
|
||||||
|
def get_price(self):
|
||||||
|
"""
|
||||||
|
Get the latest ask price for BTC_USDC using WebSocket ticker subscription (one-shot).
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"id": int(time.time() * 1000),
|
||||||
|
"method": "subscribe",
|
||||||
|
"params": {"channels": ["ticker.BTC_USDC"]}
|
||||||
|
}
|
||||||
|
resp = self._send_ws(payload)
|
||||||
|
# Wait for ticker update
|
||||||
|
while True:
|
||||||
|
data = self.ws.recv()
|
||||||
|
msg = json.loads(data)
|
||||||
|
if msg.get("method") == "ticker.update":
|
||||||
|
# 'a' is ask price
|
||||||
|
return msg["params"]["data"][0].get("a")
|
||||||
|
|
||||||
|
def get_order_book(self, depth=10):
|
||||||
|
"""
|
||||||
|
Fetch the order book for BTC_USDC with the specified depth using WebSocket (one-shot).
|
||||||
|
Returns a dict with 'bids' and 'asks'.
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"id": int(time.time() * 1000),
|
||||||
|
"method": "subscribe",
|
||||||
|
"params": {"channels": [f"book.BTC_USDC.{depth}"]}
|
||||||
|
}
|
||||||
|
resp = self._send_ws(payload)
|
||||||
|
# Wait for book update
|
||||||
|
while True:
|
||||||
|
data = self.ws.recv()
|
||||||
|
msg = json.loads(data)
|
||||||
|
if msg.get("method") == "book.update":
|
||||||
|
book = msg["params"]["data"][0]
|
||||||
|
return {
|
||||||
|
"bids": book.get("bids", []),
|
||||||
|
"asks": book.get("asks", [])
|
||||||
|
}
|
||||||
|
|
||||||
|
def _authenticate(self):
|
||||||
|
"""
|
||||||
|
Authenticate the private WebSocket connection. Only needs to be done once per session.
|
||||||
|
"""
|
||||||
|
if not self.api_key or not self.api_secret:
|
||||||
|
raise ValueError("API key and secret must be set in environment variables.")
|
||||||
|
payload = {
|
||||||
|
"id": int(time.time() * 1000),
|
||||||
|
"method": "public/auth",
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"nonce": int(time.time() * 1000),
|
||||||
|
}
|
||||||
|
# For auth, sig is HMAC_SHA256(method + id + api_key + nonce)
|
||||||
|
sig_payload = (
|
||||||
|
payload["method"] + str(payload["id"]) + self.api_key + str(payload["nonce"])
|
||||||
|
)
|
||||||
|
payload["sig"] = hmac.new(
|
||||||
|
bytes(self.api_secret, "utf-8"),
|
||||||
|
msg=bytes(sig_payload, "utf-8"),
|
||||||
|
digestmod=hashlib.sha256,
|
||||||
|
).hexdigest()
|
||||||
|
resp = self._send_ws(payload, private=True)
|
||||||
|
if not resp or resp.get("code") != 0:
|
||||||
|
raise Exception(f"WebSocket authentication failed: {resp}")
|
||||||
|
|
||||||
|
def _ensure_private_auth(self):
|
||||||
|
if self.ws_private is None:
|
||||||
|
self._connect_ws()
|
||||||
|
time.sleep(1) # recommended by docs
|
||||||
|
self._authenticate()
|
||||||
|
|
||||||
|
def get_balance(self, currency="USDC"):
|
||||||
|
"""
|
||||||
|
Fetch user balance using WebSocket private API.
|
||||||
|
"""
|
||||||
|
self._ensure_private_auth()
|
||||||
|
payload = {
|
||||||
|
"id": int(time.time() * 1000),
|
||||||
|
"method": "private/user-balance",
|
||||||
|
"params": {},
|
||||||
|
"nonce": int(time.time() * 1000),
|
||||||
|
}
|
||||||
|
resp = self._send_ws(payload, private=True)
|
||||||
|
if resp and resp.get("code") == 0:
|
||||||
|
balances = resp.get("result", {}).get("data", [])
|
||||||
|
if currency:
|
||||||
|
return [b for b in balances if b.get("instrument_name") == currency]
|
||||||
|
return balances
|
||||||
|
return []
|
||||||
|
|
||||||
|
def place_order(self, side, amount):
|
||||||
|
"""
|
||||||
|
Place a market order using WebSocket private API.
|
||||||
|
side: 'BUY' or 'SELL', amount: in BTC
|
||||||
|
"""
|
||||||
|
self._ensure_private_auth()
|
||||||
|
params = {
|
||||||
|
"instrument_name": "BTC_USDC",
|
||||||
|
"side": side,
|
||||||
|
"type": "MARKET",
|
||||||
|
"quantity": str(amount),
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"id": int(time.time() * 1000),
|
||||||
|
"method": "private/create-order",
|
||||||
|
"params": params,
|
||||||
|
"nonce": int(time.time() * 1000),
|
||||||
|
}
|
||||||
|
resp = self._send_ws(payload, private=True)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def buy_btc(self, amount):
|
||||||
|
return self.place_order("BUY", amount)
|
||||||
|
|
||||||
|
def sell_btc(self, amount):
|
||||||
|
return self.place_order("SELL", amount)
|
||||||
|
|
||||||
|
def get_candlesticks(self, timeframe='1m', count=100):
|
||||||
|
"""
|
||||||
|
Fetch candlestick (OHLCV) data for BTC_USDC using WebSocket.
|
||||||
|
Args:
|
||||||
|
timeframe (str): Timeframe for each candle (e.g., '1m', '5m', '15m', '1h', '4h', '1d').
|
||||||
|
count (int): Number of candles to fetch (max 1000 per API docs).
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: DataFrame with columns ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"id": int(time.time() * 1000),
|
||||||
|
"method": "public/get-candlestick",
|
||||||
|
"params": {
|
||||||
|
"instrument_name": "BTC_USDC",
|
||||||
|
"timeframe": timeframe,
|
||||||
|
"count": count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp = self._send_ws(payload)
|
||||||
|
candles = resp.get("result", {}).get("data", []) if resp else []
|
||||||
|
if not candles:
|
||||||
|
return pd.DataFrame(columns=["timestamp", "open", "high", "low", "close", "volume"])
|
||||||
|
df = pd.DataFrame(candles)
|
||||||
|
df['timestamp'] = pd.to_datetime(df['t'], unit='ms')
|
||||||
|
df = df.rename(columns={
|
||||||
|
'o': 'open',
|
||||||
|
'h': 'high',
|
||||||
|
'l': 'low',
|
||||||
|
'c': 'close',
|
||||||
|
'v': 'volume'
|
||||||
|
})
|
||||||
|
return df[['timestamp', 'open', 'high', 'low', 'close', 'volume']].sort_values('timestamp')
|
||||||
|
|
||||||
|
def get_instruments(self):
|
||||||
|
"""
|
||||||
|
Fetch the list of available trading instruments from Crypto.com using WebSocket.
|
||||||
|
Returns:
|
||||||
|
list: List of instrument dicts.
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"id": int(time.time() * 1000),
|
||||||
|
"method": "public/get-instruments",
|
||||||
|
"params": {}
|
||||||
|
}
|
||||||
|
resp = self._send_ws(payload)
|
||||||
|
return resp.get("result", {}).get("data", []) if resp else []
|
||||||
84
trader/main.py
Normal file
84
trader/main.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import time
|
||||||
|
import plotly.graph_objs as go
|
||||||
|
import plotly.io as pio
|
||||||
|
from cryptocom_trader import CryptoComTrader
|
||||||
|
|
||||||
|
|
||||||
|
def plot_candlesticks(df):
|
||||||
|
if df.empty:
|
||||||
|
print("No data to plot.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert columns to float
|
||||||
|
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||||||
|
df[col] = df[col].astype(float)
|
||||||
|
|
||||||
|
# Plotly expects datetime for x-axis
|
||||||
|
fig = go.Figure(data=[go.Candlestick(
|
||||||
|
x=df['timestamp'],
|
||||||
|
open=df['open'],
|
||||||
|
high=df['high'],
|
||||||
|
low=df['low'],
|
||||||
|
close=df['close'],
|
||||||
|
increasing_line_color='#089981',
|
||||||
|
decreasing_line_color='#F23645'
|
||||||
|
)])
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title='BTC/USDC Realtime Candlestick (1m)',
|
||||||
|
yaxis_title='Price (USDC)',
|
||||||
|
xaxis_title='Time',
|
||||||
|
xaxis_rangeslider_visible=False,
|
||||||
|
template='plotly_dark'
|
||||||
|
)
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
trader = CryptoComTrader()
|
||||||
|
pio.renderers.default = "browser" # Open in browser
|
||||||
|
|
||||||
|
# Fetch and print BTC/USDC-related instruments
|
||||||
|
instruments = trader.get_instruments()
|
||||||
|
btc_usdc_instruments = [
|
||||||
|
inst for inst in instruments
|
||||||
|
if (
|
||||||
|
('BTC' in inst.get('base_ccy', '') or 'BTC' in inst.get('base_currency', '')) and
|
||||||
|
('USDC' in inst.get('quote_ccy', '') or 'USDC' in inst.get('quote_currency', ''))
|
||||||
|
)
|
||||||
|
]
|
||||||
|
print("BTC/USDC-related instruments:")
|
||||||
|
for inst in btc_usdc_instruments:
|
||||||
|
print(inst)
|
||||||
|
|
||||||
|
# Optionally, show balance (private API)
|
||||||
|
try:
|
||||||
|
balance = trader.get_balance("USDC")
|
||||||
|
print("USDC Balance:", balance)
|
||||||
|
except Exception as e:
|
||||||
|
print("[WARN] Could not fetch balance (private API):", e)
|
||||||
|
|
||||||
|
all_instruments = trader.get_instruments()
|
||||||
|
for inst in all_instruments:
|
||||||
|
print(inst)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
df = trader.get_candlesticks(timeframe='1m', count=60)
|
||||||
|
# fig = plot_candlesticks(df)
|
||||||
|
# if fig:
|
||||||
|
# fig.show()
|
||||||
|
if not df.empty:
|
||||||
|
print(df[['high', 'low', 'open', 'close', 'volume']])
|
||||||
|
else:
|
||||||
|
print("No data to print.")
|
||||||
|
time.sleep(10)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print('Exiting...')
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Error: {e}')
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -1,814 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import logging
|
|
||||||
from scipy.signal import find_peaks
|
|
||||||
from matplotlib.patches import Rectangle
|
|
||||||
from scipy import stats
|
|
||||||
import concurrent.futures
|
|
||||||
from functools import partial
|
|
||||||
from functools import lru_cache
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
# Color configuration
|
|
||||||
# Plot colors
|
|
||||||
DARK_BG_COLOR = '#181C27'
|
|
||||||
LEGEND_BG_COLOR = '#333333'
|
|
||||||
TITLE_COLOR = 'white'
|
|
||||||
AXIS_LABEL_COLOR = 'white'
|
|
||||||
|
|
||||||
# Candlestick colors
|
|
||||||
CANDLE_UP_COLOR = '#089981' # Green
|
|
||||||
CANDLE_DOWN_COLOR = '#F23645' # Red
|
|
||||||
|
|
||||||
# Marker colors
|
|
||||||
MIN_COLOR = 'red'
|
|
||||||
MAX_COLOR = 'green'
|
|
||||||
|
|
||||||
# Line style colors
|
|
||||||
MIN_LINE_STYLE = 'g--' # Green dashed
|
|
||||||
MAX_LINE_STYLE = 'r--' # Red dashed
|
|
||||||
SMA7_LINE_STYLE = 'y-' # Yellow solid
|
|
||||||
SMA15_LINE_STYLE = 'm-' # Magenta solid
|
|
||||||
|
|
||||||
# SuperTrend colors
|
|
||||||
ST_COLOR_UP = 'g-'
|
|
||||||
ST_COLOR_DOWN = 'r-'
|
|
||||||
|
|
||||||
# Cache the calculation results by function parameters
|
|
||||||
@lru_cache(maxsize=32)
|
|
||||||
def cached_supertrend_calculation(period, multiplier, data_tuple):
|
|
||||||
# Convert tuple back to numpy arrays
|
|
||||||
high = np.array(data_tuple[0])
|
|
||||||
low = np.array(data_tuple[1])
|
|
||||||
close = np.array(data_tuple[2])
|
|
||||||
|
|
||||||
# Calculate TR and ATR using vectorized operations
|
|
||||||
tr = np.zeros_like(close)
|
|
||||||
tr[0] = high[0] - low[0]
|
|
||||||
hc_range = np.abs(high[1:] - close[:-1])
|
|
||||||
lc_range = np.abs(low[1:] - close[:-1])
|
|
||||||
hl_range = high[1:] - low[1:]
|
|
||||||
tr[1:] = np.maximum.reduce([hl_range, hc_range, lc_range])
|
|
||||||
|
|
||||||
# Use numpy's exponential moving average
|
|
||||||
atr = np.zeros_like(tr)
|
|
||||||
atr[0] = tr[0]
|
|
||||||
multiplier_ema = 2.0 / (period + 1)
|
|
||||||
for i in range(1, len(tr)):
|
|
||||||
atr[i] = (tr[i] * multiplier_ema) + (atr[i-1] * (1 - multiplier_ema))
|
|
||||||
|
|
||||||
# Calculate bands
|
|
||||||
upper_band = np.zeros_like(close)
|
|
||||||
lower_band = np.zeros_like(close)
|
|
||||||
for i in range(len(close)):
|
|
||||||
hl_avg = (high[i] + low[i]) / 2
|
|
||||||
upper_band[i] = hl_avg + (multiplier * atr[i])
|
|
||||||
lower_band[i] = hl_avg - (multiplier * atr[i])
|
|
||||||
|
|
||||||
final_upper = np.zeros_like(close)
|
|
||||||
final_lower = np.zeros_like(close)
|
|
||||||
supertrend = np.zeros_like(close)
|
|
||||||
trend = np.zeros_like(close)
|
|
||||||
final_upper[0] = upper_band[0]
|
|
||||||
final_lower[0] = lower_band[0]
|
|
||||||
if close[0] <= upper_band[0]:
|
|
||||||
supertrend[0] = upper_band[0]
|
|
||||||
trend[0] = -1
|
|
||||||
else:
|
|
||||||
supertrend[0] = lower_band[0]
|
|
||||||
trend[0] = 1
|
|
||||||
for i in range(1, len(close)):
|
|
||||||
if (upper_band[i] < final_upper[i-1]) or (close[i-1] > final_upper[i-1]):
|
|
||||||
final_upper[i] = upper_band[i]
|
|
||||||
else:
|
|
||||||
final_upper[i] = final_upper[i-1]
|
|
||||||
if (lower_band[i] > final_lower[i-1]) or (close[i-1] < final_lower[i-1]):
|
|
||||||
final_lower[i] = lower_band[i]
|
|
||||||
else:
|
|
||||||
final_lower[i] = final_lower[i-1]
|
|
||||||
if supertrend[i-1] == final_upper[i-1] and close[i] <= final_upper[i]:
|
|
||||||
supertrend[i] = final_upper[i]
|
|
||||||
trend[i] = -1
|
|
||||||
elif supertrend[i-1] == final_upper[i-1] and close[i] > final_upper[i]:
|
|
||||||
supertrend[i] = final_lower[i]
|
|
||||||
trend[i] = 1
|
|
||||||
elif supertrend[i-1] == final_lower[i-1] and close[i] >= final_lower[i]:
|
|
||||||
supertrend[i] = final_lower[i]
|
|
||||||
trend[i] = 1
|
|
||||||
elif supertrend[i-1] == final_lower[i-1] and close[i] < final_lower[i]:
|
|
||||||
supertrend[i] = final_upper[i]
|
|
||||||
trend[i] = -1
|
|
||||||
return {
|
|
||||||
'supertrend': supertrend,
|
|
||||||
'trend': trend,
|
|
||||||
'upper_band': final_upper,
|
|
||||||
'lower_band': final_lower
|
|
||||||
}
|
|
||||||
|
|
||||||
def calculate_supertrend_external(data, period, multiplier):
|
|
||||||
# Convert DataFrame columns to hashable tuples
|
|
||||||
high_tuple = tuple(data['high'])
|
|
||||||
low_tuple = tuple(data['low'])
|
|
||||||
close_tuple = tuple(data['close'])
|
|
||||||
|
|
||||||
# Call the cached function
|
|
||||||
return cached_supertrend_calculation(period, multiplier, (high_tuple, low_tuple, close_tuple))
|
|
||||||
|
|
||||||
class TrendDetectorSimple:
|
|
||||||
def __init__(self, data, verbose=False, display=False):
|
|
||||||
"""
|
|
||||||
Initialize the TrendDetectorSimple class.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- data: pandas DataFrame containing price data
|
|
||||||
- verbose: boolean, whether to display detailed logging information
|
|
||||||
- display: boolean, whether to enable display/plotting features
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.data = data
|
|
||||||
self.verbose = verbose
|
|
||||||
self.display = display
|
|
||||||
|
|
||||||
# Only define display-related variables if display is True
|
|
||||||
if self.display:
|
|
||||||
# Plot style configuration
|
|
||||||
self.plot_style = 'dark_background'
|
|
||||||
self.bg_color = DARK_BG_COLOR
|
|
||||||
self.plot_size = (12, 8)
|
|
||||||
|
|
||||||
# Candlestick configuration
|
|
||||||
self.candle_width = 0.6
|
|
||||||
self.candle_up_color = CANDLE_UP_COLOR
|
|
||||||
self.candle_down_color = CANDLE_DOWN_COLOR
|
|
||||||
self.candle_alpha = 0.8
|
|
||||||
self.wick_width = 1
|
|
||||||
|
|
||||||
# Marker configuration
|
|
||||||
self.min_marker = '^'
|
|
||||||
self.min_color = MIN_COLOR
|
|
||||||
self.min_size = 100
|
|
||||||
self.max_marker = 'v'
|
|
||||||
self.max_color = MAX_COLOR
|
|
||||||
self.max_size = 100
|
|
||||||
self.marker_zorder = 100
|
|
||||||
|
|
||||||
# Line configuration
|
|
||||||
self.line_width = 1
|
|
||||||
self.min_line_style = MIN_LINE_STYLE
|
|
||||||
self.max_line_style = MAX_LINE_STYLE
|
|
||||||
self.sma7_line_style = SMA7_LINE_STYLE
|
|
||||||
self.sma15_line_style = SMA15_LINE_STYLE
|
|
||||||
|
|
||||||
# Text configuration
|
|
||||||
self.title_size = 14
|
|
||||||
self.title_color = TITLE_COLOR
|
|
||||||
self.axis_label_size = 12
|
|
||||||
self.axis_label_color = AXIS_LABEL_COLOR
|
|
||||||
|
|
||||||
# Legend configuration
|
|
||||||
self.legend_loc = 'best'
|
|
||||||
self.legend_bg_color = LEGEND_BG_COLOR
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
|
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s')
|
|
||||||
self.logger = logging.getLogger('TrendDetectorSimple')
|
|
||||||
|
|
||||||
# Convert data to pandas DataFrame if it's not already
|
|
||||||
if not isinstance(self.data, pd.DataFrame):
|
|
||||||
if isinstance(self.data, list):
|
|
||||||
self.data = pd.DataFrame({'close': self.data})
|
|
||||||
else:
|
|
||||||
raise ValueError("Data must be a pandas DataFrame or a list")
|
|
||||||
|
|
||||||
def calculate_tr(self):
|
|
||||||
"""
|
|
||||||
Calculate True Range (TR) for the price data.
|
|
||||||
|
|
||||||
True Range is the greatest of:
|
|
||||||
1. Current high - current low
|
|
||||||
2. |Current high - previous close|
|
|
||||||
3. |Current low - previous close|
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- Numpy array of TR values
|
|
||||||
"""
|
|
||||||
df = self.data.copy()
|
|
||||||
high = df['high'].values
|
|
||||||
low = df['low'].values
|
|
||||||
close = df['close'].values
|
|
||||||
|
|
||||||
tr = np.zeros_like(close)
|
|
||||||
tr[0] = high[0] - low[0] # First TR is just the first day's range
|
|
||||||
|
|
||||||
for i in range(1, len(close)):
|
|
||||||
# Current high - current low
|
|
||||||
hl_range = high[i] - low[i]
|
|
||||||
# |Current high - previous close|
|
|
||||||
hc_range = abs(high[i] - close[i-1])
|
|
||||||
# |Current low - previous close|
|
|
||||||
lc_range = abs(low[i] - close[i-1])
|
|
||||||
|
|
||||||
# TR is the maximum of these three values
|
|
||||||
tr[i] = max(hl_range, hc_range, lc_range)
|
|
||||||
|
|
||||||
return tr
|
|
||||||
|
|
||||||
def calculate_atr(self, period=14):
|
|
||||||
"""
|
|
||||||
Calculate Average True Range (ATR) for the price data.
|
|
||||||
|
|
||||||
ATR is the exponential moving average of the True Range over a specified period.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- period: int, the period for the ATR calculation (default: 14)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- Numpy array of ATR values
|
|
||||||
"""
|
|
||||||
|
|
||||||
tr = self.calculate_tr()
|
|
||||||
atr = np.zeros_like(tr)
|
|
||||||
|
|
||||||
# First ATR value is just the first TR
|
|
||||||
atr[0] = tr[0]
|
|
||||||
|
|
||||||
# Calculate exponential moving average (EMA) of TR
|
|
||||||
multiplier = 2.0 / (period + 1)
|
|
||||||
|
|
||||||
for i in range(1, len(tr)):
|
|
||||||
atr[i] = (tr[i] * multiplier) + (atr[i-1] * (1 - multiplier))
|
|
||||||
|
|
||||||
return atr
|
|
||||||
|
|
||||||
def detect_trends(self):
|
|
||||||
"""
|
|
||||||
Detect trends by identifying local minima and maxima in the price data
|
|
||||||
using scipy.signal.find_peaks.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- prominence: float, required prominence of peaks (relative to the price range)
|
|
||||||
- width: int, required width of peaks in data points
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- DataFrame with columns for timestamps, prices, and trend indicators
|
|
||||||
- Dictionary containing analysis results including linear regression, SMAs, and SuperTrend indicators
|
|
||||||
"""
|
|
||||||
df = self.data
|
|
||||||
# close_prices = df['close'].values
|
|
||||||
|
|
||||||
# max_peaks, _ = find_peaks(close_prices)
|
|
||||||
# min_peaks, _ = find_peaks(-close_prices)
|
|
||||||
|
|
||||||
# df['is_min'] = False
|
|
||||||
# df['is_max'] = False
|
|
||||||
|
|
||||||
# for peak in max_peaks:
|
|
||||||
# df.at[peak, 'is_max'] = True
|
|
||||||
# for peak in min_peaks:
|
|
||||||
# df.at[peak, 'is_min'] = True
|
|
||||||
|
|
||||||
# result = df[['timestamp', 'close', 'is_min', 'is_max']].copy()
|
|
||||||
|
|
||||||
# Perform linear regression on min_peaks and max_peaks
|
|
||||||
# min_prices = df['close'].iloc[min_peaks].values
|
|
||||||
# max_prices = df['close'].iloc[max_peaks].values
|
|
||||||
|
|
||||||
# Linear regression for min peaks if we have at least 2 points
|
|
||||||
# min_slope, min_intercept, min_r_value, _, _ = stats.linregress(min_peaks, min_prices)
|
|
||||||
# Linear regression for max peaks if we have at least 2 points
|
|
||||||
# max_slope, max_intercept, max_r_value, _, _ = stats.linregress(max_peaks, max_prices)
|
|
||||||
|
|
||||||
# Calculate Simple Moving Averages (SMA) for 7 and 15 periods
|
|
||||||
# sma_7 = pd.Series(close_prices).rolling(window=7, min_periods=1).mean().values
|
|
||||||
# sma_15 = pd.Series(close_prices).rolling(window=15, min_periods=1).mean().values
|
|
||||||
|
|
||||||
analysis_results = {}
|
|
||||||
# analysis_results['linear_regression'] = {
|
|
||||||
# 'min': {
|
|
||||||
# 'slope': min_slope,
|
|
||||||
# 'intercept': min_intercept,
|
|
||||||
# 'r_squared': min_r_value ** 2
|
|
||||||
# },
|
|
||||||
# 'max': {
|
|
||||||
# 'slope': max_slope,
|
|
||||||
# 'intercept': max_intercept,
|
|
||||||
# 'r_squared': max_r_value ** 2
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# analysis_results['sma'] = {
|
|
||||||
# '7': sma_7,
|
|
||||||
# '15': sma_15
|
|
||||||
# }
|
|
||||||
|
|
||||||
# Calculate SuperTrend indicators
|
|
||||||
supertrend_results_list = self._calculate_supertrend_indicators()
|
|
||||||
analysis_results['supertrend'] = supertrend_results_list
|
|
||||||
|
|
||||||
return analysis_results
|
|
||||||
|
|
||||||
def _calculate_supertrend_indicators(self):
|
|
||||||
"""
|
|
||||||
Calculate SuperTrend indicators with different parameter sets in parallel.
|
|
||||||
Returns:
|
|
||||||
- list, the SuperTrend results
|
|
||||||
"""
|
|
||||||
supertrend_params = [
|
|
||||||
{"period": 12, "multiplier": 3.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN},
|
|
||||||
{"period": 10, "multiplier": 1.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN},
|
|
||||||
{"period": 11, "multiplier": 2.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN}
|
|
||||||
]
|
|
||||||
data = self.data.copy()
|
|
||||||
|
|
||||||
# For just 3 calculations, direct calculation might be faster than process pool
|
|
||||||
results = []
|
|
||||||
for p in supertrend_params:
|
|
||||||
result = calculate_supertrend_external(data, p["period"], p["multiplier"])
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
supertrend_results_list = []
|
|
||||||
for params, result in zip(supertrend_params, results):
|
|
||||||
supertrend_results_list.append({
|
|
||||||
"results": result,
|
|
||||||
"params": params
|
|
||||||
})
|
|
||||||
return supertrend_results_list
|
|
||||||
|
|
||||||
def plot_trends(self, trend_data, analysis_results, view="both"):
|
|
||||||
"""
|
|
||||||
Plot the price data with detected trends using a candlestick chart.
|
|
||||||
Also plots SuperTrend indicators with three different parameter sets.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- trend_data: DataFrame, the output from detect_trends()
|
|
||||||
- analysis_results: Dictionary containing analysis results from detect_trends()
|
|
||||||
- view: str, one of 'both', 'trend', 'supertrend'; determines which plot(s) to display
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- None (displays the plot)
|
|
||||||
"""
|
|
||||||
if not self.display:
|
|
||||||
return # Do nothing if display is False
|
|
||||||
|
|
||||||
plt.style.use(self.plot_style)
|
|
||||||
|
|
||||||
if view == "both":
|
|
||||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(self.plot_size[0]*2, self.plot_size[1]))
|
|
||||||
else:
|
|
||||||
fig, ax = plt.subplots(figsize=self.plot_size)
|
|
||||||
ax1 = ax2 = None
|
|
||||||
if view == "trend":
|
|
||||||
ax1 = ax
|
|
||||||
elif view == "supertrend":
|
|
||||||
ax2 = ax
|
|
||||||
|
|
||||||
fig.patch.set_facecolor(self.bg_color)
|
|
||||||
if ax1: ax1.set_facecolor(self.bg_color)
|
|
||||||
if ax2: ax2.set_facecolor(self.bg_color)
|
|
||||||
|
|
||||||
df = self.data.copy()
|
|
||||||
|
|
||||||
if ax1:
|
|
||||||
self._plot_trend_analysis(ax1, df, trend_data, analysis_results)
|
|
||||||
|
|
||||||
if ax2:
|
|
||||||
self._plot_supertrend_analysis(ax2, df, analysis_results['supertrend'])
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
def _plot_candlesticks(self, ax, df):
|
|
||||||
"""
|
|
||||||
Plot candlesticks on the given axis.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- ax: matplotlib.axes.Axes, the axis to plot on
|
|
||||||
- df: pandas.DataFrame, the data to plot
|
|
||||||
"""
|
|
||||||
from matplotlib.patches import Rectangle
|
|
||||||
|
|
||||||
for i in range(len(df)):
|
|
||||||
# Get OHLC values for this candle
|
|
||||||
open_val = df['open'].iloc[i]
|
|
||||||
close_val = df['close'].iloc[i]
|
|
||||||
high_val = df['high'].iloc[i]
|
|
||||||
low_val = df['low'].iloc[i]
|
|
||||||
|
|
||||||
# Determine candle color
|
|
||||||
color = self.candle_up_color if close_val >= open_val else self.candle_down_color
|
|
||||||
|
|
||||||
# Plot candle body
|
|
||||||
body_height = abs(close_val - open_val)
|
|
||||||
bottom = min(open_val, close_val)
|
|
||||||
rect = Rectangle((i - self.candle_width/2, bottom), self.candle_width, body_height,
|
|
||||||
color=color, alpha=self.candle_alpha)
|
|
||||||
ax.add_patch(rect)
|
|
||||||
|
|
||||||
# Plot candle wicks
|
|
||||||
ax.plot([i, i], [low_val, high_val], color=color, linewidth=self.wick_width)
|
|
||||||
|
|
||||||
def _plot_trend_analysis(self, ax, df, trend_data, analysis_results):
|
|
||||||
"""
|
|
||||||
Plot trend analysis on the given axis.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- ax: matplotlib.axes.Axes, the axis to plot on
|
|
||||||
- df: pandas.DataFrame, the data to plot
|
|
||||||
- trend_data: pandas.DataFrame, the trend data
|
|
||||||
- analysis_results: dict, the analysis results
|
|
||||||
"""
|
|
||||||
# Draw candlesticks
|
|
||||||
self._plot_candlesticks(ax, df)
|
|
||||||
|
|
||||||
# Plot minima and maxima points
|
|
||||||
self._plot_min_max_points(ax, df, trend_data)
|
|
||||||
|
|
||||||
# Plot trend lines and moving averages
|
|
||||||
if analysis_results:
|
|
||||||
self._plot_trend_lines(ax, df, analysis_results)
|
|
||||||
|
|
||||||
# Configure the subplot
|
|
||||||
self._configure_subplot(ax, 'Price Chart with Trend Analysis', len(df))
|
|
||||||
|
|
||||||
def _plot_min_max_points(self, ax, df, trend_data):
|
|
||||||
"""
|
|
||||||
Plot minimum and maximum points on the given axis.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- ax: matplotlib.axes.Axes, the axis to plot on
|
|
||||||
- df: pandas.DataFrame, the data to plot
|
|
||||||
- trend_data: pandas.DataFrame, the trend data
|
|
||||||
"""
|
|
||||||
min_indices = trend_data.index[trend_data['is_min'] == True].tolist()
|
|
||||||
if min_indices:
|
|
||||||
min_y = [df['close'].iloc[i] for i in min_indices]
|
|
||||||
ax.scatter(min_indices, min_y, color=self.min_color, s=self.min_size,
|
|
||||||
marker=self.min_marker, label='Local Minima', zorder=self.marker_zorder)
|
|
||||||
|
|
||||||
max_indices = trend_data.index[trend_data['is_max'] == True].tolist()
|
|
||||||
if max_indices:
|
|
||||||
max_y = [df['close'].iloc[i] for i in max_indices]
|
|
||||||
ax.scatter(max_indices, max_y, color=self.max_color, s=self.max_size,
|
|
||||||
marker=self.max_marker, label='Local Maxima', zorder=self.marker_zorder)
|
|
||||||
|
|
||||||
def _plot_trend_lines(self, ax, df, analysis_results):
|
|
||||||
"""
|
|
||||||
Plot trend lines on the given axis.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- ax: matplotlib.axes.Axes, the axis to plot on
|
|
||||||
- df: pandas.DataFrame, the data to plot
|
|
||||||
- analysis_results: dict, the analysis results
|
|
||||||
"""
|
|
||||||
x_vals = np.arange(len(df))
|
|
||||||
|
|
||||||
# Minima regression line (support)
|
|
||||||
min_slope = analysis_results['linear_regression']['min']['slope']
|
|
||||||
min_intercept = analysis_results['linear_regression']['min']['intercept']
|
|
||||||
min_line = min_slope * x_vals + min_intercept
|
|
||||||
ax.plot(x_vals, min_line, self.min_line_style, linewidth=self.line_width,
|
|
||||||
label='Minima Regression')
|
|
||||||
|
|
||||||
# Maxima regression line (resistance)
|
|
||||||
max_slope = analysis_results['linear_regression']['max']['slope']
|
|
||||||
max_intercept = analysis_results['linear_regression']['max']['intercept']
|
|
||||||
max_line = max_slope * x_vals + max_intercept
|
|
||||||
ax.plot(x_vals, max_line, self.max_line_style, linewidth=self.line_width,
|
|
||||||
label='Maxima Regression')
|
|
||||||
|
|
||||||
# SMA-7 line
|
|
||||||
sma_7 = analysis_results['sma']['7']
|
|
||||||
ax.plot(x_vals, sma_7, self.sma7_line_style, linewidth=self.line_width,
|
|
||||||
label='SMA-7')
|
|
||||||
|
|
||||||
# SMA-15 line
|
|
||||||
sma_15 = analysis_results['sma']['15']
|
|
||||||
valid_idx_15 = ~np.isnan(sma_15)
|
|
||||||
ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], self.sma15_line_style,
|
|
||||||
linewidth=self.line_width, label='SMA-15')
|
|
||||||
|
|
||||||
def _configure_subplot(self, ax, title, data_length):
|
|
||||||
"""
|
|
||||||
Configure the subplot with title, labels, limits, and legend.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- ax: matplotlib.axes.Axes, the axis to configure
|
|
||||||
- title: str, the title of the subplot
|
|
||||||
- data_length: int, the length of the data
|
|
||||||
"""
|
|
||||||
# Set title and labels
|
|
||||||
ax.set_title(title, fontsize=self.title_size, color=self.title_color)
|
|
||||||
ax.set_xlabel('Date', fontsize=self.axis_label_size, color=self.axis_label_color)
|
|
||||||
ax.set_ylabel('Price', fontsize=self.axis_label_size, color=self.axis_label_color)
|
|
||||||
|
|
||||||
# Set appropriate x-axis limits
|
|
||||||
ax.set_xlim(-0.5, data_length - 0.5)
|
|
||||||
|
|
||||||
# Add a legend
|
|
||||||
ax.legend(loc=self.legend_loc, facecolor=self.legend_bg_color)
|
|
||||||
|
|
||||||
def _plot_supertrend_analysis(self, ax, df, supertrend_results_list=None):
|
|
||||||
"""
|
|
||||||
Plot SuperTrend analysis on the given axis.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- ax: matplotlib.axes.Axes, the axis to plot on
|
|
||||||
- df: pandas.DataFrame, the data to plot
|
|
||||||
- supertrend_results_list: list, the SuperTrend results (optional)
|
|
||||||
"""
|
|
||||||
self._plot_candlesticks(ax, df)
|
|
||||||
self._plot_supertrend_lines(ax, df, supertrend_results_list, style='Both')
|
|
||||||
self._configure_subplot(ax, 'Multiple SuperTrend Indicators', len(df))
|
|
||||||
|
|
||||||
def _plot_supertrend_lines(self, ax, df, supertrend_results_list, style="Horizontal"):
|
|
||||||
"""
|
|
||||||
Plot SuperTrend lines on the given axis.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- ax: matplotlib.axes.Axes, the axis to plot on
|
|
||||||
- df: pandas.DataFrame, the data to plot
|
|
||||||
- supertrend_results_list: list, the SuperTrend results
|
|
||||||
"""
|
|
||||||
x_vals = np.arange(len(df))
|
|
||||||
|
|
||||||
if style == 'Horizontal' or style == 'Both':
|
|
||||||
if len(supertrend_results_list) != 3:
|
|
||||||
raise ValueError("Expected exactly 3 SuperTrend results for meta calculation")
|
|
||||||
|
|
||||||
trends = [st["results"]["trend"] for st in supertrend_results_list]
|
|
||||||
|
|
||||||
band_height = 0.02 * (df["high"].max() - df["low"].min())
|
|
||||||
y_base = df["low"].min() - band_height * 1.5
|
|
||||||
|
|
||||||
prev_color = None
|
|
||||||
for i in range(1, len(x_vals)):
|
|
||||||
t_vals = [t[i] for t in trends]
|
|
||||||
up_count = t_vals.count(1)
|
|
||||||
down_count = t_vals.count(-1)
|
|
||||||
|
|
||||||
if down_count == 3:
|
|
||||||
color = "red"
|
|
||||||
elif down_count == 2 and up_count == 1:
|
|
||||||
color = "orange"
|
|
||||||
elif down_count == 1 and up_count == 2:
|
|
||||||
color = "yellow"
|
|
||||||
elif up_count == 3:
|
|
||||||
color = "green"
|
|
||||||
else:
|
|
||||||
continue # skip if unknown or inconsistent values
|
|
||||||
|
|
||||||
ax.add_patch(Rectangle(
|
|
||||||
(x_vals[i-1], y_base),
|
|
||||||
1,
|
|
||||||
band_height,
|
|
||||||
color=color,
|
|
||||||
linewidth=0,
|
|
||||||
alpha=0.6
|
|
||||||
))
|
|
||||||
# Draw a vertical line at the change of color
|
|
||||||
if prev_color and prev_color != color:
|
|
||||||
ax.axvline(x_vals[i-1], color="grey", alpha=0.3, linewidth=1)
|
|
||||||
prev_color = color
|
|
||||||
|
|
||||||
ax.set_ylim(bottom=y_base - band_height * 0.5)
|
|
||||||
if style == 'Curves' or style == 'Both':
|
|
||||||
for st in supertrend_results_list:
|
|
||||||
params = st["params"]
|
|
||||||
results = st["results"]
|
|
||||||
supertrend = results["supertrend"]
|
|
||||||
trend = results["trend"]
|
|
||||||
|
|
||||||
# Plot SuperTrend line with color based on trend
|
|
||||||
for i in range(1, len(x_vals)):
|
|
||||||
if trend[i] == 1: # Uptrend
|
|
||||||
ax.plot(x_vals[i-1:i+1], supertrend[i-1:i+1], params["color_up"], linewidth=self.line_width)
|
|
||||||
else: # Downtrend
|
|
||||||
ax.plot(x_vals[i-1:i+1], supertrend[i-1:i+1], params["color_down"], linewidth=self.line_width)
|
|
||||||
self._plot_metasupertrend_lines(ax, df, supertrend_results_list)
|
|
||||||
self._add_supertrend_legend(ax, supertrend_results_list)
|
|
||||||
|
|
||||||
def _plot_metasupertrend_lines(self, ax, df, supertrend_results_list):
|
|
||||||
"""
|
|
||||||
Plot a Meta SuperTrend line where all individual SuperTrends agree on trend.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- ax: matplotlib.axes.Axes, the axis to plot on
|
|
||||||
- df: pandas.DataFrame, the data to plot
|
|
||||||
- supertrend_results_list: list, each item contains SuperTrend 'results' and 'params'
|
|
||||||
"""
|
|
||||||
x_vals = np.arange(len(df))
|
|
||||||
|
|
||||||
if len(supertrend_results_list) != 3:
|
|
||||||
raise ValueError("Expected exactly 3 SuperTrend results for meta calculation")
|
|
||||||
|
|
||||||
trends = [st["results"]["trend"] for st in supertrend_results_list]
|
|
||||||
supertrends = [st["results"]["supertrend"] for st in supertrend_results_list]
|
|
||||||
params = supertrend_results_list[0]["params"] # Use first config for styling
|
|
||||||
|
|
||||||
trends_arr = np.stack(trends, axis=1)
|
|
||||||
meta_trend = np.where((trends_arr[:,0] == trends_arr[:,1]) & (trends_arr[:,1] == trends_arr[:,2]), trends_arr[:,0], 0)
|
|
||||||
|
|
||||||
for i in range(1, len(x_vals)):
|
|
||||||
t1, t2, t3 = trends[0][i], trends[1][i], trends[2][i]
|
|
||||||
if t1 == t2 == t3:
|
|
||||||
meta_trend = t1
|
|
||||||
# Average the 3 supertrend values
|
|
||||||
st_avg_prev = np.mean([s[i-1] for s in supertrends])
|
|
||||||
st_avg_curr = np.mean([s[i] for s in supertrends])
|
|
||||||
color = params["color_up"] if meta_trend == 1 else params["color_down"]
|
|
||||||
ax.plot(x_vals[i-1:i+1], [st_avg_prev, st_avg_curr], color, linewidth=self.line_width)
|
|
||||||
|
|
||||||
def _add_supertrend_legend(self, ax, supertrend_results_list):
|
|
||||||
"""
|
|
||||||
Add SuperTrend legend entries to the given axis.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- ax: matplotlib.axes.Axes, the axis to add legend entries to
|
|
||||||
- supertrend_results_list: list, the SuperTrend results
|
|
||||||
"""
|
|
||||||
for st in supertrend_results_list:
|
|
||||||
params = st["params"]
|
|
||||||
period = params["period"]
|
|
||||||
multiplier = params["multiplier"]
|
|
||||||
color_up = params["color_up"]
|
|
||||||
color_down = params["color_down"]
|
|
||||||
|
|
||||||
ax.plot([], [], color_up, linewidth=self.line_width,
|
|
||||||
label=f'ST (P:{period}, M:{multiplier}) Up')
|
|
||||||
ax.plot([], [], color_down, linewidth=self.line_width,
|
|
||||||
label=f'ST (P:{period}, M:{multiplier}) Down')
|
|
||||||
|
|
||||||
def backtest_meta_supertrend(self, min1_df, initial_usd=10000, stop_loss_pct=0.05, transaction_cost=0.001, debug=False):
|
|
||||||
"""
|
|
||||||
Backtest a simple strategy using the meta supertrend (all three supertrends agree).
|
|
||||||
Buys when meta supertrend is positive, sells when negative, applies a percentage stop loss.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- min1_df: pandas DataFrame, 1-minute timeframe data for more accurate stop loss checking (optional)
|
|
||||||
- initial_usd: float, starting USD amount
|
|
||||||
- stop_loss_pct: float, stop loss as a fraction (e.g. 0.05 for 5%)
|
|
||||||
- transaction_cost: float, transaction cost as a fraction (e.g. 0.001 for 0.1%)
|
|
||||||
- debug: bool, whether to print debug info
|
|
||||||
"""
|
|
||||||
df = self.data.copy().reset_index(drop=True)
|
|
||||||
df['timestamp'] = pd.to_datetime(df['timestamp'])
|
|
||||||
|
|
||||||
# Get meta supertrend (all three agree)
|
|
||||||
supertrend_results_list = self._calculate_supertrend_indicators()
|
|
||||||
trends = [st['results']['trend'] for st in supertrend_results_list]
|
|
||||||
trends_arr = np.stack(trends, axis=1)
|
|
||||||
meta_trend = np.where((trends_arr[:,0] == trends_arr[:,1]) & (trends_arr[:,1] == trends_arr[:,2]),
|
|
||||||
trends_arr[:,0], 0)
|
|
||||||
|
|
||||||
position = 0 # 0 = no position, 1 = long
|
|
||||||
entry_price = 0
|
|
||||||
usd = initial_usd
|
|
||||||
coin = 0
|
|
||||||
trade_log = []
|
|
||||||
max_balance = initial_usd
|
|
||||||
drawdowns = []
|
|
||||||
trades = []
|
|
||||||
entry_time = None
|
|
||||||
current_trade_min1_start_idx = None
|
|
||||||
|
|
||||||
min1_df['timestamp'] = pd.to_datetime(min1_df.index)
|
|
||||||
|
|
||||||
for i in range(1, len(df)):
|
|
||||||
if i % 100 == 0 and debug:
|
|
||||||
self.logger.debug(f"Progress: {i}/{len(df)} rows processed.")
|
|
||||||
|
|
||||||
price_open = df['open'].iloc[i]
|
|
||||||
price_high = df['high'].iloc[i]
|
|
||||||
price_low = df['low'].iloc[i]
|
|
||||||
price_close = df['close'].iloc[i]
|
|
||||||
date = df['timestamp'].iloc[i]
|
|
||||||
mt = meta_trend[i]
|
|
||||||
|
|
||||||
# Check stop loss if in position
|
|
||||||
if position == 1:
|
|
||||||
stop_price = entry_price * (1 - stop_loss_pct)
|
|
||||||
|
|
||||||
if current_trade_min1_start_idx is None:
|
|
||||||
# First check after entry, find the entry point in 1-min data
|
|
||||||
current_trade_min1_start_idx = min1_df.index[min1_df.index >= entry_time][0]
|
|
||||||
|
|
||||||
# Get the end index for current check
|
|
||||||
current_min1_end_idx = min1_df.index[min1_df.index <= date][-1]
|
|
||||||
|
|
||||||
# Check all 1-minute candles in between for stop loss
|
|
||||||
min1_slice = min1_df.loc[current_trade_min1_start_idx:current_min1_end_idx]
|
|
||||||
if (min1_slice['low'] <= stop_price).any():
|
|
||||||
# Stop loss triggered, find the exact candle
|
|
||||||
stop_candle = min1_slice[min1_slice['low'] <= stop_price].iloc[0]
|
|
||||||
# More realistic fill: if open < stop, fill at open, else at stop
|
|
||||||
if stop_candle['open'] < stop_price:
|
|
||||||
sell_price = stop_candle['open']
|
|
||||||
else:
|
|
||||||
sell_price = stop_price
|
|
||||||
if debug:
|
|
||||||
print(f"STOP LOSS triggered: entry={entry_price}, stop={stop_price}, sell_price={sell_price}, entry_time={entry_time}, stop_time={stop_candle.name}")
|
|
||||||
usd = coin * sell_price * (1 - transaction_cost) # Apply transaction cost
|
|
||||||
trade_log.append({
|
|
||||||
'type': 'STOP',
|
|
||||||
'entry': entry_price,
|
|
||||||
'exit': sell_price,
|
|
||||||
'entry_time': entry_time,
|
|
||||||
'exit_time': stop_candle.name # Use index name instead of timestamp column
|
|
||||||
})
|
|
||||||
coin = 0
|
|
||||||
position = 0
|
|
||||||
entry_price = 0
|
|
||||||
current_trade_min1_start_idx = None
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Update the start index for next check
|
|
||||||
current_trade_min1_start_idx = current_min1_end_idx
|
|
||||||
|
|
||||||
# Entry logic
|
|
||||||
if position == 0 and mt == 1:
|
|
||||||
# Buy at open, apply transaction cost
|
|
||||||
coin = (usd * (1 - transaction_cost)) / price_open
|
|
||||||
entry_price = price_open
|
|
||||||
entry_time = date
|
|
||||||
usd = 0
|
|
||||||
position = 1
|
|
||||||
current_trade_min1_start_idx = None # Will be set on first stop loss check
|
|
||||||
|
|
||||||
# Exit logic
|
|
||||||
elif position == 1 and mt == -1:
|
|
||||||
# Sell at open, apply transaction cost
|
|
||||||
usd = coin * price_open * (1 - transaction_cost)
|
|
||||||
trade_log.append({
|
|
||||||
'type': 'SELL',
|
|
||||||
'entry': entry_price,
|
|
||||||
'exit': price_open,
|
|
||||||
'entry_time': entry_time,
|
|
||||||
'exit_time': date
|
|
||||||
})
|
|
||||||
coin = 0
|
|
||||||
position = 0
|
|
||||||
entry_price = 0
|
|
||||||
current_trade_min1_start_idx = None
|
|
||||||
|
|
||||||
# Track drawdown
|
|
||||||
balance = usd if position == 0 else coin * price_close
|
|
||||||
if balance > max_balance:
|
|
||||||
max_balance = balance
|
|
||||||
drawdown = (max_balance - balance) / max_balance
|
|
||||||
drawdowns.append(drawdown)
|
|
||||||
|
|
||||||
# If still in position at end, sell at last close
|
|
||||||
if position == 1:
|
|
||||||
usd = coin * df['close'].iloc[-1] * (1 - transaction_cost) # Apply transaction cost
|
|
||||||
trade_log.append({
|
|
||||||
'type': 'EOD',
|
|
||||||
'entry': entry_price,
|
|
||||||
'exit': df['close'].iloc[-1],
|
|
||||||
'entry_time': entry_time,
|
|
||||||
'exit_time': df['timestamp'].iloc[-1]
|
|
||||||
})
|
|
||||||
coin = 0
|
|
||||||
position = 0
|
|
||||||
entry_price = 0
|
|
||||||
|
|
||||||
# Calculate statistics
|
|
||||||
final_balance = usd
|
|
||||||
n_trades = len(trade_log)
|
|
||||||
wins = [1 for t in trade_log if t['exit'] > t['entry']]
|
|
||||||
win_rate = len(wins) / n_trades if n_trades > 0 else 0
|
|
||||||
max_drawdown = max(drawdowns) if drawdowns else 0
|
|
||||||
avg_trade = np.mean([t['exit']/t['entry']-1 for t in trade_log]) if trade_log else 0
|
|
||||||
|
|
||||||
trades = []
|
|
||||||
for trade in trade_log:
|
|
||||||
profit_pct = (trade['exit'] - trade['entry']) / trade['entry']
|
|
||||||
trades.append({
|
|
||||||
'entry_time': trade['entry_time'],
|
|
||||||
'exit_time': trade['exit_time'],
|
|
||||||
'entry': trade['entry'],
|
|
||||||
'exit': trade['exit'],
|
|
||||||
'profit_pct': profit_pct,
|
|
||||||
'type': trade.get('type', 'SELL')
|
|
||||||
})
|
|
||||||
|
|
||||||
results = {
|
|
||||||
"initial_usd": initial_usd,
|
|
||||||
"final_usd": final_balance,
|
|
||||||
"n_trades": n_trades,
|
|
||||||
"win_rate": win_rate,
|
|
||||||
"max_drawdown": max_drawdown,
|
|
||||||
"avg_trade": avg_trade,
|
|
||||||
"trade_log": trade_log,
|
|
||||||
"trades": trades,
|
|
||||||
}
|
|
||||||
if n_trades > 0:
|
|
||||||
results["first_trade"] = {
|
|
||||||
"entry_time": trade_log[0]['entry_time'],
|
|
||||||
"entry": trade_log[0]['entry']
|
|
||||||
}
|
|
||||||
results["last_trade"] = {
|
|
||||||
"exit_time": trade_log[-1]['exit_time'],
|
|
||||||
"exit": trade_log[-1]['exit']
|
|
||||||
}
|
|
||||||
return results
|
|
||||||
|
|
||||||
39
xgboost/custom_xgboost.py
Normal file
39
xgboost/custom_xgboost.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import xgboost as xgb
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class CustomXGBoostGPU:
|
||||||
|
def __init__(self, X_train, X_test, y_train, y_test):
|
||||||
|
self.X_train = X_train.astype(np.float32)
|
||||||
|
self.X_test = X_test.astype(np.float32)
|
||||||
|
self.y_train = y_train.astype(np.float32)
|
||||||
|
self.y_test = y_test.astype(np.float32)
|
||||||
|
self.model = None
|
||||||
|
self.params = None # Will be set during training
|
||||||
|
|
||||||
|
def train(self, **xgb_params):
|
||||||
|
params = {
|
||||||
|
'tree_method': 'hist',
|
||||||
|
'device': 'cuda',
|
||||||
|
'objective': 'reg:squarederror',
|
||||||
|
'eval_metric': 'rmse',
|
||||||
|
'verbosity': 1,
|
||||||
|
}
|
||||||
|
params.update(xgb_params)
|
||||||
|
self.params = params # Store params for later access
|
||||||
|
dtrain = xgb.DMatrix(self.X_train, label=self.y_train)
|
||||||
|
dtest = xgb.DMatrix(self.X_test, label=self.y_test)
|
||||||
|
evals = [(dtrain, 'train'), (dtest, 'eval')]
|
||||||
|
self.model = xgb.train(params, dtrain, num_boost_round=100, evals=evals, early_stopping_rounds=10)
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def predict(self, X):
|
||||||
|
if self.model is None:
|
||||||
|
raise ValueError('Model not trained yet.')
|
||||||
|
dmatrix = xgb.DMatrix(X.astype(np.float32))
|
||||||
|
return self.model.predict(dmatrix)
|
||||||
|
|
||||||
|
def save_model(self, file_path):
|
||||||
|
"""Save the trained XGBoost model to the specified file path."""
|
||||||
|
if self.model is None:
|
||||||
|
raise ValueError('Model not trained yet.')
|
||||||
|
self.model.save_model(file_path)
|
||||||
806
xgboost/main.py
Normal file
806
xgboost/main.py
Normal file
@@ -0,0 +1,806 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from custom_xgboost import CustomXGBoostGPU
|
||||||
|
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
||||||
|
from plot_results import plot_prediction_error_distribution, plot_direction_transition_heatmap
|
||||||
|
from cycles.supertrend import Supertrends
|
||||||
|
import time
|
||||||
|
from numba import njit
|
||||||
|
import itertools
|
||||||
|
import csv
|
||||||
|
import pandas_ta as ta
|
||||||
|
|
||||||
|
def run_indicator(func, *args):
|
||||||
|
return func(*args)
|
||||||
|
|
||||||
|
def run_indicator_job(job):
|
||||||
|
import time
|
||||||
|
func, *args = job
|
||||||
|
indicator_name = func.__name__
|
||||||
|
start = time.time()
|
||||||
|
result = func(*args)
|
||||||
|
elapsed = time.time() - start
|
||||||
|
print(f'Indicator {indicator_name} computed in {elapsed:.4f} seconds')
|
||||||
|
return result
|
||||||
|
|
||||||
|
def calc_rsi(close):
|
||||||
|
from ta.momentum import RSIIndicator
|
||||||
|
return ('rsi', RSIIndicator(close, window=14).rsi())
|
||||||
|
|
||||||
|
def calc_macd(close):
|
||||||
|
from ta.trend import MACD
|
||||||
|
return ('macd', MACD(close).macd())
|
||||||
|
|
||||||
|
def calc_bollinger(close):
|
||||||
|
from ta.volatility import BollingerBands
|
||||||
|
bb = BollingerBands(close=close, window=20, window_dev=2)
|
||||||
|
return [
|
||||||
|
('bb_bbm', bb.bollinger_mavg()),
|
||||||
|
('bb_bbh', bb.bollinger_hband()),
|
||||||
|
('bb_bbl', bb.bollinger_lband()),
|
||||||
|
('bb_bb_width', bb.bollinger_hband() - bb.bollinger_lband())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_stochastic(high, low, close):
|
||||||
|
from ta.momentum import StochasticOscillator
|
||||||
|
stoch = StochasticOscillator(high=high, low=low, close=close, window=14, smooth_window=3)
|
||||||
|
return [
|
||||||
|
('stoch_k', stoch.stoch()),
|
||||||
|
('stoch_d', stoch.stoch_signal())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_atr(high, low, close):
|
||||||
|
from ta.volatility import AverageTrueRange
|
||||||
|
atr = AverageTrueRange(high=high, low=low, close=close, window=14)
|
||||||
|
return ('atr', atr.average_true_range())
|
||||||
|
|
||||||
|
def calc_cci(high, low, close):
|
||||||
|
from ta.trend import CCIIndicator
|
||||||
|
cci = CCIIndicator(high=high, low=low, close=close, window=20)
|
||||||
|
return ('cci', cci.cci())
|
||||||
|
|
||||||
|
def calc_williamsr(high, low, close):
|
||||||
|
from ta.momentum import WilliamsRIndicator
|
||||||
|
willr = WilliamsRIndicator(high=high, low=low, close=close, lbp=14)
|
||||||
|
return ('williams_r', willr.williams_r())
|
||||||
|
|
||||||
|
def calc_ema(close):
|
||||||
|
from ta.trend import EMAIndicator
|
||||||
|
ema = EMAIndicator(close=close, window=14)
|
||||||
|
return ('ema_14', ema.ema_indicator())
|
||||||
|
|
||||||
|
def calc_obv(close, volume):
|
||||||
|
from ta.volume import OnBalanceVolumeIndicator
|
||||||
|
obv = OnBalanceVolumeIndicator(close=close, volume=volume)
|
||||||
|
return ('obv', obv.on_balance_volume())
|
||||||
|
|
||||||
|
def calc_cmf(high, low, close, volume):
|
||||||
|
from ta.volume import ChaikinMoneyFlowIndicator
|
||||||
|
cmf = ChaikinMoneyFlowIndicator(high=high, low=low, close=close, volume=volume, window=20)
|
||||||
|
return ('cmf', cmf.chaikin_money_flow())
|
||||||
|
|
||||||
|
def calc_sma(close):
|
||||||
|
from ta.trend import SMAIndicator
|
||||||
|
return [
|
||||||
|
('sma_50', SMAIndicator(close, window=50).sma_indicator()),
|
||||||
|
('sma_200', SMAIndicator(close, window=200).sma_indicator())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_roc(close):
|
||||||
|
from ta.momentum import ROCIndicator
|
||||||
|
return ('roc_10', ROCIndicator(close, window=10).roc())
|
||||||
|
|
||||||
|
def calc_momentum(close):
|
||||||
|
return ('momentum_10', close - close.shift(10))
|
||||||
|
|
||||||
|
def calc_psar(high, low, close):
|
||||||
|
# Use the Numba-accelerated fast_psar function for speed
|
||||||
|
psar_values = fast_psar(np.array(high), np.array(low), np.array(close))
|
||||||
|
return [('psar', pd.Series(psar_values, index=close.index))]
|
||||||
|
|
||||||
|
def calc_donchian(high, low, close):
|
||||||
|
from ta.volatility import DonchianChannel
|
||||||
|
donchian = DonchianChannel(high, low, close, window=20)
|
||||||
|
return [
|
||||||
|
('donchian_hband', donchian.donchian_channel_hband()),
|
||||||
|
('donchian_lband', donchian.donchian_channel_lband()),
|
||||||
|
('donchian_mband', donchian.donchian_channel_mband())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_keltner(high, low, close):
|
||||||
|
from ta.volatility import KeltnerChannel
|
||||||
|
keltner = KeltnerChannel(high, low, close, window=20)
|
||||||
|
return [
|
||||||
|
('keltner_hband', keltner.keltner_channel_hband()),
|
||||||
|
('keltner_lband', keltner.keltner_channel_lband()),
|
||||||
|
('keltner_mband', keltner.keltner_channel_mband())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_dpo(close):
|
||||||
|
from ta.trend import DPOIndicator
|
||||||
|
return ('dpo_20', DPOIndicator(close, window=20).dpo())
|
||||||
|
|
||||||
|
def calc_ultimate(high, low, close):
|
||||||
|
from ta.momentum import UltimateOscillator
|
||||||
|
return ('ultimate_osc', UltimateOscillator(high, low, close).ultimate_oscillator())
|
||||||
|
|
||||||
|
def calc_ichimoku(high, low):
|
||||||
|
from ta.trend import IchimokuIndicator
|
||||||
|
ichimoku = IchimokuIndicator(high, low, window1=9, window2=26, window3=52)
|
||||||
|
return [
|
||||||
|
('ichimoku_a', ichimoku.ichimoku_a()),
|
||||||
|
('ichimoku_b', ichimoku.ichimoku_b()),
|
||||||
|
('ichimoku_base_line', ichimoku.ichimoku_base_line()),
|
||||||
|
('ichimoku_conversion_line', ichimoku.ichimoku_conversion_line())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_elder_ray(close, low, high):
|
||||||
|
from ta.trend import EMAIndicator
|
||||||
|
ema = EMAIndicator(close, window=13).ema_indicator()
|
||||||
|
return [
|
||||||
|
('elder_ray_bull', ema - low),
|
||||||
|
('elder_ray_bear', ema - high)
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_daily_return(close):
|
||||||
|
from ta.others import DailyReturnIndicator
|
||||||
|
return ('daily_return', DailyReturnIndicator(close).daily_return())
|
||||||
|
|
||||||
|
@njit
|
||||||
|
def fast_psar(high, low, close, af=0.02, max_af=0.2):
|
||||||
|
length = len(close)
|
||||||
|
psar = np.zeros(length)
|
||||||
|
bull = True
|
||||||
|
af_step = af
|
||||||
|
ep = low[0]
|
||||||
|
psar[0] = low[0]
|
||||||
|
for i in range(1, length):
|
||||||
|
prev_psar = psar[i-1]
|
||||||
|
if bull:
|
||||||
|
psar[i] = prev_psar + af_step * (ep - prev_psar)
|
||||||
|
if low[i] < psar[i]:
|
||||||
|
bull = False
|
||||||
|
psar[i] = ep
|
||||||
|
af_step = af
|
||||||
|
ep = low[i]
|
||||||
|
else:
|
||||||
|
if high[i] > ep:
|
||||||
|
ep = high[i]
|
||||||
|
af_step = min(af_step + af, max_af)
|
||||||
|
else:
|
||||||
|
psar[i] = prev_psar + af_step * (ep - prev_psar)
|
||||||
|
if high[i] > psar[i]:
|
||||||
|
bull = True
|
||||||
|
psar[i] = ep
|
||||||
|
af_step = af
|
||||||
|
ep = high[i]
|
||||||
|
else:
|
||||||
|
if low[i] < ep:
|
||||||
|
ep = low[i]
|
||||||
|
af_step = min(af_step + af, max_af)
|
||||||
|
return psar
|
||||||
|
|
||||||
|
def compute_lag(df, col, lag):
|
||||||
|
return df[col].shift(lag)
|
||||||
|
|
||||||
|
def compute_rolling(df, col, stat, window):
|
||||||
|
if stat == 'mean':
|
||||||
|
return df[col].rolling(window).mean()
|
||||||
|
elif stat == 'std':
|
||||||
|
return df[col].rolling(window).std()
|
||||||
|
elif stat == 'min':
|
||||||
|
return df[col].rolling(window).min()
|
||||||
|
elif stat == 'max':
|
||||||
|
return df[col].rolling(window).max()
|
||||||
|
|
||||||
|
def compute_log_return(df, horizon):
|
||||||
|
return np.log(df['Close'] / df['Close'].shift(horizon))
|
||||||
|
|
||||||
|
def compute_volatility(df, window):
|
||||||
|
return df['log_return'].rolling(window).std()
|
||||||
|
|
||||||
|
def run_feature_job(job, df):
|
||||||
|
feature_name, func, *args = job
|
||||||
|
print(f'Computing feature: {feature_name}')
|
||||||
|
result = func(df, *args)
|
||||||
|
return feature_name, result
|
||||||
|
|
||||||
|
def calc_adx(high, low, close):
|
||||||
|
from ta.trend import ADXIndicator
|
||||||
|
adx = ADXIndicator(high=high, low=low, close=close, window=14)
|
||||||
|
return [
|
||||||
|
('adx', adx.adx()),
|
||||||
|
('adx_pos', adx.adx_pos()),
|
||||||
|
('adx_neg', adx.adx_neg())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_trix(close):
|
||||||
|
from ta.trend import TRIXIndicator
|
||||||
|
trix = TRIXIndicator(close=close, window=15)
|
||||||
|
return ('trix', trix.trix())
|
||||||
|
|
||||||
|
def calc_vortex(high, low, close):
|
||||||
|
from ta.trend import VortexIndicator
|
||||||
|
vortex = VortexIndicator(high=high, low=low, close=close, window=14)
|
||||||
|
return [
|
||||||
|
('vortex_pos', vortex.vortex_indicator_pos()),
|
||||||
|
('vortex_neg', vortex.vortex_indicator_neg())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_kama(close):
|
||||||
|
import pandas_ta as ta
|
||||||
|
kama = ta.kama(close, length=10)
|
||||||
|
return ('kama', kama)
|
||||||
|
|
||||||
|
def calc_force_index(close, volume):
|
||||||
|
from ta.volume import ForceIndexIndicator
|
||||||
|
fi = ForceIndexIndicator(close=close, volume=volume, window=13)
|
||||||
|
return ('force_index', fi.force_index())
|
||||||
|
|
||||||
|
def calc_eom(high, low, volume):
|
||||||
|
from ta.volume import EaseOfMovementIndicator
|
||||||
|
eom = EaseOfMovementIndicator(high=high, low=low, volume=volume, window=14)
|
||||||
|
return ('eom', eom.ease_of_movement())
|
||||||
|
|
||||||
|
def calc_mfi(high, low, close, volume):
|
||||||
|
from ta.volume import MFIIndicator
|
||||||
|
mfi = MFIIndicator(high=high, low=low, close=close, volume=volume, window=14)
|
||||||
|
return ('mfi', mfi.money_flow_index())
|
||||||
|
|
||||||
|
def calc_adi(high, low, close, volume):
|
||||||
|
from ta.volume import AccDistIndexIndicator
|
||||||
|
adi = AccDistIndexIndicator(high=high, low=low, close=close, volume=volume)
|
||||||
|
return ('adi', adi.acc_dist_index())
|
||||||
|
|
||||||
|
def calc_tema(close):
|
||||||
|
import pandas_ta as ta
|
||||||
|
tema = ta.tema(close, length=10)
|
||||||
|
return ('tema', tema)
|
||||||
|
|
||||||
|
def calc_stochrsi(close):
|
||||||
|
from ta.momentum import StochRSIIndicator
|
||||||
|
stochrsi = StochRSIIndicator(close=close, window=14, smooth1=3, smooth2=3)
|
||||||
|
return [
|
||||||
|
('stochrsi', stochrsi.stochrsi()),
|
||||||
|
('stochrsi_k', stochrsi.stochrsi_k()),
|
||||||
|
('stochrsi_d', stochrsi.stochrsi_d())
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_awesome_oscillator(high, low):
|
||||||
|
from ta.momentum import AwesomeOscillatorIndicator
|
||||||
|
ao = AwesomeOscillatorIndicator(high=high, low=low, window1=5, window2=34)
|
||||||
|
return ('awesome_osc', ao.awesome_oscillator())
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
IMPUTE_NANS = True # Set to True to impute NaNs, False to drop rows with NaNs
|
||||||
|
csv_path = './data/btcusd_1-min_data.csv'
|
||||||
|
csv_prefix = os.path.splitext(os.path.basename(csv_path))[0]
|
||||||
|
|
||||||
|
print('Reading CSV and filtering data...')
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
df = df[df['Volume'] != 0]
|
||||||
|
|
||||||
|
min_date = '2017-06-01'
|
||||||
|
print('Converting Timestamp and filtering by date...')
|
||||||
|
df['Timestamp'] = pd.to_datetime(df['Timestamp'], unit='s')
|
||||||
|
df = df[df['Timestamp'] >= min_date]
|
||||||
|
|
||||||
|
lags = 3
|
||||||
|
|
||||||
|
print('Calculating log returns as the new target...')
|
||||||
|
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
|
||||||
|
|
||||||
|
ohlcv_cols = ['Open', 'High', 'Low', 'Close', 'Volume']
|
||||||
|
window_sizes = [5, 15, 30] # in minutes, adjust as needed
|
||||||
|
|
||||||
|
features_dict = {}
|
||||||
|
|
||||||
|
print('Starting feature computation...')
|
||||||
|
feature_start_time = time.time()
|
||||||
|
|
||||||
|
# --- Technical Indicator Features: Calculate or Load from Cache ---
|
||||||
|
print('Calculating or loading technical indicator features...')
|
||||||
|
# RSI
|
||||||
|
feature_file = f'./data/{csv_prefix}_rsi.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['rsi'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: rsi')
|
||||||
|
_, values = calc_rsi(df['Close'])
|
||||||
|
features_dict['rsi'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# MACD
|
||||||
|
feature_file = f'./data/{csv_prefix}_macd.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['macd'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: macd')
|
||||||
|
_, values = calc_macd(df['Close'])
|
||||||
|
features_dict['macd'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# ATR
|
||||||
|
feature_file = f'./data/{csv_prefix}_atr.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['atr'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: atr')
|
||||||
|
_, values = calc_atr(df['High'], df['Low'], df['Close'])
|
||||||
|
features_dict['atr'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# CCI
|
||||||
|
feature_file = f'./data/{csv_prefix}_cci.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['cci'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: cci')
|
||||||
|
_, values = calc_cci(df['High'], df['Low'], df['Close'])
|
||||||
|
features_dict['cci'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# Williams %R
|
||||||
|
feature_file = f'./data/{csv_prefix}_williams_r.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['williams_r'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: williams_r')
|
||||||
|
_, values = calc_williamsr(df['High'], df['Low'], df['Close'])
|
||||||
|
features_dict['williams_r'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# EMA 14
|
||||||
|
feature_file = f'./data/{csv_prefix}_ema_14.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['ema_14'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: ema_14')
|
||||||
|
_, values = calc_ema(df['Close'])
|
||||||
|
features_dict['ema_14'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# OBV
|
||||||
|
feature_file = f'./data/{csv_prefix}_obv.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['obv'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: obv')
|
||||||
|
_, values = calc_obv(df['Close'], df['Volume'])
|
||||||
|
features_dict['obv'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# CMF
|
||||||
|
feature_file = f'./data/{csv_prefix}_cmf.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['cmf'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: cmf')
|
||||||
|
_, values = calc_cmf(df['High'], df['Low'], df['Close'], df['Volume'])
|
||||||
|
features_dict['cmf'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# ROC 10
|
||||||
|
feature_file = f'./data/{csv_prefix}_roc_10.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['roc_10'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: roc_10')
|
||||||
|
_, values = calc_roc(df['Close'])
|
||||||
|
features_dict['roc_10'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# DPO 20
|
||||||
|
feature_file = f'./data/{csv_prefix}_dpo_20.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['dpo_20'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: dpo_20')
|
||||||
|
_, values = calc_dpo(df['Close'])
|
||||||
|
features_dict['dpo_20'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# Ultimate Oscillator
|
||||||
|
feature_file = f'./data/{csv_prefix}_ultimate_osc.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['ultimate_osc'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: ultimate_osc')
|
||||||
|
_, values = calc_ultimate(df['High'], df['Low'], df['Close'])
|
||||||
|
features_dict['ultimate_osc'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# Daily Return
|
||||||
|
feature_file = f'./data/{csv_prefix}_daily_return.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'A Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['daily_return'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: daily_return')
|
||||||
|
_, values = calc_daily_return(df['Close'])
|
||||||
|
features_dict['daily_return'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# Multi-column indicators
|
||||||
|
# Bollinger Bands
|
||||||
|
print('Calculating multi-column indicator: bollinger')
|
||||||
|
result = calc_bollinger(df['Close'])
|
||||||
|
for subname, values in result:
|
||||||
|
print(f"Adding subfeature: {subname}")
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
if os.path.exists(sub_feature_file):
|
||||||
|
print(f'B Loading cached feature: {sub_feature_file}')
|
||||||
|
arr = np.load(sub_feature_file)
|
||||||
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# Stochastic Oscillator
|
||||||
|
print('Calculating multi-column indicator: stochastic')
|
||||||
|
result = calc_stochastic(df['High'], df['Low'], df['Close'])
|
||||||
|
for subname, values in result:
|
||||||
|
print(f"Adding subfeature: {subname}")
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
if os.path.exists(sub_feature_file):
|
||||||
|
print(f'B Loading cached feature: {sub_feature_file}')
|
||||||
|
arr = np.load(sub_feature_file)
|
||||||
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# SMA
|
||||||
|
print('Calculating multi-column indicator: sma')
|
||||||
|
result = calc_sma(df['Close'])
|
||||||
|
for subname, values in result:
|
||||||
|
print(f"Adding subfeature: {subname}")
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
if os.path.exists(sub_feature_file):
|
||||||
|
print(f'B Loading cached feature: {sub_feature_file}')
|
||||||
|
arr = np.load(sub_feature_file)
|
||||||
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# PSAR
|
||||||
|
print('Calculating multi-column indicator: psar')
|
||||||
|
result = calc_psar(df['High'], df['Low'], df['Close'])
|
||||||
|
for subname, values in result:
|
||||||
|
print(f"Adding subfeature: {subname}")
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
if os.path.exists(sub_feature_file):
|
||||||
|
print(f'B Loading cached feature: {sub_feature_file}')
|
||||||
|
arr = np.load(sub_feature_file)
|
||||||
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# Donchian Channel
|
||||||
|
print('Calculating multi-column indicator: donchian')
|
||||||
|
result = calc_donchian(df['High'], df['Low'], df['Close'])
|
||||||
|
for subname, values in result:
|
||||||
|
print(f"Adding subfeature: {subname}")
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
if os.path.exists(sub_feature_file):
|
||||||
|
print(f'B Loading cached feature: {sub_feature_file}')
|
||||||
|
arr = np.load(sub_feature_file)
|
||||||
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# Keltner Channel
|
||||||
|
print('Calculating multi-column indicator: keltner')
|
||||||
|
result = calc_keltner(df['High'], df['Low'], df['Close'])
|
||||||
|
for subname, values in result:
|
||||||
|
print(f"Adding subfeature: {subname}")
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
if os.path.exists(sub_feature_file):
|
||||||
|
print(f'B Loading cached feature: {sub_feature_file}')
|
||||||
|
arr = np.load(sub_feature_file)
|
||||||
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# Ichimoku
|
||||||
|
print('Calculating multi-column indicator: ichimoku')
|
||||||
|
result = calc_ichimoku(df['High'], df['Low'])
|
||||||
|
for subname, values in result:
|
||||||
|
print(f"Adding subfeature: {subname}")
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
if os.path.exists(sub_feature_file):
|
||||||
|
print(f'B Loading cached feature: {sub_feature_file}')
|
||||||
|
arr = np.load(sub_feature_file)
|
||||||
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# Elder Ray
|
||||||
|
print('Calculating multi-column indicator: elder_ray')
|
||||||
|
result = calc_elder_ray(df['Close'], df['Low'], df['High'])
|
||||||
|
for subname, values in result:
|
||||||
|
print(f"Adding subfeature: {subname}")
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
if os.path.exists(sub_feature_file):
|
||||||
|
print(f'B Loading cached feature: {sub_feature_file}')
|
||||||
|
arr = np.load(sub_feature_file)
|
||||||
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# Prepare lags, rolling stats, log returns, and volatility features sequentially
|
||||||
|
# Lags
|
||||||
|
for col in ohlcv_cols:
|
||||||
|
for lag in range(1, lags + 1):
|
||||||
|
feature_name = f'{col}_lag{lag}'
|
||||||
|
feature_file = f'./data/{csv_prefix}_{feature_name}.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'C Loading cached feature: {feature_file}')
|
||||||
|
features_dict[feature_name] = np.load(feature_file)
|
||||||
|
else:
|
||||||
|
print(f'Computing lag feature: {feature_name}')
|
||||||
|
result = compute_lag(df, col, lag)
|
||||||
|
features_dict[feature_name] = result
|
||||||
|
np.save(feature_file, result.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
# Rolling statistics
|
||||||
|
for col in ohlcv_cols:
|
||||||
|
for window in window_sizes:
|
||||||
|
if (col == 'Open' and window == 5):
|
||||||
|
continue
|
||||||
|
if (col == 'High' and window == 5):
|
||||||
|
continue
|
||||||
|
if (col == 'High' and window == 30):
|
||||||
|
continue
|
||||||
|
if (col == 'Low' and window == 15):
|
||||||
|
continue
|
||||||
|
for stat in ['mean', 'std', 'min', 'max']:
|
||||||
|
feature_name = f'{col}_roll_{stat}_{window}'
|
||||||
|
feature_file = f'./data/{csv_prefix}_{feature_name}.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'D Loading cached feature: {feature_file}')
|
||||||
|
features_dict[feature_name] = np.load(feature_file)
|
||||||
|
else:
|
||||||
|
print(f'Computing rolling stat feature: {feature_name}')
|
||||||
|
result = compute_rolling(df, col, stat, window)
|
||||||
|
features_dict[feature_name] = result
|
||||||
|
np.save(feature_file, result.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
# Log returns for different horizons
|
||||||
|
for horizon in [5, 15, 30]:
|
||||||
|
feature_name = f'log_return_{horizon}'
|
||||||
|
feature_file = f'./data/{csv_prefix}_{feature_name}.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'E Loading cached feature: {feature_file}')
|
||||||
|
features_dict[feature_name] = np.load(feature_file)
|
||||||
|
else:
|
||||||
|
print(f'Computing log return feature: {feature_name}')
|
||||||
|
result = compute_log_return(df, horizon)
|
||||||
|
features_dict[feature_name] = result
|
||||||
|
np.save(feature_file, result.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
# Volatility
|
||||||
|
for window in window_sizes:
|
||||||
|
feature_name = f'volatility_{window}'
|
||||||
|
feature_file = f'./data/{csv_prefix}_{feature_name}.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'F Loading cached feature: {feature_file}')
|
||||||
|
features_dict[feature_name] = np.load(feature_file)
|
||||||
|
else:
|
||||||
|
print(f'Computing volatility feature: {feature_name}')
|
||||||
|
result = compute_volatility(df, window)
|
||||||
|
features_dict[feature_name] = result
|
||||||
|
np.save(feature_file, result.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# --- Additional Technical Indicator Features ---
|
||||||
|
# ADX
|
||||||
|
adx_names = ['adx', 'adx_pos', 'adx_neg']
|
||||||
|
adx_files = [f'./data/{csv_prefix}_{name}.npy' for name in adx_names]
|
||||||
|
if all(os.path.exists(f) for f in adx_files):
|
||||||
|
print('G Loading cached features: ADX')
|
||||||
|
for name, f in zip(adx_names, adx_files):
|
||||||
|
arr = np.load(f)
|
||||||
|
features_dict[name] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating multi-column indicator: adx')
|
||||||
|
result = calc_adx(df['High'], df['Low'], df['Close'])
|
||||||
|
for subname, values in result:
|
||||||
|
sub_feature_file = f'./data/{csv_prefix}_{subname}.npy'
|
||||||
|
features_dict[subname] = values
|
||||||
|
np.save(sub_feature_file, values.values)
|
||||||
|
print(f'Saved feature: {sub_feature_file}')
|
||||||
|
|
||||||
|
# Force Index
|
||||||
|
feature_file = f'./data/{csv_prefix}_force_index.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
print(f'K Loading cached feature: {feature_file}')
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict['force_index'] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
print('Calculating feature: force_index')
|
||||||
|
_, values = calc_force_index(df['Close'], df['Volume'])
|
||||||
|
features_dict['force_index'] = values
|
||||||
|
np.save(feature_file, values.values)
|
||||||
|
print(f'Saved feature: {feature_file}')
|
||||||
|
|
||||||
|
# Supertrend indicators
|
||||||
|
for period, multiplier in [(12, 3.0), (10, 1.0), (11, 2.0)]:
|
||||||
|
st_name = f'supertrend_{period}_{multiplier}'
|
||||||
|
st_trend_name = f'supertrend_trend_{period}_{multiplier}'
|
||||||
|
st_file = f'./data/{csv_prefix}_{st_name}.npy'
|
||||||
|
st_trend_file = f'./data/{csv_prefix}_{st_trend_name}.npy'
|
||||||
|
if os.path.exists(st_file) and os.path.exists(st_trend_file):
|
||||||
|
print(f'L Loading cached features: {st_file}, {st_trend_file}')
|
||||||
|
features_dict[st_name] = pd.Series(np.load(st_file), index=df.index)
|
||||||
|
features_dict[st_trend_name] = pd.Series(np.load(st_trend_file), index=df.index)
|
||||||
|
else:
|
||||||
|
print(f'Calculating Supertrend indicator: {st_name}')
|
||||||
|
st = ta.supertrend(df['High'], df['Low'], df['Close'], length=period, multiplier=multiplier)
|
||||||
|
features_dict[st_name] = st[f'SUPERT_{period}_{multiplier}']
|
||||||
|
features_dict[st_trend_name] = st[f'SUPERTd_{period}_{multiplier}']
|
||||||
|
np.save(st_file, features_dict[st_name].values)
|
||||||
|
np.save(st_trend_file, features_dict[st_trend_name].values)
|
||||||
|
print(f'Saved features: {st_file}, {st_trend_file}')
|
||||||
|
|
||||||
|
# Concatenate all new features at once
|
||||||
|
print('Concatenating all new features to DataFrame...')
|
||||||
|
features_df = pd.DataFrame(features_dict)
|
||||||
|
print("Columns in features_df:", features_df.columns.tolist())
|
||||||
|
print("All-NaN columns in features_df:", features_df.columns[features_df.isna().all()].tolist())
|
||||||
|
df = pd.concat([df, features_df], axis=1)
|
||||||
|
|
||||||
|
# Print all columns after concatenation
|
||||||
|
print("All columns in df after concat:", df.columns.tolist())
|
||||||
|
|
||||||
|
# Downcast all float columns to save memory
|
||||||
|
print('Downcasting float columns to save memory...')
|
||||||
|
for col in df.columns:
|
||||||
|
try:
|
||||||
|
df[col] = pd.to_numeric(df[col], downcast='float')
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Add time features (exclude 'dayofweek')
|
||||||
|
print('Adding hour feature...')
|
||||||
|
df['Timestamp'] = pd.to_datetime(df['Timestamp'], errors='coerce')
|
||||||
|
df['hour'] = df['Timestamp'].dt.hour
|
||||||
|
|
||||||
|
# Handle NaNs after all feature engineering
|
||||||
|
if IMPUTE_NANS:
|
||||||
|
print('Imputing NaNs after feature engineering (using mean imputation)...')
|
||||||
|
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
||||||
|
for col in numeric_cols:
|
||||||
|
df[col] = df[col].fillna(df[col].mean())
|
||||||
|
# If you want to impute non-numeric columns differently, add logic here
|
||||||
|
else:
|
||||||
|
print('Dropping NaNs after feature engineering...')
|
||||||
|
df = df.dropna().reset_index(drop=True)
|
||||||
|
|
||||||
|
# Exclude 'Timestamp', 'Close', 'log_return', and any future target columns from features
|
||||||
|
print('Selecting feature columns...')
|
||||||
|
exclude_cols = ['Timestamp', 'Close', 'log_return', 'log_return_5', 'log_return_15', 'log_return_30']
|
||||||
|
feature_cols = [col for col in df.columns if col not in exclude_cols]
|
||||||
|
print('Features used for training:', feature_cols)
|
||||||
|
|
||||||
|
# Prepare CSV for results
|
||||||
|
results_csv = './data/leave_one_out_results.csv'
|
||||||
|
if not os.path.exists(results_csv):
|
||||||
|
with open(results_csv, 'w', newline='') as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
writer.writerow(['left_out_feature', 'used_features', 'rmse', 'mae', 'r2', 'mape', 'directional_accuracy'])
|
||||||
|
|
||||||
|
total_features = len(feature_cols)
|
||||||
|
never_leave_out = {'Open', 'High', 'Low', 'Close', 'Volume'}
|
||||||
|
for idx, left_out in enumerate(feature_cols):
|
||||||
|
if left_out in never_leave_out:
|
||||||
|
continue
|
||||||
|
used = [f for f in feature_cols if f != left_out]
|
||||||
|
print(f'\n=== Leave-one-out {idx+1}/{total_features}: left out {left_out} ===')
|
||||||
|
try:
|
||||||
|
# Prepare X and y for this combination
|
||||||
|
X = df[used].values.astype(np.float32)
|
||||||
|
y = df["log_return"].values.astype(np.float32)
|
||||||
|
split_idx = int(len(X) * 0.8)
|
||||||
|
X_train, X_test = X[:split_idx], X[split_idx:]
|
||||||
|
y_train, y_test = y[:split_idx], y[split_idx:]
|
||||||
|
test_timestamps = df['Timestamp'].values[split_idx:]
|
||||||
|
|
||||||
|
model = CustomXGBoostGPU(X_train, X_test, y_train, y_test)
|
||||||
|
booster = model.train()
|
||||||
|
model.save_model(f'./data/xgboost_model_wo_{left_out}.json')
|
||||||
|
|
||||||
|
test_preds = model.predict(X_test)
|
||||||
|
rmse = np.sqrt(mean_squared_error(y_test, test_preds))
|
||||||
|
|
||||||
|
# Reconstruct price series from log returns
|
||||||
|
if 'Close' in df.columns:
|
||||||
|
close_prices = df['Close'].values
|
||||||
|
else:
|
||||||
|
close_prices = pd.read_csv(csv_path)['Close'].values
|
||||||
|
start_price = close_prices[split_idx]
|
||||||
|
actual_prices = [start_price]
|
||||||
|
for r_ in y_test:
|
||||||
|
actual_prices.append(actual_prices[-1] * np.exp(r_))
|
||||||
|
actual_prices = np.array(actual_prices[1:])
|
||||||
|
predicted_prices = [start_price]
|
||||||
|
for r_ in test_preds:
|
||||||
|
predicted_prices.append(predicted_prices[-1] * np.exp(r_))
|
||||||
|
predicted_prices = np.array(predicted_prices[1:])
|
||||||
|
|
||||||
|
mae = mean_absolute_error(actual_prices, predicted_prices)
|
||||||
|
r2 = r2_score(actual_prices, predicted_prices)
|
||||||
|
direction_actual = np.sign(np.diff(actual_prices))
|
||||||
|
direction_pred = np.sign(np.diff(predicted_prices))
|
||||||
|
directional_accuracy = (direction_actual == direction_pred).mean()
|
||||||
|
mape = np.mean(np.abs((actual_prices - predicted_prices) / actual_prices)) * 100
|
||||||
|
|
||||||
|
# Save results to CSV
|
||||||
|
with open(results_csv, 'a', newline='') as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
writer.writerow([left_out, "|".join(used), rmse, mae, r2, mape, directional_accuracy])
|
||||||
|
print(f'Left out {left_out}: RMSE={rmse:.4f}, MAE={mae:.4f}, R2={r2:.4f}, MAPE={mape:.2f}%, DirAcc={directional_accuracy*100:.2f}%')
|
||||||
|
|
||||||
|
# Plotting for this run
|
||||||
|
plot_prefix = f'loo_{left_out}'
|
||||||
|
print('Plotting distribution of absolute prediction errors...')
|
||||||
|
plot_prediction_error_distribution(predicted_prices, actual_prices, prefix=plot_prefix)
|
||||||
|
|
||||||
|
print('Plotting directional accuracy...')
|
||||||
|
plot_direction_transition_heatmap(actual_prices, predicted_prices, prefix=plot_prefix)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Leave-one-out failed for {left_out}: {e}')
|
||||||
|
print(f'All leave-one-out runs completed. Results saved to {results_csv}')
|
||||||
|
sys.exit(0)
|
||||||
318
xgboost/plot_results.py
Normal file
318
xgboost/plot_results.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
import numpy as np
|
||||||
|
import dash
|
||||||
|
from dash import dcc, html
|
||||||
|
import plotly.graph_objs as go
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
def display_actual_vs_predicted(y_test, test_preds, timestamps, n_plot=200):
|
||||||
|
import plotly.offline as pyo
|
||||||
|
n_plot = min(n_plot, len(y_test))
|
||||||
|
plot_indices = timestamps[:n_plot]
|
||||||
|
actual = y_test[:n_plot]
|
||||||
|
predicted = test_preds[:n_plot]
|
||||||
|
|
||||||
|
trace_actual = go.Scatter(x=plot_indices, y=actual, mode='lines', name='Actual')
|
||||||
|
trace_predicted = go.Scatter(x=plot_indices, y=predicted, mode='lines', name='Predicted')
|
||||||
|
data = [trace_actual, trace_predicted]
|
||||||
|
layout = go.Layout(
|
||||||
|
title='Actual vs. Predicted BTC Close Prices (Test Set)',
|
||||||
|
xaxis={'title': 'Timestamp'},
|
||||||
|
yaxis={'title': 'BTC Close Price'},
|
||||||
|
legend={'x': 0, 'y': 1},
|
||||||
|
margin={'l': 40, 'b': 40, 't': 40, 'r': 10},
|
||||||
|
hovermode='closest'
|
||||||
|
)
|
||||||
|
fig = go.Figure(data=data, layout=layout)
|
||||||
|
pyo.plot(fig, auto_open=False)
|
||||||
|
|
||||||
|
def plot_target_distribution(y_train, y_test):
|
||||||
|
import plotly.offline as pyo
|
||||||
|
trace_train = go.Histogram(
|
||||||
|
x=y_train,
|
||||||
|
nbinsx=100,
|
||||||
|
opacity=0.5,
|
||||||
|
name='Train',
|
||||||
|
marker=dict(color='blue')
|
||||||
|
)
|
||||||
|
trace_test = go.Histogram(
|
||||||
|
x=y_test,
|
||||||
|
nbinsx=100,
|
||||||
|
opacity=0.5,
|
||||||
|
name='Test',
|
||||||
|
marker=dict(color='orange')
|
||||||
|
)
|
||||||
|
data = [trace_train, trace_test]
|
||||||
|
layout = go.Layout(
|
||||||
|
title='Distribution of Target Variable (Close Price)',
|
||||||
|
xaxis=dict(title='BTC Close Price'),
|
||||||
|
yaxis=dict(title='Frequency'),
|
||||||
|
barmode='overlay'
|
||||||
|
)
|
||||||
|
fig = go.Figure(data=data, layout=layout)
|
||||||
|
pyo.plot(fig, auto_open=False)
|
||||||
|
|
||||||
|
def plot_predicted_vs_actual_log_returns(y_test, test_preds, timestamps=None, n_plot=200):
|
||||||
|
import plotly.offline as pyo
|
||||||
|
import plotly.graph_objs as go
|
||||||
|
n_plot = min(n_plot, len(y_test))
|
||||||
|
actual = y_test[:n_plot]
|
||||||
|
predicted = test_preds[:n_plot]
|
||||||
|
if timestamps is not None:
|
||||||
|
x_axis = timestamps[:n_plot]
|
||||||
|
x_label = 'Timestamp'
|
||||||
|
else:
|
||||||
|
x_axis = list(range(n_plot))
|
||||||
|
x_label = 'Index'
|
||||||
|
|
||||||
|
# Line plot: Actual vs Predicted over time
|
||||||
|
trace_actual = go.Scatter(x=x_axis, y=actual, mode='lines', name='Actual')
|
||||||
|
trace_predicted = go.Scatter(x=x_axis, y=predicted, mode='lines', name='Predicted')
|
||||||
|
data_line = [trace_actual, trace_predicted]
|
||||||
|
layout_line = go.Layout(
|
||||||
|
title='Actual vs. Predicted Log Returns (Test Set)',
|
||||||
|
xaxis={'title': x_label},
|
||||||
|
yaxis={'title': 'Log Return'},
|
||||||
|
legend={'x': 0, 'y': 1},
|
||||||
|
margin={'l': 40, 'b': 40, 't': 40, 'r': 10},
|
||||||
|
hovermode='closest'
|
||||||
|
)
|
||||||
|
fig_line = go.Figure(data=data_line, layout=layout_line)
|
||||||
|
pyo.plot(fig_line, filename='charts/log_return_line_plot.html', auto_open=False)
|
||||||
|
|
||||||
|
# Scatter plot: Predicted vs Actual
|
||||||
|
trace_scatter = go.Scatter(
|
||||||
|
x=actual,
|
||||||
|
y=predicted,
|
||||||
|
mode='markers',
|
||||||
|
name='Predicted vs Actual',
|
||||||
|
opacity=0.5
|
||||||
|
)
|
||||||
|
# Diagonal reference line
|
||||||
|
min_val = min(np.min(actual), np.min(predicted))
|
||||||
|
max_val = max(np.max(actual), np.max(predicted))
|
||||||
|
trace_diag = go.Scatter(
|
||||||
|
x=[min_val, max_val],
|
||||||
|
y=[min_val, max_val],
|
||||||
|
mode='lines',
|
||||||
|
name='Ideal',
|
||||||
|
line=dict(dash='dash', color='red')
|
||||||
|
)
|
||||||
|
data_scatter = [trace_scatter, trace_diag]
|
||||||
|
layout_scatter = go.Layout(
|
||||||
|
title='Predicted vs Actual Log Returns (Scatter)',
|
||||||
|
xaxis={'title': 'Actual Log Return'},
|
||||||
|
yaxis={'title': 'Predicted Log Return'},
|
||||||
|
showlegend=True,
|
||||||
|
margin={'l': 40, 'b': 40, 't': 40, 'r': 10},
|
||||||
|
hovermode='closest'
|
||||||
|
)
|
||||||
|
fig_scatter = go.Figure(data=data_scatter, layout=layout_scatter)
|
||||||
|
pyo.plot(fig_scatter, filename='charts/log_return_scatter_plot.html', auto_open=False)
|
||||||
|
|
||||||
|
def plot_predicted_vs_actual_prices(actual_prices, predicted_prices, timestamps=None, n_plot=200):
|
||||||
|
import plotly.offline as pyo
|
||||||
|
import plotly.graph_objs as go
|
||||||
|
n_plot = min(n_plot, len(actual_prices))
|
||||||
|
actual = actual_prices[:n_plot]
|
||||||
|
predicted = predicted_prices[:n_plot]
|
||||||
|
if timestamps is not None:
|
||||||
|
x_axis = timestamps[:n_plot]
|
||||||
|
x_label = 'Timestamp'
|
||||||
|
else:
|
||||||
|
x_axis = list(range(n_plot))
|
||||||
|
x_label = 'Index'
|
||||||
|
|
||||||
|
# Line plot: Actual vs Predicted over time
|
||||||
|
trace_actual = go.Scatter(x=x_axis, y=actual, mode='lines', name='Actual Price')
|
||||||
|
trace_predicted = go.Scatter(x=x_axis, y=predicted, mode='lines', name='Predicted Price')
|
||||||
|
data_line = [trace_actual, trace_predicted]
|
||||||
|
layout_line = go.Layout(
|
||||||
|
title='Actual vs. Predicted BTC Prices (Test Set)',
|
||||||
|
xaxis={'title': x_label},
|
||||||
|
yaxis={'title': 'BTC Price'},
|
||||||
|
legend={'x': 0, 'y': 1},
|
||||||
|
margin={'l': 40, 'b': 40, 't': 40, 'r': 10},
|
||||||
|
hovermode='closest'
|
||||||
|
)
|
||||||
|
fig_line = go.Figure(data=data_line, layout=layout_line)
|
||||||
|
pyo.plot(fig_line, filename='charts/price_line_plot.html', auto_open=False)
|
||||||
|
|
||||||
|
# Scatter plot: Predicted vs Actual
|
||||||
|
trace_scatter = go.Scatter(
|
||||||
|
x=actual,
|
||||||
|
y=predicted,
|
||||||
|
mode='markers',
|
||||||
|
name='Predicted vs Actual',
|
||||||
|
opacity=0.5
|
||||||
|
)
|
||||||
|
# Diagonal reference line
|
||||||
|
min_val = min(np.min(actual), np.min(predicted))
|
||||||
|
max_val = max(np.max(actual), np.max(predicted))
|
||||||
|
trace_diag = go.Scatter(
|
||||||
|
x=[min_val, max_val],
|
||||||
|
y=[min_val, max_val],
|
||||||
|
mode='lines',
|
||||||
|
name='Ideal',
|
||||||
|
line=dict(dash='dash', color='red')
|
||||||
|
)
|
||||||
|
data_scatter = [trace_scatter, trace_diag]
|
||||||
|
layout_scatter = go.Layout(
|
||||||
|
title='Predicted vs Actual Prices (Scatter)',
|
||||||
|
xaxis={'title': 'Actual Price'},
|
||||||
|
yaxis={'title': 'Predicted Price'},
|
||||||
|
showlegend=True,
|
||||||
|
margin={'l': 40, 'b': 40, 't': 40, 'r': 10},
|
||||||
|
hovermode='closest'
|
||||||
|
)
|
||||||
|
fig_scatter = go.Figure(data=data_scatter, layout=layout_scatter)
|
||||||
|
pyo.plot(fig_scatter, filename='charts/price_scatter_plot.html', auto_open=False)
|
||||||
|
|
||||||
|
def plot_prediction_error_distribution(predicted_prices, actual_prices, nbins=100, prefix=""):
|
||||||
|
"""
|
||||||
|
Plots the distribution of signed prediction errors between predicted and actual prices,
|
||||||
|
coloring negative errors (under-prediction) and positive errors (over-prediction) differently.
|
||||||
|
"""
|
||||||
|
import plotly.offline as pyo
|
||||||
|
import plotly.graph_objs as go
|
||||||
|
errors = np.array(predicted_prices) - np.array(actual_prices)
|
||||||
|
|
||||||
|
# Separate negative and positive errors
|
||||||
|
neg_errors = errors[errors < 0]
|
||||||
|
pos_errors = errors[errors >= 0]
|
||||||
|
|
||||||
|
# Calculate common bin edges
|
||||||
|
min_error = np.min(errors)
|
||||||
|
max_error = np.max(errors)
|
||||||
|
bin_edges = np.linspace(min_error, max_error, nbins + 1)
|
||||||
|
xbins = dict(start=min_error, end=max_error, size=(max_error - min_error) / nbins)
|
||||||
|
|
||||||
|
trace_neg = go.Histogram(
|
||||||
|
x=neg_errors,
|
||||||
|
opacity=0.75,
|
||||||
|
marker=dict(color='blue'),
|
||||||
|
name='Negative Error (Under-prediction)',
|
||||||
|
xbins=xbins
|
||||||
|
)
|
||||||
|
trace_pos = go.Histogram(
|
||||||
|
x=pos_errors,
|
||||||
|
opacity=0.75,
|
||||||
|
marker=dict(color='orange'),
|
||||||
|
name='Positive Error (Over-prediction)',
|
||||||
|
xbins=xbins
|
||||||
|
)
|
||||||
|
layout = go.Layout(
|
||||||
|
title='Distribution of Prediction Errors (Signed)',
|
||||||
|
xaxis=dict(title='Prediction Error (Predicted - Actual)'),
|
||||||
|
yaxis=dict(title='Frequency'),
|
||||||
|
barmode='overlay',
|
||||||
|
bargap=0.05
|
||||||
|
)
|
||||||
|
fig = go.Figure(data=[trace_neg, trace_pos], layout=layout)
|
||||||
|
filename = f'charts/{prefix}_prediction_error_distribution.html'
|
||||||
|
pyo.plot(fig, filename=filename, auto_open=False)
|
||||||
|
|
||||||
|
def plot_directional_accuracy(actual_prices, predicted_prices, timestamps=None, n_plot=200):
|
||||||
|
"""
|
||||||
|
Plots the directional accuracy of predictions compared to actual price movements.
|
||||||
|
Shows whether the predicted direction matches the actual direction of price movement.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actual_prices: Array of actual price values
|
||||||
|
predicted_prices: Array of predicted price values
|
||||||
|
timestamps: Optional array of timestamps for x-axis
|
||||||
|
n_plot: Number of points to plot (default 200, plots last n_plot points)
|
||||||
|
"""
|
||||||
|
import plotly.graph_objs as go
|
||||||
|
import plotly.offline as pyo
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Calculate price changes
|
||||||
|
actual_changes = np.diff(actual_prices)
|
||||||
|
predicted_changes = np.diff(predicted_prices)
|
||||||
|
|
||||||
|
# Determine if directions match
|
||||||
|
actual_direction = np.sign(actual_changes)
|
||||||
|
predicted_direction = np.sign(predicted_changes)
|
||||||
|
correct_direction = actual_direction == predicted_direction
|
||||||
|
|
||||||
|
# Get last n_plot points
|
||||||
|
actual_changes = actual_changes[-n_plot:]
|
||||||
|
predicted_changes = predicted_changes[-n_plot:]
|
||||||
|
correct_direction = correct_direction[-n_plot:]
|
||||||
|
|
||||||
|
if timestamps is not None:
|
||||||
|
x_values = timestamps[1:] # Skip first since we took diff
|
||||||
|
x_values = x_values[-n_plot:] # Get last n_plot points
|
||||||
|
else:
|
||||||
|
x_values = list(range(len(actual_changes)))
|
||||||
|
|
||||||
|
# Create traces for correct and incorrect predictions
|
||||||
|
correct_trace = go.Scatter(
|
||||||
|
x=np.array(x_values)[correct_direction],
|
||||||
|
y=actual_changes[correct_direction],
|
||||||
|
mode='markers',
|
||||||
|
name='Correct Direction',
|
||||||
|
marker=dict(color='green', size=8)
|
||||||
|
)
|
||||||
|
|
||||||
|
incorrect_trace = go.Scatter(
|
||||||
|
x=np.array(x_values)[~correct_direction],
|
||||||
|
y=actual_changes[~correct_direction],
|
||||||
|
mode='markers',
|
||||||
|
name='Incorrect Direction',
|
||||||
|
marker=dict(color='red', size=8)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate accuracy percentage
|
||||||
|
accuracy = np.mean(correct_direction) * 100
|
||||||
|
|
||||||
|
layout = go.Layout(
|
||||||
|
title=f'Directional Accuracy (Overall: {accuracy:.1f}%)',
|
||||||
|
xaxis=dict(title='Time' if timestamps is not None else 'Sample'),
|
||||||
|
yaxis=dict(title='Price Change'),
|
||||||
|
showlegend=True
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = go.Figure(data=[correct_trace, incorrect_trace], layout=layout)
|
||||||
|
pyo.plot(fig, filename='charts/directional_accuracy.html', auto_open=False)
|
||||||
|
|
||||||
|
def plot_direction_transition_heatmap(actual_prices, predicted_prices, prefix=""):
|
||||||
|
"""
|
||||||
|
Plots a heatmap showing the frequency of each (actual, predicted) direction pair.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import plotly.graph_objs as go
|
||||||
|
import plotly.offline as pyo
|
||||||
|
|
||||||
|
# Calculate directions
|
||||||
|
actual_direction = np.sign(np.diff(actual_prices))
|
||||||
|
predicted_direction = np.sign(np.diff(predicted_prices))
|
||||||
|
|
||||||
|
# Build 3x3 matrix: rows=actual, cols=predicted, values=counts
|
||||||
|
# Map -1 -> 0, 0 -> 1, 1 -> 2 for indexing
|
||||||
|
mapping = {-1: 0, 0: 1, 1: 2}
|
||||||
|
matrix = np.zeros((3, 3), dtype=int)
|
||||||
|
for a, p in zip(actual_direction, predicted_direction):
|
||||||
|
matrix[mapping[a], mapping[p]] += 1
|
||||||
|
|
||||||
|
# Axis labels
|
||||||
|
directions = ['Down (-1)', 'No Change (0)', 'Up (+1)']
|
||||||
|
|
||||||
|
# Plot heatmap
|
||||||
|
heatmap = go.Heatmap(
|
||||||
|
z=matrix,
|
||||||
|
x=directions, # predicted
|
||||||
|
y=directions, # actual
|
||||||
|
colorscale='Viridis',
|
||||||
|
colorbar=dict(title='Count')
|
||||||
|
)
|
||||||
|
layout = go.Layout(
|
||||||
|
title='Direction Prediction Transition Matrix',
|
||||||
|
xaxis=dict(title='Predicted Direction'),
|
||||||
|
yaxis=dict(title='Actual Direction')
|
||||||
|
)
|
||||||
|
fig = go.Figure(data=[heatmap], layout=layout)
|
||||||
|
filename = f'charts/{prefix}_direction_transition_heatmap.html'
|
||||||
|
pyo.plot(fig, filename=filename, auto_open=False)
|
||||||
|
|
||||||
Reference in New Issue
Block a user