artificial-intelligence
Linear Regression in Artificial Intelligence
Linear regression is the "hello world" of machine learning. It predicts a number based on input data by finding the best straight line through your data p...
26 Mar 2024
Linear regression is the "hello world" of machine learning. It predicts a number based on input data by finding the best straight line through your data points.
Think of it like this: if you plot house sizes on the X axis and prices on the Y axis, linear regression draws the line that best fits those dots. Then you can feed it a new size and it predicts the price.
The math (kept simple)
The equation:
y = mx + b
- y -- the thing you're predicting (price, temperature, sales)
- x -- the input feature (size, time, ad spend)
- m -- the slope. How much y changes when x changes by 1.
- b -- the y-intercept. Where the line crosses zero.
The model's job: find values of m and b that minimize the gap between predicted values and actual values. That gap is called the "loss," and the most common measure is mean squared error -- the average of (predicted - actual) squared.
Building it with TensorFlow.js
You don't need Python for this. TensorFlow.js runs machine learning directly in Node.js or the browser.
Here's a complete example:
const tf = require('@tensorflow/tfjs');
const X = tf.tensor2d([[1], [2], [3], [4], [5]]);
const Y = tf.tensor2d([[2], [3], [4], [5], [6]]);
const model = tf.sequential();
model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
model.compile({ loss: 'meanSquaredError', optimizer: 'sgd' });
async function trainModel() {
await model.fit(X, Y, { epochs: 100 });
console.log('Training complete.');
const weights = model.getWeights();
const slope = weights[0].dataSync()[0];
const yIntercept = weights[1].dataSync()[0];
console.log(`Slope: ${slope}, Y-intercept: ${yIntercept}`);
const newX = tf.tensor2d([[6]]);
const predictedY = model.predict(newX);
predictedY.print();
}
trainModel();
// Output:
// Training complete.
// Slope: ~1.0, Y-intercept: ~1.0
// Tensor [[~7.0]]
What's happening: the model trains for 100 epochs (passes through the data), adjusting m and b each time to reduce the error. After training, it predicts that when x=6, y should be approximately 7. Which makes sense -- our data follows y = x + 1.
When to use linear regression
Good for: Predicting continuous values where the relationship between input and output is roughly linear. Sales forecasting, trend analysis, simple pricing models.
Bad for: Non-linear relationships, classification problems, or anything where the input-output relationship has curves, jumps, or complex interactions. For those, you need different tools.
The benefit of linear regression: it's fast, interpretable, and hard to overfit. You can explain the model to a non-technical stakeholder by showing them the line.
The cost: it assumes linearity. Real-world data is rarely that clean. It's a starting point, not a destination.