|
| 1 | +# TrendMaster API Reference |
| 2 | + |
| 3 | +This document provides detailed information about the classes and functions in the TrendMaster package. |
| 4 | + |
| 5 | +## Table of Contents |
| 6 | + |
| 7 | +1. [DataLoader](#dataloader) |
| 8 | +2. [TransAm (Model)](#transam-model) |
| 9 | +3. [Trainer](#trainer) |
| 10 | +4. [Inferencer](#inferencer) |
| 11 | +5. [Utility Functions](#utility-functions) |
| 12 | + |
| 13 | +## DataLoader |
| 14 | + |
| 15 | +The `DataLoader` class is responsible for loading and preprocessing stock data. |
| 16 | + |
| 17 | +### Methods |
| 18 | + |
| 19 | +#### `__init__(self)` |
| 20 | +Initializes the DataLoader object. |
| 21 | + |
| 22 | +#### `authenticate(self, user_id=None, password=None, twofa=None)` |
| 23 | +Authenticates with the Zerodha API. |
| 24 | + |
| 25 | +- Parameters: |
| 26 | + - `user_id` (str, optional): Zerodha user ID |
| 27 | + - `password` (str, optional): Zerodha password |
| 28 | + - `twofa` (str, optional): Zerodha two-factor authentication code |
| 29 | +- Returns: Authenticated Zerodha kite instance |
| 30 | + |
| 31 | +#### `get_stock_data(self, symbol, from_date, to_date, interval='minute')` |
| 32 | +Fetches stock data for a given symbol. |
| 33 | + |
| 34 | +- Parameters: |
| 35 | + - `symbol` (str): Stock symbol |
| 36 | + - `from_date` (str): Start date in 'YYYY-MM-DD' format |
| 37 | + - `to_date` (str): End date in 'YYYY-MM-DD' format |
| 38 | + - `interval` (str, optional): Data interval (default: 'minute') |
| 39 | +- Returns: pandas DataFrame containing stock data |
| 40 | + |
| 41 | +#### `preprocess_data(self, data, column='close')` |
| 42 | +Preprocesses the data for model input. |
| 43 | + |
| 44 | +- Parameters: |
| 45 | + - `data` (pandas.DataFrame): Stock data |
| 46 | + - `column` (str, optional): Column to preprocess (default: 'close') |
| 47 | +- Returns: Preprocessed numpy array |
| 48 | + |
| 49 | +#### `create_sequences(self, data, input_window, output_window)` |
| 50 | +Creates input-output sequences for training. |
| 51 | + |
| 52 | +- Parameters: |
| 53 | + - `data` (numpy.array): Preprocessed data |
| 54 | + - `input_window` (int): Number of input time steps |
| 55 | + - `output_window` (int): Number of output time steps |
| 56 | +- Returns: List of (input_sequence, target_sequence) tuples |
| 57 | + |
| 58 | +## TransAm (Model) |
| 59 | + |
| 60 | +The `TransAm` class implements the Transformer model for stock price prediction. |
| 61 | + |
| 62 | +### Methods |
| 63 | + |
| 64 | +#### `__init__(self, feature_size=30, num_layers=2, dropout=0.2)` |
| 65 | +Initializes the TransAm model. |
| 66 | + |
| 67 | +- Parameters: |
| 68 | + - `feature_size` (int, optional): Size of input features (default: 30) |
| 69 | + - `num_layers` (int, optional): Number of transformer layers (default: 2) |
| 70 | + - `dropout` (float, optional): Dropout rate (default: 0.2) |
| 71 | + |
| 72 | +#### `forward(self, src)` |
| 73 | +Performs a forward pass through the model. |
| 74 | + |
| 75 | +- Parameters: |
| 76 | + - `src` (torch.Tensor): Input tensor |
| 77 | +- Returns: Output tensor |
| 78 | + |
| 79 | +## Trainer |
| 80 | + |
| 81 | +The `Trainer` class handles model training and validation. |
| 82 | + |
| 83 | +### Methods |
| 84 | + |
| 85 | +#### `__init__(self, model, device, learning_rate=0.001)` |
| 86 | +Initializes the Trainer object. |
| 87 | + |
| 88 | +- Parameters: |
| 89 | + - `model` (TransAm): The model to train |
| 90 | + - `device` (torch.device): Device to use for training |
| 91 | + - `learning_rate` (float, optional): Learning rate (default: 0.001) |
| 92 | + |
| 93 | +#### `train(self, train_data, val_data, epochs, batch_size)` |
| 94 | +Trains the model. |
| 95 | + |
| 96 | +- Parameters: |
| 97 | + - `train_data` (list): Training data sequences |
| 98 | + - `val_data` (list): Validation data sequences |
| 99 | + - `epochs` (int): Number of training epochs |
| 100 | + - `batch_size` (int): Batch size for training |
| 101 | +- Returns: Lists of training and validation losses |
| 102 | + |
| 103 | +#### `validate(self, val_data, batch_size)` |
| 104 | +Validates the model on the provided data. |
| 105 | + |
| 106 | +- Parameters: |
| 107 | + - `val_data` (list): Validation data sequences |
| 108 | + - `batch_size` (int): Batch size for validation |
| 109 | +- Returns: Validation loss |
| 110 | + |
| 111 | +## Inferencer |
| 112 | + |
| 113 | +The `Inferencer` class handles model inference and evaluation. |
| 114 | + |
| 115 | +### Methods |
| 116 | + |
| 117 | +#### `__init__(self, model, device, data_loader)` |
| 118 | +Initializes the Inferencer object. |
| 119 | + |
| 120 | +- Parameters: |
| 121 | + - `model` (TransAm): Trained model |
| 122 | + - `device` (torch.device): Device to use for inference |
| 123 | + - `data_loader` (DataLoader): DataLoader instance |
| 124 | + |
| 125 | +#### `predict(self, symbol, from_date, to_date, input_window, future_steps)` |
| 126 | +Makes predictions for future stock prices. |
| 127 | + |
| 128 | +- Parameters: |
| 129 | + - `symbol` (str): Stock symbol |
| 130 | + - `from_date` (str): Start date for historical data |
| 131 | + - `to_date` (str): End date for historical data |
| 132 | + - `input_window` (int): Number of input time steps |
| 133 | + - `future_steps` (int): Number of future time steps to predict |
| 134 | +- Returns: DataFrame with predicted prices |
| 135 | + |
| 136 | +#### `evaluate(self, test_data, batch_size)` |
| 137 | +Evaluates the model on test data. |
| 138 | + |
| 139 | +- Parameters: |
| 140 | + - `test_data` (list): Test data sequences |
| 141 | + - `batch_size` (int): Batch size for evaluation |
| 142 | +- Returns: Test loss |
| 143 | + |
| 144 | +## Utility Functions |
| 145 | + |
| 146 | +### `set_seed(seed)` |
| 147 | +Sets random seed for reproducibility. |
| 148 | + |
| 149 | +- Parameters: |
| 150 | + - `seed` (int): Random seed |
| 151 | + |
| 152 | +### `plot_results(train_losses, val_losses)` |
| 153 | +Plots training and validation losses. |
| 154 | + |
| 155 | +- Parameters: |
| 156 | + - `train_losses` (list): Training losses |
| 157 | + - `val_losses` (list): Validation losses |
| 158 | + |
| 159 | +### `plot_predictions(actual, predictions)` |
| 160 | +Plots actual vs predicted stock prices. |
| 161 | + |
| 162 | +- Parameters: |
| 163 | + - `actual` (pandas.Series): Actual stock prices |
| 164 | + - `predictions` (pandas.DataFrame): Predicted stock prices |
0 commit comments