Files
v-search/etc/test-data/tfjs-training.txt
2025-10-06 21:43:08 +09:00

135 lines
4.4 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

Training a model in TensorFlow.js involves several steps, similar to training a model in the regular TensorFlow library. Heres a step-by-step guide on how to do it:
### 1. **Set Up Your Environment**
- Include TensorFlow.js in your project. If you are using it in a web browser, you can include it via a CDN:
```html
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
```
- If you are using Node.js, install TensorFlow.js using npm:
```bash
npm install @tensorflow/tfjs
```
### 2. **Prepare Your Data**
- Data in TensorFlow.js is represented as `tf.Tensor` objects. You can create tensors manually, load data from files, or use existing datasets.
Example of creating tensors:
```javascript
const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
```
- Alternatively, you can load data from an external source:
```javascript
const data = tf.data.csv('path/to/your/csvfile.csv');
```
### 3. **Define Your Model**
- Create a sequential model and add layers to it. In TensorFlow.js, you can use high-level APIs similar to Keras:
```javascript
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
```
- You can add more layers depending on your problem:
```javascript
model.add(tf.layers.dense({units: 10, activation: 'relu'}));
model.add(tf.layers.dense({units: 1}));
```
### 4. **Compile the Model**
- After defining the model, you need to compile it by specifying the optimizer, loss function, and optionally, metrics:
```javascript
model.compile({
optimizer: 'sgd',
loss: 'meanSquaredError',
metrics: ['mse']
});
```
### 5. **Train the Model**
- Now, you can train the model using the `fit` method. This method is similar to TensorFlow in Python:
```javascript
model.fit(xs, ys, {
epochs: 100,
callbacks: {
onEpochEnd: (epoch, logs) => {
console.log(`Epoch: ${epoch}, Loss: ${logs.loss}`);
}
}
}).then(() => {
console.log('Training complete');
});
```
- Here, `xs` is your input data (features), and `ys` is the target data (labels). The `epochs` parameter controls how many times the model sees the data during training.
### 6. **Evaluate the Model**
- After training, you can evaluate the model on test data or use it to make predictions:
```javascript
const output = model.predict(tf.tensor2d([5], [1, 1]));
output.print(); // Display the prediction for input 5
```
- For evaluating, if you have test data:
```javascript
const loss = model.evaluate(xs_test, ys_test);
loss.print(); // Print the loss on test data
```
### 7. **Save or Load a Model**
- After training, you might want to save your model:
```javascript
await model.save('localstorage://my-model');
```
- You can load a saved model later:
```javascript
const model = await tf.loadLayersModel('localstorage://my-model');
```
### 8. **Deploy the Model**
- Once trained, you can deploy your model in a web application or a Node.js environment, making predictions in real-time.
### Example: Simple Linear Regression in TensorFlow.js
Heres a small complete example:
```javascript
const tf = require('@tensorflow/tfjs');
// Define a model
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
// Compile the model
model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
// Prepare the training data
const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
// Train the model
model.fit(xs, ys, {epochs: 500}).then(() => {
// Use the model to make predictions
model.predict(tf.tensor2d([5], [1, 1])).print(); // Should output a value close to 9
});
```
### Key Considerations
- **Data Handling**: Ensure your data is properly normalized or preprocessed as needed.
- **Model Complexity**: TensorFlow.js is powerful but may not handle extremely complex models or very large datasets as efficiently as TensorFlow in Python.
- **WebGL Acceleration**: When running in the browser, ensure that WebGL is available and enabled for GPU acceleration.
By following these steps, you can effectively train machine learning models directly in JavaScript using TensorFlow.js.