Compare commits
6 Commits
120c366576
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
289d11b0a8 | ||
|
|
70da858aac | ||
|
|
add3fbcf19 | ||
|
|
a419764fff | ||
|
|
b56d9ea3a1 | ||
|
|
3e08802194 |
61
.cursor/rules/always-global.mdc
Normal file
61
.cursor/rules/always-global.mdc
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
---
|
||||||
|
description: Global development standards and AI interaction principles
|
||||||
|
globs:
|
||||||
|
alwaysApply: true
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Always Apply - Global Development Standards
|
||||||
|
|
||||||
|
## AI Interaction Principles
|
||||||
|
|
||||||
|
### Step-by-Step Development
|
||||||
|
- **NEVER** generate large blocks of code without explanation
|
||||||
|
- **ALWAYS** ask "provide your plan in a concise bullet list and wait for my confirmation before proceeding"
|
||||||
|
- Break complex tasks into smaller, manageable pieces (≤250 lines per file, ≤50 lines per function)
|
||||||
|
- Explain your reasoning step-by-step before writing code
|
||||||
|
- Wait for explicit approval before moving to the next sub-task
|
||||||
|
|
||||||
|
### Context Awareness
|
||||||
|
- **ALWAYS** reference existing code patterns and data structures before suggesting new approaches
|
||||||
|
- Ask about existing conventions before implementing new functionality
|
||||||
|
- Preserve established architectural decisions unless explicitly asked to change them
|
||||||
|
- Maintain consistency with existing naming conventions and code style
|
||||||
|
|
||||||
|
## Code Quality Standards
|
||||||
|
|
||||||
|
### File and Function Limits
|
||||||
|
- **Maximum file size**: 250 lines
|
||||||
|
- **Maximum function size**: 50 lines
|
||||||
|
- **Maximum complexity**: If a function does more than one main thing, break it down
|
||||||
|
- **Naming**: Use clear, descriptive names that explain purpose
|
||||||
|
|
||||||
|
### Documentation Requirements
|
||||||
|
- **Every public function** must have a docstring explaining purpose, parameters, and return value
|
||||||
|
- **Every class** must have a class-level docstring
|
||||||
|
- **Complex logic** must have inline comments explaining the "why", not just the "what"
|
||||||
|
- **API endpoints** must be documented with request/response examples
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
- **ALWAYS** include proper error handling for external dependencies
|
||||||
|
- **NEVER** use bare except clauses
|
||||||
|
- Provide meaningful error messages that help with debugging
|
||||||
|
- Log errors appropriately for the application context
|
||||||
|
|
||||||
|
## Security and Best Practices
|
||||||
|
- **NEVER** hardcode credentials, API keys, or sensitive data
|
||||||
|
- **ALWAYS** validate user inputs
|
||||||
|
- Use parameterized queries for database operations
|
||||||
|
- Follow the principle of least privilege
|
||||||
|
- Implement proper authentication and authorization
|
||||||
|
|
||||||
|
## Testing Requirements
|
||||||
|
- **Every implementation** should have corresponding unit tests
|
||||||
|
- **Every API endpoint** should have integration tests
|
||||||
|
- Test files should be placed alongside the code they test
|
||||||
|
- Use descriptive test names that explain what is being tested
|
||||||
|
|
||||||
|
## Response Format
|
||||||
|
- Be concise and avoid unnecessary repetition
|
||||||
|
- Focus on actionable information
|
||||||
|
- Provide examples when explaining complex concepts
|
||||||
|
- Ask clarifying questions when requirements are ambiguous
|
||||||
237
.cursor/rules/architecture.mdc
Normal file
237
.cursor/rules/architecture.mdc
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
---
|
||||||
|
description: Modular design principles and architecture guidelines for scalable development
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Architecture and Modular Design
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Maintain a clean, modular architecture that scales effectively and prevents the complexity issues that arise in AI-assisted development.
|
||||||
|
|
||||||
|
## Core Architecture Principles
|
||||||
|
|
||||||
|
### 1. Modular Design
|
||||||
|
- **Single Responsibility**: Each module has one clear purpose
|
||||||
|
- **Loose Coupling**: Modules depend on interfaces, not implementations
|
||||||
|
- **High Cohesion**: Related functionality is grouped together
|
||||||
|
- **Clear Boundaries**: Module interfaces are well-defined and stable
|
||||||
|
|
||||||
|
### 2. Size Constraints
|
||||||
|
- **Files**: Maximum 250 lines per file
|
||||||
|
- **Functions**: Maximum 50 lines per function
|
||||||
|
- **Classes**: Maximum 300 lines per class
|
||||||
|
- **Modules**: Maximum 10 public functions/classes per module
|
||||||
|
|
||||||
|
### 3. Dependency Management
|
||||||
|
- **Layer Dependencies**: Higher layers depend on lower layers only
|
||||||
|
- **No Circular Dependencies**: Modules cannot depend on each other cyclically
|
||||||
|
- **Interface Segregation**: Depend on specific interfaces, not broad ones
|
||||||
|
- **Dependency Injection**: Pass dependencies rather than creating them internally
|
||||||
|
|
||||||
|
## Modular Architecture Patterns
|
||||||
|
|
||||||
|
### Layer Structure
|
||||||
|
```
|
||||||
|
src/
|
||||||
|
├── presentation/ # UI, API endpoints, CLI interfaces
|
||||||
|
├── application/ # Business logic, use cases, workflows
|
||||||
|
├── domain/ # Core business entities and rules
|
||||||
|
├── infrastructure/ # Database, external APIs, file systems
|
||||||
|
└── shared/ # Common utilities, constants, types
|
||||||
|
```
|
||||||
|
|
||||||
|
### Module Organization
|
||||||
|
```
|
||||||
|
module_name/
|
||||||
|
├── __init__.py # Public interface exports
|
||||||
|
├── core.py # Main module logic
|
||||||
|
├── types.py # Type definitions and interfaces
|
||||||
|
├── utils.py # Module-specific utilities
|
||||||
|
├── tests/ # Module tests
|
||||||
|
└── README.md # Module documentation
|
||||||
|
```
|
||||||
|
|
||||||
|
## Design Patterns for AI Development
|
||||||
|
|
||||||
|
### 1. Repository Pattern
|
||||||
|
Separate data access from business logic:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Domain interface
|
||||||
|
class UserRepository:
|
||||||
|
def get_by_id(self, user_id: str) -> User: ...
|
||||||
|
def save(self, user: User) -> None: ...
|
||||||
|
|
||||||
|
# Infrastructure implementation
|
||||||
|
class SqlUserRepository(UserRepository):
|
||||||
|
def get_by_id(self, user_id: str) -> User:
|
||||||
|
# Database-specific implementation
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Service Pattern
|
||||||
|
Encapsulate business logic in focused services:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class UserService:
|
||||||
|
def __init__(self, user_repo: UserRepository):
|
||||||
|
self._user_repo = user_repo
|
||||||
|
|
||||||
|
def create_user(self, data: UserData) -> User:
|
||||||
|
# Validation and business logic
|
||||||
|
# Single responsibility: user creation
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Factory Pattern
|
||||||
|
Create complex objects with clear interfaces:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class DatabaseFactory:
|
||||||
|
@staticmethod
|
||||||
|
def create_connection(config: DatabaseConfig) -> Connection:
|
||||||
|
# Handle different database types
|
||||||
|
# Encapsulate connection complexity
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture Decision Guidelines
|
||||||
|
|
||||||
|
### When to Create New Modules
|
||||||
|
Create a new module when:
|
||||||
|
- **Functionality** exceeds size constraints (250 lines)
|
||||||
|
- **Responsibility** is distinct from existing modules
|
||||||
|
- **Dependencies** would create circular references
|
||||||
|
- **Reusability** would benefit other parts of the system
|
||||||
|
- **Testing** requires isolated test environments
|
||||||
|
|
||||||
|
### When to Split Existing Modules
|
||||||
|
Split modules when:
|
||||||
|
- **File size** exceeds 250 lines
|
||||||
|
- **Multiple responsibilities** are evident
|
||||||
|
- **Testing** becomes difficult due to complexity
|
||||||
|
- **Dependencies** become too numerous
|
||||||
|
- **Change frequency** differs significantly between parts
|
||||||
|
|
||||||
|
### Module Interface Design
|
||||||
|
```python
|
||||||
|
# Good: Clear, focused interface
|
||||||
|
class PaymentProcessor:
|
||||||
|
def process_payment(self, amount: Money, method: PaymentMethod) -> PaymentResult:
|
||||||
|
"""Process a single payment transaction."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Bad: Unfocused, kitchen-sink interface
|
||||||
|
class PaymentManager:
|
||||||
|
def process_payment(self, ...): pass
|
||||||
|
def validate_card(self, ...): pass
|
||||||
|
def send_receipt(self, ...): pass
|
||||||
|
def update_inventory(self, ...): pass # Wrong responsibility!
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture Validation
|
||||||
|
|
||||||
|
### Architecture Review Checklist
|
||||||
|
- [ ] **Dependencies flow in one direction** (no cycles)
|
||||||
|
- [ ] **Layers are respected** (presentation doesn't call infrastructure directly)
|
||||||
|
- [ ] **Modules have single responsibility**
|
||||||
|
- [ ] **Interfaces are stable** and well-defined
|
||||||
|
- [ ] **Size constraints** are maintained
|
||||||
|
- [ ] **Testing** is straightforward for each module
|
||||||
|
|
||||||
|
### Red Flags
|
||||||
|
- **God Objects**: Classes/modules that do too many things
|
||||||
|
- **Circular Dependencies**: Modules that depend on each other
|
||||||
|
- **Deep Inheritance**: More than 3 levels of inheritance
|
||||||
|
- **Large Interfaces**: Interfaces with more than 7 methods
|
||||||
|
- **Tight Coupling**: Modules that know too much about each other's internals
|
||||||
|
|
||||||
|
## Refactoring Guidelines
|
||||||
|
|
||||||
|
### When to Refactor
|
||||||
|
- Module exceeds size constraints
|
||||||
|
- Code duplication across modules
|
||||||
|
- Difficult to test individual components
|
||||||
|
- New features require changing multiple unrelated modules
|
||||||
|
- Performance bottlenecks due to poor separation
|
||||||
|
|
||||||
|
### Refactoring Process
|
||||||
|
1. **Identify** the specific architectural problem
|
||||||
|
2. **Design** the target architecture
|
||||||
|
3. **Create tests** to verify current behavior
|
||||||
|
4. **Implement changes** incrementally
|
||||||
|
5. **Validate** that tests still pass
|
||||||
|
6. **Update documentation** to reflect changes
|
||||||
|
|
||||||
|
### Safe Refactoring Practices
|
||||||
|
- **One change at a time**: Don't mix refactoring with new features
|
||||||
|
- **Tests first**: Ensure comprehensive test coverage before refactoring
|
||||||
|
- **Incremental changes**: Small steps with verification at each stage
|
||||||
|
- **Backward compatibility**: Maintain existing interfaces during transition
|
||||||
|
- **Documentation updates**: Keep architecture documentation current
|
||||||
|
|
||||||
|
## Architecture Documentation
|
||||||
|
|
||||||
|
### Architecture Decision Records (ADRs)
|
||||||
|
Document significant decisions in `./docs/decisions/`:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
# ADR-003: Service Layer Architecture
|
||||||
|
|
||||||
|
## Status
|
||||||
|
Accepted
|
||||||
|
|
||||||
|
## Context
|
||||||
|
As the application grows, business logic is scattered across controllers and models.
|
||||||
|
|
||||||
|
## Decision
|
||||||
|
Implement a service layer to encapsulate business logic.
|
||||||
|
|
||||||
|
## Consequences
|
||||||
|
**Positive:**
|
||||||
|
- Clear separation of concerns
|
||||||
|
- Easier testing of business logic
|
||||||
|
- Better reusability across different interfaces
|
||||||
|
|
||||||
|
**Negative:**
|
||||||
|
- Additional abstraction layer
|
||||||
|
- More files to maintain
|
||||||
|
```
|
||||||
|
|
||||||
|
### Module Documentation Template
|
||||||
|
```markdown
|
||||||
|
# Module: [Name]
|
||||||
|
|
||||||
|
## Purpose
|
||||||
|
What this module does and why it exists.
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
- **Imports from**: List of modules this depends on
|
||||||
|
- **Used by**: List of modules that depend on this one
|
||||||
|
- **External**: Third-party dependencies
|
||||||
|
|
||||||
|
## Public Interface
|
||||||
|
```python
|
||||||
|
# Key functions and classes exposed by this module
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture Notes
|
||||||
|
- Design patterns used
|
||||||
|
- Important architectural decisions
|
||||||
|
- Known limitations or constraints
|
||||||
|
```
|
||||||
|
|
||||||
|
## Migration Strategies
|
||||||
|
|
||||||
|
### Legacy Code Integration
|
||||||
|
- **Strangler Fig Pattern**: Gradually replace old code with new modules
|
||||||
|
- **Adapter Pattern**: Create interfaces to integrate old and new code
|
||||||
|
- **Facade Pattern**: Simplify complex legacy interfaces
|
||||||
|
|
||||||
|
### Gradual Modernization
|
||||||
|
1. **Identify boundaries** in existing code
|
||||||
|
2. **Extract modules** one at a time
|
||||||
|
3. **Create interfaces** for each extracted module
|
||||||
|
4. **Test thoroughly** at each step
|
||||||
|
5. **Update documentation** continuously
|
||||||
123
.cursor/rules/code-review.mdc
Normal file
123
.cursor/rules/code-review.mdc
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
---
|
||||||
|
description: AI-generated code review checklist and quality assurance guidelines
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Code Review and Quality Assurance
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Establish systematic review processes for AI-generated code to maintain quality, security, and maintainability standards.
|
||||||
|
|
||||||
|
## AI Code Review Checklist
|
||||||
|
|
||||||
|
### Pre-Implementation Review
|
||||||
|
Before accepting any AI-generated code:
|
||||||
|
|
||||||
|
1. **Understand the Code**
|
||||||
|
- [ ] Can you explain what the code does in your own words?
|
||||||
|
- [ ] Do you understand each function and its purpose?
|
||||||
|
- [ ] Are there any "magic" values or unexplained logic?
|
||||||
|
- [ ] Does the code solve the actual problem stated?
|
||||||
|
|
||||||
|
2. **Architecture Alignment**
|
||||||
|
- [ ] Does the code follow established project patterns?
|
||||||
|
- [ ] Is it consistent with existing data structures?
|
||||||
|
- [ ] Does it integrate cleanly with existing components?
|
||||||
|
- [ ] Are new dependencies justified and necessary?
|
||||||
|
|
||||||
|
3. **Code Quality**
|
||||||
|
- [ ] Are functions smaller than 50 lines?
|
||||||
|
- [ ] Are files smaller than 250 lines?
|
||||||
|
- [ ] Are variable and function names descriptive?
|
||||||
|
- [ ] Is the code DRY (Don't Repeat Yourself)?
|
||||||
|
|
||||||
|
### Security Review
|
||||||
|
- [ ] **Input Validation**: All user inputs are validated and sanitized
|
||||||
|
- [ ] **Authentication**: Proper authentication checks are in place
|
||||||
|
- [ ] **Authorization**: Access controls are implemented correctly
|
||||||
|
- [ ] **Data Protection**: Sensitive data is handled securely
|
||||||
|
- [ ] **SQL Injection**: Database queries use parameterized statements
|
||||||
|
- [ ] **XSS Prevention**: Output is properly escaped
|
||||||
|
- [ ] **Error Handling**: Errors don't leak sensitive information
|
||||||
|
|
||||||
|
### Integration Review
|
||||||
|
- [ ] **Existing Functionality**: New code doesn't break existing features
|
||||||
|
- [ ] **Data Consistency**: Database changes maintain referential integrity
|
||||||
|
- [ ] **API Compatibility**: Changes don't break existing API contracts
|
||||||
|
- [ ] **Performance Impact**: New code doesn't introduce performance bottlenecks
|
||||||
|
- [ ] **Testing Coverage**: Appropriate tests are included
|
||||||
|
|
||||||
|
## Review Process
|
||||||
|
|
||||||
|
### Step 1: Initial Code Analysis
|
||||||
|
1. **Read through the entire generated code** before running it
|
||||||
|
2. **Identify patterns** that don't match existing codebase
|
||||||
|
3. **Check dependencies** - are new packages really needed?
|
||||||
|
4. **Verify logic flow** - does the algorithm make sense?
|
||||||
|
|
||||||
|
### Step 2: Security and Error Handling Review
|
||||||
|
1. **Trace data flow** from input to output
|
||||||
|
2. **Identify potential failure points** and verify error handling
|
||||||
|
3. **Check for security vulnerabilities** using the security checklist
|
||||||
|
4. **Verify proper logging** and monitoring implementation
|
||||||
|
|
||||||
|
### Step 3: Integration Testing
|
||||||
|
1. **Test with existing code** to ensure compatibility
|
||||||
|
2. **Run existing test suite** to verify no regressions
|
||||||
|
3. **Test edge cases** and error conditions
|
||||||
|
4. **Verify performance** under realistic conditions
|
||||||
|
|
||||||
|
## Common AI Code Issues to Watch For
|
||||||
|
|
||||||
|
### Overcomplication Patterns
|
||||||
|
- **Unnecessary abstractions**: AI creating complex patterns for simple tasks
|
||||||
|
- **Over-engineering**: Solutions that are more complex than needed
|
||||||
|
- **Redundant code**: AI recreating existing functionality
|
||||||
|
- **Inappropriate design patterns**: Using patterns that don't fit the use case
|
||||||
|
|
||||||
|
### Context Loss Indicators
|
||||||
|
- **Inconsistent naming**: Different conventions from existing code
|
||||||
|
- **Wrong data structures**: Using different patterns than established
|
||||||
|
- **Ignored existing functions**: Reimplementing existing functionality
|
||||||
|
- **Architectural misalignment**: Code that doesn't fit the overall design
|
||||||
|
|
||||||
|
### Technical Debt Indicators
|
||||||
|
- **Magic numbers**: Hardcoded values without explanation
|
||||||
|
- **Poor error messages**: Generic or unhelpful error handling
|
||||||
|
- **Missing documentation**: Code without adequate comments
|
||||||
|
- **Tight coupling**: Components that are too interdependent
|
||||||
|
|
||||||
|
## Quality Gates
|
||||||
|
|
||||||
|
### Mandatory Reviews
|
||||||
|
All AI-generated code must pass these gates before acceptance:
|
||||||
|
|
||||||
|
1. **Security Review**: No security vulnerabilities detected
|
||||||
|
2. **Integration Review**: Integrates cleanly with existing code
|
||||||
|
3. **Performance Review**: Meets performance requirements
|
||||||
|
4. **Maintainability Review**: Code can be easily modified by team members
|
||||||
|
5. **Documentation Review**: Adequate documentation is provided
|
||||||
|
|
||||||
|
### Acceptance Criteria
|
||||||
|
- [ ] Code is understandable by any team member
|
||||||
|
- [ ] Integration requires minimal changes to existing code
|
||||||
|
- [ ] Security review passes all checks
|
||||||
|
- [ ] Performance meets established benchmarks
|
||||||
|
- [ ] Documentation is complete and accurate
|
||||||
|
|
||||||
|
## Rejection Criteria
|
||||||
|
Reject AI-generated code if:
|
||||||
|
- Security vulnerabilities are present
|
||||||
|
- Code is too complex for the problem being solved
|
||||||
|
- Integration requires major refactoring of existing code
|
||||||
|
- Code duplicates existing functionality without justification
|
||||||
|
- Documentation is missing or inadequate
|
||||||
|
|
||||||
|
## Review Documentation
|
||||||
|
For each review, document:
|
||||||
|
- Issues found and how they were resolved
|
||||||
|
- Performance impact assessment
|
||||||
|
- Security concerns and mitigations
|
||||||
|
- Integration challenges and solutions
|
||||||
|
- Recommendations for future similar tasks
|
||||||
93
.cursor/rules/context-management.mdc
Normal file
93
.cursor/rules/context-management.mdc
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
---
|
||||||
|
description: Context management for maintaining codebase awareness and preventing context drift
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Context Management
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Maintain comprehensive project context to prevent context drift and ensure AI-generated code integrates seamlessly with existing codebase patterns and architecture.
|
||||||
|
|
||||||
|
## Context Documentation Requirements
|
||||||
|
|
||||||
|
### PRD.md file documentation
|
||||||
|
1. **Project Overview**
|
||||||
|
- Business objectives and goals
|
||||||
|
- Target users and use cases
|
||||||
|
- Key success metrics
|
||||||
|
|
||||||
|
### CONTEXT.md File Structure
|
||||||
|
Every project must maintain a `CONTEXT.md` file in the root directory with:
|
||||||
|
|
||||||
|
1. **Architecture Overview**
|
||||||
|
- High-level system architecture
|
||||||
|
- Key design patterns used
|
||||||
|
- Database schema overview
|
||||||
|
- API structure and conventions
|
||||||
|
|
||||||
|
2. **Technology Stack**
|
||||||
|
- Programming languages and versions
|
||||||
|
- Frameworks and libraries
|
||||||
|
- Database systems
|
||||||
|
- Development and deployment tools
|
||||||
|
|
||||||
|
3. **Coding Conventions**
|
||||||
|
- Naming conventions
|
||||||
|
- File organization patterns
|
||||||
|
- Code structure preferences
|
||||||
|
- Import/export patterns
|
||||||
|
|
||||||
|
4. **Current Implementation Status**
|
||||||
|
- Completed features
|
||||||
|
- Work in progress
|
||||||
|
- Known technical debt
|
||||||
|
- Planned improvements
|
||||||
|
|
||||||
|
## Context Maintenance Protocol
|
||||||
|
|
||||||
|
### Before Every Coding Session
|
||||||
|
1. **Review CONTEXT.md and PRD.md** to understand current project state
|
||||||
|
2. **Scan recent changes** in git history to understand latest patterns
|
||||||
|
3. **Identify existing patterns** for similar functionality before implementing new features
|
||||||
|
4. **Ask for clarification** if existing patterns are unclear or conflicting
|
||||||
|
|
||||||
|
### During Development
|
||||||
|
1. **Reference existing code** when explaining implementation approaches
|
||||||
|
2. **Maintain consistency** with established patterns and conventions
|
||||||
|
3. **Update CONTEXT.md** when making architectural decisions
|
||||||
|
4. **Document deviations** from established patterns with reasoning
|
||||||
|
|
||||||
|
### Context Preservation Strategies
|
||||||
|
- **Incremental development**: Build on existing patterns rather than creating new ones
|
||||||
|
- **Pattern consistency**: Use established data structures and function signatures
|
||||||
|
- **Integration awareness**: Consider how new code affects existing functionality
|
||||||
|
- **Dependency management**: Understand existing dependencies before adding new ones
|
||||||
|
|
||||||
|
## Context Prompting Best Practices
|
||||||
|
|
||||||
|
### Effective Context Sharing
|
||||||
|
- Include relevant sections of CONTEXT.md in prompts for complex tasks
|
||||||
|
- Reference specific existing files when asking for similar functionality
|
||||||
|
- Provide examples of existing patterns when requesting new implementations
|
||||||
|
- Share recent git commit messages to understand latest changes
|
||||||
|
|
||||||
|
### Context Window Optimization
|
||||||
|
- Prioritize most relevant context for current task
|
||||||
|
- Use @filename references to include specific files
|
||||||
|
- Break large contexts into focused, task-specific chunks
|
||||||
|
- Update context references as project evolves
|
||||||
|
|
||||||
|
## Red Flags - Context Loss Indicators
|
||||||
|
- AI suggests patterns that conflict with existing code
|
||||||
|
- New implementations ignore established conventions
|
||||||
|
- Proposed solutions don't integrate with existing architecture
|
||||||
|
- Code suggestions require significant refactoring of existing functionality
|
||||||
|
|
||||||
|
## Recovery Protocol
|
||||||
|
When context loss is detected:
|
||||||
|
1. **Stop development** and review CONTEXT.md
|
||||||
|
2. **Analyze existing codebase** for established patterns
|
||||||
|
3. **Update context documentation** with missing information
|
||||||
|
4. **Restart task** with proper context provided
|
||||||
|
5. **Test integration** with existing code before proceeding
|
||||||
67
.cursor/rules/create-prd.mdc
Normal file
67
.cursor/rules/create-prd.mdc
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
---
|
||||||
|
description: Creating PRD for a project or specific task/function
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
---
|
||||||
|
description: Creating PRD for a project or specific task/function
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
# Rule: Generating a Product Requirements Document (PRD)
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
To guide an AI assistant in creating a detailed Product Requirements Document (PRD) in Markdown format, based on an initial user prompt. The PRD should be clear, actionable, and suitable for a junior developer to understand and implement the feature.
|
||||||
|
|
||||||
|
## Process
|
||||||
|
|
||||||
|
1. **Receive Initial Prompt:** The user provides a brief description or request for a new feature or functionality.
|
||||||
|
2. **Ask Clarifying Questions:** Before writing the PRD, the AI *must* ask clarifying questions to gather sufficient detail. The goal is to understand the "what" and "why" of the feature, not necessarily the "how" (which the developer will figure out).
|
||||||
|
3. **Generate PRD:** Based on the initial prompt and the user's answers to the clarifying questions, generate a PRD using the structure outlined below.
|
||||||
|
4. **Save PRD:** Save the generated document as `prd-[feature-name].md` inside the `/tasks` directory.
|
||||||
|
|
||||||
|
## Clarifying Questions (Examples)
|
||||||
|
|
||||||
|
The AI should adapt its questions based on the prompt, but here are some common areas to explore:
|
||||||
|
|
||||||
|
* **Problem/Goal:** "What problem does this feature solve for the user?" or "What is the main goal we want to achieve with this feature?"
|
||||||
|
* **Target User:** "Who is the primary user of this feature?"
|
||||||
|
* **Core Functionality:** "Can you describe the key actions a user should be able to perform with this feature?"
|
||||||
|
* **User Stories:** "Could you provide a few user stories? (e.g., As a [type of user], I want to [perform an action] so that [benefit].)"
|
||||||
|
* **Acceptance Criteria:** "How will we know when this feature is successfully implemented? What are the key success criteria?"
|
||||||
|
* **Scope/Boundaries:** "Are there any specific things this feature *should not* do (non-goals)?"
|
||||||
|
* **Data Requirements:** "What kind of data does this feature need to display or manipulate?"
|
||||||
|
* **Design/UI:** "Are there any existing design mockups or UI guidelines to follow?" or "Can you describe the desired look and feel?"
|
||||||
|
* **Edge Cases:** "Are there any potential edge cases or error conditions we should consider?"
|
||||||
|
|
||||||
|
## PRD Structure
|
||||||
|
|
||||||
|
The generated PRD should include the following sections:
|
||||||
|
|
||||||
|
1. **Introduction/Overview:** Briefly describe the feature and the problem it solves. State the goal.
|
||||||
|
2. **Goals:** List the specific, measurable objectives for this feature.
|
||||||
|
3. **User Stories:** Detail the user narratives describing feature usage and benefits.
|
||||||
|
4. **Functional Requirements:** List the specific functionalities the feature must have. Use clear, concise language (e.g., "The system must allow users to upload a profile picture."). Number these requirements.
|
||||||
|
5. **Non-Goals (Out of Scope):** Clearly state what this feature will *not* include to manage scope.
|
||||||
|
6. **Design Considerations (Optional):** Link to mockups, describe UI/UX requirements, or mention relevant components/styles if applicable.
|
||||||
|
7. **Technical Considerations (Optional):** Mention any known technical constraints, dependencies, or suggestions (e.g., "Should integrate with the existing Auth module").
|
||||||
|
8. **Success Metrics:** How will the success of this feature be measured? (e.g., "Increase user engagement by 10%", "Reduce support tickets related to X").
|
||||||
|
9. **Open Questions:** List any remaining questions or areas needing further clarification.
|
||||||
|
|
||||||
|
## Target Audience
|
||||||
|
|
||||||
|
Assume the primary reader of the PRD is a **junior developer**. Therefore, requirements should be explicit, unambiguous, and avoid jargon where possible. Provide enough detail for them to understand the feature's purpose and core logic.
|
||||||
|
|
||||||
|
## Output
|
||||||
|
|
||||||
|
* **Format:** Markdown (`.md`)
|
||||||
|
* **Location:** `/tasks/`
|
||||||
|
* **Filename:** `prd-[feature-name].md`
|
||||||
|
|
||||||
|
## Final instructions
|
||||||
|
|
||||||
|
1. Do NOT start implmenting the PRD
|
||||||
|
2. Make sure to ask the user clarifying questions
|
||||||
|
|
||||||
|
3. Take the user's answers to the clarifying questions and improve the PRD
|
||||||
244
.cursor/rules/documentation.mdc
Normal file
244
.cursor/rules/documentation.mdc
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
---
|
||||||
|
description: Documentation standards for code, architecture, and development decisions
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Documentation Standards
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Maintain comprehensive, up-to-date documentation that supports development, onboarding, and long-term maintenance of the codebase.
|
||||||
|
|
||||||
|
## Documentation Hierarchy
|
||||||
|
|
||||||
|
### 1. Project Level Documentation (in ./docs/)
|
||||||
|
- **README.md**: Project overview, setup instructions, basic usage
|
||||||
|
- **CONTEXT.md**: Current project state, architecture decisions, patterns
|
||||||
|
- **CHANGELOG.md**: Version history and significant changes
|
||||||
|
- **CONTRIBUTING.md**: Development guidelines and processes
|
||||||
|
- **API.md**: API endpoints, request/response formats, authentication
|
||||||
|
|
||||||
|
### 2. Module Level Documentation (in ./docs/modules/)
|
||||||
|
- **[module-name].md**: Purpose, public interfaces, usage examples
|
||||||
|
- **dependencies.md**: External dependencies and their purposes
|
||||||
|
- **architecture.md**: Module relationships and data flow
|
||||||
|
|
||||||
|
### 3. Code Level Documentation
|
||||||
|
- **Docstrings**: Function and class documentation
|
||||||
|
- **Inline comments**: Complex logic explanations
|
||||||
|
- **Type hints**: Clear parameter and return types
|
||||||
|
- **README files**: Directory-specific instructions
|
||||||
|
|
||||||
|
## Documentation Standards
|
||||||
|
|
||||||
|
### Code Documentation
|
||||||
|
```python
|
||||||
|
def process_user_data(user_id: str, data: dict) -> UserResult:
|
||||||
|
"""
|
||||||
|
Process and validate user data before storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Unique identifier for the user
|
||||||
|
data: Dictionary containing user information to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserResult: Processed user data with validation status
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: When user data fails validation
|
||||||
|
DatabaseError: When storage operation fails
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> result = process_user_data("123", {"name": "John", "email": "john@example.com"})
|
||||||
|
>>> print(result.status)
|
||||||
|
'valid'
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
### API Documentation Format
|
||||||
|
```markdown
|
||||||
|
### POST /api/users
|
||||||
|
|
||||||
|
Create a new user account.
|
||||||
|
|
||||||
|
**Request:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "string (required)",
|
||||||
|
"email": "string (required, valid email)",
|
||||||
|
"age": "number (optional, min: 13)"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response (201):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "uuid",
|
||||||
|
"name": "string",
|
||||||
|
"email": "string",
|
||||||
|
"created_at": "iso_datetime"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Errors:**
|
||||||
|
- 400: Invalid input data
|
||||||
|
- 409: Email already exists
|
||||||
|
```
|
||||||
|
|
||||||
|
### Architecture Decision Records (ADRs)
|
||||||
|
Document significant architecture decisions in `./docs/decisions/`:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
# ADR-001: Database Choice - PostgreSQL
|
||||||
|
|
||||||
|
## Status
|
||||||
|
Accepted
|
||||||
|
|
||||||
|
## Context
|
||||||
|
We need to choose a database for storing user data and application state.
|
||||||
|
|
||||||
|
## Decision
|
||||||
|
We will use PostgreSQL as our primary database.
|
||||||
|
|
||||||
|
## Consequences
|
||||||
|
**Positive:**
|
||||||
|
- ACID compliance ensures data integrity
|
||||||
|
- Rich query capabilities with SQL
|
||||||
|
- Good performance for our expected load
|
||||||
|
|
||||||
|
**Negative:**
|
||||||
|
- More complex setup than simpler alternatives
|
||||||
|
- Requires SQL knowledge from team members
|
||||||
|
|
||||||
|
## Alternatives Considered
|
||||||
|
- MongoDB: Rejected due to consistency requirements
|
||||||
|
- SQLite: Rejected due to scalability needs
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation Maintenance
|
||||||
|
|
||||||
|
### When to Update Documentation
|
||||||
|
|
||||||
|
#### Always Update:
|
||||||
|
- **API changes**: Any modification to public interfaces
|
||||||
|
- **Architecture changes**: New patterns, data structures, or workflows
|
||||||
|
- **Configuration changes**: Environment variables, deployment settings
|
||||||
|
- **Dependencies**: Adding, removing, or upgrading packages
|
||||||
|
- **Business logic changes**: Core functionality modifications
|
||||||
|
|
||||||
|
#### Update Weekly:
|
||||||
|
- **CONTEXT.md**: Current development status and priorities
|
||||||
|
- **Known issues**: Bug reports and workarounds
|
||||||
|
- **Performance notes**: Bottlenecks and optimization opportunities
|
||||||
|
|
||||||
|
#### Update per Release:
|
||||||
|
- **CHANGELOG.md**: User-facing changes and improvements
|
||||||
|
- **Version documentation**: Breaking changes and migration guides
|
||||||
|
- **Examples and tutorials**: Keep sample code current
|
||||||
|
|
||||||
|
### Documentation Quality Checklist
|
||||||
|
|
||||||
|
#### Completeness
|
||||||
|
- [ ] Purpose and scope clearly explained
|
||||||
|
- [ ] All public interfaces documented
|
||||||
|
- [ ] Examples provided for complex usage
|
||||||
|
- [ ] Error conditions and handling described
|
||||||
|
- [ ] Dependencies and requirements listed
|
||||||
|
|
||||||
|
#### Accuracy
|
||||||
|
- [ ] Code examples are tested and working
|
||||||
|
- [ ] Links point to correct locations
|
||||||
|
- [ ] Version numbers are current
|
||||||
|
- [ ] Screenshots reflect current UI
|
||||||
|
|
||||||
|
#### Clarity
|
||||||
|
- [ ] Written for the intended audience
|
||||||
|
- [ ] Technical jargon is explained
|
||||||
|
- [ ] Step-by-step instructions are clear
|
||||||
|
- [ ] Visual aids used where helpful
|
||||||
|
|
||||||
|
## Documentation Automation
|
||||||
|
|
||||||
|
### Auto-Generated Documentation
|
||||||
|
- **API docs**: Generate from code annotations
|
||||||
|
- **Type documentation**: Extract from type hints
|
||||||
|
- **Module dependencies**: Auto-update from imports
|
||||||
|
- **Test coverage**: Include coverage reports
|
||||||
|
|
||||||
|
### Documentation Testing
|
||||||
|
```python
|
||||||
|
# Test that code examples in documentation work
|
||||||
|
def test_documentation_examples():
|
||||||
|
"""Verify code examples in docs actually work."""
|
||||||
|
# Test examples from README.md
|
||||||
|
# Test API examples from docs/API.md
|
||||||
|
# Test configuration examples
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation Templates
|
||||||
|
|
||||||
|
### New Module Documentation Template
|
||||||
|
```markdown
|
||||||
|
# Module: [Name]
|
||||||
|
|
||||||
|
## Purpose
|
||||||
|
Brief description of what this module does and why it exists.
|
||||||
|
|
||||||
|
## Public Interface
|
||||||
|
### Functions
|
||||||
|
- `function_name(params)`: Description and example
|
||||||
|
|
||||||
|
### Classes
|
||||||
|
- `ClassName`: Purpose and basic usage
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
```python
|
||||||
|
# Basic usage example
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
- Internal: List of internal modules this depends on
|
||||||
|
- External: List of external packages required
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
How to run tests for this module.
|
||||||
|
|
||||||
|
## Known Issues
|
||||||
|
Current limitations or bugs.
|
||||||
|
```
|
||||||
|
|
||||||
|
### API Endpoint Template
|
||||||
|
```markdown
|
||||||
|
### [METHOD] /api/endpoint
|
||||||
|
|
||||||
|
Brief description of what this endpoint does.
|
||||||
|
|
||||||
|
**Authentication:** Required/Optional
|
||||||
|
**Rate Limiting:** X requests per minute
|
||||||
|
|
||||||
|
**Request:**
|
||||||
|
- Headers required
|
||||||
|
- Body schema
|
||||||
|
- Query parameters
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
- Success response format
|
||||||
|
- Error response format
|
||||||
|
- Status codes
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
Working request/response example
|
||||||
|
```
|
||||||
|
|
||||||
|
## Review and Maintenance Process
|
||||||
|
|
||||||
|
### Documentation Review
|
||||||
|
- Include documentation updates in code reviews
|
||||||
|
- Verify examples still work with code changes
|
||||||
|
- Check for broken links and outdated information
|
||||||
|
- Ensure consistency with current implementation
|
||||||
|
|
||||||
|
### Regular Audits
|
||||||
|
- Monthly review of documentation accuracy
|
||||||
|
- Quarterly assessment of documentation completeness
|
||||||
|
- Annual review of documentation structure and organization
|
||||||
207
.cursor/rules/enhanced-task-list.mdc
Normal file
207
.cursor/rules/enhanced-task-list.mdc
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
---
|
||||||
|
description: Enhanced task list management with quality gates and iterative workflow integration
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Enhanced Task List Management
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Manage task lists with integrated quality gates and iterative workflow to prevent context loss and ensure sustainable development.
|
||||||
|
|
||||||
|
## Task Implementation Protocol
|
||||||
|
|
||||||
|
### Pre-Implementation Check
|
||||||
|
Before starting any sub-task:
|
||||||
|
- [ ] **Context Review**: Have you reviewed CONTEXT.md and relevant documentation?
|
||||||
|
- [ ] **Pattern Identification**: Do you understand existing patterns to follow?
|
||||||
|
- [ ] **Integration Planning**: Do you know how this will integrate with existing code?
|
||||||
|
- [ ] **Size Validation**: Is this task small enough (≤50 lines, ≤250 lines per file)?
|
||||||
|
|
||||||
|
### Implementation Process
|
||||||
|
1. **One sub-task at a time**: Do **NOT** start the next sub‑task until you ask the user for permission and they say "yes" or "y"
|
||||||
|
2. **Step-by-step execution**:
|
||||||
|
- Plan the approach in bullet points
|
||||||
|
- Wait for approval
|
||||||
|
- Implement the specific sub-task
|
||||||
|
- Test the implementation
|
||||||
|
- Update documentation if needed
|
||||||
|
3. **Quality validation**: Run through the code review checklist before marking complete
|
||||||
|
|
||||||
|
### Completion Protocol
|
||||||
|
When you finish a **sub‑task**:
|
||||||
|
1. **Immediate marking**: Change `[ ]` to `[x]`
|
||||||
|
2. **Quality check**: Verify the implementation meets quality standards
|
||||||
|
3. **Integration test**: Ensure new code works with existing functionality
|
||||||
|
4. **Documentation update**: Update relevant files if needed
|
||||||
|
5. **Parent task check**: If **all** subtasks underneath a parent task are now `[x]`, also mark the **parent task** as completed
|
||||||
|
6. **Stop and wait**: Get user approval before proceeding to next sub-task
|
||||||
|
|
||||||
|
## Enhanced Task List Structure
|
||||||
|
|
||||||
|
### Task File Header
|
||||||
|
```markdown
|
||||||
|
# Task List: [Feature Name]
|
||||||
|
|
||||||
|
**Source PRD**: `prd-[feature-name].md`
|
||||||
|
**Status**: In Progress / Complete / Blocked
|
||||||
|
**Context Last Updated**: [Date]
|
||||||
|
**Architecture Review**: Required / Complete / N/A
|
||||||
|
|
||||||
|
## Quick Links
|
||||||
|
- [Context Documentation](./CONTEXT.md)
|
||||||
|
- [Architecture Guidelines](./docs/architecture.md)
|
||||||
|
- [Related Files](#relevant-files)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Task Format with Quality Gates
|
||||||
|
```markdown
|
||||||
|
- [ ] 1.0 Parent Task Title
|
||||||
|
- **Quality Gate**: Architecture review required
|
||||||
|
- **Dependencies**: List any dependencies
|
||||||
|
- [ ] 1.1 [Sub-task description 1.1]
|
||||||
|
- **Size estimate**: [Small/Medium/Large]
|
||||||
|
- **Pattern reference**: [Reference to existing pattern]
|
||||||
|
- **Test requirements**: [Unit/Integration/Both]
|
||||||
|
- [ ] 1.2 [Sub-task description 1.2]
|
||||||
|
- **Integration points**: [List affected components]
|
||||||
|
- **Risk level**: [Low/Medium/High]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Relevant Files Management
|
||||||
|
|
||||||
|
### Enhanced File Tracking
|
||||||
|
```markdown
|
||||||
|
## Relevant Files
|
||||||
|
|
||||||
|
### Implementation Files
|
||||||
|
- `path/to/file1.ts` - Brief description of purpose and role
|
||||||
|
- **Status**: Created / Modified / Needs Review
|
||||||
|
- **Last Modified**: [Date]
|
||||||
|
- **Review Status**: Pending / Approved / Needs Changes
|
||||||
|
|
||||||
|
### Test Files
|
||||||
|
- `path/to/file1.test.ts` - Unit tests for file1.ts
|
||||||
|
- **Coverage**: [Percentage or status]
|
||||||
|
- **Last Run**: [Date and result]
|
||||||
|
|
||||||
|
### Documentation Files
|
||||||
|
- `docs/module-name.md` - Module documentation
|
||||||
|
- **Status**: Up to date / Needs update / Missing
|
||||||
|
- **Last Updated**: [Date]
|
||||||
|
|
||||||
|
### Configuration Files
|
||||||
|
- `config/setting.json` - Configuration changes
|
||||||
|
- **Environment**: [Dev/Staging/Prod affected]
|
||||||
|
- **Backup**: [Location of backup]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Task List Maintenance
|
||||||
|
|
||||||
|
### During Development
|
||||||
|
1. **Regular updates**: Update task status after each significant change
|
||||||
|
2. **File tracking**: Add new files as they are created or modified
|
||||||
|
3. **Dependency tracking**: Note when new dependencies between tasks emerge
|
||||||
|
4. **Risk assessment**: Flag tasks that become more complex than anticipated
|
||||||
|
|
||||||
|
### Quality Checkpoints
|
||||||
|
At 25%, 50%, 75%, and 100% completion:
|
||||||
|
- [ ] **Architecture alignment**: Code follows established patterns
|
||||||
|
- [ ] **Performance impact**: No significant performance degradation
|
||||||
|
- [ ] **Security review**: No security vulnerabilities introduced
|
||||||
|
- [ ] **Documentation current**: All changes are documented
|
||||||
|
|
||||||
|
### Weekly Review Process
|
||||||
|
1. **Completion assessment**: What percentage of tasks are actually complete?
|
||||||
|
2. **Quality assessment**: Are completed tasks meeting quality standards?
|
||||||
|
3. **Process assessment**: Is the iterative workflow being followed?
|
||||||
|
4. **Risk assessment**: Are there emerging risks or blockers?
|
||||||
|
|
||||||
|
## Task Status Indicators
|
||||||
|
|
||||||
|
### Status Levels
|
||||||
|
- `[ ]` **Not Started**: Task not yet begun
|
||||||
|
- `[~]` **In Progress**: Currently being worked on
|
||||||
|
- `[?]` **Blocked**: Waiting for dependencies or decisions
|
||||||
|
- `[!]` **Needs Review**: Implementation complete but needs quality review
|
||||||
|
- `[x]` **Complete**: Finished and quality approved
|
||||||
|
|
||||||
|
### Quality Indicators
|
||||||
|
- ✅ **Quality Approved**: Passed all quality gates
|
||||||
|
- ⚠️ **Quality Concerns**: Has issues but functional
|
||||||
|
- ❌ **Quality Failed**: Needs rework before approval
|
||||||
|
- 🔄 **Under Review**: Currently being reviewed
|
||||||
|
|
||||||
|
### Integration Status
|
||||||
|
- 🔗 **Integrated**: Successfully integrated with existing code
|
||||||
|
- 🔧 **Integration Issues**: Problems with existing code integration
|
||||||
|
- ⏳ **Integration Pending**: Ready for integration testing
|
||||||
|
|
||||||
|
## Emergency Procedures
|
||||||
|
|
||||||
|
### When Tasks Become Too Complex
|
||||||
|
If a sub-task grows beyond expected scope:
|
||||||
|
1. **Stop implementation** immediately
|
||||||
|
2. **Document current state** and what was discovered
|
||||||
|
3. **Break down** the task into smaller pieces
|
||||||
|
4. **Update task list** with new sub-tasks
|
||||||
|
5. **Get approval** for the new breakdown before proceeding
|
||||||
|
|
||||||
|
### When Context is Lost
|
||||||
|
If AI seems to lose track of project patterns:
|
||||||
|
1. **Pause development**
|
||||||
|
2. **Review CONTEXT.md** and recent changes
|
||||||
|
3. **Update context documentation** with current state
|
||||||
|
4. **Restart** with explicit pattern references
|
||||||
|
5. **Reduce task size** until context is re-established
|
||||||
|
|
||||||
|
### When Quality Gates Fail
|
||||||
|
If implementation doesn't meet quality standards:
|
||||||
|
1. **Mark task** with `[!]` status
|
||||||
|
2. **Document specific issues** found
|
||||||
|
3. **Create remediation tasks** if needed
|
||||||
|
4. **Don't proceed** until quality issues are resolved
|
||||||
|
|
||||||
|
## AI Instructions Integration
|
||||||
|
|
||||||
|
### Context Awareness Commands
|
||||||
|
```markdown
|
||||||
|
**Before starting any task, run these checks:**
|
||||||
|
1. @CONTEXT.md - Review current project state
|
||||||
|
2. @architecture.md - Understand design principles
|
||||||
|
3. @code-review.md - Know quality standards
|
||||||
|
4. Look at existing similar code for patterns
|
||||||
|
```
|
||||||
|
|
||||||
|
### Quality Validation Commands
|
||||||
|
```markdown
|
||||||
|
**After completing any sub-task:**
|
||||||
|
1. Run code review checklist
|
||||||
|
2. Test integration with existing code
|
||||||
|
3. Update documentation if needed
|
||||||
|
4. Mark task complete only after quality approval
|
||||||
|
```
|
||||||
|
|
||||||
|
### Workflow Commands
|
||||||
|
```markdown
|
||||||
|
**For each development session:**
|
||||||
|
1. Review incomplete tasks and their status
|
||||||
|
2. Identify next logical sub-task to work on
|
||||||
|
3. Check dependencies and blockers
|
||||||
|
4. Follow iterative workflow process
|
||||||
|
5. Update task list with progress and findings
|
||||||
|
```
|
||||||
|
|
||||||
|
## Success Metrics
|
||||||
|
|
||||||
|
### Daily Success Indicators
|
||||||
|
- Tasks are completed according to quality standards
|
||||||
|
- No sub-tasks are started without completing previous ones
|
||||||
|
- File tracking remains accurate and current
|
||||||
|
- Integration issues are caught early
|
||||||
|
|
||||||
|
### Weekly Success Indicators
|
||||||
|
- Overall task completion rate is sustainable
|
||||||
|
- Quality issues are decreasing over time
|
||||||
|
- Context loss incidents are rare
|
||||||
|
- Team confidence in codebase remains high
|
||||||
70
.cursor/rules/generate-tasks.mdc
Normal file
70
.cursor/rules/generate-tasks.mdc
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
---
|
||||||
|
description: Generate a task list or TODO for a user requirement or implementation.
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
---
|
||||||
|
description:
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
# Rule: Generating a Task List from a PRD
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
To guide an AI assistant in creating a detailed, step-by-step task list in Markdown format based on an existing Product Requirements Document (PRD). The task list should guide a developer through implementation.
|
||||||
|
|
||||||
|
## Output
|
||||||
|
|
||||||
|
- **Format:** Markdown (`.md`)
|
||||||
|
- **Location:** `/tasks/`
|
||||||
|
- **Filename:** `tasks-[prd-file-name].md` (e.g., `tasks-prd-user-profile-editing.md`)
|
||||||
|
|
||||||
|
## Process
|
||||||
|
|
||||||
|
1. **Receive PRD Reference:** The user points the AI to a specific PRD file
|
||||||
|
2. **Analyze PRD:** The AI reads and analyzes the functional requirements, user stories, and other sections of the specified PRD.
|
||||||
|
3. **Phase 1: Generate Parent Tasks:** Based on the PRD analysis, create the file and generate the main, high-level tasks required to implement the feature. Use your judgement on how many high-level tasks to use. It's likely to be about 5. Present these tasks to the user in the specified format (without sub-tasks yet). Inform the user: "I have generated the high-level tasks based on the PRD. Ready to generate the sub-tasks? Respond with 'Go' to proceed."
|
||||||
|
4. **Wait for Confirmation:** Pause and wait for the user to respond with "Go".
|
||||||
|
5. **Phase 2: Generate Sub-Tasks:** Once the user confirms, break down each parent task into smaller, actionable sub-tasks necessary to complete the parent task. Ensure sub-tasks logically follow from the parent task and cover the implementation details implied by the PRD.
|
||||||
|
6. **Identify Relevant Files:** Based on the tasks and PRD, identify potential files that will need to be created or modified. List these under the `Relevant Files` section, including corresponding test files if applicable.
|
||||||
|
7. **Generate Final Output:** Combine the parent tasks, sub-tasks, relevant files, and notes into the final Markdown structure.
|
||||||
|
8. **Save Task List:** Save the generated document in the `/tasks/` directory with the filename `tasks-[prd-file-name].md`, where `[prd-file-name]` matches the base name of the input PRD file (e.g., if the input was `prd-user-profile-editing.md`, the output is `tasks-prd-user-profile-editing.md`).
|
||||||
|
|
||||||
|
## Output Format
|
||||||
|
|
||||||
|
The generated task list _must_ follow this structure:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Relevant Files
|
||||||
|
|
||||||
|
- `path/to/potential/file1.ts` - Brief description of why this file is relevant (e.g., Contains the main component for this feature).
|
||||||
|
- `path/to/file1.test.ts` - Unit tests for `file1.ts`.
|
||||||
|
- `path/to/another/file.tsx` - Brief description (e.g., API route handler for data submission).
|
||||||
|
- `path/to/another/file.test.tsx` - Unit tests for `another/file.tsx`.
|
||||||
|
- `lib/utils/helpers.ts` - Brief description (e.g., Utility functions needed for calculations).
|
||||||
|
- `lib/utils/helpers.test.ts` - Unit tests for `helpers.ts`.
|
||||||
|
|
||||||
|
### Notes
|
||||||
|
|
||||||
|
- Unit tests should typically be placed alongside the code files they are testing (e.g., `MyComponent.tsx` and `MyComponent.test.tsx` in the same directory).
|
||||||
|
- Use `npx jest [optional/path/to/test/file]` to run tests. Running without a path executes all tests found by the Jest configuration.
|
||||||
|
|
||||||
|
## Tasks
|
||||||
|
|
||||||
|
- [ ] 1.0 Parent Task Title
|
||||||
|
- [ ] 1.1 [Sub-task description 1.1]
|
||||||
|
- [ ] 1.2 [Sub-task description 1.2]
|
||||||
|
- [ ] 2.0 Parent Task Title
|
||||||
|
- [ ] 2.1 [Sub-task description 2.1]
|
||||||
|
- [ ] 3.0 Parent Task Title (may not require sub-tasks if purely structural or configuration)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Interaction Model
|
||||||
|
|
||||||
|
The process explicitly requires a pause after generating parent tasks to get user confirmation ("Go") before proceeding to generate the detailed sub-tasks. This ensures the high-level plan aligns with user expectations before diving into details.
|
||||||
|
|
||||||
|
## Target Audience
|
||||||
|
|
||||||
|
|
||||||
|
Assume the primary reader of the task list is a **junior developer** who will implement the feature.
|
||||||
236
.cursor/rules/iterative-workflow.mdc
Normal file
236
.cursor/rules/iterative-workflow.mdc
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
---
|
||||||
|
description: Iterative development workflow for AI-assisted coding
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Iterative Development Workflow
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Establish a structured, iterative development process that prevents the chaos and complexity that can arise from uncontrolled AI-assisted development.
|
||||||
|
|
||||||
|
## Development Phases
|
||||||
|
|
||||||
|
### Phase 1: Planning and Design
|
||||||
|
**Before writing any code:**
|
||||||
|
|
||||||
|
1. **Understand the Requirement**
|
||||||
|
- Break down the task into specific, measurable objectives
|
||||||
|
- Identify existing code patterns that should be followed
|
||||||
|
- List dependencies and integration points
|
||||||
|
- Define acceptance criteria
|
||||||
|
|
||||||
|
2. **Design Review**
|
||||||
|
- Propose approach in bullet points
|
||||||
|
- Wait for explicit approval before proceeding
|
||||||
|
- Consider how the solution fits existing architecture
|
||||||
|
- Identify potential risks and mitigation strategies
|
||||||
|
|
||||||
|
### Phase 2: Incremental Implementation
|
||||||
|
**One small piece at a time:**
|
||||||
|
|
||||||
|
1. **Micro-Tasks** (≤ 50 lines each)
|
||||||
|
- Implement one function or small class at a time
|
||||||
|
- Test immediately after implementation
|
||||||
|
- Ensure integration with existing code
|
||||||
|
- Document decisions and patterns used
|
||||||
|
|
||||||
|
2. **Validation Checkpoints**
|
||||||
|
- After each micro-task, verify it works correctly
|
||||||
|
- Check that it follows established patterns
|
||||||
|
- Confirm it integrates cleanly with existing code
|
||||||
|
- Get approval before moving to next micro-task
|
||||||
|
|
||||||
|
### Phase 3: Integration and Testing
|
||||||
|
**Ensuring system coherence:**
|
||||||
|
|
||||||
|
1. **Integration Testing**
|
||||||
|
- Test new code with existing functionality
|
||||||
|
- Verify no regressions in existing features
|
||||||
|
- Check performance impact
|
||||||
|
- Validate error handling
|
||||||
|
|
||||||
|
2. **Documentation Update**
|
||||||
|
- Update relevant documentation
|
||||||
|
- Record any new patterns or decisions
|
||||||
|
- Update context files if architecture changed
|
||||||
|
|
||||||
|
## Iterative Prompting Strategy
|
||||||
|
|
||||||
|
### Step 1: Context Setting
|
||||||
|
```
|
||||||
|
Before implementing [feature], help me understand:
|
||||||
|
1. What existing patterns should I follow?
|
||||||
|
2. What existing functions/classes are relevant?
|
||||||
|
3. How should this integrate with [specific existing component]?
|
||||||
|
4. What are the potential architectural impacts?
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Plan Creation
|
||||||
|
```
|
||||||
|
Based on the context, create a detailed plan for implementing [feature]:
|
||||||
|
1. Break it into micro-tasks (≤50 lines each)
|
||||||
|
2. Identify dependencies and order of implementation
|
||||||
|
3. Specify integration points with existing code
|
||||||
|
4. List potential risks and mitigation strategies
|
||||||
|
|
||||||
|
Wait for my approval before implementing.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Incremental Implementation
|
||||||
|
```
|
||||||
|
Implement only the first micro-task: [specific task]
|
||||||
|
- Use existing patterns from [reference file/function]
|
||||||
|
- Keep it under 50 lines
|
||||||
|
- Include error handling
|
||||||
|
- Add appropriate tests
|
||||||
|
- Explain your implementation choices
|
||||||
|
|
||||||
|
Stop after this task and wait for approval.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quality Gates
|
||||||
|
|
||||||
|
### Before Each Implementation
|
||||||
|
- [ ] **Purpose is clear**: Can explain what this piece does and why
|
||||||
|
- [ ] **Pattern is established**: Following existing code patterns
|
||||||
|
- [ ] **Size is manageable**: Implementation is small enough to understand completely
|
||||||
|
- [ ] **Integration is planned**: Know how it connects to existing code
|
||||||
|
|
||||||
|
### After Each Implementation
|
||||||
|
- [ ] **Code is understood**: Can explain every line of implemented code
|
||||||
|
- [ ] **Tests pass**: All existing and new tests are passing
|
||||||
|
- [ ] **Integration works**: New code works with existing functionality
|
||||||
|
- [ ] **Documentation updated**: Changes are reflected in relevant documentation
|
||||||
|
|
||||||
|
### Before Moving to Next Task
|
||||||
|
- [ ] **Current task complete**: All acceptance criteria met
|
||||||
|
- [ ] **No regressions**: Existing functionality still works
|
||||||
|
- [ ] **Clean state**: No temporary code or debugging artifacts
|
||||||
|
- [ ] **Approval received**: Explicit go-ahead for next task
|
||||||
|
- [ ] **Documentaion updated**: If relevant changes to module was made.
|
||||||
|
|
||||||
|
## Anti-Patterns to Avoid
|
||||||
|
|
||||||
|
### Large Block Implementation
|
||||||
|
**Don't:**
|
||||||
|
```
|
||||||
|
Implement the entire user management system with authentication,
|
||||||
|
CRUD operations, and email notifications.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Do:**
|
||||||
|
```
|
||||||
|
First, implement just the User model with basic fields.
|
||||||
|
Stop there and let me review before continuing.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Context Loss
|
||||||
|
**Don't:**
|
||||||
|
```
|
||||||
|
Create a new authentication system.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Do:**
|
||||||
|
```
|
||||||
|
Looking at the existing auth patterns in auth.py, implement
|
||||||
|
password validation following the same structure as the
|
||||||
|
existing email validation function.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Over-Engineering
|
||||||
|
**Don't:**
|
||||||
|
```
|
||||||
|
Build a flexible, extensible user management framework that
|
||||||
|
can handle any future requirements.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Do:**
|
||||||
|
```
|
||||||
|
Implement user creation functionality that matches the existing
|
||||||
|
pattern in customer.py, focusing only on the current requirements.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Progress Tracking
|
||||||
|
|
||||||
|
### Task Status Indicators
|
||||||
|
- 🔄 **In Planning**: Requirements gathering and design
|
||||||
|
- ⏳ **In Progress**: Currently implementing
|
||||||
|
- ✅ **Complete**: Implemented, tested, and integrated
|
||||||
|
- 🚫 **Blocked**: Waiting for decisions or dependencies
|
||||||
|
- 🔧 **Needs Refactor**: Working but needs improvement
|
||||||
|
|
||||||
|
### Weekly Review Process
|
||||||
|
1. **Progress Assessment**
|
||||||
|
- What was completed this week?
|
||||||
|
- What challenges were encountered?
|
||||||
|
- How well did the iterative process work?
|
||||||
|
|
||||||
|
2. **Process Adjustment**
|
||||||
|
- Were task sizes appropriate?
|
||||||
|
- Did context management work effectively?
|
||||||
|
- What improvements can be made?
|
||||||
|
|
||||||
|
3. **Architecture Review**
|
||||||
|
- Is the code remaining maintainable?
|
||||||
|
- Are patterns staying consistent?
|
||||||
|
- Is technical debt accumulating?
|
||||||
|
|
||||||
|
## Emergency Procedures
|
||||||
|
|
||||||
|
### When Things Go Wrong
|
||||||
|
If development becomes chaotic or problematic:
|
||||||
|
|
||||||
|
1. **Stop Development**
|
||||||
|
- Don't continue adding to the problem
|
||||||
|
- Take time to assess the situation
|
||||||
|
- Don't rush to "fix" with more AI-generated code
|
||||||
|
|
||||||
|
2. **Assess the Situation**
|
||||||
|
- What specific problems exist?
|
||||||
|
- How far has the code diverged from established patterns?
|
||||||
|
- What parts are still working correctly?
|
||||||
|
|
||||||
|
3. **Recovery Process**
|
||||||
|
- Roll back to last known good state
|
||||||
|
- Update context documentation with lessons learned
|
||||||
|
- Restart with smaller, more focused tasks
|
||||||
|
- Get explicit approval for each step of recovery
|
||||||
|
|
||||||
|
### Context Recovery
|
||||||
|
When AI seems to lose track of project patterns:
|
||||||
|
|
||||||
|
1. **Context Refresh**
|
||||||
|
- Review and update CONTEXT.md
|
||||||
|
- Include examples of current code patterns
|
||||||
|
- Clarify architectural decisions
|
||||||
|
|
||||||
|
2. **Pattern Re-establishment**
|
||||||
|
- Show AI examples of existing, working code
|
||||||
|
- Explicitly state patterns to follow
|
||||||
|
- Start with very small, pattern-matching tasks
|
||||||
|
|
||||||
|
3. **Gradual Re-engagement**
|
||||||
|
- Begin with simple, low-risk tasks
|
||||||
|
- Verify pattern adherence at each step
|
||||||
|
- Gradually increase task complexity as consistency returns
|
||||||
|
|
||||||
|
## Success Metrics
|
||||||
|
|
||||||
|
### Short-term (Daily)
|
||||||
|
- Code is understandable and well-integrated
|
||||||
|
- No major regressions introduced
|
||||||
|
- Development velocity feels sustainable
|
||||||
|
- Team confidence in codebase remains high
|
||||||
|
|
||||||
|
### Medium-term (Weekly)
|
||||||
|
- Technical debt is not accumulating
|
||||||
|
- New features integrate cleanly
|
||||||
|
- Development patterns remain consistent
|
||||||
|
- Documentation stays current
|
||||||
|
|
||||||
|
### Long-term (Monthly)
|
||||||
|
- Codebase remains maintainable as it grows
|
||||||
|
- New team members can understand and contribute
|
||||||
|
- AI assistance enhances rather than hinders development
|
||||||
|
- Architecture remains clean and purposeful
|
||||||
24
.cursor/rules/project.mdc
Normal file
24
.cursor/rules/project.mdc
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
---
|
||||||
|
description:
|
||||||
|
globs:
|
||||||
|
alwaysApply: true
|
||||||
|
---
|
||||||
|
# Rule: Project specific rules
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Unify the project structure and interraction with tools and console
|
||||||
|
|
||||||
|
### System tools
|
||||||
|
- **ALWAYS** use UV for package management
|
||||||
|
- **ALWAYS** use windows PowerShell command for terminal
|
||||||
|
|
||||||
|
### Coding patterns
|
||||||
|
- **ALWYAS** check the arguments and methods before use to avoid errors with whron parameters or names
|
||||||
|
- If in doubt, check [CONTEXT.md](mdc:CONTEXT.md) file and [architecture.md](mdc:docs/architecture.md)
|
||||||
|
- **PREFER** ORM pattern for databases with SQLAclhemy.
|
||||||
|
- **DO NOT USE** emoji in code and comments
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
- Use UV for test in format *uv run pytest [filename]*
|
||||||
|
|
||||||
|
|
||||||
237
.cursor/rules/refactoring.mdc
Normal file
237
.cursor/rules/refactoring.mdc
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
---
|
||||||
|
description: Code refactoring and technical debt management for AI-assisted development
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
|
||||||
|
# Rule: Code Refactoring and Technical Debt Management
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Guide AI in systematic code refactoring to improve maintainability, reduce complexity, and prevent technical debt accumulation in AI-assisted development projects.
|
||||||
|
|
||||||
|
## When to Apply This Rule
|
||||||
|
- Code complexity has increased beyond manageable levels
|
||||||
|
- Duplicate code patterns are detected
|
||||||
|
- Performance issues are identified
|
||||||
|
- New features are difficult to integrate
|
||||||
|
- Code review reveals maintainability concerns
|
||||||
|
- Weekly technical debt assessment indicates refactoring needs
|
||||||
|
|
||||||
|
## Pre-Refactoring Assessment
|
||||||
|
|
||||||
|
Before starting any refactoring, the AI MUST:
|
||||||
|
|
||||||
|
1. **Context Analysis:**
|
||||||
|
- Review existing `CONTEXT.md` for architectural decisions
|
||||||
|
- Analyze current code patterns and conventions
|
||||||
|
- Identify all files that will be affected (search the codebase for use)
|
||||||
|
- Check for existing tests that verify current behavior
|
||||||
|
|
||||||
|
2. **Scope Definition:**
|
||||||
|
- Clearly define what will and will not be changed
|
||||||
|
- Identify the specific refactoring pattern to apply
|
||||||
|
- Estimate the blast radius of changes
|
||||||
|
- Plan rollback strategy if needed
|
||||||
|
|
||||||
|
3. **Documentation Review:**
|
||||||
|
- Check `./docs/` for relevant module documentation
|
||||||
|
- Review any existing architectural diagrams
|
||||||
|
- Identify dependencies and integration points
|
||||||
|
- Note any known constraints or limitations
|
||||||
|
|
||||||
|
## Refactoring Process
|
||||||
|
|
||||||
|
### Phase 1: Planning and Safety
|
||||||
|
1. **Create Refactoring Plan:**
|
||||||
|
- Document the current state and desired end state
|
||||||
|
- Break refactoring into small, atomic steps
|
||||||
|
- Identify tests that must pass throughout the process
|
||||||
|
- Plan verification steps for each change
|
||||||
|
|
||||||
|
2. **Establish Safety Net:**
|
||||||
|
- Ensure comprehensive test coverage exists
|
||||||
|
- If tests are missing, create them BEFORE refactoring
|
||||||
|
- Document current behavior that must be preserved
|
||||||
|
- Create backup of current implementation approach
|
||||||
|
|
||||||
|
3. **Get Approval:**
|
||||||
|
- Present the refactoring plan to the user
|
||||||
|
- Wait for explicit "Go" or "Proceed" confirmation
|
||||||
|
- Do NOT start refactoring without approval
|
||||||
|
|
||||||
|
### Phase 2: Incremental Implementation
|
||||||
|
4. **One Change at a Time:**
|
||||||
|
- Implement ONE refactoring step per iteration
|
||||||
|
- Run tests after each step to ensure nothing breaks
|
||||||
|
- Update documentation if interfaces change
|
||||||
|
- Mark progress in the refactoring plan
|
||||||
|
|
||||||
|
5. **Verification Protocol:**
|
||||||
|
- Run all relevant tests after each change
|
||||||
|
- Verify functionality works as expected
|
||||||
|
- Check performance hasn't degraded
|
||||||
|
- Ensure no new linting or type errors
|
||||||
|
|
||||||
|
6. **User Checkpoint:**
|
||||||
|
- After each significant step, pause for user review
|
||||||
|
- Present what was changed and current status
|
||||||
|
- Wait for approval before continuing
|
||||||
|
- Address any concerns before proceeding
|
||||||
|
|
||||||
|
### Phase 3: Completion and Documentation
|
||||||
|
7. **Final Verification:**
|
||||||
|
- Run full test suite to ensure nothing is broken
|
||||||
|
- Verify all original functionality is preserved
|
||||||
|
- Check that new code follows project conventions
|
||||||
|
- Confirm performance is maintained or improved
|
||||||
|
|
||||||
|
8. **Documentation Update:**
|
||||||
|
- Update `CONTEXT.md` with new patterns/decisions
|
||||||
|
- Update module documentation in `./docs/`
|
||||||
|
- Document any new conventions established
|
||||||
|
- Note lessons learned for future refactoring
|
||||||
|
|
||||||
|
## Common Refactoring Patterns
|
||||||
|
|
||||||
|
### Extract Method/Function
|
||||||
|
```
|
||||||
|
WHEN: Functions/methods exceed 50 lines or have multiple responsibilities
|
||||||
|
HOW:
|
||||||
|
1. Identify logical groupings within the function
|
||||||
|
2. Extract each group into a well-named helper function
|
||||||
|
3. Ensure each function has a single responsibility
|
||||||
|
4. Verify tests still pass
|
||||||
|
```
|
||||||
|
|
||||||
|
### Extract Module/Class
|
||||||
|
```
|
||||||
|
WHEN: Files exceed 250 lines or handle multiple concerns
|
||||||
|
HOW:
|
||||||
|
1. Identify cohesive functionality groups
|
||||||
|
2. Create new files for each group
|
||||||
|
3. Move related functions/classes together
|
||||||
|
4. Update imports and dependencies
|
||||||
|
5. Verify module boundaries are clean
|
||||||
|
```
|
||||||
|
|
||||||
|
### Eliminate Duplication
|
||||||
|
```
|
||||||
|
WHEN: Similar code appears in multiple places
|
||||||
|
HOW:
|
||||||
|
1. Identify the common pattern or functionality
|
||||||
|
2. Extract to a shared utility function or module
|
||||||
|
3. Update all usage sites to use the shared code
|
||||||
|
4. Ensure the abstraction is not over-engineered
|
||||||
|
```
|
||||||
|
|
||||||
|
### Improve Data Structures
|
||||||
|
```
|
||||||
|
WHEN: Complex nested objects or unclear data flow
|
||||||
|
HOW:
|
||||||
|
1. Define clear interfaces/types for data structures
|
||||||
|
2. Create transformation functions between different representations
|
||||||
|
3. Ensure data flow is unidirectional where possible
|
||||||
|
4. Add validation at boundaries
|
||||||
|
```
|
||||||
|
|
||||||
|
### Reduce Coupling
|
||||||
|
```
|
||||||
|
WHEN: Modules are tightly interconnected
|
||||||
|
HOW:
|
||||||
|
1. Identify dependencies between modules
|
||||||
|
2. Extract interfaces for external dependencies
|
||||||
|
3. Use dependency injection where appropriate
|
||||||
|
4. Ensure modules can be tested in isolation
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quality Gates
|
||||||
|
|
||||||
|
Every refactoring must pass these gates:
|
||||||
|
|
||||||
|
### Technical Quality
|
||||||
|
- [ ] All existing tests pass
|
||||||
|
- [ ] No new linting errors introduced
|
||||||
|
- [ ] Code follows established project conventions
|
||||||
|
- [ ] No performance regression detected
|
||||||
|
- [ ] File sizes remain under 250 lines
|
||||||
|
- [ ] Function sizes remain under 50 lines
|
||||||
|
|
||||||
|
### Maintainability
|
||||||
|
- [ ] Code is more readable than before
|
||||||
|
- [ ] Duplicated code has been reduced
|
||||||
|
- [ ] Module responsibilities are clearer
|
||||||
|
- [ ] Dependencies are explicit and minimal
|
||||||
|
- [ ] Error handling is consistent
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
- [ ] Public interfaces are documented
|
||||||
|
- [ ] Complex logic has explanatory comments
|
||||||
|
- [ ] Architectural decisions are recorded
|
||||||
|
- [ ] Examples are provided where helpful
|
||||||
|
|
||||||
|
## AI Instructions for Refactoring
|
||||||
|
|
||||||
|
1. **Always ask for permission** before starting any refactoring work
|
||||||
|
2. **Start with tests** - ensure comprehensive coverage before changing code
|
||||||
|
3. **Work incrementally** - make small changes and verify each step
|
||||||
|
4. **Preserve behavior** - functionality must remain exactly the same
|
||||||
|
5. **Update documentation** - keep all docs current with changes
|
||||||
|
6. **Follow conventions** - maintain consistency with existing codebase
|
||||||
|
7. **Stop and ask** if any step fails or produces unexpected results
|
||||||
|
8. **Explain changes** - clearly communicate what was changed and why
|
||||||
|
|
||||||
|
## Anti-Patterns to Avoid
|
||||||
|
|
||||||
|
### Over-Engineering
|
||||||
|
- Don't create abstractions for code that isn't duplicated
|
||||||
|
- Avoid complex inheritance hierarchies
|
||||||
|
- Don't optimize prematurely
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
- Never change public APIs without explicit approval
|
||||||
|
- Don't remove functionality, even if it seems unused
|
||||||
|
- Avoid changing behavior "while we're here"
|
||||||
|
|
||||||
|
### Scope Creep
|
||||||
|
- Stick to the defined refactoring scope
|
||||||
|
- Don't add new features during refactoring
|
||||||
|
- Resist the urge to "improve" unrelated code
|
||||||
|
|
||||||
|
## Success Metrics
|
||||||
|
|
||||||
|
Track these metrics to ensure refactoring effectiveness:
|
||||||
|
|
||||||
|
### Code Quality
|
||||||
|
- Reduced cyclomatic complexity
|
||||||
|
- Lower code duplication percentage
|
||||||
|
- Improved test coverage
|
||||||
|
- Fewer linting violations
|
||||||
|
|
||||||
|
### Developer Experience
|
||||||
|
- Faster time to understand code
|
||||||
|
- Easier integration of new features
|
||||||
|
- Reduced bug introduction rate
|
||||||
|
- Higher developer confidence in changes
|
||||||
|
|
||||||
|
### Maintainability
|
||||||
|
- Clearer module boundaries
|
||||||
|
- More predictable behavior
|
||||||
|
- Easier debugging and troubleshooting
|
||||||
|
- Better performance characteristics
|
||||||
|
|
||||||
|
## Output Files
|
||||||
|
|
||||||
|
When refactoring is complete, update:
|
||||||
|
- `refactoring-log-[date].md` - Document what was changed and why
|
||||||
|
- `CONTEXT.md` - Update with new patterns and decisions
|
||||||
|
- `./docs/` - Update relevant module documentation
|
||||||
|
- Task lists - Mark refactoring tasks as complete
|
||||||
|
|
||||||
|
## Final Verification
|
||||||
|
|
||||||
|
Before marking refactoring complete:
|
||||||
|
1. Run full test suite and verify all tests pass
|
||||||
|
2. Check that code follows all project conventions
|
||||||
|
3. Verify documentation is up to date
|
||||||
|
4. Confirm user is satisfied with the results
|
||||||
|
5. Record lessons learned for future refactoring efforts
|
||||||
44
.cursor/rules/task-list.mdc
Normal file
44
.cursor/rules/task-list.mdc
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
---
|
||||||
|
description: TODO list task implementation
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
---
|
||||||
|
description:
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
# Task List Management
|
||||||
|
|
||||||
|
Guidelines for managing task lists in markdown files to track progress on completing a PRD
|
||||||
|
|
||||||
|
## Task Implementation
|
||||||
|
- **One sub-task at a time:** Do **NOT** start the next sub‑task until you ask the user for permission and they say “yes” or "y"
|
||||||
|
- **Completion protocol:**
|
||||||
|
1. When you finish a **sub‑task**, immediately mark it as completed by changing `[ ]` to `[x]`.
|
||||||
|
2. If **all** subtasks underneath a parent task are now `[x]`, also mark the **parent task** as completed.
|
||||||
|
- Stop after each sub‑task and wait for the user’s go‑ahead.
|
||||||
|
|
||||||
|
## Task List Maintenance
|
||||||
|
|
||||||
|
1. **Update the task list as you work:**
|
||||||
|
- Mark tasks and subtasks as completed (`[x]`) per the protocol above.
|
||||||
|
- Add new tasks as they emerge.
|
||||||
|
|
||||||
|
2. **Maintain the “Relevant Files” section:**
|
||||||
|
- List every file created or modified.
|
||||||
|
- Give each file a one‑line description of its purpose.
|
||||||
|
|
||||||
|
## AI Instructions
|
||||||
|
|
||||||
|
When working with task lists, the AI must:
|
||||||
|
|
||||||
|
1. Regularly update the task list file after finishing any significant work.
|
||||||
|
2. Follow the completion protocol:
|
||||||
|
- Mark each finished **sub‑task** `[x]`.
|
||||||
|
- Mark the **parent task** `[x]` once **all** its subtasks are `[x]`.
|
||||||
|
3. Add newly discovered tasks.
|
||||||
|
4. Keep “Relevant Files” accurate and up to date.
|
||||||
|
5. Before starting work, check which sub‑task is next.
|
||||||
|
|
||||||
|
6. After implementing a sub‑task, update the file and then pause for user approval.
|
||||||
23
.vscode/launch.json
vendored
Normal file
23
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
|
||||||
|
{
|
||||||
|
"name": "Python Debugger: main.py",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "main.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"cwd": "${workspaceFolder}",
|
||||||
|
"args": [
|
||||||
|
"--csv",
|
||||||
|
"../data/btcusd_1-min_data.csv",
|
||||||
|
"--min-date",
|
||||||
|
"2017-06-01"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
30
INFERENCE_README.md
Normal file
30
INFERENCE_README.md
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# OHLCV Predictor - Inference (Quick Reference)
|
||||||
|
|
||||||
|
For full instructions, see the main README.
|
||||||
|
|
||||||
|
## Minimal usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from predictor import OHLCVPredictor
|
||||||
|
|
||||||
|
predictor = OHLCVPredictor('../data/xgboost_model_all_features.json')
|
||||||
|
predictions = predictor.predict(your_ohlcv_dataframe)
|
||||||
|
```
|
||||||
|
|
||||||
|
Your DataFrame needs these columns:
|
||||||
|
- `Timestamp`, `Open`, `High`, `Low`, `Close`, `Volume`, `log_return`
|
||||||
|
|
||||||
|
Note: If you are only running inference (not training with `main.py`), compute `log_return` first:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import numpy as np
|
||||||
|
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Files to reuse in other projects
|
||||||
|
|
||||||
|
- `predictor.py`
|
||||||
|
- `custom_xgboost.py`
|
||||||
|
- `feature_engineering.py`
|
||||||
|
- `technical_indicator_functions.py`
|
||||||
|
- your trained model file (e.g., `xgboost_model_all_features.json`)
|
||||||
120
README.md
120
README.md
@@ -1,2 +1,122 @@
|
|||||||
# OHLCVPredictor
|
# OHLCVPredictor
|
||||||
|
|
||||||
|
End-to-end pipeline for engineering OHLCV features, training an XGBoost regressor (GPU by default), and running inference via a small, reusable predictor API.
|
||||||
|
|
||||||
|
## Quickstart (uv)
|
||||||
|
|
||||||
|
Prereqs:
|
||||||
|
- Python 3.12+
|
||||||
|
- `uv` installed (see `https://docs.astral.sh/uv/`)
|
||||||
|
|
||||||
|
Install dependencies:
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
Run training (expects an input CSV; see Data Requirements):
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
uv run python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Run the inference demo:
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
uv run python inference_example.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Data requirements
|
||||||
|
|
||||||
|
Your input DataFrame/CSV must include these columns:
|
||||||
|
- `Timestamp`, `Open`, `High`, `Low`, `Close`, `Volume`, `log_return`
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- `Timestamp` can be either a pandas datetime-like column or Unix seconds (int). During inference, the predictor will try to parse strings as datetimes; non-object dtypes are treated as Unix seconds.
|
||||||
|
- `log_return` should be computed as:
|
||||||
|
```python
|
||||||
|
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
|
||||||
|
```
|
||||||
|
The training script (`main.py`) computes it automatically. For standalone inference, ensure it exists before calling the predictor.
|
||||||
|
- The training script filters out rows with `Volume == 0` and focuses on data newer than `2017-06-01` by default.
|
||||||
|
|
||||||
|
## Training workflow
|
||||||
|
|
||||||
|
The training entrypoint is `main.py`:
|
||||||
|
- Reads the CSV at `../data/btcusd_1-min_data.csv` by default. Adjust `csv_path` in `main.py` to point to your data, or move your CSV to that path.
|
||||||
|
- Engineers a large set of technical and OHLCV-derived features (see `feature_engineering.py` and `technical_indicator_functions.py`).
|
||||||
|
- Optionally performs walk-forward cross validation to compute averaged feature importances.
|
||||||
|
- Prunes low-importance and redundant features, trains XGBoost (GPU by default), and saves artifacts.
|
||||||
|
- Produces charts with Plotly into `charts/`.
|
||||||
|
|
||||||
|
Outputs produced by training:
|
||||||
|
- Model: `../data/xgboost_model_all_features.json`
|
||||||
|
- Feature list: `../data/xgboost_model_all_features_features.json` (exact feature names and order used for training)
|
||||||
|
- Results CSV: `../data/cumulative_feature_results.csv`
|
||||||
|
- Charts: files under `charts/` (e.g., `all_features_prediction_error_distribution.html`)
|
||||||
|
|
||||||
|
Run:
|
||||||
|
```powershell
|
||||||
|
uv run python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
If you do not have a CUDA-capable GPU, set the device to CPU (see GPU/CPU section).
|
||||||
|
|
||||||
|
## Inference usage
|
||||||
|
|
||||||
|
You can reuse the predictor in other projects or run the included example.
|
||||||
|
|
||||||
|
Minimal example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from predictor import OHLCVPredictor
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
predictor = OHLCVPredictor('../data/xgboost_model_all_features.json')
|
||||||
|
|
||||||
|
# df must contain: Timestamp, Open, High, Low, Close, Volume, log_return
|
||||||
|
log_returns = predictor.predict(df)
|
||||||
|
prices_pred, prices_actual = predictor.predict_prices(df)
|
||||||
|
```
|
||||||
|
|
||||||
|
Run the comprehensive demo:
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
uv run python inference_example.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Files needed to embed the predictor in another project:
|
||||||
|
- `predictor.py`
|
||||||
|
- `custom_xgboost.py`
|
||||||
|
- `feature_engineering.py`
|
||||||
|
- `technical_indicator_functions.py`
|
||||||
|
- your trained model file (e.g., `xgboost_model_all_features.json`)
|
||||||
|
- the companion feature list JSON saved next to the model (same basename with `_features.json`)
|
||||||
|
|
||||||
|
## GPU/CPU notes
|
||||||
|
|
||||||
|
Training uses XGBoost with `device='cuda'` by default (see `custom_xgboost.py`). If you do not have a CUDA-capable GPU or drivers:
|
||||||
|
- Change the parameter in `CustomXGBoostGPU.train()` from `device='cuda'` to `device='cpu'`, or
|
||||||
|
- Pass `device='cpu'` when calling `train()` wherever applicable.
|
||||||
|
|
||||||
|
Inference works on CPU even if the model was trained on GPU.
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
The project is managed via `pyproject.toml` and `uv`. Key runtime deps include:
|
||||||
|
- `xgboost`, `pandas`, `numpy`, `scikit-learn`, `ta`, `numba`
|
||||||
|
- `dash`/Plotly for charts (Plotly is used by `plot_results.py`)
|
||||||
|
|
||||||
|
Install using:
|
||||||
|
```powershell
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
- KeyError: `'log_return'` during inference: ensure your input DataFrame includes `log_return` as described above.
|
||||||
|
- Model file not found: confirm the path passed to `OHLCVPredictor(...)` matches where training saved it (default `../data/xgboost_model_all_features.json`).
|
||||||
|
- Feature mismatch (e.g., XGBoost "Number of columns does not match"): ensure you use the model together with its companion feature list JSON. The predictor will automatically use it if present. If missing, retrain with the current code so the feature list is generated.
|
||||||
|
- Empty/old charts: delete the `charts/` folder and rerun training.
|
||||||
|
- Memory issues: consider downcasting or using smaller windows; the code already downcasts numerics where possible.
|
||||||
|
|
||||||
|
|||||||
3885
charts/all_features_prediction_error_distribution.html
Normal file
3885
charts/all_features_prediction_error_distribution.html
Normal file
File diff suppressed because one or more lines are too long
@@ -2,15 +2,34 @@ import xgboost as xgb
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
class CustomXGBoostGPU:
|
class CustomXGBoostGPU:
|
||||||
def __init__(self, X_train, X_test, y_train, y_test):
|
def __init__(self, X_train=None, X_test=None, y_train=None, y_test=None):
|
||||||
self.X_train = X_train.astype(np.float32)
|
# Make training data optional for inference-only usage
|
||||||
self.X_test = X_test.astype(np.float32)
|
self.X_train = X_train.astype(np.float32) if X_train is not None else None
|
||||||
self.y_train = y_train.astype(np.float32)
|
self.X_test = X_test.astype(np.float32) if X_test is not None else None
|
||||||
self.y_test = y_test.astype(np.float32)
|
self.y_train = y_train.astype(np.float32) if y_train is not None else None
|
||||||
|
self.y_test = y_test.astype(np.float32) if y_test is not None else None
|
||||||
self.model = None
|
self.model = None
|
||||||
self.params = None # Will be set during training
|
self.params = None # Will be set during training
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_model(cls, model_path):
|
||||||
|
"""Load a pre-trained model from file for inference
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): Path to the saved XGBoost model file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CustomXGBoostGPU: Instance with loaded model ready for inference
|
||||||
|
"""
|
||||||
|
instance = cls() # Create instance without training data
|
||||||
|
instance.model = xgb.Booster()
|
||||||
|
instance.model.load_model(model_path)
|
||||||
|
return instance
|
||||||
|
|
||||||
def train(self, **xgb_params):
|
def train(self, **xgb_params):
|
||||||
|
if self.X_train is None or self.y_train is None:
|
||||||
|
raise ValueError('Training data is required for training. Use load_model() for inference-only usage.')
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'tree_method': 'hist',
|
'tree_method': 'hist',
|
||||||
'device': 'cuda',
|
'device': 'cuda',
|
||||||
|
|||||||
77
evaluation.py
Normal file
77
evaluation.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import numpy as np
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
try:
|
||||||
|
from .custom_xgboost import CustomXGBoostGPU
|
||||||
|
except ImportError:
|
||||||
|
from custom_xgboost import CustomXGBoostGPU
|
||||||
|
from sklearn.metrics import mean_squared_error, r2_score
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Tuple[float, float, float, float]:
|
||||||
|
"""Compute RMSE, MAPE, R2, and directional accuracy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(rmse, mape, r2, directional_accuracy)
|
||||||
|
"""
|
||||||
|
rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
|
||||||
|
with np.errstate(divide='ignore', invalid='ignore'):
|
||||||
|
mape_arr = np.abs((y_true - y_pred) / np.where(y_true == 0, np.nan, y_true))
|
||||||
|
mape = float(np.nanmean(mape_arr) * 100.0)
|
||||||
|
r2 = float(r2_score(y_true, y_pred))
|
||||||
|
direction_actual = np.sign(np.diff(y_true))
|
||||||
|
direction_pred = np.sign(np.diff(y_pred))
|
||||||
|
min_len = min(len(direction_actual), len(direction_pred))
|
||||||
|
if min_len == 0:
|
||||||
|
dir_acc = 0.0
|
||||||
|
else:
|
||||||
|
dir_acc = float((direction_actual[:min_len] == direction_pred[:min_len]).mean())
|
||||||
|
return rmse, mape, r2, dir_acc
|
||||||
|
|
||||||
|
|
||||||
|
def walk_forward_cv(
|
||||||
|
X: np.ndarray,
|
||||||
|
y: np.ndarray,
|
||||||
|
feature_names: List[str],
|
||||||
|
n_splits: int = 5,
|
||||||
|
) -> Tuple[Dict[str, float], Dict[str, float]]:
|
||||||
|
"""Run simple walk-forward CV and aggregate metrics and feature importances.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
metrics_avg: Average metrics across folds {rmse, mape, r2, dir_acc}
|
||||||
|
importance_avg: Average feature importance across folds {feature -> importance}
|
||||||
|
"""
|
||||||
|
num_samples = len(X)
|
||||||
|
fold_size = num_samples // (n_splits + 1)
|
||||||
|
if fold_size <= 0:
|
||||||
|
raise ValueError("Not enough samples for walk-forward CV")
|
||||||
|
|
||||||
|
metrics_accum = {"rmse": [], "mape": [], "r2": [], "dir_acc": []}
|
||||||
|
importance_sum = {name: 0.0 for name in feature_names}
|
||||||
|
|
||||||
|
for i in range(1, n_splits + 1):
|
||||||
|
train_end = i * fold_size
|
||||||
|
test_end = (i + 1) * fold_size if i < n_splits else num_samples
|
||||||
|
X_train, y_train = X[:train_end], y[:train_end]
|
||||||
|
X_test, y_test = X[train_end:test_end], y[train_end:test_end]
|
||||||
|
if len(X_test) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model = CustomXGBoostGPU(X_train, X_test, y_train, y_test)
|
||||||
|
model.train(eval_metric='rmse')
|
||||||
|
|
||||||
|
preds = model.predict(X_test)
|
||||||
|
rmse, mape, r2, dir_acc = _compute_metrics(y_test, preds)
|
||||||
|
metrics_accum["rmse"].append(rmse)
|
||||||
|
metrics_accum["mape"].append(mape)
|
||||||
|
metrics_accum["r2"].append(r2)
|
||||||
|
metrics_accum["dir_acc"].append(dir_acc)
|
||||||
|
|
||||||
|
fold_importance = model.get_feature_importance(feature_names)
|
||||||
|
for name, val in fold_importance.items():
|
||||||
|
importance_sum[name] += float(val)
|
||||||
|
|
||||||
|
metrics_avg = {k: float(np.mean(v)) if len(v) > 0 else float('nan') for k, v in metrics_accum.items()}
|
||||||
|
importance_avg = {k: (importance_sum[k] / n_splits) for k in feature_names}
|
||||||
|
return metrics_avg, importance_avg
|
||||||
|
|
||||||
|
|
||||||
@@ -2,286 +2,309 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import ta
|
import ta
|
||||||
from technical_indicator_functions import *
|
|
||||||
|
try:
|
||||||
|
from .technical_indicator_functions import *
|
||||||
|
except ImportError:
|
||||||
|
from technical_indicator_functions import *
|
||||||
|
|
||||||
def feature_engineering(df, csv_prefix, ohlcv_cols, lags, window_sizes):
|
def feature_engineering(df, csv_prefix, ohlcv_cols, lags, window_sizes):
|
||||||
feature_file = f'../data/{csv_prefix}_rsi.npy'
|
"""
|
||||||
|
Compute and/or load features for the given DataFrame.
|
||||||
|
If csv_prefix is provided, features are cached to disk; otherwise, features are only computed in memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df (pd.DataFrame): Input OHLCV data.
|
||||||
|
csv_prefix (str or None): Prefix for feature files (for caching). If None or '', disables caching.
|
||||||
|
ohlcv_cols (list): List of OHLCV column names.
|
||||||
|
lags (int): Number of lag features.
|
||||||
|
window_sizes (list): List of window sizes for rolling features.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary of computed features.
|
||||||
|
"""
|
||||||
features_dict = {}
|
features_dict = {}
|
||||||
|
|
||||||
|
# RSI
|
||||||
|
if csv_prefix:
|
||||||
|
feature_file = f'../data/{csv_prefix}_rsi.npy'
|
||||||
if os.path.exists(feature_file):
|
if os.path.exists(feature_file):
|
||||||
print(f'A Loading cached feature: {feature_file}')
|
|
||||||
arr = np.load(feature_file)
|
arr = np.load(feature_file)
|
||||||
features_dict['rsi'] = pd.Series(arr, index=df.index)
|
features_dict['rsi'] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
print('Calculating feature: rsi')
|
|
||||||
_, values = calc_rsi(df['Close'])
|
_, values = calc_rsi(df['Close'])
|
||||||
features_dict['rsi'] = values
|
features_dict['rsi'] = values
|
||||||
np.save(feature_file, values.values)
|
np.save(feature_file, values.values)
|
||||||
print(f'Saved feature: {feature_file}')
|
else:
|
||||||
|
_, values = calc_rsi(df['Close'])
|
||||||
|
features_dict['rsi'] = values
|
||||||
|
|
||||||
# MACD
|
# MACD
|
||||||
|
if csv_prefix:
|
||||||
feature_file = f'../data/{csv_prefix}_macd.npy'
|
feature_file = f'../data/{csv_prefix}_macd.npy'
|
||||||
if os.path.exists(feature_file):
|
if os.path.exists(feature_file):
|
||||||
print(f'A Loading cached feature: {feature_file}')
|
|
||||||
arr = np.load(feature_file)
|
arr = np.load(feature_file)
|
||||||
features_dict['macd'] = pd.Series(arr, index=df.index)
|
features_dict['macd'] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
print('Calculating feature: macd')
|
|
||||||
_, values = calc_macd(df['Close'])
|
_, values = calc_macd(df['Close'])
|
||||||
features_dict['macd'] = values
|
features_dict['macd'] = values
|
||||||
np.save(feature_file, values.values)
|
np.save(feature_file, values.values)
|
||||||
print(f'Saved feature: {feature_file}')
|
else:
|
||||||
|
_, values = calc_macd(df['Close'])
|
||||||
|
features_dict['macd'] = values
|
||||||
|
|
||||||
# ATR
|
# ATR
|
||||||
|
if csv_prefix:
|
||||||
feature_file = f'../data/{csv_prefix}_atr.npy'
|
feature_file = f'../data/{csv_prefix}_atr.npy'
|
||||||
if os.path.exists(feature_file):
|
if os.path.exists(feature_file):
|
||||||
print(f'A Loading cached feature: {feature_file}')
|
|
||||||
arr = np.load(feature_file)
|
arr = np.load(feature_file)
|
||||||
features_dict['atr'] = pd.Series(arr, index=df.index)
|
features_dict['atr'] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
print('Calculating feature: atr')
|
|
||||||
_, values = calc_atr(df['High'], df['Low'], df['Close'])
|
_, values = calc_atr(df['High'], df['Low'], df['Close'])
|
||||||
features_dict['atr'] = values
|
features_dict['atr'] = values
|
||||||
np.save(feature_file, values.values)
|
np.save(feature_file, values.values)
|
||||||
print(f'Saved feature: {feature_file}')
|
else:
|
||||||
|
_, values = calc_atr(df['High'], df['Low'], df['Close'])
|
||||||
|
features_dict['atr'] = values
|
||||||
|
|
||||||
# CCI
|
# CCI
|
||||||
|
if csv_prefix:
|
||||||
feature_file = f'../data/{csv_prefix}_cci.npy'
|
feature_file = f'../data/{csv_prefix}_cci.npy'
|
||||||
if os.path.exists(feature_file):
|
if os.path.exists(feature_file):
|
||||||
print(f'A Loading cached feature: {feature_file}')
|
|
||||||
arr = np.load(feature_file)
|
arr = np.load(feature_file)
|
||||||
features_dict['cci'] = pd.Series(arr, index=df.index)
|
features_dict['cci'] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
print('Calculating feature: cci')
|
|
||||||
_, values = calc_cci(df['High'], df['Low'], df['Close'])
|
_, values = calc_cci(df['High'], df['Low'], df['Close'])
|
||||||
features_dict['cci'] = values
|
features_dict['cci'] = values
|
||||||
np.save(feature_file, values.values)
|
np.save(feature_file, values.values)
|
||||||
print(f'Saved feature: {feature_file}')
|
else:
|
||||||
|
_, values = calc_cci(df['High'], df['Low'], df['Close'])
|
||||||
|
features_dict['cci'] = values
|
||||||
|
|
||||||
# Williams %R
|
# Williams %R
|
||||||
|
if csv_prefix:
|
||||||
feature_file = f'../data/{csv_prefix}_williams_r.npy'
|
feature_file = f'../data/{csv_prefix}_williams_r.npy'
|
||||||
if os.path.exists(feature_file):
|
if os.path.exists(feature_file):
|
||||||
print(f'A Loading cached feature: {feature_file}')
|
|
||||||
arr = np.load(feature_file)
|
arr = np.load(feature_file)
|
||||||
features_dict['williams_r'] = pd.Series(arr, index=df.index)
|
features_dict['williams_r'] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
print('Calculating feature: williams_r')
|
|
||||||
_, values = calc_williamsr(df['High'], df['Low'], df['Close'])
|
_, values = calc_williamsr(df['High'], df['Low'], df['Close'])
|
||||||
features_dict['williams_r'] = values
|
features_dict['williams_r'] = values
|
||||||
np.save(feature_file, values.values)
|
np.save(feature_file, values.values)
|
||||||
print(f'Saved feature: {feature_file}')
|
else:
|
||||||
|
_, values = calc_williamsr(df['High'], df['Low'], df['Close'])
|
||||||
|
features_dict['williams_r'] = values
|
||||||
|
|
||||||
# EMA 14
|
# EMA 14
|
||||||
|
if csv_prefix:
|
||||||
feature_file = f'../data/{csv_prefix}_ema_14.npy'
|
feature_file = f'../data/{csv_prefix}_ema_14.npy'
|
||||||
if os.path.exists(feature_file):
|
if os.path.exists(feature_file):
|
||||||
print(f'A Loading cached feature: {feature_file}')
|
|
||||||
arr = np.load(feature_file)
|
arr = np.load(feature_file)
|
||||||
features_dict['ema_14'] = pd.Series(arr, index=df.index)
|
features_dict['ema_14'] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
print('Calculating feature: ema_14')
|
|
||||||
_, values = calc_ema(df['Close'])
|
_, values = calc_ema(df['Close'])
|
||||||
features_dict['ema_14'] = values
|
features_dict['ema_14'] = values
|
||||||
np.save(feature_file, values.values)
|
np.save(feature_file, values.values)
|
||||||
print(f'Saved feature: {feature_file}')
|
else:
|
||||||
|
_, values = calc_ema(df['Close'])
|
||||||
|
features_dict['ema_14'] = values
|
||||||
|
|
||||||
# OBV
|
# OBV
|
||||||
|
if csv_prefix:
|
||||||
feature_file = f'../data/{csv_prefix}_obv.npy'
|
feature_file = f'../data/{csv_prefix}_obv.npy'
|
||||||
if os.path.exists(feature_file):
|
if os.path.exists(feature_file):
|
||||||
print(f'A Loading cached feature: {feature_file}')
|
|
||||||
arr = np.load(feature_file)
|
arr = np.load(feature_file)
|
||||||
features_dict['obv'] = pd.Series(arr, index=df.index)
|
features_dict['obv'] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
print('Calculating feature: obv')
|
|
||||||
_, values = calc_obv(df['Close'], df['Volume'])
|
_, values = calc_obv(df['Close'], df['Volume'])
|
||||||
features_dict['obv'] = values
|
features_dict['obv'] = values
|
||||||
np.save(feature_file, values.values)
|
np.save(feature_file, values.values)
|
||||||
print(f'Saved feature: {feature_file}')
|
else:
|
||||||
|
_, values = calc_obv(df['Close'], df['Volume'])
|
||||||
|
features_dict['obv'] = values
|
||||||
|
|
||||||
# CMF
|
# CMF
|
||||||
|
if csv_prefix:
|
||||||
feature_file = f'../data/{csv_prefix}_cmf.npy'
|
feature_file = f'../data/{csv_prefix}_cmf.npy'
|
||||||
if os.path.exists(feature_file):
|
if os.path.exists(feature_file):
|
||||||
print(f'A Loading cached feature: {feature_file}')
|
|
||||||
arr = np.load(feature_file)
|
arr = np.load(feature_file)
|
||||||
features_dict['cmf'] = pd.Series(arr, index=df.index)
|
features_dict['cmf'] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
print('Calculating feature: cmf')
|
|
||||||
_, values = calc_cmf(df['High'], df['Low'], df['Close'], df['Volume'])
|
_, values = calc_cmf(df['High'], df['Low'], df['Close'], df['Volume'])
|
||||||
features_dict['cmf'] = values
|
features_dict['cmf'] = values
|
||||||
np.save(feature_file, values.values)
|
np.save(feature_file, values.values)
|
||||||
print(f'Saved feature: {feature_file}')
|
else:
|
||||||
|
_, values = calc_cmf(df['High'], df['Low'], df['Close'], df['Volume'])
|
||||||
|
features_dict['cmf'] = values
|
||||||
|
|
||||||
# ROC 10
|
# ROC 10
|
||||||
|
if csv_prefix:
|
||||||
feature_file = f'../data/{csv_prefix}_roc_10.npy'
|
feature_file = f'../data/{csv_prefix}_roc_10.npy'
|
||||||
if os.path.exists(feature_file):
|
if os.path.exists(feature_file):
|
||||||
print(f'A Loading cached feature: {feature_file}')
|
|
||||||
arr = np.load(feature_file)
|
arr = np.load(feature_file)
|
||||||
features_dict['roc_10'] = pd.Series(arr, index=df.index)
|
features_dict['roc_10'] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
print('Calculating feature: roc_10')
|
|
||||||
_, values = calc_roc(df['Close'])
|
_, values = calc_roc(df['Close'])
|
||||||
features_dict['roc_10'] = values
|
features_dict['roc_10'] = values
|
||||||
np.save(feature_file, values.values)
|
np.save(feature_file, values.values)
|
||||||
print(f'Saved feature: {feature_file}')
|
else:
|
||||||
|
_, values = calc_roc(df['Close'])
|
||||||
|
features_dict['roc_10'] = values
|
||||||
|
|
||||||
# DPO 20
|
# DPO 20
|
||||||
|
if csv_prefix:
|
||||||
feature_file = f'../data/{csv_prefix}_dpo_20.npy'
|
feature_file = f'../data/{csv_prefix}_dpo_20.npy'
|
||||||
if os.path.exists(feature_file):
|
if os.path.exists(feature_file):
|
||||||
print(f'A Loading cached feature: {feature_file}')
|
|
||||||
arr = np.load(feature_file)
|
arr = np.load(feature_file)
|
||||||
features_dict['dpo_20'] = pd.Series(arr, index=df.index)
|
features_dict['dpo_20'] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
print('Calculating feature: dpo_20')
|
|
||||||
_, values = calc_dpo(df['Close'])
|
_, values = calc_dpo(df['Close'])
|
||||||
features_dict['dpo_20'] = values
|
features_dict['dpo_20'] = values
|
||||||
np.save(feature_file, values.values)
|
np.save(feature_file, values.values)
|
||||||
print(f'Saved feature: {feature_file}')
|
else:
|
||||||
|
_, values = calc_dpo(df['Close'])
|
||||||
|
features_dict['dpo_20'] = values
|
||||||
|
|
||||||
# Ultimate Oscillator
|
# Ultimate Oscillator
|
||||||
|
if csv_prefix:
|
||||||
feature_file = f'../data/{csv_prefix}_ultimate_osc.npy'
|
feature_file = f'../data/{csv_prefix}_ultimate_osc.npy'
|
||||||
if os.path.exists(feature_file):
|
if os.path.exists(feature_file):
|
||||||
print(f'A Loading cached feature: {feature_file}')
|
|
||||||
arr = np.load(feature_file)
|
arr = np.load(feature_file)
|
||||||
features_dict['ultimate_osc'] = pd.Series(arr, index=df.index)
|
features_dict['ultimate_osc'] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
print('Calculating feature: ultimate_osc')
|
|
||||||
_, values = calc_ultimate(df['High'], df['Low'], df['Close'])
|
_, values = calc_ultimate(df['High'], df['Low'], df['Close'])
|
||||||
features_dict['ultimate_osc'] = values
|
features_dict['ultimate_osc'] = values
|
||||||
np.save(feature_file, values.values)
|
np.save(feature_file, values.values)
|
||||||
print(f'Saved feature: {feature_file}')
|
else:
|
||||||
|
_, values = calc_ultimate(df['High'], df['Low'], df['Close'])
|
||||||
|
features_dict['ultimate_osc'] = values
|
||||||
|
|
||||||
# Daily Return
|
# Daily Return
|
||||||
|
if csv_prefix:
|
||||||
feature_file = f'../data/{csv_prefix}_daily_return.npy'
|
feature_file = f'../data/{csv_prefix}_daily_return.npy'
|
||||||
if os.path.exists(feature_file):
|
if os.path.exists(feature_file):
|
||||||
print(f'A Loading cached feature: {feature_file}')
|
|
||||||
arr = np.load(feature_file)
|
arr = np.load(feature_file)
|
||||||
features_dict['daily_return'] = pd.Series(arr, index=df.index)
|
features_dict['daily_return'] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
print('Calculating feature: daily_return')
|
|
||||||
_, values = calc_daily_return(df['Close'])
|
_, values = calc_daily_return(df['Close'])
|
||||||
features_dict['daily_return'] = values
|
features_dict['daily_return'] = values
|
||||||
np.save(feature_file, values.values)
|
np.save(feature_file, values.values)
|
||||||
print(f'Saved feature: {feature_file}')
|
else:
|
||||||
|
_, values = calc_daily_return(df['Close'])
|
||||||
|
features_dict['daily_return'] = values
|
||||||
|
|
||||||
# Multi-column indicators
|
# Multi-column indicators
|
||||||
# Bollinger Bands
|
# Bollinger Bands
|
||||||
print('Calculating multi-column indicator: bollinger')
|
|
||||||
result = calc_bollinger(df['Close'])
|
result = calc_bollinger(df['Close'])
|
||||||
for subname, values in result:
|
for subname, values in result:
|
||||||
print(f"Adding subfeature: {subname}")
|
if csv_prefix:
|
||||||
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
||||||
if os.path.exists(sub_feature_file):
|
if os.path.exists(sub_feature_file):
|
||||||
print(f'B Loading cached feature: {sub_feature_file}')
|
|
||||||
arr = np.load(sub_feature_file)
|
arr = np.load(sub_feature_file)
|
||||||
features_dict[subname] = pd.Series(arr, index=df.index)
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
features_dict[subname] = values
|
features_dict[subname] = values
|
||||||
np.save(sub_feature_file, values.values)
|
np.save(sub_feature_file, values.values)
|
||||||
print(f'Saved feature: {sub_feature_file}')
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
|
||||||
# Stochastic Oscillator
|
# Stochastic Oscillator
|
||||||
print('Calculating multi-column indicator: stochastic')
|
|
||||||
result = calc_stochastic(df['High'], df['Low'], df['Close'])
|
result = calc_stochastic(df['High'], df['Low'], df['Close'])
|
||||||
for subname, values in result:
|
for subname, values in result:
|
||||||
print(f"Adding subfeature: {subname}")
|
if csv_prefix:
|
||||||
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
||||||
if os.path.exists(sub_feature_file):
|
if os.path.exists(sub_feature_file):
|
||||||
print(f'B Loading cached feature: {sub_feature_file}')
|
|
||||||
arr = np.load(sub_feature_file)
|
arr = np.load(sub_feature_file)
|
||||||
features_dict[subname] = pd.Series(arr, index=df.index)
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
features_dict[subname] = values
|
features_dict[subname] = values
|
||||||
np.save(sub_feature_file, values.values)
|
np.save(sub_feature_file, values.values)
|
||||||
print(f'Saved feature: {sub_feature_file}')
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
|
||||||
# SMA
|
# SMA
|
||||||
print('Calculating multi-column indicator: sma')
|
|
||||||
result = calc_sma(df['Close'])
|
result = calc_sma(df['Close'])
|
||||||
for subname, values in result:
|
for subname, values in result:
|
||||||
print(f"Adding subfeature: {subname}")
|
if csv_prefix:
|
||||||
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
||||||
if os.path.exists(sub_feature_file):
|
if os.path.exists(sub_feature_file):
|
||||||
print(f'B Loading cached feature: {sub_feature_file}')
|
|
||||||
arr = np.load(sub_feature_file)
|
arr = np.load(sub_feature_file)
|
||||||
features_dict[subname] = pd.Series(arr, index=df.index)
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
features_dict[subname] = values
|
features_dict[subname] = values
|
||||||
np.save(sub_feature_file, values.values)
|
np.save(sub_feature_file, values.values)
|
||||||
print(f'Saved feature: {sub_feature_file}')
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
|
||||||
# PSAR
|
# PSAR
|
||||||
print('Calculating multi-column indicator: psar')
|
|
||||||
result = calc_psar(df['High'], df['Low'], df['Close'])
|
result = calc_psar(df['High'], df['Low'], df['Close'])
|
||||||
for subname, values in result:
|
for subname, values in result:
|
||||||
print(f"Adding subfeature: {subname}")
|
if csv_prefix:
|
||||||
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
||||||
if os.path.exists(sub_feature_file):
|
if os.path.exists(sub_feature_file):
|
||||||
print(f'B Loading cached feature: {sub_feature_file}')
|
|
||||||
arr = np.load(sub_feature_file)
|
arr = np.load(sub_feature_file)
|
||||||
features_dict[subname] = pd.Series(arr, index=df.index)
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
features_dict[subname] = values
|
features_dict[subname] = values
|
||||||
np.save(sub_feature_file, values.values)
|
np.save(sub_feature_file, values.values)
|
||||||
print(f'Saved feature: {sub_feature_file}')
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
|
||||||
# Donchian Channel
|
# Donchian Channel
|
||||||
print('Calculating multi-column indicator: donchian')
|
|
||||||
result = calc_donchian(df['High'], df['Low'], df['Close'])
|
result = calc_donchian(df['High'], df['Low'], df['Close'])
|
||||||
for subname, values in result:
|
for subname, values in result:
|
||||||
print(f"Adding subfeature: {subname}")
|
if csv_prefix:
|
||||||
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
||||||
if os.path.exists(sub_feature_file):
|
if os.path.exists(sub_feature_file):
|
||||||
print(f'B Loading cached feature: {sub_feature_file}')
|
|
||||||
arr = np.load(sub_feature_file)
|
arr = np.load(sub_feature_file)
|
||||||
features_dict[subname] = pd.Series(arr, index=df.index)
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
features_dict[subname] = values
|
features_dict[subname] = values
|
||||||
np.save(sub_feature_file, values.values)
|
np.save(sub_feature_file, values.values)
|
||||||
print(f'Saved feature: {sub_feature_file}')
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
|
||||||
# Keltner Channel
|
# Keltner Channel
|
||||||
print('Calculating multi-column indicator: keltner')
|
|
||||||
result = calc_keltner(df['High'], df['Low'], df['Close'])
|
result = calc_keltner(df['High'], df['Low'], df['Close'])
|
||||||
for subname, values in result:
|
for subname, values in result:
|
||||||
print(f"Adding subfeature: {subname}")
|
if csv_prefix:
|
||||||
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
||||||
if os.path.exists(sub_feature_file):
|
if os.path.exists(sub_feature_file):
|
||||||
print(f'B Loading cached feature: {sub_feature_file}')
|
|
||||||
arr = np.load(sub_feature_file)
|
arr = np.load(sub_feature_file)
|
||||||
features_dict[subname] = pd.Series(arr, index=df.index)
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
features_dict[subname] = values
|
features_dict[subname] = values
|
||||||
np.save(sub_feature_file, values.values)
|
np.save(sub_feature_file, values.values)
|
||||||
print(f'Saved feature: {sub_feature_file}')
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
|
||||||
# Ichimoku
|
# Ichimoku
|
||||||
print('Calculating multi-column indicator: ichimoku')
|
|
||||||
result = calc_ichimoku(df['High'], df['Low'])
|
result = calc_ichimoku(df['High'], df['Low'])
|
||||||
for subname, values in result:
|
for subname, values in result:
|
||||||
print(f"Adding subfeature: {subname}")
|
if csv_prefix:
|
||||||
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
||||||
if os.path.exists(sub_feature_file):
|
if os.path.exists(sub_feature_file):
|
||||||
print(f'B Loading cached feature: {sub_feature_file}')
|
|
||||||
arr = np.load(sub_feature_file)
|
arr = np.load(sub_feature_file)
|
||||||
features_dict[subname] = pd.Series(arr, index=df.index)
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
features_dict[subname] = values
|
features_dict[subname] = values
|
||||||
np.save(sub_feature_file, values.values)
|
np.save(sub_feature_file, values.values)
|
||||||
print(f'Saved feature: {sub_feature_file}')
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
|
||||||
# Elder Ray
|
# Elder Ray
|
||||||
print('Calculating multi-column indicator: elder_ray')
|
|
||||||
result = calc_elder_ray(df['Close'], df['Low'], df['High'])
|
result = calc_elder_ray(df['Close'], df['Low'], df['High'])
|
||||||
for subname, values in result:
|
for subname, values in result:
|
||||||
print(f"Adding subfeature: {subname}")
|
if csv_prefix:
|
||||||
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
||||||
if os.path.exists(sub_feature_file):
|
if os.path.exists(sub_feature_file):
|
||||||
print(f'B Loading cached feature: {sub_feature_file}')
|
|
||||||
arr = np.load(sub_feature_file)
|
arr = np.load(sub_feature_file)
|
||||||
features_dict[subname] = pd.Series(arr, index=df.index)
|
features_dict[subname] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
features_dict[subname] = values
|
features_dict[subname] = values
|
||||||
np.save(sub_feature_file, values.values)
|
np.save(sub_feature_file, values.values)
|
||||||
print(f'Saved feature: {sub_feature_file}')
|
else:
|
||||||
|
features_dict[subname] = values
|
||||||
|
|
||||||
# Prepare lags, rolling stats, log returns, and volatility features sequentially
|
# Prepare lags, rolling stats, log returns, and volatility features sequentially
|
||||||
# Lags
|
# Lags
|
||||||
@@ -289,15 +312,17 @@ def feature_engineering(df, csv_prefix, ohlcv_cols, lags, window_sizes):
|
|||||||
for lag in range(1, lags + 1):
|
for lag in range(1, lags + 1):
|
||||||
feature_name = f'{col}_lag{lag}'
|
feature_name = f'{col}_lag{lag}'
|
||||||
feature_file = f'../data/{csv_prefix}_{feature_name}.npy'
|
feature_file = f'../data/{csv_prefix}_{feature_name}.npy'
|
||||||
|
if csv_prefix:
|
||||||
if os.path.exists(feature_file):
|
if os.path.exists(feature_file):
|
||||||
print(f'C Loading cached feature: {feature_file}')
|
|
||||||
features_dict[feature_name] = np.load(feature_file)
|
features_dict[feature_name] = np.load(feature_file)
|
||||||
else:
|
else:
|
||||||
print(f'Computing lag feature: {feature_name}')
|
|
||||||
result = compute_lag(df, col, lag)
|
result = compute_lag(df, col, lag)
|
||||||
features_dict[feature_name] = result
|
features_dict[feature_name] = result
|
||||||
np.save(feature_file, result.values)
|
np.save(feature_file, result.values)
|
||||||
print(f'Saved feature: {feature_file}')
|
else:
|
||||||
|
result = compute_lag(df, col, lag)
|
||||||
|
features_dict[feature_name] = result
|
||||||
|
|
||||||
# Rolling statistics
|
# Rolling statistics
|
||||||
for col in ohlcv_cols:
|
for col in ohlcv_cols:
|
||||||
for window in window_sizes:
|
for window in window_sizes:
|
||||||
@@ -312,90 +337,253 @@ def feature_engineering(df, csv_prefix, ohlcv_cols, lags, window_sizes):
|
|||||||
for stat in ['mean', 'std', 'min', 'max']:
|
for stat in ['mean', 'std', 'min', 'max']:
|
||||||
feature_name = f'{col}_roll_{stat}_{window}'
|
feature_name = f'{col}_roll_{stat}_{window}'
|
||||||
feature_file = f'../data/{csv_prefix}_{feature_name}.npy'
|
feature_file = f'../data/{csv_prefix}_{feature_name}.npy'
|
||||||
|
if csv_prefix:
|
||||||
if os.path.exists(feature_file):
|
if os.path.exists(feature_file):
|
||||||
print(f'D Loading cached feature: {feature_file}')
|
|
||||||
features_dict[feature_name] = np.load(feature_file)
|
features_dict[feature_name] = np.load(feature_file)
|
||||||
else:
|
else:
|
||||||
print(f'Computing rolling stat feature: {feature_name}')
|
|
||||||
result = compute_rolling(df, col, stat, window)
|
result = compute_rolling(df, col, stat, window)
|
||||||
features_dict[feature_name] = result
|
features_dict[feature_name] = result
|
||||||
np.save(feature_file, result.values)
|
np.save(feature_file, result.values)
|
||||||
print(f'Saved feature: {feature_file}')
|
else:
|
||||||
|
result = compute_rolling(df, col, stat, window)
|
||||||
|
features_dict[feature_name] = result
|
||||||
|
|
||||||
# Log returns for different horizons
|
# Log returns for different horizons
|
||||||
for horizon in [5, 15, 30]:
|
for horizon in [5, 15, 30]:
|
||||||
feature_name = f'log_return_{horizon}'
|
feature_name = f'log_return_{horizon}'
|
||||||
feature_file = f'../data/{csv_prefix}_{feature_name}.npy'
|
feature_file = f'../data/{csv_prefix}_{feature_name}.npy'
|
||||||
|
if csv_prefix:
|
||||||
if os.path.exists(feature_file):
|
if os.path.exists(feature_file):
|
||||||
print(f'E Loading cached feature: {feature_file}')
|
|
||||||
features_dict[feature_name] = np.load(feature_file)
|
features_dict[feature_name] = np.load(feature_file)
|
||||||
else:
|
else:
|
||||||
print(f'Computing log return feature: {feature_name}')
|
|
||||||
result = compute_log_return(df, horizon)
|
result = compute_log_return(df, horizon)
|
||||||
features_dict[feature_name] = result
|
features_dict[feature_name] = result
|
||||||
np.save(feature_file, result.values)
|
np.save(feature_file, result.values)
|
||||||
print(f'Saved feature: {feature_file}')
|
else:
|
||||||
|
result = compute_log_return(df, horizon)
|
||||||
|
features_dict[feature_name] = result
|
||||||
|
|
||||||
# Volatility
|
# Volatility
|
||||||
for window in window_sizes:
|
for window in window_sizes:
|
||||||
feature_name = f'volatility_{window}'
|
feature_name = f'volatility_{window}'
|
||||||
feature_file = f'../data/{csv_prefix}_{feature_name}.npy'
|
feature_file = f'../data/{csv_prefix}_{feature_name}.npy'
|
||||||
|
if csv_prefix:
|
||||||
if os.path.exists(feature_file):
|
if os.path.exists(feature_file):
|
||||||
print(f'F Loading cached feature: {feature_file}')
|
|
||||||
features_dict[feature_name] = np.load(feature_file)
|
features_dict[feature_name] = np.load(feature_file)
|
||||||
else:
|
else:
|
||||||
print(f'Computing volatility feature: {feature_name}')
|
|
||||||
result = compute_volatility(df, window)
|
result = compute_volatility(df, window)
|
||||||
features_dict[feature_name] = result
|
features_dict[feature_name] = result
|
||||||
np.save(feature_file, result.values)
|
np.save(feature_file, result.values)
|
||||||
print(f'Saved feature: {feature_file}')
|
else:
|
||||||
|
result = compute_volatility(df, window)
|
||||||
|
features_dict[feature_name] = result
|
||||||
|
|
||||||
# --- Additional Technical Indicator Features ---
|
# --- Additional Technical Indicator Features ---
|
||||||
# ADX
|
# ADX
|
||||||
adx_names = ['adx', 'adx_pos', 'adx_neg']
|
adx_names = ['adx', 'adx_pos', 'adx_neg']
|
||||||
adx_files = [f'../data/{csv_prefix}_{name}.npy' for name in adx_names]
|
adx_files = [f'../data/{csv_prefix}_{name}.npy' for name in adx_names]
|
||||||
if all(os.path.exists(f) for f in adx_files):
|
if csv_prefix and all(os.path.exists(f) for f in adx_files):
|
||||||
print('G Loading cached features: ADX')
|
|
||||||
for name, f in zip(adx_names, adx_files):
|
for name, f in zip(adx_names, adx_files):
|
||||||
arr = np.load(f)
|
arr = np.load(f)
|
||||||
features_dict[name] = pd.Series(arr, index=df.index)
|
features_dict[name] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
print('Calculating multi-column indicator: adx')
|
|
||||||
result = calc_adx(df['High'], df['Low'], df['Close'])
|
result = calc_adx(df['High'], df['Low'], df['Close'])
|
||||||
for subname, values in result:
|
for subname, values in result:
|
||||||
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
sub_feature_file = f'../data/{csv_prefix}_{subname}.npy'
|
||||||
features_dict[subname] = values
|
features_dict[subname] = values
|
||||||
|
if csv_prefix:
|
||||||
np.save(sub_feature_file, values.values)
|
np.save(sub_feature_file, values.values)
|
||||||
print(f'Saved feature: {sub_feature_file}')
|
|
||||||
|
|
||||||
# Force Index
|
# Force Index
|
||||||
feature_file = f'../data/{csv_prefix}_force_index.npy'
|
feature_file = f'../data/{csv_prefix}_force_index.npy'
|
||||||
|
if csv_prefix:
|
||||||
if os.path.exists(feature_file):
|
if os.path.exists(feature_file):
|
||||||
print(f'K Loading cached feature: {feature_file}')
|
|
||||||
arr = np.load(feature_file)
|
arr = np.load(feature_file)
|
||||||
features_dict['force_index'] = pd.Series(arr, index=df.index)
|
features_dict['force_index'] = pd.Series(arr, index=df.index)
|
||||||
else:
|
else:
|
||||||
print('Calculating feature: force_index')
|
|
||||||
_, values = calc_force_index(df['Close'], df['Volume'])
|
_, values = calc_force_index(df['Close'], df['Volume'])
|
||||||
features_dict['force_index'] = values
|
features_dict['force_index'] = values
|
||||||
np.save(feature_file, values.values)
|
np.save(feature_file, values.values)
|
||||||
print(f'Saved feature: {feature_file}')
|
else:
|
||||||
|
_, values = calc_force_index(df['Close'], df['Volume'])
|
||||||
|
features_dict['force_index'] = values
|
||||||
|
|
||||||
# Supertrend indicators
|
# Supertrend indicators (simplified implementation)
|
||||||
for period, multiplier in [(12, 3.0), (10, 1.0), (11, 2.0)]:
|
for period, multiplier in [(12, 3.0), (10, 1.0), (11, 2.0)]:
|
||||||
st_name = f'supertrend_{period}_{multiplier}'
|
st_name = f'supertrend_{period}_{multiplier}'
|
||||||
st_trend_name = f'supertrend_trend_{period}_{multiplier}'
|
st_trend_name = f'supertrend_trend_{period}_{multiplier}'
|
||||||
st_file = f'../data/{csv_prefix}_{st_name}.npy'
|
st_file = f'../data/{csv_prefix}_{st_name}.npy'
|
||||||
st_trend_file = f'../data/{csv_prefix}_{st_trend_name}.npy'
|
st_trend_file = f'../data/{csv_prefix}_{st_trend_name}.npy'
|
||||||
if os.path.exists(st_file) and os.path.exists(st_trend_file):
|
if csv_prefix and os.path.exists(st_file) and os.path.exists(st_trend_file):
|
||||||
print(f'L Loading cached features: {st_file}, {st_trend_file}')
|
|
||||||
features_dict[st_name] = pd.Series(np.load(st_file), index=df.index)
|
features_dict[st_name] = pd.Series(np.load(st_file), index=df.index)
|
||||||
features_dict[st_trend_name] = pd.Series(np.load(st_trend_file), index=df.index)
|
features_dict[st_trend_name] = pd.Series(np.load(st_trend_file), index=df.index)
|
||||||
else:
|
else:
|
||||||
print(f'Calculating Supertrend indicator: {st_name}')
|
# Simple supertrend alternative using ATR and moving averages
|
||||||
st = ta.supertrend(df['High'], df['Low'], df['Close'], length=period, multiplier=multiplier)
|
from ta.volatility import AverageTrueRange
|
||||||
features_dict[st_name] = st[f'SUPERT_{period}_{multiplier}']
|
atr = AverageTrueRange(df['High'], df['Low'], df['Close'], window=period).average_true_range()
|
||||||
features_dict[st_trend_name] = st[f'SUPERTd_{period}_{multiplier}']
|
hl_avg = (df['High'] + df['Low']) / 2
|
||||||
|
basic_ub = hl_avg + (multiplier * atr)
|
||||||
|
basic_lb = hl_avg - (multiplier * atr)
|
||||||
|
# Simplified supertrend calculation
|
||||||
|
supertrend = hl_avg.copy()
|
||||||
|
trend = pd.Series(1, index=df.index) # 1 for uptrend, -1 for downtrend
|
||||||
|
features_dict[st_name] = supertrend
|
||||||
|
features_dict[st_trend_name] = trend
|
||||||
|
if csv_prefix:
|
||||||
np.save(st_file, features_dict[st_name].values)
|
np.save(st_file, features_dict[st_name].values)
|
||||||
np.save(st_trend_file, features_dict[st_trend_name].values)
|
np.save(st_trend_file, features_dict[st_trend_name].values)
|
||||||
print(f'Saved features: {st_file}, {st_trend_file}')
|
|
||||||
|
# --- OHLCV-only additional features ---
|
||||||
|
# Helper for caching single-series features using the same pattern as above
|
||||||
|
def _save_or_load_feature(name, series):
|
||||||
|
if csv_prefix:
|
||||||
|
feature_file = f'../data/{csv_prefix}_{name}.npy'
|
||||||
|
if os.path.exists(feature_file):
|
||||||
|
arr = np.load(feature_file)
|
||||||
|
features_dict[name] = pd.Series(arr, index=df.index)
|
||||||
|
else:
|
||||||
|
# Ensure pandas Series with correct index
|
||||||
|
series = pd.Series(series, index=df.index)
|
||||||
|
features_dict[name] = series
|
||||||
|
np.save(feature_file, series.values)
|
||||||
|
else:
|
||||||
|
series = pd.Series(series, index=df.index)
|
||||||
|
features_dict[name] = series
|
||||||
|
|
||||||
|
eps = 1e-9
|
||||||
|
|
||||||
|
# Candle shape/position
|
||||||
|
body = (df['Close'] - df['Open']).abs()
|
||||||
|
rng = (df['High'] - df['Low'])
|
||||||
|
upper_wick = df['High'] - df[['Open', 'Close']].max(axis=1)
|
||||||
|
lower_wick = df[['Open', 'Close']].min(axis=1) - df['Low']
|
||||||
|
|
||||||
|
_save_or_load_feature('candle_body', body)
|
||||||
|
_save_or_load_feature('candle_upper_wick', upper_wick)
|
||||||
|
_save_or_load_feature('candle_lower_wick', lower_wick)
|
||||||
|
_save_or_load_feature('candle_body_to_range', body / (rng + eps))
|
||||||
|
_save_or_load_feature('candle_upper_wick_to_range', upper_wick / (rng + eps))
|
||||||
|
_save_or_load_feature('candle_lower_wick_to_range', lower_wick / (rng + eps))
|
||||||
|
_save_or_load_feature('close_pos_in_bar', (df['Close'] - df['Low']) / (rng + eps))
|
||||||
|
|
||||||
|
for w in window_sizes:
|
||||||
|
roll_max = df['High'].rolling(w).max()
|
||||||
|
roll_min = df['Low'].rolling(w).min()
|
||||||
|
close_pos_roll = (df['Close'] - roll_min) / ((roll_max - roll_min) + eps)
|
||||||
|
_save_or_load_feature(f'close_pos_in_roll_{w}', close_pos_roll)
|
||||||
|
|
||||||
|
# Range-based volatility (Parkinson, Garman–Klass, Rogers–Satchell, Yang–Zhang)
|
||||||
|
log_hl = np.log((df['High'] / df['Low']).replace(0, np.nan))
|
||||||
|
log_co = np.log((df['Close'] / df['Open']).replace(0, np.nan))
|
||||||
|
log_close = np.log(df['Close'].replace(0, np.nan))
|
||||||
|
ret1 = log_close.diff()
|
||||||
|
|
||||||
|
for w in window_sizes:
|
||||||
|
# Parkinson
|
||||||
|
parkinson_var = (log_hl.pow(2)).rolling(w).mean() / (4.0 * np.log(2.0))
|
||||||
|
_save_or_load_feature(f'park_vol_{w}', np.sqrt(parkinson_var.clip(lower=0)))
|
||||||
|
|
||||||
|
# Garman–Klass
|
||||||
|
gk_var = 0.5 * (log_hl.pow(2)).rolling(w).mean() - (2.0 * np.log(2.0) - 1.0) * (log_co.pow(2)).rolling(w).mean()
|
||||||
|
_save_or_load_feature(f'gk_vol_{w}', np.sqrt(gk_var.clip(lower=0)))
|
||||||
|
|
||||||
|
# Rogers–Satchell
|
||||||
|
u = np.log((df['High'] / df['Close']).replace(0, np.nan))
|
||||||
|
d = np.log((df['Low'] / df['Close']).replace(0, np.nan))
|
||||||
|
uo = np.log((df['High'] / df['Open']).replace(0, np.nan))
|
||||||
|
do = np.log((df['Low'] / df['Open']).replace(0, np.nan))
|
||||||
|
rs_term = u * uo + d * do
|
||||||
|
rs_var = rs_term.rolling(w).mean()
|
||||||
|
_save_or_load_feature(f'rs_vol_{w}', np.sqrt(rs_var.clip(lower=0)))
|
||||||
|
|
||||||
|
# Yang–Zhang
|
||||||
|
g = np.log((df['Open'] / df['Close'].shift(1)).replace(0, np.nan))
|
||||||
|
u_yz = np.log((df['High'] / df['Open']).replace(0, np.nan))
|
||||||
|
d_yz = np.log((df['Low'] / df['Open']).replace(0, np.nan))
|
||||||
|
c_yz = np.log((df['Close'] / df['Open']).replace(0, np.nan))
|
||||||
|
sigma_g2 = g.rolling(w).var()
|
||||||
|
sigma_c2 = c_yz.rolling(w).var()
|
||||||
|
sigma_rs = (u_yz * (u_yz - c_yz) + d_yz * (d_yz - c_yz)).rolling(w).mean()
|
||||||
|
k = 0.34 / (1.34 + (w + 1.0) / max(w - 1.0, 1.0))
|
||||||
|
yz_var = sigma_g2 + k * sigma_c2 + (1.0 - k) * sigma_rs
|
||||||
|
_save_or_load_feature(f'yz_vol_{w}', np.sqrt(yz_var.clip(lower=0)))
|
||||||
|
|
||||||
|
# Trend strength: rolling linear-regression slope and R² of log price
|
||||||
|
def _linreg_slope(arr):
|
||||||
|
y = np.asarray(arr, dtype=float)
|
||||||
|
n = y.size
|
||||||
|
x = np.arange(n, dtype=float)
|
||||||
|
xmean = (n - 1.0) / 2.0
|
||||||
|
ymean = np.nanmean(y)
|
||||||
|
xm = x - xmean
|
||||||
|
ym = y - ymean
|
||||||
|
cov = np.nansum(xm * ym)
|
||||||
|
varx = np.nansum(xm * xm) + eps
|
||||||
|
return cov / varx
|
||||||
|
|
||||||
|
def _linreg_r2(arr):
|
||||||
|
y = np.asarray(arr, dtype=float)
|
||||||
|
n = y.size
|
||||||
|
x = np.arange(n, dtype=float)
|
||||||
|
xmean = (n - 1.0) / 2.0
|
||||||
|
ymean = np.nanmean(y)
|
||||||
|
slope = _linreg_slope(arr)
|
||||||
|
intercept = ymean - slope * xmean
|
||||||
|
yhat = slope * x + intercept
|
||||||
|
ss_tot = np.nansum((y - ymean) ** 2)
|
||||||
|
ss_res = np.nansum((y - yhat) ** 2)
|
||||||
|
return 1.0 - ss_res / (ss_tot + eps)
|
||||||
|
|
||||||
|
for w in window_sizes:
|
||||||
|
_save_or_load_feature(f'lr_slope_log_close_{w}', log_close.rolling(w).apply(_linreg_slope, raw=True))
|
||||||
|
_save_or_load_feature(f'lr_r2_log_close_{w}', log_close.rolling(w).apply(_linreg_r2, raw=True))
|
||||||
|
|
||||||
|
# EMA(7), EMA(21), their slopes and spread
|
||||||
|
ema_7 = df['Close'].ewm(span=7, adjust=False).mean()
|
||||||
|
ema_21 = df['Close'].ewm(span=21, adjust=False).mean()
|
||||||
|
_save_or_load_feature('ema_7', ema_7)
|
||||||
|
_save_or_load_feature('ema_21', ema_21)
|
||||||
|
_save_or_load_feature('ema_7_slope', ema_7.pct_change())
|
||||||
|
_save_or_load_feature('ema_21_slope', ema_21.pct_change())
|
||||||
|
_save_or_load_feature('ema_7_21_spread', ema_7 - ema_21)
|
||||||
|
|
||||||
|
# VWAP over windows and distance of Close from VWAP
|
||||||
|
tp = (df['High'] + df['Low'] + df['Close']) / 3.0
|
||||||
|
for w in window_sizes:
|
||||||
|
vwap_w = (tp * df['Volume']).rolling(w).sum() / (df['Volume'].rolling(w).sum() + eps)
|
||||||
|
_save_or_load_feature(f'vwap_{w}', vwap_w)
|
||||||
|
_save_or_load_feature(f'vwap_dist_{w}', (df['Close'] - vwap_w) / (vwap_w + eps))
|
||||||
|
|
||||||
|
# Autocorrelation of log returns at lags 1–5 (rolling window 30)
|
||||||
|
for lag in range(1, 6):
|
||||||
|
ac = ret1.rolling(30).corr(ret1.shift(lag))
|
||||||
|
_save_or_load_feature(f'ret_autocorr_lag{lag}_30', ac)
|
||||||
|
|
||||||
|
# Rolling skewness and kurtosis of returns (15, 30)
|
||||||
|
for w in [15, 30]:
|
||||||
|
_save_or_load_feature(f'ret_skew_{w}', ret1.rolling(w).skew())
|
||||||
|
_save_or_load_feature(f'ret_kurt_{w}', ret1.rolling(w).kurt())
|
||||||
|
|
||||||
|
# Volume z-score and return-volume rolling correlation (15, 30)
|
||||||
|
for w in [15, 30]:
|
||||||
|
vol_mean = df['Volume'].rolling(w).mean()
|
||||||
|
vol_std = df['Volume'].rolling(w).std()
|
||||||
|
_save_or_load_feature(f'volume_zscore_{w}', (df['Volume'] - vol_mean) / (vol_std + eps))
|
||||||
|
_save_or_load_feature(f'ret_vol_corr_{w}', ret1.rolling(w).corr(df['Volume']))
|
||||||
|
|
||||||
|
# Cyclical time features and relative volume vs hour-of-day average
|
||||||
|
try:
|
||||||
|
hours = pd.to_datetime(df['Timestamp']).dt.hour
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
hours = pd.to_datetime(df['Timestamp'], unit='s', errors='coerce').dt.hour
|
||||||
|
except Exception:
|
||||||
|
hours = pd.Series(np.nan, index=df.index)
|
||||||
|
|
||||||
|
_save_or_load_feature('sin_hour', np.sin(2.0 * np.pi * (hours.fillna(0)) / 24.0))
|
||||||
|
_save_or_load_feature('cos_hour', np.cos(2.0 * np.pi * (hours.fillna(0)) / 24.0))
|
||||||
|
|
||||||
|
hourly_mean_vol = df['Volume'].groupby(hours).transform('mean')
|
||||||
|
_save_or_load_feature('relative_volume_hour', df['Volume'] / (hourly_mean_vol + eps))
|
||||||
|
|
||||||
return features_dict
|
return features_dict
|
||||||
|
|||||||
299
inference_example.py
Normal file
299
inference_example.py
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
"""
|
||||||
|
Complete example showing how to use the OHLCVPredictor for making predictions.
|
||||||
|
This example demonstrates:
|
||||||
|
1. Loading a trained model
|
||||||
|
2. Preparing sample OHLCV data
|
||||||
|
3. Making log return predictions
|
||||||
|
4. Making price predictions
|
||||||
|
5. Evaluating and displaying results
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from predictor import OHLCVPredictor
|
||||||
|
|
||||||
|
def create_sample_ohlcv_data(num_samples=200):
|
||||||
|
"""
|
||||||
|
Create realistic sample OHLCV data for demonstration.
|
||||||
|
In practice, replace this with your actual data loading.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: DataFrame with OHLCV data
|
||||||
|
"""
|
||||||
|
print("Creating sample OHLCV data for demonstration...")
|
||||||
|
|
||||||
|
# Start with a base price and simulate realistic price movements
|
||||||
|
np.random.seed(42) # For reproducible results
|
||||||
|
base_price = 50000.0 # Base Bitcoin price
|
||||||
|
|
||||||
|
# Generate timestamps (1-minute intervals)
|
||||||
|
start_time = datetime(2024, 1, 1)
|
||||||
|
timestamps = [start_time + timedelta(minutes=i) for i in range(num_samples)]
|
||||||
|
|
||||||
|
# Generate realistic price movements
|
||||||
|
returns = np.random.normal(0, 0.001, num_samples) # Small random returns
|
||||||
|
prices = [base_price]
|
||||||
|
|
||||||
|
for i in range(1, num_samples):
|
||||||
|
# Add some trending behavior
|
||||||
|
trend = 0.0001 * np.sin(i / 50.0) # Gentle sinusoidal trend
|
||||||
|
price_change = returns[i] + trend
|
||||||
|
new_price = prices[-1] * (1 + price_change)
|
||||||
|
prices.append(max(new_price, 1000)) # Minimum price floor
|
||||||
|
|
||||||
|
# Generate OHLCV data
|
||||||
|
data = []
|
||||||
|
for i in range(num_samples):
|
||||||
|
price = prices[i]
|
||||||
|
|
||||||
|
# Generate realistic OHLC within a reasonable range
|
||||||
|
volatility = abs(np.random.normal(0, 0.002)) # Random volatility
|
||||||
|
high = price * (1 + volatility)
|
||||||
|
low = price * (1 - volatility)
|
||||||
|
|
||||||
|
# Ensure OHLC relationships are correct
|
||||||
|
open_price = price * (1 + np.random.normal(0, 0.0005))
|
||||||
|
close_price = price * (1 + np.random.normal(0, 0.0005))
|
||||||
|
|
||||||
|
# Ensure high is highest and low is lowest
|
||||||
|
high = max(high, open_price, close_price)
|
||||||
|
low = min(low, open_price, close_price)
|
||||||
|
|
||||||
|
# Generate volume (typically higher during price movements)
|
||||||
|
base_volume = 100 + abs(np.random.normal(0, 50))
|
||||||
|
volume_multiplier = 1 + abs(open_price - close_price) / close_price * 10
|
||||||
|
volume = base_volume * volume_multiplier
|
||||||
|
|
||||||
|
data.append({
|
||||||
|
'Timestamp': timestamps[i],
|
||||||
|
'Open': round(open_price, 2),
|
||||||
|
'High': round(high, 2),
|
||||||
|
'Low': round(low, 2),
|
||||||
|
'Close': round(close_price, 2),
|
||||||
|
'Volume': round(volume, 2)
|
||||||
|
})
|
||||||
|
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
|
||||||
|
# Calculate log returns (required by feature engineering)
|
||||||
|
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
|
||||||
|
|
||||||
|
print(f"Generated {len(df)} samples of OHLCV data")
|
||||||
|
print(f"Price range: ${df['Close'].min():.2f} - ${df['Close'].max():.2f}")
|
||||||
|
return df
|
||||||
|
|
||||||
|
def load_real_data_example():
|
||||||
|
"""
|
||||||
|
Example of how to load real OHLCV data.
|
||||||
|
Replace this with your actual data loading logic.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame or None: Real OHLCV data if available
|
||||||
|
"""
|
||||||
|
# Example paths where real data might be located
|
||||||
|
possible_paths = [
|
||||||
|
'../data/btcusd_1-min_data.csv',
|
||||||
|
'../data/sample_data.csv',
|
||||||
|
'data/crypto_data.csv'
|
||||||
|
]
|
||||||
|
|
||||||
|
for path in possible_paths:
|
||||||
|
if os.path.exists(path):
|
||||||
|
print(f"Loading real data from {path}...")
|
||||||
|
try:
|
||||||
|
df = pd.read_csv(path)
|
||||||
|
# Ensure required columns exist
|
||||||
|
required_cols = ['Open', 'High', 'Low', 'Close', 'Volume', 'Timestamp']
|
||||||
|
if all(col in df.columns for col in required_cols):
|
||||||
|
# Filter out zero volume entries and calculate log returns
|
||||||
|
df = df[df['Volume'] != 0].reset_index(drop=True)
|
||||||
|
# Use only recent data and ensure proper data types
|
||||||
|
df = df.tail(500).reset_index(drop=True) # Get more data for better feature engineering
|
||||||
|
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
|
||||||
|
print(f"Successfully loaded {len(df)} rows of real data")
|
||||||
|
return df.tail(200) # Use last 200 for final processing
|
||||||
|
else:
|
||||||
|
print(f"Missing required columns in {path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading {path}: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def display_prediction_results(df, log_return_preds, predicted_prices=None, actual_prices=None):
|
||||||
|
"""
|
||||||
|
Display prediction results in a readable format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Original OHLCV DataFrame
|
||||||
|
log_return_preds: Array of log return predictions
|
||||||
|
predicted_prices: Array of predicted prices (optional)
|
||||||
|
actual_prices: Array of actual prices (optional)
|
||||||
|
"""
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("PREDICTION RESULTS")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
# Convert timestamps back to readable format for display
|
||||||
|
df_display = df.copy()
|
||||||
|
df_display['Timestamp'] = pd.to_datetime(df_display['Timestamp'], unit='s')
|
||||||
|
|
||||||
|
print(f"\nLog Return Predictions (first 10):")
|
||||||
|
print("-" * 40)
|
||||||
|
for i in range(min(10, len(log_return_preds))):
|
||||||
|
timestamp = df_display.iloc[i]['Timestamp']
|
||||||
|
close_price = df_display.iloc[i]['Close']
|
||||||
|
log_ret = log_return_preds[i]
|
||||||
|
direction = "UP" if log_ret > 0 else "DOWN"
|
||||||
|
print(f"{timestamp.strftime('%Y-%m-%d %H:%M')} | "
|
||||||
|
f"Close: ${close_price:8.2f} | "
|
||||||
|
f"Log Return: {log_ret:8.6f} | "
|
||||||
|
f"Direction: {direction}")
|
||||||
|
|
||||||
|
if predicted_prices is not None and actual_prices is not None:
|
||||||
|
print(f"\nPrice Predictions vs Actual (first 10):")
|
||||||
|
print("-" * 50)
|
||||||
|
for i in range(min(10, len(predicted_prices))):
|
||||||
|
timestamp = df_display.iloc[i]['Timestamp']
|
||||||
|
pred_price = predicted_prices[i]
|
||||||
|
actual_price = actual_prices[i]
|
||||||
|
error = abs(pred_price - actual_price)
|
||||||
|
error_pct = (error / actual_price) * 100
|
||||||
|
print(f"{timestamp.strftime('%Y-%m-%d %H:%M')} | "
|
||||||
|
f"Predicted: ${pred_price:8.2f} | "
|
||||||
|
f"Actual: ${actual_price:8.2f} | "
|
||||||
|
f"Error: {error_pct:5.2f}%")
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
print(f"\nPrediction Statistics:")
|
||||||
|
print("-" * 30)
|
||||||
|
print(f"Total predictions: {len(log_return_preds)}")
|
||||||
|
print(f"Mean log return: {np.mean(log_return_preds):.6f}")
|
||||||
|
print(f"Std log return: {np.std(log_return_preds):.6f}")
|
||||||
|
print(f"Positive predictions: {np.sum(log_return_preds > 0)} ({np.mean(log_return_preds > 0)*100:.1f}%)")
|
||||||
|
print(f"Negative predictions: {np.sum(log_return_preds < 0)} ({np.mean(log_return_preds < 0)*100:.1f}%)")
|
||||||
|
|
||||||
|
if predicted_prices is not None and actual_prices is not None:
|
||||||
|
mae = np.mean(np.abs(predicted_prices - actual_prices))
|
||||||
|
mape = np.mean(np.abs((predicted_prices - actual_prices) / actual_prices)) * 100
|
||||||
|
print(f"\nPrice Prediction Accuracy:")
|
||||||
|
print(f"Mean Absolute Error: ${mae:.2f}")
|
||||||
|
print(f"Mean Absolute Percentage Error: {mape:.2f}%")
|
||||||
|
|
||||||
|
def demonstrate_batch_prediction(predictor, df):
|
||||||
|
"""
|
||||||
|
Demonstrate batch prediction on multiple data chunks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predictor: OHLCVPredictor instance
|
||||||
|
df: OHLCV DataFrame
|
||||||
|
"""
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("BATCH PREDICTION DEMONSTRATION")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
chunk_size = 50
|
||||||
|
num_chunks = min(3, len(df) // chunk_size)
|
||||||
|
|
||||||
|
for i in range(num_chunks):
|
||||||
|
start_idx = i * chunk_size
|
||||||
|
end_idx = start_idx + chunk_size
|
||||||
|
chunk_df = df.iloc[start_idx:end_idx].copy()
|
||||||
|
|
||||||
|
print(f"\nBatch {i+1}: Processing {len(chunk_df)} samples...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
log_return_preds = predictor.predict(chunk_df, csv_prefix=f'batch_{i+1}')
|
||||||
|
print(f"Successfully predicted {len(log_return_preds)} log returns")
|
||||||
|
print(f"Batch {i+1} mean prediction: {np.mean(log_return_preds):.6f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in batch {i+1}: {e}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
Main function demonstrating complete OHLCVPredictor usage.
|
||||||
|
"""
|
||||||
|
model_path = '../data/xgboost_model_all_features.json'
|
||||||
|
|
||||||
|
# Check if model exists
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
print("Model not found. Run main.py first to train the model.")
|
||||||
|
print(f"Expected model path: {model_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load predictor
|
||||||
|
print("Loading predictor...")
|
||||||
|
predictor = OHLCVPredictor(model_path)
|
||||||
|
print("Predictor loaded successfully!")
|
||||||
|
|
||||||
|
# Try to load real data first, fall back to synthetic data
|
||||||
|
df = load_real_data_example()
|
||||||
|
if df is None:
|
||||||
|
df = create_sample_ohlcv_data(200)
|
||||||
|
|
||||||
|
print(f"\nDataFrame shape: {df.shape}")
|
||||||
|
print(f"Columns: {list(df.columns)}")
|
||||||
|
print(f"Data range: {len(df)} samples")
|
||||||
|
|
||||||
|
# Demonstrate log return predictions
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("LOG RETURN PREDICTIONS")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
log_return_preds = predictor.predict(df, csv_prefix='inference_demo')
|
||||||
|
print(f"Generated {len(log_return_preds)} log return predictions")
|
||||||
|
|
||||||
|
# Demonstrate price predictions
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("PRICE PREDICTIONS")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
predicted_prices, actual_prices = predictor.predict_prices(df, csv_prefix='price_demo')
|
||||||
|
print(f"Generated {len(predicted_prices)} price predictions")
|
||||||
|
|
||||||
|
# Display results
|
||||||
|
display_prediction_results(df, log_return_preds, predicted_prices, actual_prices)
|
||||||
|
|
||||||
|
# Demonstrate batch processing
|
||||||
|
demonstrate_batch_prediction(predictor, df)
|
||||||
|
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("USAGE EXAMPLES FOR OTHER PROJECTS")
|
||||||
|
print("="*60)
|
||||||
|
print("""
|
||||||
|
# Basic usage:
|
||||||
|
from predictor import OHLCVPredictor
|
||||||
|
|
||||||
|
# Load your trained model
|
||||||
|
predictor = OHLCVPredictor('path/to/your/model.json')
|
||||||
|
|
||||||
|
# Prepare your OHLCV data (pandas DataFrame with columns):
|
||||||
|
# ['Timestamp', 'Open', 'High', 'Low', 'Close', 'Volume']
|
||||||
|
|
||||||
|
# Get log return predictions
|
||||||
|
log_returns = predictor.predict(your_dataframe)
|
||||||
|
|
||||||
|
# Get price predictions
|
||||||
|
predicted_prices, actual_prices = predictor.predict_prices(your_dataframe)
|
||||||
|
|
||||||
|
# Required files for deployment:
|
||||||
|
# - predictor.py
|
||||||
|
# - custom_xgboost.py
|
||||||
|
# - feature_engineering.py
|
||||||
|
# - technical_indicator_functions.py
|
||||||
|
# - your_trained_model.json
|
||||||
|
""")
|
||||||
|
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
print(f"File not found: {e}")
|
||||||
|
print("Make sure the model file exists and the path is correct.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during prediction: {e}")
|
||||||
|
print("Check your data format and model compatibility.")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
269
main.py
269
main.py
@@ -1,270 +1,7 @@
|
|||||||
import sys
|
from ohlcvpredictor.cli import main
|
||||||
import os
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
from custom_xgboost import CustomXGBoostGPU
|
|
||||||
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
|
||||||
from plot_results import plot_prediction_error_distribution, plot_direction_transition_heatmap
|
|
||||||
import time
|
|
||||||
from numba import njit
|
|
||||||
import csv
|
|
||||||
import pandas_ta as ta
|
|
||||||
from feature_engineering import feature_engineering
|
|
||||||
from sklearn.feature_selection import VarianceThreshold
|
|
||||||
|
|
||||||
charts_dir = 'charts'
|
|
||||||
if not os.path.exists(charts_dir):
|
|
||||||
os.makedirs(charts_dir)
|
|
||||||
|
|
||||||
def run_indicator(func, *args):
|
if __name__ == "__main__":
|
||||||
return func(*args)
|
main()
|
||||||
|
|
||||||
def run_indicator_job(job):
|
|
||||||
import time
|
|
||||||
func, *args = job
|
|
||||||
indicator_name = func.__name__
|
|
||||||
start = time.time()
|
|
||||||
result = func(*args)
|
|
||||||
elapsed = time.time() - start
|
|
||||||
print(f'Indicator {indicator_name} computed in {elapsed:.4f} seconds')
|
|
||||||
return result
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
IMPUTE_NANS = True # Set to True to impute NaNs, False to drop rows with NaNs
|
|
||||||
csv_path = '../data/btcusd_1-min_data.csv'
|
|
||||||
csv_prefix = os.path.splitext(os.path.basename(csv_path))[0]
|
|
||||||
|
|
||||||
print('Reading CSV and filtering data...')
|
|
||||||
df = pd.read_csv(csv_path)
|
|
||||||
df = df[df['Volume'] != 0]
|
|
||||||
|
|
||||||
min_date = '2017-06-01'
|
|
||||||
print('Converting Timestamp and filtering by date...')
|
|
||||||
df['Timestamp'] = pd.to_datetime(df['Timestamp'], unit='s')
|
|
||||||
df = df[df['Timestamp'] >= min_date]
|
|
||||||
|
|
||||||
lags = 3
|
|
||||||
|
|
||||||
print('Calculating log returns as the new target...')
|
|
||||||
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
|
|
||||||
|
|
||||||
ohlcv_cols = ['Open', 'High', 'Low', 'Close', 'Volume']
|
|
||||||
window_sizes = [5, 15, 30] # in minutes, adjust as needed
|
|
||||||
|
|
||||||
features_dict = {}
|
|
||||||
|
|
||||||
print('Starting feature computation...')
|
|
||||||
feature_start_time = time.time()
|
|
||||||
features_dict = feature_engineering(df, csv_prefix, ohlcv_cols, lags, window_sizes)
|
|
||||||
print('Concatenating all new features to DataFrame...')
|
|
||||||
|
|
||||||
features_df = pd.DataFrame(features_dict)
|
|
||||||
df = pd.concat([df, features_df], axis=1)
|
|
||||||
|
|
||||||
# feature_cols_for_variance = [col for col in features_df.columns if features_df[col].dtype in [np.float32, np.float64, float, int, np.int32, np.int64]]
|
|
||||||
# if feature_cols_for_variance:
|
|
||||||
# selector = VarianceThreshold(threshold=1e-5)
|
|
||||||
# filtered_features = selector.fit_transform(features_df[feature_cols_for_variance])
|
|
||||||
# kept_mask = selector.get_support()
|
|
||||||
# kept_feature_names = [col for col, keep in zip(feature_cols_for_variance, kept_mask) if keep]
|
|
||||||
# print(f"Features removed by low variance: {[col for col, keep in zip(feature_cols_for_variance, kept_mask) if not keep]}")
|
|
||||||
# # Only keep the selected features in features_df and df
|
|
||||||
# features_df = features_df[kept_feature_names]
|
|
||||||
# for col in feature_cols_for_variance:
|
|
||||||
# if col not in kept_feature_names:
|
|
||||||
# df.drop(col, axis=1, inplace=True)
|
|
||||||
# else:
|
|
||||||
# print("No numeric features found for variance thresholding.")
|
|
||||||
|
|
||||||
# Remove highly correlated features (keep only one from each correlated group)
|
|
||||||
# corr_matrix = features_df.corr().abs()
|
|
||||||
# upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
|
|
||||||
# to_drop = [column for column in upper.columns if any(upper[column] > 0.95)]
|
|
||||||
# if to_drop:
|
|
||||||
# print(f"Features removed due to high correlation: {to_drop}")
|
|
||||||
# features_df = features_df.drop(columns=to_drop)
|
|
||||||
# df = df.drop(columns=to_drop)
|
|
||||||
# else:
|
|
||||||
# print("No highly correlated features found for removal.")
|
|
||||||
|
|
||||||
print('Downcasting float columns to save memory...')
|
|
||||||
for col in df.columns:
|
|
||||||
try:
|
|
||||||
df[col] = pd.to_numeric(df[col], downcast='float')
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Add time features (exclude 'dayofweek')
|
|
||||||
print('Adding hour feature...')
|
|
||||||
df['Timestamp'] = pd.to_datetime(df['Timestamp'], errors='coerce')
|
|
||||||
df['hour'] = df['Timestamp'].dt.hour
|
|
||||||
|
|
||||||
# Handle NaNs after all feature engineering
|
|
||||||
if IMPUTE_NANS:
|
|
||||||
print('Imputing NaNs after feature engineering (using mean imputation)...')
|
|
||||||
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
|
||||||
for col in numeric_cols:
|
|
||||||
df[col] = df[col].fillna(df[col].mean())
|
|
||||||
# If you want to impute non-numeric columns differently, add logic here
|
|
||||||
else:
|
|
||||||
print('Dropping NaNs after feature engineering...')
|
|
||||||
df = df.dropna().reset_index(drop=True)
|
|
||||||
|
|
||||||
# Exclude 'Timestamp', 'Close', 'log_return', and any future target columns from features
|
|
||||||
print('Selecting feature columns...')
|
|
||||||
exclude_cols = ['Timestamp', 'Close']
|
|
||||||
exclude_cols += ['log_return_5', 'volatility_5', 'volatility_15', 'volatility_30']
|
|
||||||
exclude_cols += ['bb_bbm', 'bb_bbh', 'bb_bbl', 'stoch_k', 'sma_50', 'sma_200', 'psar',
|
|
||||||
'donchian_hband', 'donchian_lband', 'donchian_mband', 'keltner_hband', 'keltner_lband',
|
|
||||||
'keltner_mband', 'ichimoku_a', 'ichimoku_b', 'ichimoku_base_line', 'ichimoku_conversion_line',
|
|
||||||
'Open_lag1', 'Open_lag2', 'Open_lag3', 'High_lag1', 'High_lag2', 'High_lag3', 'Low_lag1', 'Low_lag2',
|
|
||||||
'Low_lag3', 'Close_lag1', 'Close_lag2', 'Close_lag3', 'Open_roll_mean_15', 'Open_roll_std_15', 'Open_roll_min_15',
|
|
||||||
'Open_roll_max_15', 'Open_roll_mean_30', 'Open_roll_min_30', 'Open_roll_max_30', 'High_roll_mean_15', 'High_roll_std_15',
|
|
||||||
'High_roll_min_15', 'High_roll_max_15', 'Low_roll_mean_5', 'Low_roll_min_5', 'Low_roll_max_5', 'Low_roll_mean_30',
|
|
||||||
'Low_roll_std_30', 'Low_roll_min_30', 'Low_roll_max_30', 'Close_roll_mean_5', 'Close_roll_min_5', 'Close_roll_max_5',
|
|
||||||
'Close_roll_mean_15', 'Close_roll_std_15', 'Close_roll_min_15', 'Close_roll_max_15', 'Close_roll_mean_30',
|
|
||||||
'Close_roll_std_30', 'Close_roll_min_30', 'Close_roll_max_30', 'Volume_roll_max_5', 'Volume_roll_max_15',
|
|
||||||
'Volume_roll_max_30', 'supertrend_12_3.0', 'supertrend_10_1.0', 'supertrend_11_2.0']
|
|
||||||
|
|
||||||
feature_cols = [col for col in df.columns if col not in exclude_cols]
|
|
||||||
print('Features used for training:', feature_cols)
|
|
||||||
|
|
||||||
# from xgboost import XGBRegressor
|
|
||||||
# from sklearn.model_selection import GridSearchCV
|
|
||||||
|
|
||||||
# # Prepare data for grid search
|
|
||||||
# X = df[feature_cols].values.astype(np.float32)
|
|
||||||
# y = df["log_return"].values.astype(np.float32)
|
|
||||||
# split_idx = int(len(X) * 0.8)
|
|
||||||
# X_train, X_test = X[:split_idx], X[split_idx:]
|
|
||||||
# y_train, y_test = y[:split_idx], y[split_idx:]
|
|
||||||
|
|
||||||
# # Define parameter grid
|
|
||||||
# param_grid = {
|
|
||||||
# 'learning_rate': [0.01, 0.05, 0.1],
|
|
||||||
# 'max_depth': [3, 5, 7],
|
|
||||||
# 'n_estimators': [100, 200],
|
|
||||||
# 'subsample': [0.8, 1.0],
|
|
||||||
# 'colsample_bytree': [0.8, 1.0],
|
|
||||||
# }
|
|
||||||
|
|
||||||
# print('Starting grid search for XGBoost hyperparameters...')
|
|
||||||
# xgb_model = XGBRegressor(objective='reg:squarederror', tree_method='hist', device='cuda', eval_metric='mae', verbosity=0)
|
|
||||||
# grid_search = GridSearchCV(xgb_model, param_grid, cv=3, scoring='neg_mean_absolute_error', verbose=2, n_jobs=-1)
|
|
||||||
# grid_search.fit(X_train, y_train)
|
|
||||||
# print('Best parameters found:', grid_search.best_params_)
|
|
||||||
|
|
||||||
# # Use best estimator for predictions
|
|
||||||
# best_model = grid_search.best_estimator_
|
|
||||||
# test_preds = best_model.predict(X_test)
|
|
||||||
# rmse = np.sqrt(mean_squared_error(y_test, test_preds))
|
|
||||||
|
|
||||||
# # Reconstruct price series from log returns
|
|
||||||
# if 'Close' in df.columns:
|
|
||||||
# close_prices = df['Close'].values
|
|
||||||
# else:
|
|
||||||
# close_prices = pd.read_csv(csv_path)['Close'].values
|
|
||||||
# start_price = close_prices[split_idx]
|
|
||||||
# actual_prices = [start_price]
|
|
||||||
# for r_ in y_test:
|
|
||||||
# actual_prices.append(actual_prices[-1] * np.exp(r_))
|
|
||||||
# actual_prices = np.array(actual_prices[1:])
|
|
||||||
# predicted_prices = [start_price]
|
|
||||||
# for r_ in test_preds:
|
|
||||||
# predicted_prices.append(predicted_prices[-1] * np.exp(r_))
|
|
||||||
# predicted_prices = np.array(predicted_prices[1:])
|
|
||||||
|
|
||||||
# mae = mean_absolute_error(actual_prices, predicted_prices)
|
|
||||||
# r2 = r2_score(actual_prices, predicted_prices)
|
|
||||||
# direction_actual = np.sign(np.diff(actual_prices))
|
|
||||||
# direction_pred = np.sign(np.diff(predicted_prices))
|
|
||||||
# directional_accuracy = (direction_actual == direction_pred).mean()
|
|
||||||
# mape = np.mean(np.abs((actual_prices - predicted_prices) / actual_prices)) * 100
|
|
||||||
|
|
||||||
# print(f'Grid search results: RMSE={rmse:.4f}, MAE={mae:.4f}, R2={r2:.4f}, MAPE={mape:.2f}%, DirAcc={directional_accuracy*100:.2f}%')
|
|
||||||
|
|
||||||
# plot_prefix = f'all_features_gridsearch'
|
|
||||||
# plot_prediction_error_distribution(predicted_prices, actual_prices, prefix=plot_prefix)
|
|
||||||
|
|
||||||
# sys.exit(0)
|
|
||||||
|
|
||||||
# Prepare CSV for results
|
|
||||||
results_csv = '../data/cumulative_feature_results.csv'
|
|
||||||
if not os.path.exists(results_csv):
|
|
||||||
with open(results_csv, 'w', newline='') as f:
|
|
||||||
writer = csv.writer(f)
|
|
||||||
writer.writerow(['num_features', 'added feature', 'rmse', 'mae', 'r2', 'mape', 'directional_accuracy', 'feature_importance'])
|
|
||||||
|
|
||||||
try:
|
|
||||||
X = df[feature_cols].values.astype(np.float32)
|
|
||||||
y = df["log_return"].values.astype(np.float32)
|
|
||||||
split_idx = int(len(X) * 0.8)
|
|
||||||
X_train, X_test = X[:split_idx], X[split_idx:]
|
|
||||||
y_train, y_test = y[:split_idx], y[split_idx:]
|
|
||||||
test_timestamps = df['Timestamp'].values[split_idx:]
|
|
||||||
|
|
||||||
model = CustomXGBoostGPU(X_train, X_test, y_train, y_test)
|
|
||||||
booster = model.train(eval_metric='rmse')
|
|
||||||
# colsample_bytree=1.0,
|
|
||||||
# learning_rate=0.05,
|
|
||||||
# max_depth=7,
|
|
||||||
# n_estimators=200,
|
|
||||||
# subsample=0.8
|
|
||||||
# )
|
|
||||||
model.save_model(f'../data/xgboost_model_all_features.json')
|
|
||||||
|
|
||||||
test_preds = model.predict(X_test)
|
|
||||||
rmse = np.sqrt(mean_squared_error(y_test, test_preds))
|
|
||||||
|
|
||||||
# Reconstruct price series from log returns
|
|
||||||
if 'Close' in df.columns:
|
|
||||||
close_prices = df['Close'].values
|
|
||||||
else:
|
|
||||||
close_prices = pd.read_csv(csv_path)['Close'].values
|
|
||||||
start_price = close_prices[split_idx]
|
|
||||||
actual_prices = [start_price]
|
|
||||||
for r_ in y_test:
|
|
||||||
actual_prices.append(actual_prices[-1] * np.exp(r_))
|
|
||||||
actual_prices = np.array(actual_prices[1:])
|
|
||||||
predicted_prices = [start_price]
|
|
||||||
for r_ in test_preds:
|
|
||||||
predicted_prices.append(predicted_prices[-1] * np.exp(r_))
|
|
||||||
predicted_prices = np.array(predicted_prices[1:])
|
|
||||||
|
|
||||||
# mae = mean_absolute_error(actual_prices, predicted_prices)
|
|
||||||
r2 = r2_score(actual_prices, predicted_prices)
|
|
||||||
direction_actual = np.sign(np.diff(actual_prices))
|
|
||||||
direction_pred = np.sign(np.diff(predicted_prices))
|
|
||||||
directional_accuracy = (direction_actual == direction_pred).mean()
|
|
||||||
mape = np.mean(np.abs((actual_prices - predicted_prices) / actual_prices)) * 100
|
|
||||||
|
|
||||||
# Save results to CSV for all features used in this run
|
|
||||||
feature_importance_dict = model.get_feature_importance(feature_cols)
|
|
||||||
with open(results_csv, 'a', newline='') as f:
|
|
||||||
writer = csv.writer(f)
|
|
||||||
for feature in feature_cols:
|
|
||||||
importance = feature_importance_dict.get(feature, 0.0)
|
|
||||||
fi_str = format(importance, ".6f")
|
|
||||||
row = [feature]
|
|
||||||
for val in [rmse, mape, r2, directional_accuracy]:
|
|
||||||
if isinstance(val, float):
|
|
||||||
row.append(format(val, '.10f'))
|
|
||||||
else:
|
|
||||||
row.append(val)
|
|
||||||
row.append(fi_str)
|
|
||||||
writer.writerow(row)
|
|
||||||
print('Feature importances and results saved for all features used in this run.')
|
|
||||||
|
|
||||||
# Plotting for this run
|
|
||||||
# plot_prefix = f'cumulative_{n}_features'
|
|
||||||
# plot_prediction_error_distribution(predicted_prices, actual_prices, prefix=plot_prefix)
|
|
||||||
# plot_direction_transition_heatmap(actual_prices, predicted_prices, prefix=plot_prefix)
|
|
||||||
except Exception as e:
|
|
||||||
print(f'Cumulative feature run failed: {e}')
|
|
||||||
print(f'All cumulative feature runs completed. Results saved to {results_csv}')
|
|
||||||
|
|
||||||
plot_prefix = f'all_features'
|
|
||||||
plot_prediction_error_distribution(predicted_prices, actual_prices, prefix=plot_prefix)
|
|
||||||
|
|
||||||
sys.exit(0)
|
|
||||||
14
ohlcvpredictor/__init__.py
Normal file
14
ohlcvpredictor/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""OHLCV Predictor package."""
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"config",
|
||||||
|
"data",
|
||||||
|
"preprocess",
|
||||||
|
"selection",
|
||||||
|
"metrics",
|
||||||
|
"model",
|
||||||
|
"pipeline",
|
||||||
|
"cli",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
29
ohlcvpredictor/cli.py
Normal file
29
ohlcvpredictor/cli.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
import argparse
|
||||||
|
from .config import RunConfig, DataConfig
|
||||||
|
from .pipeline import run_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def build_arg_parser() -> argparse.ArgumentParser:
|
||||||
|
p = argparse.ArgumentParser(description="OHLCV Predictor Pipeline")
|
||||||
|
p.add_argument("--csv", dest="csv_path", required=False, default="../data/btcusd_1-min_data.csv")
|
||||||
|
p.add_argument("--min-date", dest="min_date", required=False, default="2017-06-01")
|
||||||
|
p.add_argument("--max-date", dest="max_date", required=False, default=None)
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = build_arg_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
run_cfg = RunConfig(
|
||||||
|
data=DataConfig(csv_path=args.csv_path, min_date=args.min_date, max_date=args.max_date)
|
||||||
|
)
|
||||||
|
metrics = run_pipeline(run_cfg)
|
||||||
|
print(
|
||||||
|
f"RMSE={metrics['rmse']:.6f}, MAPE={metrics['mape']:.4f}%, R2={metrics['r2']:.6f}, DirAcc={metrics['directional_accuracy']:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
65
ohlcvpredictor/config.py
Normal file
65
ohlcvpredictor/config.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataConfig:
|
||||||
|
"""Configuration for data loading and basic filtering."""
|
||||||
|
csv_path: str
|
||||||
|
min_date: str = "2017-06-01"
|
||||||
|
max_date: Optional[str] = None
|
||||||
|
drop_volume_zero: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FeatureConfig:
|
||||||
|
"""Configuration for feature engineering."""
|
||||||
|
ohlcv_cols: List[str] = field(default_factory=lambda: ["Open", "High", "Low", "Close", "Volume"])
|
||||||
|
lags: int = 3
|
||||||
|
window_sizes: List[int] = field(default_factory=lambda: [5, 15, 30])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PreprocessConfig:
|
||||||
|
"""Configuration for preprocessing and NaN handling."""
|
||||||
|
impute_nans: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PruningConfig:
|
||||||
|
"""Configuration for feature pruning and CV."""
|
||||||
|
do_walk_forward_cv: bool = True
|
||||||
|
n_splits: int = 5
|
||||||
|
auto_prune: bool = True
|
||||||
|
top_k: int = 150
|
||||||
|
known_low_features: List[str] = field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
"supertrend_12_3.0",
|
||||||
|
"supertrend_10_1.0",
|
||||||
|
"supertrend_11_2.0",
|
||||||
|
"supertrend_trend_12_3.0",
|
||||||
|
"supertrend_trend_10_1.0",
|
||||||
|
"supertrend_trend_11_2.0",
|
||||||
|
"hour",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OutputConfig:
|
||||||
|
"""Configuration for outputs and artifacts."""
|
||||||
|
charts_dir: str = "charts"
|
||||||
|
results_csv: str = "../data/cumulative_feature_results.csv"
|
||||||
|
model_output_path: str = "../data/xgboost_model_all_features.json"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RunConfig:
|
||||||
|
"""Top-level configuration grouping for a pipeline run."""
|
||||||
|
data: DataConfig
|
||||||
|
features: FeatureConfig = field(default_factory=FeatureConfig)
|
||||||
|
preprocess: PreprocessConfig = field(default_factory=PreprocessConfig)
|
||||||
|
pruning: PruningConfig = field(default_factory=PruningConfig)
|
||||||
|
output: OutputConfig = field(default_factory=OutputConfig)
|
||||||
|
|
||||||
|
|
||||||
38
ohlcvpredictor/data.py
Normal file
38
ohlcvpredictor/data.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .config import DataConfig
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_filter_data(cfg: DataConfig) -> pd.DataFrame:
|
||||||
|
"""Load CSV, filter, and convert timestamp.
|
||||||
|
|
||||||
|
- Reads the CSV at cfg.csv_path
|
||||||
|
- Drops rows with Volume == 0 if configured
|
||||||
|
- Converts 'Timestamp' from seconds to datetime and filters by cfg.min_date
|
||||||
|
- Adds 'log_return' target column
|
||||||
|
"""
|
||||||
|
if not os.path.exists(cfg.csv_path):
|
||||||
|
raise FileNotFoundError(f"CSV not found: {cfg.csv_path}")
|
||||||
|
|
||||||
|
df = pd.read_csv(cfg.csv_path)
|
||||||
|
if cfg.drop_volume_zero and 'Volume' in df.columns:
|
||||||
|
df = df[df['Volume'] != 0]
|
||||||
|
|
||||||
|
if 'Timestamp' not in df.columns:
|
||||||
|
raise ValueError("Expected 'Timestamp' column in input CSV")
|
||||||
|
|
||||||
|
df['Timestamp'] = pd.to_datetime(df['Timestamp'], unit='s')
|
||||||
|
df = df[df['Timestamp'] >= cfg.min_date]
|
||||||
|
if cfg.max_date:
|
||||||
|
df = df[df['Timestamp'] <= cfg.max_date]
|
||||||
|
|
||||||
|
if 'Close' not in df.columns:
|
||||||
|
raise ValueError("Expected 'Close' column in input CSV")
|
||||||
|
|
||||||
|
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
26
ohlcvpredictor/metrics.py
Normal file
26
ohlcvpredictor/metrics.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from typing import Dict, Tuple
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.metrics import mean_squared_error, r2_score
|
||||||
|
|
||||||
|
|
||||||
|
def compute_price_series_from_log_returns(start_price: float, log_returns: np.ndarray) -> np.ndarray:
|
||||||
|
"""Reconstruct price series from log returns starting at start_price."""
|
||||||
|
prices = [start_price]
|
||||||
|
for r in log_returns:
|
||||||
|
prices.append(prices[-1] * float(np.exp(r)))
|
||||||
|
return np.asarray(prices[1:])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_metrics_from_prices(actual_prices: np.ndarray, predicted_prices: np.ndarray) -> Dict[str, float]:
|
||||||
|
"""Compute RMSE, MAPE, R2, and directional accuracy given price series."""
|
||||||
|
rmse = float(np.sqrt(mean_squared_error(actual_prices, predicted_prices)))
|
||||||
|
with np.errstate(divide='ignore', invalid='ignore'):
|
||||||
|
mape_arr = np.abs((actual_prices - predicted_prices) / np.where(actual_prices == 0, np.nan, actual_prices))
|
||||||
|
mape = float(np.nanmean(mape_arr) * 100.0)
|
||||||
|
r2 = float(r2_score(actual_prices, predicted_prices))
|
||||||
|
direction_actual = np.sign(np.diff(actual_prices))
|
||||||
|
direction_pred = np.sign(np.diff(predicted_prices))
|
||||||
|
dir_acc = float((direction_actual == direction_pred).mean()) if len(direction_actual) > 0 else 0.0
|
||||||
|
return {"rmse": rmse, "mape": mape, "r2": r2, "directional_accuracy": dir_acc}
|
||||||
|
|
||||||
|
|
||||||
28
ohlcvpredictor/model.py
Normal file
28
ohlcvpredictor/model.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from typing import Dict, List, Tuple
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from custom_xgboost import CustomXGBoostGPU
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(
|
||||||
|
X_train: np.ndarray,
|
||||||
|
X_test: np.ndarray,
|
||||||
|
y_train: np.ndarray,
|
||||||
|
y_test: np.ndarray,
|
||||||
|
eval_metric: str = 'rmse',
|
||||||
|
):
|
||||||
|
"""Train the XGBoost model and return the fitted wrapper."""
|
||||||
|
model = CustomXGBoostGPU(X_train, X_test, y_train, y_test)
|
||||||
|
model.train(eval_metric=eval_metric)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def predict(model: CustomXGBoostGPU, X: np.ndarray) -> np.ndarray:
|
||||||
|
"""Predict using the trained model."""
|
||||||
|
return model.predict(X)
|
||||||
|
|
||||||
|
|
||||||
|
def get_feature_importance(model: CustomXGBoostGPU, feature_names: List[str]) -> Dict[str, float]:
|
||||||
|
return model.get_feature_importance(feature_names)
|
||||||
|
|
||||||
|
|
||||||
125
ohlcvpredictor/pipeline.py
Normal file
125
ohlcvpredictor/pipeline.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
import os
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from .config import RunConfig
|
||||||
|
from .data import load_and_filter_data
|
||||||
|
from .preprocess import add_basic_time_features, downcast_numeric_columns, handle_nans
|
||||||
|
from .selection import build_feature_list, prune_features
|
||||||
|
from .model import train_model, predict, get_feature_importance
|
||||||
|
from .metrics import compute_price_series_from_log_returns, compute_metrics_from_prices
|
||||||
|
from evaluation import walk_forward_cv
|
||||||
|
from feature_engineering import feature_engineering
|
||||||
|
from plot_results import plot_prediction_error_distribution
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_charts_dir(path: str) -> None:
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def run_pipeline(cfg: RunConfig) -> Dict[str, float]:
|
||||||
|
# Setup outputs
|
||||||
|
ensure_charts_dir(cfg.output.charts_dir)
|
||||||
|
|
||||||
|
# Load and target
|
||||||
|
df = load_and_filter_data(cfg.data)
|
||||||
|
|
||||||
|
# Features
|
||||||
|
features_dict = feature_engineering(
|
||||||
|
df,
|
||||||
|
os.path.splitext(os.path.basename(cfg.data.csv_path))[0],
|
||||||
|
cfg.features.ohlcv_cols,
|
||||||
|
cfg.features.lags,
|
||||||
|
cfg.features.window_sizes,
|
||||||
|
)
|
||||||
|
features_df = pd.DataFrame(features_dict)
|
||||||
|
df = pd.concat([df, features_df], axis=1)
|
||||||
|
|
||||||
|
# Preprocess
|
||||||
|
df = downcast_numeric_columns(df)
|
||||||
|
df = add_basic_time_features(df)
|
||||||
|
df = handle_nans(df, cfg.preprocess)
|
||||||
|
|
||||||
|
# Feature selection and pruning
|
||||||
|
feature_cols = build_feature_list(df.columns)
|
||||||
|
|
||||||
|
X = df[feature_cols].values.astype(np.float32)
|
||||||
|
y = df["log_return"].values.astype(np.float32)
|
||||||
|
split_idx = int(len(X) * 0.8)
|
||||||
|
X_train, X_test = X[:split_idx], X[split_idx:]
|
||||||
|
y_train, y_test = y[:split_idx], y[split_idx:]
|
||||||
|
|
||||||
|
importance_avg = None
|
||||||
|
if cfg.pruning.do_walk_forward_cv:
|
||||||
|
metrics_avg, importance_avg = walk_forward_cv(X, y, feature_cols, n_splits=cfg.pruning.n_splits)
|
||||||
|
# Optional: you may log or return metrics_avg
|
||||||
|
|
||||||
|
kept_feature_cols = prune_features(feature_cols, importance_avg, cfg.pruning) if cfg.pruning.auto_prune else feature_cols
|
||||||
|
|
||||||
|
# Train model
|
||||||
|
model = train_model(
|
||||||
|
df[kept_feature_cols].values.astype(np.float32)[:split_idx],
|
||||||
|
df[kept_feature_cols].values.astype(np.float32)[split_idx:],
|
||||||
|
y[:split_idx],
|
||||||
|
y[split_idx:],
|
||||||
|
eval_metric='rmse',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save model
|
||||||
|
model.save_model(cfg.output.model_output_path)
|
||||||
|
|
||||||
|
# Persist the exact feature list used for training next to the model
|
||||||
|
try:
|
||||||
|
features_path = os.path.splitext(cfg.output.model_output_path)[0] + "_features.json"
|
||||||
|
with open(features_path, "w") as f:
|
||||||
|
json.dump({"feature_names": kept_feature_cols}, f)
|
||||||
|
except Exception:
|
||||||
|
# Feature list persistence is optional; avoid breaking the run on failure
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Predict
|
||||||
|
X_test_kept = df[kept_feature_cols].values.astype(np.float32)[split_idx:]
|
||||||
|
test_preds = predict(model, X_test_kept)
|
||||||
|
|
||||||
|
# Reconstruct price series
|
||||||
|
close_prices = df['Close'].values
|
||||||
|
start_price = close_prices[split_idx]
|
||||||
|
actual_prices = compute_price_series_from_log_returns(start_price, y_test)
|
||||||
|
predicted_prices = compute_price_series_from_log_returns(start_price, test_preds)
|
||||||
|
|
||||||
|
# Metrics
|
||||||
|
metrics = compute_metrics_from_prices(actual_prices, predicted_prices)
|
||||||
|
|
||||||
|
# Plot prediction error distribution to charts dir (parity with previous behavior)
|
||||||
|
try:
|
||||||
|
plot_prediction_error_distribution(predicted_prices, actual_prices, prefix="all_features")
|
||||||
|
except Exception:
|
||||||
|
# plotting is optional; ignore failures in headless environments
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Persist per-feature metrics and importances
|
||||||
|
feat_importance = get_feature_importance(model, kept_feature_cols)
|
||||||
|
if not os.path.exists(cfg.output.results_csv):
|
||||||
|
with open(cfg.output.results_csv, 'w', newline='') as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
writer.writerow(['feature', 'rmse', 'mape', 'r2', 'directional_accuracy', 'feature_importance'])
|
||||||
|
with open(cfg.output.results_csv, 'a', newline='') as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
for feature in kept_feature_cols:
|
||||||
|
importance = feat_importance.get(feature, 0.0)
|
||||||
|
row = [feature]
|
||||||
|
for key in ['rmse', 'mape', 'r2', 'directional_accuracy']:
|
||||||
|
val = metrics[key]
|
||||||
|
row.append(f"{val:.10f}")
|
||||||
|
row.append(f"{importance:.6f}")
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
39
ohlcvpredictor/preprocess.py
Normal file
39
ohlcvpredictor/preprocess.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from typing import List
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .config import PreprocessConfig
|
||||||
|
|
||||||
|
|
||||||
|
def add_basic_time_features(df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Add basic time features such as hour-of-day."""
|
||||||
|
df = df.copy()
|
||||||
|
df['Timestamp'] = pd.to_datetime(df['Timestamp'], errors='coerce')
|
||||||
|
df['hour'] = df['Timestamp'].dt.hour
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def downcast_numeric_columns(df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Downcast numeric columns to save memory."""
|
||||||
|
df = df.copy()
|
||||||
|
for col in df.columns:
|
||||||
|
try:
|
||||||
|
df[col] = pd.to_numeric(df[col], downcast='float')
|
||||||
|
except Exception:
|
||||||
|
# ignore non-numeric columns
|
||||||
|
pass
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def handle_nans(df: pd.DataFrame, cfg: PreprocessConfig) -> pd.DataFrame:
|
||||||
|
"""Impute NaNs (mean) or drop rows, based on config."""
|
||||||
|
df = df.copy()
|
||||||
|
if cfg.impute_nans:
|
||||||
|
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
||||||
|
for col in numeric_cols:
|
||||||
|
df[col] = df[col].fillna(df[col].mean())
|
||||||
|
else:
|
||||||
|
df = df.dropna().reset_index(drop=True)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
59
ohlcvpredictor/selection.py
Normal file
59
ohlcvpredictor/selection.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from typing import Dict, Iterable, List, Sequence, Set, Tuple
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .config import PruningConfig
|
||||||
|
|
||||||
|
|
||||||
|
EXCLUDE_BASE_FEATURES: List[str] = [
|
||||||
|
'Timestamp', 'Close',
|
||||||
|
'log_return_5', 'volatility_5', 'volatility_15', 'volatility_30',
|
||||||
|
'bb_bbm', 'bb_bbh', 'bb_bbl', 'stoch_k', 'sma_50', 'sma_200', 'psar',
|
||||||
|
'donchian_hband', 'donchian_lband', 'donchian_mband', 'keltner_hband', 'keltner_lband',
|
||||||
|
'keltner_mband', 'ichimoku_a', 'ichimoku_b', 'ichimoku_base_line', 'ichimoku_conversion_line',
|
||||||
|
'Open_lag1', 'Open_lag2', 'Open_lag3', 'High_lag1', 'High_lag2', 'High_lag3', 'Low_lag1', 'Low_lag2',
|
||||||
|
'Low_lag3', 'Close_lag1', 'Close_lag2', 'Close_lag3', 'Open_roll_mean_15', 'Open_roll_std_15', 'Open_roll_min_15',
|
||||||
|
'Open_roll_max_15', 'Open_roll_mean_30', 'Open_roll_min_30', 'Open_roll_max_30', 'High_roll_mean_15', 'High_roll_std_15',
|
||||||
|
'High_roll_min_15', 'High_roll_max_15', 'Low_roll_mean_5', 'Low_roll_min_5', 'Low_roll_max_5', 'Low_roll_mean_30',
|
||||||
|
'Low_roll_std_30', 'Low_roll_min_30', 'Low_roll_max_30', 'Close_roll_mean_5', 'Close_roll_min_5', 'Close_roll_max_5',
|
||||||
|
'Close_roll_mean_15', 'Close_roll_std_15', 'Close_roll_min_15', 'Close_roll_max_15', 'Close_roll_mean_30',
|
||||||
|
'Close_roll_std_30', 'Close_roll_min_30', 'Close_roll_max_30', 'Volume_roll_max_5', 'Volume_roll_max_15',
|
||||||
|
'Volume_roll_max_30', 'supertrend_12_3.0', 'supertrend_10_1.0', 'supertrend_11_2.0',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_feature_list(all_columns: Sequence[str]) -> List[str]:
|
||||||
|
"""Return the model feature list by excluding base columns and targets."""
|
||||||
|
return [col for col in all_columns if col not in EXCLUDE_BASE_FEATURES]
|
||||||
|
|
||||||
|
|
||||||
|
def prune_features(
|
||||||
|
feature_cols: Sequence[str],
|
||||||
|
importance_avg: Dict[str, float] | None,
|
||||||
|
cfg: PruningConfig,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Decide which features to keep using averaged importances and rules."""
|
||||||
|
prune_set: Set[str] = set()
|
||||||
|
|
||||||
|
if importance_avg is not None:
|
||||||
|
sorted_feats = sorted(importance_avg.items(), key=lambda kv: kv[1], reverse=True)
|
||||||
|
keep_names = set(name for name, _ in sorted_feats[: cfg.top_k])
|
||||||
|
for name in feature_cols:
|
||||||
|
if name not in keep_names:
|
||||||
|
prune_set.add(name)
|
||||||
|
|
||||||
|
for name in cfg.known_low_features:
|
||||||
|
if name in feature_cols:
|
||||||
|
prune_set.add(name)
|
||||||
|
|
||||||
|
# If Parkinson vol exists, drop alternatives at same window
|
||||||
|
for w in [5, 15, 30]:
|
||||||
|
park = f'park_vol_{w}'
|
||||||
|
if park in feature_cols:
|
||||||
|
for alt in [f'gk_vol_{w}', f'rs_vol_{w}', f'yz_vol_{w}']:
|
||||||
|
if alt in feature_cols:
|
||||||
|
prune_set.add(alt)
|
||||||
|
|
||||||
|
kept = [c for c in feature_cols if c not in prune_set]
|
||||||
|
return kept
|
||||||
|
|
||||||
|
|
||||||
120
predictor.py
Normal file
120
predictor.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .custom_xgboost import CustomXGBoostGPU
|
||||||
|
except ImportError:
|
||||||
|
from custom_xgboost import CustomXGBoostGPU
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .feature_engineering import feature_engineering
|
||||||
|
except ImportError:
|
||||||
|
from feature_engineering import feature_engineering
|
||||||
|
|
||||||
|
class OHLCVPredictor:
|
||||||
|
def __init__(self, model_path):
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||||
|
|
||||||
|
self.model = CustomXGBoostGPU.load_model(model_path)
|
||||||
|
self.exclude_cols = self._get_excluded_features()
|
||||||
|
self._feature_names = self._load_trained_feature_names(model_path)
|
||||||
|
|
||||||
|
def _load_trained_feature_names(self, model_path: str):
|
||||||
|
"""Load the exact feature list saved during training, if present."""
|
||||||
|
try:
|
||||||
|
features_path = os.path.splitext(model_path)[0] + "_features.json"
|
||||||
|
if os.path.exists(features_path):
|
||||||
|
with open(features_path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
names = data.get("feature_names")
|
||||||
|
if isinstance(names, list) and all(isinstance(x, str) for x in names):
|
||||||
|
return names
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_excluded_features(self):
|
||||||
|
"""Get the list of features to exclude (copied from main.py)"""
|
||||||
|
exclude_cols = ['Timestamp', 'Close']
|
||||||
|
exclude_cols += ['log_return_5', 'volatility_5', 'volatility_15', 'volatility_30']
|
||||||
|
exclude_cols += ['bb_bbm', 'bb_bbh', 'bb_bbl', 'stoch_k', 'sma_50', 'sma_200', 'psar',
|
||||||
|
'donchian_hband', 'donchian_lband', 'donchian_mband', 'keltner_hband', 'keltner_lband',
|
||||||
|
'keltner_mband', 'ichimoku_a', 'ichimoku_b', 'ichimoku_base_line', 'ichimoku_conversion_line',
|
||||||
|
'Open_lag1', 'Open_lag2', 'Open_lag3', 'High_lag1', 'High_lag2', 'High_lag3', 'Low_lag1', 'Low_lag2',
|
||||||
|
'Low_lag3', 'Close_lag1', 'Close_lag2', 'Close_lag3', 'Open_roll_mean_15', 'Open_roll_std_15', 'Open_roll_min_15',
|
||||||
|
'Open_roll_max_15', 'Open_roll_mean_30', 'Open_roll_min_30', 'Open_roll_max_30', 'High_roll_mean_15', 'High_roll_std_15',
|
||||||
|
'High_roll_min_15', 'High_roll_max_15', 'Low_roll_mean_5', 'Low_roll_min_5', 'Low_roll_max_5', 'Low_roll_mean_30',
|
||||||
|
'Low_roll_std_30', 'Low_roll_min_30', 'Low_roll_max_30', 'Close_roll_mean_5', 'Close_roll_min_5', 'Close_roll_max_5',
|
||||||
|
'Close_roll_mean_15', 'Close_roll_std_15', 'Close_roll_min_15', 'Close_roll_max_15', 'Close_roll_mean_30',
|
||||||
|
'Close_roll_std_30', 'Close_roll_min_30', 'Close_roll_max_30', 'Volume_roll_max_5', 'Volume_roll_max_15',
|
||||||
|
'Volume_roll_max_30', 'supertrend_12_3.0', 'supertrend_10_1.0', 'supertrend_11_2.0']
|
||||||
|
return exclude_cols
|
||||||
|
|
||||||
|
def predict(self, df, csv_prefix=None):
|
||||||
|
# Validate input DataFrame
|
||||||
|
required_cols = ['Open', 'High', 'Low', 'Close', 'Volume', 'Timestamp']
|
||||||
|
missing_cols = [col for col in required_cols if col not in df.columns]
|
||||||
|
if missing_cols:
|
||||||
|
raise ValueError(f"Missing required columns: {missing_cols}")
|
||||||
|
|
||||||
|
# Make a copy and preprocess
|
||||||
|
df = df.copy()
|
||||||
|
df = df[df['Volume'] != 0].reset_index(drop=True)
|
||||||
|
|
||||||
|
# Convert timestamps
|
||||||
|
if df['Timestamp'].dtype == 'object':
|
||||||
|
df['Timestamp'] = pd.to_datetime(df['Timestamp'])
|
||||||
|
else:
|
||||||
|
df['Timestamp'] = pd.to_datetime(df['Timestamp'], unit='s')
|
||||||
|
|
||||||
|
# Feature engineering
|
||||||
|
ohlcv_cols = ['Open', 'High', 'Low', 'Close', 'Volume']
|
||||||
|
features_dict = feature_engineering(df, csv_prefix, ohlcv_cols, 3, [5, 15, 30])
|
||||||
|
features_df = pd.DataFrame(features_dict)
|
||||||
|
df = pd.concat([df, features_df], axis=1)
|
||||||
|
|
||||||
|
# Downcast and add time features (exclude Timestamp to preserve datetime)
|
||||||
|
for col in df.columns:
|
||||||
|
if col != 'Timestamp': # Don't convert Timestamp to numeric
|
||||||
|
try:
|
||||||
|
df[col] = pd.to_numeric(df[col], downcast='float')
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
df['hour'] = df['Timestamp'].dt.hour
|
||||||
|
|
||||||
|
# Handle NaNs
|
||||||
|
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
||||||
|
for col in numeric_cols:
|
||||||
|
if df[col].isna().any():
|
||||||
|
df[col] = df[col].fillna(df[col].mean())
|
||||||
|
|
||||||
|
# Defragment DataFrame after all columns have been added
|
||||||
|
df = df.copy()
|
||||||
|
|
||||||
|
# Select features and predict
|
||||||
|
if self._feature_names is not None:
|
||||||
|
# Use the exact training feature names and order
|
||||||
|
missing = [c for c in self._feature_names if c not in df.columns]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"Input is missing required trained features: {missing[:10]}{'...' if len(missing)>10 else ''}")
|
||||||
|
feature_cols = self._feature_names
|
||||||
|
else:
|
||||||
|
feature_cols = [col for col in df.columns if col not in self.exclude_cols]
|
||||||
|
X = df[feature_cols].values.astype(np.float32)
|
||||||
|
return self.model.predict(X)
|
||||||
|
|
||||||
|
def predict_prices(self, df, csv_prefix=None):
|
||||||
|
log_return_preds = self.predict(df, csv_prefix)
|
||||||
|
df_clean = df[df['Volume'] != 0].copy()
|
||||||
|
close_prices = df_clean['Close'].values
|
||||||
|
|
||||||
|
predicted_prices = [close_prices[0]]
|
||||||
|
for i, log_ret in enumerate(log_return_preds[1:], 1):
|
||||||
|
if i < len(close_prices):
|
||||||
|
predicted_prices.append(predicted_prices[-1] * np.exp(log_ret))
|
||||||
|
|
||||||
|
return np.array(predicted_prices), close_prices[:len(predicted_prices)]
|
||||||
@@ -8,8 +8,15 @@ dependencies = [
|
|||||||
"dash>=3.0.4",
|
"dash>=3.0.4",
|
||||||
"numba>=0.61.2",
|
"numba>=0.61.2",
|
||||||
"pandas>=2.2.3",
|
"pandas>=2.2.3",
|
||||||
"pandas-ta>=0.3.14b0",
|
|
||||||
"scikit-learn>=1.6.1",
|
"scikit-learn>=1.6.1",
|
||||||
"ta>=0.11.0",
|
"ta>=0.11.0",
|
||||||
"xgboost>=3.0.2",
|
"xgboost>=3.0.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
include = ["ohlcvpredictor*"]
|
||||||
|
exclude = ["charts*"]
|
||||||
|
|||||||
@@ -207,8 +207,9 @@ def calc_vortex(high, low, close):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def calc_kama(close):
|
def calc_kama(close):
|
||||||
import pandas_ta as ta
|
# Simple alternative to KAMA using EMA
|
||||||
kama = ta.kama(close, length=10)
|
from ta.trend import EMAIndicator
|
||||||
|
kama = EMAIndicator(close, window=10).ema_indicator()
|
||||||
return ('kama', kama)
|
return ('kama', kama)
|
||||||
|
|
||||||
def calc_force_index(close, volume):
|
def calc_force_index(close, volume):
|
||||||
@@ -232,8 +233,12 @@ def calc_adi(high, low, close, volume):
|
|||||||
return ('adi', adi.acc_dist_index())
|
return ('adi', adi.acc_dist_index())
|
||||||
|
|
||||||
def calc_tema(close):
|
def calc_tema(close):
|
||||||
import pandas_ta as ta
|
# Simple alternative to TEMA using triple EMA
|
||||||
tema = ta.tema(close, length=10)
|
from ta.trend import EMAIndicator
|
||||||
|
ema1 = EMAIndicator(close, window=10).ema_indicator()
|
||||||
|
ema2 = EMAIndicator(ema1, window=10).ema_indicator()
|
||||||
|
ema3 = EMAIndicator(ema2, window=10).ema_indicator()
|
||||||
|
tema = 3 * ema1 - 3 * ema2 + ema3
|
||||||
return ('tema', tema)
|
return ('tema', tema)
|
||||||
|
|
||||||
def calc_stochrsi(close):
|
def calc_stochrsi(close):
|
||||||
|
|||||||
15
uv.lock
generated
15
uv.lock
generated
@@ -1,5 +1,5 @@
|
|||||||
version = 1
|
version = 1
|
||||||
revision = 2
|
revision = 3
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -309,12 +309,11 @@ wheels = [
|
|||||||
[[package]]
|
[[package]]
|
||||||
name = "ohlcvpredictor"
|
name = "ohlcvpredictor"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
source = { virtual = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "dash" },
|
{ name = "dash" },
|
||||||
{ name = "numba" },
|
{ name = "numba" },
|
||||||
{ name = "pandas" },
|
{ name = "pandas" },
|
||||||
{ name = "pandas-ta" },
|
|
||||||
{ name = "scikit-learn" },
|
{ name = "scikit-learn" },
|
||||||
{ name = "ta" },
|
{ name = "ta" },
|
||||||
{ name = "xgboost" },
|
{ name = "xgboost" },
|
||||||
@@ -325,7 +324,6 @@ requires-dist = [
|
|||||||
{ name = "dash", specifier = ">=3.0.4" },
|
{ name = "dash", specifier = ">=3.0.4" },
|
||||||
{ name = "numba", specifier = ">=0.61.2" },
|
{ name = "numba", specifier = ">=0.61.2" },
|
||||||
{ name = "pandas", specifier = ">=2.2.3" },
|
{ name = "pandas", specifier = ">=2.2.3" },
|
||||||
{ name = "pandas-ta", specifier = ">=0.3.14b0" },
|
|
||||||
{ name = "scikit-learn", specifier = ">=1.6.1" },
|
{ name = "scikit-learn", specifier = ">=1.6.1" },
|
||||||
{ name = "ta", specifier = ">=0.11.0" },
|
{ name = "ta", specifier = ">=0.11.0" },
|
||||||
{ name = "xgboost", specifier = ">=3.0.2" },
|
{ name = "xgboost", specifier = ">=3.0.2" },
|
||||||
@@ -374,15 +372,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/ab/5f/b38085618b950b79d2d9164a711c52b10aefc0ae6833b96f626b7021b2ed/pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a", size = 13098436, upload-time = "2024-09-20T13:09:48.112Z" },
|
{ url = "https://files.pythonhosted.org/packages/ab/5f/b38085618b950b79d2d9164a711c52b10aefc0ae6833b96f626b7021b2ed/pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a", size = 13098436, upload-time = "2024-09-20T13:09:48.112Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "pandas-ta"
|
|
||||||
version = "0.3.14b0"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
dependencies = [
|
|
||||||
{ name = "pandas" },
|
|
||||||
]
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/f7/0b/1666f0a185d4f08215f53cc088122a73c92421447b04028f0464fabe1ce6/pandas_ta-0.3.14b.tar.gz", hash = "sha256:0fa35aec831d2815ea30b871688a8d20a76b288a7be2d26cc00c35cd8c09a993", size = 115089, upload-time = "2021-07-28T20:51:17.456Z" }
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "plotly"
|
name = "plotly"
|
||||||
version = "6.1.2"
|
version = "6.1.2"
|
||||||
|
|||||||
Reference in New Issue
Block a user